3#if defined(GPU) && __has_include(<magma_v2.h>)
6 #include <magma_operators.h>
9template<
typename ScalarType>
10static inline void magma_axpy_wrapper(magma_int_t n, ScalarType alpha, ScalarType
const* dx,
11 magma_int_t incx, ScalarType* dy, magma_int_t incy,
12 magma_queue_t queue) {
13 if constexpr(std::is_same_v<ScalarType, Complex_t<float>>) {
14 magma_caxpy(n, alpha,
reinterpret_cast<magmaFloatComplex_const_ptr
>(dx), incx,
15 reinterpret_cast<magmaFloatComplex_const_ptr
>(dy), incy, queue);
17 else if constexpr(std::is_same_v<ScalarType, Complex_t<double>>) {
18 magma_zaxpy(n, alpha,
reinterpret_cast<magmaDoubleComplex_const_ptr
>(dx), incx,
19 reinterpret_cast<magmaDoubleComplex_const_ptr
>(dy), incy, queue);
22 static_assert([] {
return false; }(),
23 "Error: magma_axpy_wrapper for the input scalar type is NOT implemented.");
47template<
typename ScalarType>
48static inline ScalarType magma_dotc_wrapper(magma_int_t n, ScalarType
const* dx, magma_int_t incx,
49 ScalarType
const* dy, magma_int_t incy,
50 magma_queue_t queue) {
51 if constexpr(std::is_same_v<ScalarType, float>) {
52 return magma_sdot(n, dx, incx, dy, incy, queue);
54 else if constexpr(std::is_same_v<ScalarType, double>) {
55 return magma_ddot(n, dx, incx, dy, incy, queue);
57 else if constexpr(std::is_same_v<ScalarType, Complex_t<float>>) {
58 return magma_cdotc(n,
reinterpret_cast<magmaFloatComplex_const_ptr
>(dx), incx,
59 reinterpret_cast<magmaFloatComplex_const_ptr
>(dy), incy, queue);
61 else if constexpr(std::is_same_v<ScalarType, Complex_t<double>>) {
62 return magma_zdotc(n,
reinterpret_cast<magmaDoubleComplex_const_ptr
>(dx), incx,
63 reinterpret_cast<magmaDoubleComplex_const_ptr
>(dy), incy, queue);
66 static_assert([] {
return false; }(),
67 "Error: magma_dotc_wrapper for the input scalar type is NOT implemented.");
74template<
typename ScalarType>
75static inline void magma_hemm_wrapper(magma_side_t side, magma_uplo_t uplo, magma_int_t m,
76 magma_int_t n, ScalarType alpha, ScalarType
const* dA,
77 magma_int_t ldda, ScalarType
const* dB, magma_int_t lddb,
78 ScalarType beta, ScalarType* dC, magma_int_t lddc,
79 magma_queue_t queue) {
80 if constexpr(std::is_same_v<ScalarType, float>) {
81 return magmablas_ssymm(side, uplo, m, n, alpha, dA, ldda, dB, lddb, beta, dC, lddc, queue);
83 else if constexpr(std::is_same_v<ScalarType, double>) {
84 return magmablas_dsymm(side, uplo, m, n, alpha, dA, ldda, dB, lddb, beta, dC, lddc, queue);
86 else if constexpr(std::is_same_v<ScalarType, Complex_t<float>>) {
87 return magma_chemm(side, uplo, m, n, alpha,
88 reinterpret_cast<magmaFloatComplex_const_ptr
>(dA), ldda,
89 reinterpret_cast<magmaFloatComplex_const_ptr
>(dB), lddb, beta,
90 reinterpret_cast<magmaFloatComplex_ptr
>(dC), lddc, queue);
92 else if constexpr(std::is_same_v<ScalarType, Complex_t<double>>) {
93 return magma_zhemm(side, uplo, m, n, alpha,
94 reinterpret_cast<magmaDoubleComplex_const_ptr
>(dA), ldda,
95 reinterpret_cast<magmaDoubleComplex_const_ptr
>(dB), lddb, beta,
96 reinterpret_cast<magmaDoubleComplex_ptr
>(dC), lddc, queue);
99 static_assert([] {
return false; }(),
100 "Error: magma_hemm_wrapper for the input scalar type is NOT implemented.");
106template<
typename ScalarType>
107static inline magma_int_t magma_heevd_gpu_wrapper(
108 magma_vec_t jobz, magma_uplo_t uplo, magma_int_t n, ScalarType* dA, magma_int_t ldda,
109 typename Eigen::NumTraits<ScalarType>::Real* w, ScalarType* wA, magma_int_t ldwa,
110 ScalarType* work, magma_int_t lwork,
typename Eigen::NumTraits<ScalarType>::Real* rwork,
111 magma_int_t lrwork, magma_int_t* iwork, magma_int_t liwork, magma_int_t* info) {
112 if constexpr(std::is_same_v<ScalarType, float>) {
113 return magma_sheevd_gpu(jobz, uplo, n, dA, ldda, w, wA, ldwa, work, lwork, rwork, lrwork,
114 iwork, liwork, info);
116 else if constexpr(std::is_same_v<ScalarType, double>) {
117 return magma_dheevd_gpu(jobz, uplo, n, dA, ldda, w, wA, ldwa, work, lwork, rwork, lrwork,
118 iwork, liwork, info);
120 else if constexpr(std::is_same_v<ScalarType, Complex_t<float>>) {
121 return magma_cheevd_gpu(jobz, uplo, n,
reinterpret_cast<magmaFloatComplex_ptr
>(dA), ldda, w,
122 reinterpret_cast<magmaFloatComplex_ptr
>(wA), ldwa,
123 reinterpret_cast<magmaFloatComplex_ptr
>(work), lwork, rwork, lrwork,
124 iwork, liwork, info);
126 else if constexpr(std::is_same_v<ScalarType, Complex_t<double>>) {
127 return magma_zheevd_gpu(jobz, uplo, n,
reinterpret_cast<magmaDoubleComplex_ptr
>(dA), ldda,
128 w,
reinterpret_cast<magmaDoubleComplex_ptr
>(wA), ldwa,
129 reinterpret_cast<magmaDoubleComplex_ptr
>(work), lwork, rwork,
130 lrwork, iwork, liwork, info);
134 [] {
return false; }(),
135 "Error: magma_heevd_gpu_wrapper for the input scalar type is NOT implemented.");