Stxxl
1.4.0
|
00001 /*************************************************************************** 00002 * include/stxxl/bits/containers/matrix_low_level.h 00003 * 00004 * Part of the STXXL. See http://stxxl.sourceforge.net 00005 * 00006 * Copyright (C) 2010-2011 Raoul Steffen <R-Steffen@gmx.de> 00007 * 00008 * Distributed under the Boost Software License, Version 1.0. 00009 * (See accompanying file LICENSE_1_0.txt or copy at 00010 * http://www.boost.org/LICENSE_1_0.txt) 00011 **************************************************************************/ 00012 00013 #ifndef STXXL_MATRIX_LOW_LEVEL_HEADER 00014 #define STXXL_MATRIX_LOW_LEVEL_HEADER 00015 00016 #ifndef STXXL_BLAS 00017 #define STXXL_BLAS 0 00018 #endif 00019 00020 #include <complex> 00021 00022 #include <stxxl/bits/common/types.h> 00023 #include <stxxl/bits/parallel.h> 00024 00025 __STXXL_BEGIN_NAMESPACE 00026 00027 //forward declaration 00028 template <typename ValueType, unsigned BlockSideLength> 00029 struct matrix_operations; 00030 00031 //generic declaration 00032 template <unsigned BlockSideLength, bool transposed> 00033 struct switch_major_index; 00034 00035 //row-major specialization 00036 template <unsigned BlockSideLength> 00037 struct switch_major_index<BlockSideLength, false> 00038 { 00039 inline switch_major_index(const int_type row, const int_type col) : i(row * BlockSideLength + col) {} 00040 inline operator int_type & () { return i; } 00041 private: 00042 int_type i; 00043 }; 00044 00045 //column-major specialization 00046 template <unsigned BlockSideLength> 00047 struct switch_major_index<BlockSideLength, true> 00048 { 00049 inline switch_major_index(const int_type row, const int_type col) : i(row + col * BlockSideLength) {} 00050 inline operator int_type & () { return i; } 00051 private: 00052 int_type i; 00053 }; 00054 00055 //! \brief c = a <op> b; for arbitrary entries 00056 template <typename ValueType, unsigned BlockSideLength, bool a_transposed, bool b_transposed, class Op> 00057 struct low_level_matrix_binary_ass_op 00058 { 00059 low_level_matrix_binary_ass_op(ValueType * c, const ValueType * a, const ValueType * b, Op op = Op()) 00060 { 00061 if (a) 00062 if (b) 00063 #if STXXL_PARALLEL 00064 #pragma omp parallel for 00065 #endif 00066 for (int_type row = 0; row < int_type(BlockSideLength); ++row) 00067 for (int_type col = 0; col < int_type(BlockSideLength); ++col) 00068 op(c[switch_major_index<BlockSideLength, false>(row, col)], 00069 a[switch_major_index<BlockSideLength, a_transposed>(row, col)], 00070 b[switch_major_index<BlockSideLength, b_transposed>(row, col)]); 00071 else 00072 #if STXXL_PARALLEL 00073 #pragma omp parallel for 00074 #endif 00075 for (int_type row = 0; row < int_type(BlockSideLength); ++row) 00076 for (int_type col = 0; col < int_type(BlockSideLength); ++col) 00077 op(c[switch_major_index<BlockSideLength, false>(row, col)], 00078 a[switch_major_index<BlockSideLength, a_transposed>(row, col)], 0); 00079 else 00080 { 00081 assert(b /* do not add nothing to nothing */); 00082 #if STXXL_PARALLEL 00083 #pragma omp parallel for 00084 #endif 00085 for (int_type row = 0; row < int_type(BlockSideLength); ++row) 00086 for (int_type col = 0; col < int_type(BlockSideLength); ++col) 00087 op(c[switch_major_index<BlockSideLength, false>(row, col)], 00088 0, b[switch_major_index<BlockSideLength, b_transposed>(row, col)]); 00089 } 00090 } 00091 }; 00092 00093 //! \brief c <op>= a; for arbitrary entries 00094 template <typename ValueType, unsigned BlockSideLength, bool a_transposed, class Op> 00095 struct low_level_matrix_unary_ass_op 00096 { 00097 low_level_matrix_unary_ass_op(ValueType * c, const ValueType * a, Op op = Op()) 00098 { 00099 if (a) 00100 #if STXXL_PARALLEL 00101 #pragma omp parallel for 00102 #endif 00103 for (int_type row = 0; row < int_type(BlockSideLength); ++row) 00104 for (int_type col = 0; col < int_type(BlockSideLength); ++col) 00105 op(c[switch_major_index<BlockSideLength, false>(row, col)], 00106 a[switch_major_index<BlockSideLength, a_transposed>(row, col)]); 00107 } 00108 }; 00109 00110 //! \brief c = <op>a; for arbitrary entries 00111 template <typename ValueType, unsigned BlockSideLength, bool a_transposed, class Op> 00112 struct low_level_matrix_unary_op 00113 { 00114 low_level_matrix_unary_op(ValueType * c, const ValueType * a, Op op = Op()) 00115 { 00116 assert(a); 00117 #if STXXL_PARALLEL 00118 #pragma omp parallel for 00119 #endif 00120 for (int_type row = 0; row < int_type(BlockSideLength); ++row) 00121 for (int_type col = 0; col < int_type(BlockSideLength); ++col) 00122 c[switch_major_index<BlockSideLength, false>(row, col)] = 00123 op(a[switch_major_index<BlockSideLength, a_transposed>(row, col)]); 00124 } 00125 }; 00126 00127 //! \brief multiplies matrices A and B, adds result to C, for arbitrary entries 00128 //! param pointer to blocks of A,B,C; elements in blocks have to be in row-major 00129 /* designated usage as: 00130 * void 00131 * low_level_matrix_multiply_and_add(const double * a, bool a_in_col_major, 00132 const double * b, bool b_in_col_major, 00133 double * c, const bool c_in_col_major) */ 00134 template <typename ValueType, unsigned BlockSideLength> 00135 struct low_level_matrix_multiply_and_add 00136 { 00137 low_level_matrix_multiply_and_add(const ValueType * a, bool a_in_col_major, 00138 const ValueType * b, bool b_in_col_major, 00139 ValueType * c, const bool c_in_col_major) 00140 { 00141 if (c_in_col_major) 00142 { 00143 std::swap(a,b); 00144 bool a_cm = ! b_in_col_major; 00145 b_in_col_major = ! a_in_col_major; 00146 a_in_col_major = a_cm; 00147 } 00148 if (! a_in_col_major) 00149 { 00150 if (! b_in_col_major) 00151 { // => both row-major 00152 #if STXXL_PARALLEL 00153 #pragma omp parallel for 00154 #endif 00155 for (int_type i = 0; i < int_type(BlockSideLength); ++i) //OpenMP does not like unsigned iteration variables 00156 for (unsigned_type k = 0; k < BlockSideLength; ++k) 00157 for (unsigned_type j = 0; j < BlockSideLength; ++j) 00158 c[i * BlockSideLength + j] += a[i * BlockSideLength + k] * b[k * BlockSideLength + j]; 00159 } 00160 else 00161 { // => a row-major, b col-major 00162 #if STXXL_PARALLEL 00163 #pragma omp parallel for 00164 #endif 00165 for (int_type i = 0; i < int_type(BlockSideLength); ++i) //OpenMP does not like unsigned iteration variables 00166 for (unsigned_type j = 0; j < BlockSideLength; ++j) 00167 for (unsigned_type k = 0; k < BlockSideLength; ++k) 00168 c[i * BlockSideLength + j] += a[i * BlockSideLength + k] * b[k + j * BlockSideLength]; 00169 } 00170 } 00171 else 00172 { 00173 if (! b_in_col_major) 00174 { // => a col-major, b row-major 00175 #if STXXL_PARALLEL 00176 #pragma omp parallel for 00177 #endif 00178 for (int_type i = 0; i < int_type(BlockSideLength); ++i) //OpenMP does not like unsigned iteration variables 00179 for (unsigned_type k = 0; k < BlockSideLength; ++k) 00180 for (unsigned_type j = 0; j < BlockSideLength; ++j) 00181 c[i * BlockSideLength + j] += a[i + k * BlockSideLength] * b[k * BlockSideLength + j]; 00182 } 00183 else 00184 { // => both col-major 00185 #if STXXL_PARALLEL 00186 #pragma omp parallel for 00187 #endif 00188 for (int_type i = 0; i < int_type(BlockSideLength); ++i) //OpenMP does not like unsigned iteration variables 00189 for (unsigned_type k = 0; k < BlockSideLength; ++k) 00190 for (unsigned_type j = 0; j < BlockSideLength; ++j) 00191 c[i * BlockSideLength + j] += a[i + k * BlockSideLength] * b[k + j * BlockSideLength]; 00192 } 00193 } 00194 } 00195 }; 00196 00197 #if STXXL_BLAS 00198 typedef int_type blas_int; 00199 typedef std::complex<double> blas_double_complex; 00200 typedef std::complex<float> blas_single_complex; 00201 00202 // --- vector add (used as matrix-add) ----------------- 00203 00204 extern "C" void daxpy_(const blas_int *n, const double *alpha, const double *x, const blas_int *incx, double *y, const blas_int *incy); 00205 extern "C" void saxpy_(const blas_int *n, const float *alpha, const float *x, const blas_int *incx, float *y, const blas_int *incy); 00206 extern "C" void zaxpy_(const blas_int *n, const blas_double_complex *alpha, const blas_double_complex *x, const blas_int *incx, blas_double_complex *y, const blas_int *incy); 00207 extern "C" void caxpy_(const blas_int *n, const blas_single_complex *alpha, const blas_single_complex *x, const blas_int *incx, blas_single_complex *y, const blas_int *incy); 00208 extern "C" void dcopy_(const blas_int *n, const double *x, const blas_int *incx, double *y, const blas_int *incy); 00209 extern "C" void scopy_(const blas_int *n, const float *x, const blas_int *incx, float *y, const blas_int *incy); 00210 extern "C" void zcopy_(const blas_int *n, const blas_double_complex *x, const blas_int *incx, blas_double_complex *y, const blas_int *incy); 00211 extern "C" void ccopy_(const blas_int *n, const blas_single_complex *x, const blas_int *incx, blas_single_complex *y, const blas_int *incy); 00212 00213 //! \brief c = a + b; for double entries 00214 template <unsigned BlockSideLength> 00215 struct low_level_matrix_binary_ass_op<double, BlockSideLength, false, false, typename matrix_operations<double, BlockSideLength>::addition> 00216 { 00217 low_level_matrix_binary_ass_op(double * c, const double * a, const double * b, typename matrix_operations<double, BlockSideLength>::addition = typename matrix_operations<double, BlockSideLength>::addition()) 00218 { 00219 if (a) 00220 if (b) 00221 { 00222 low_level_matrix_unary_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition> 00223 (c, a); 00224 low_level_matrix_unary_ass_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition> 00225 (c, b); 00226 } 00227 else 00228 low_level_matrix_unary_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition> 00229 (c, a); 00230 else 00231 { 00232 assert(b /* do not add nothing to nothing */); 00233 low_level_matrix_unary_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition> 00234 (c, b); 00235 } 00236 } 00237 }; 00238 //! \brief c = a - b; for double entries 00239 template <unsigned BlockSideLength> 00240 struct low_level_matrix_binary_ass_op<double, BlockSideLength, false, false, typename matrix_operations<double, BlockSideLength>::subtraction> 00241 { 00242 low_level_matrix_binary_ass_op(double * c, const double * a, const double * b, 00243 typename matrix_operations<double, BlockSideLength>::subtraction = typename matrix_operations<double, BlockSideLength>::subtraction()) 00244 { 00245 if (a) 00246 if (b) 00247 { 00248 low_level_matrix_unary_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition> 00249 (c, a); 00250 low_level_matrix_unary_ass_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::subtraction> 00251 (c, b); 00252 } 00253 else 00254 low_level_matrix_unary_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition> 00255 (c, a); 00256 else 00257 { 00258 assert(b /* do not add nothing to nothing */); 00259 low_level_matrix_unary_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::subtraction> 00260 (c, b); 00261 } 00262 } 00263 }; 00264 //! \brief c += a; for double entries 00265 template <unsigned BlockSideLength> 00266 struct low_level_matrix_unary_ass_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition> 00267 { 00268 low_level_matrix_unary_ass_op(double * c, const double * a, 00269 typename matrix_operations<double, BlockSideLength>::addition = typename matrix_operations<double, BlockSideLength>::addition()) 00270 { 00271 const blas_int size = BlockSideLength * BlockSideLength; 00272 const blas_int int_one = 1; 00273 const double one = 1.0; 00274 if (a) 00275 daxpy_(&size, &one, a, &int_one, c, &int_one); 00276 } 00277 }; 00278 //! \brief c -= a; for double entries 00279 template <unsigned BlockSideLength> 00280 struct low_level_matrix_unary_ass_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::subtraction> 00281 { 00282 low_level_matrix_unary_ass_op(double * c, const double * a, 00283 typename matrix_operations<double, BlockSideLength>::subtraction = typename matrix_operations<double, BlockSideLength>::subtraction()) 00284 { 00285 const blas_int size = BlockSideLength * BlockSideLength; 00286 const blas_int int_one = 1; 00287 const double minusone = -1.0; 00288 if (a) 00289 daxpy_(&size, &minusone, a, &int_one, c, &int_one); 00290 } 00291 }; 00292 //! \brief c = a; for double entries 00293 template <unsigned BlockSideLength> 00294 struct low_level_matrix_unary_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition> 00295 { 00296 low_level_matrix_unary_op(double * c, const double * a, 00297 typename matrix_operations<double, BlockSideLength>::addition = typename matrix_operations<double, BlockSideLength>::addition()) 00298 { 00299 const blas_int size = BlockSideLength * BlockSideLength; 00300 const blas_int int_one = 1; 00301 dcopy_(&size, a, &int_one, c, &int_one); 00302 } 00303 }; 00304 00305 //! \brief c = a + b; for float entries 00306 template <unsigned BlockSideLength> 00307 struct low_level_matrix_binary_ass_op<float, BlockSideLength, false, false, typename matrix_operations<float, BlockSideLength>::addition> 00308 { 00309 low_level_matrix_binary_ass_op(float * c, const float * a, const float * b, typename matrix_operations<float, BlockSideLength>::addition = typename matrix_operations<float, BlockSideLength>::addition()) 00310 { 00311 if (a) 00312 if (b) 00313 { 00314 low_level_matrix_unary_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition> 00315 (c, a); 00316 low_level_matrix_unary_ass_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition> 00317 (c, b); 00318 } 00319 else 00320 low_level_matrix_unary_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition> 00321 (c, a); 00322 else 00323 { 00324 assert(b /* do not add nothing to nothing */); 00325 low_level_matrix_unary_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition> 00326 (c, b); 00327 } 00328 } 00329 }; 00330 //! \brief c = a - b; for float entries 00331 template <unsigned BlockSideLength> 00332 struct low_level_matrix_binary_ass_op<float, BlockSideLength, false, false, typename matrix_operations<float, BlockSideLength>::subtraction> 00333 { 00334 low_level_matrix_binary_ass_op(float * c, const float * a, const float * b, 00335 typename matrix_operations<float, BlockSideLength>::subtraction = typename matrix_operations<float, BlockSideLength>::subtraction()) 00336 { 00337 if (a) 00338 if (b) 00339 { 00340 low_level_matrix_unary_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition> 00341 (c, a); 00342 low_level_matrix_unary_ass_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::subtraction> 00343 (c, b); 00344 } 00345 else 00346 low_level_matrix_unary_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition> 00347 (c, a); 00348 else 00349 { 00350 assert(b /* do not add nothing to nothing */); 00351 low_level_matrix_unary_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::subtraction> 00352 (c, b); 00353 } 00354 } 00355 }; 00356 //! \brief c += a; for float entries 00357 template <unsigned BlockSideLength> 00358 struct low_level_matrix_unary_ass_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition> 00359 { 00360 low_level_matrix_unary_ass_op(float * c, const float * a, 00361 typename matrix_operations<float, BlockSideLength>::addition = typename matrix_operations<float, BlockSideLength>::addition()) 00362 { 00363 const blas_int size = BlockSideLength * BlockSideLength; 00364 const blas_int int_one = 1; 00365 const float one = 1.0; 00366 if (a) 00367 saxpy_(&size, &one, a, &int_one, c, &int_one); 00368 } 00369 }; 00370 //! \brief c -= a; for float entries 00371 template <unsigned BlockSideLength> 00372 struct low_level_matrix_unary_ass_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::subtraction> 00373 { 00374 low_level_matrix_unary_ass_op(float * c, const float * a, 00375 typename matrix_operations<float, BlockSideLength>::subtraction = typename matrix_operations<float, BlockSideLength>::subtraction()) 00376 { 00377 const blas_int size = BlockSideLength * BlockSideLength; 00378 const blas_int int_one = 1; 00379 const float minusone = -1.0; 00380 if (a) 00381 saxpy_(&size, &minusone, a, &int_one, c, &int_one); 00382 } 00383 }; 00384 //! \brief c = a; for float entries 00385 template <unsigned BlockSideLength> 00386 struct low_level_matrix_unary_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition> 00387 { 00388 low_level_matrix_unary_op(float * c, const float * a, 00389 typename matrix_operations<float, BlockSideLength>::addition = typename matrix_operations<float, BlockSideLength>::addition()) 00390 { 00391 const blas_int size = BlockSideLength * BlockSideLength; 00392 const blas_int int_one = 1; 00393 scopy_(&size, a, &int_one, c, &int_one); 00394 } 00395 }; 00396 00397 //! \brief c = a + b; for blas_double_complex entries 00398 template <unsigned BlockSideLength> 00399 struct low_level_matrix_binary_ass_op<blas_double_complex, BlockSideLength, false, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition> 00400 { 00401 low_level_matrix_binary_ass_op(blas_double_complex * c, const blas_double_complex * a, const blas_double_complex * b, typename matrix_operations<blas_double_complex, BlockSideLength>::addition = typename matrix_operations<blas_double_complex, BlockSideLength>::addition()) 00402 { 00403 if (a) 00404 if (b) 00405 { 00406 low_level_matrix_unary_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition> 00407 (c, a); 00408 low_level_matrix_unary_ass_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition> 00409 (c, b); 00410 } 00411 else 00412 low_level_matrix_unary_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition> 00413 (c, a); 00414 else 00415 { 00416 assert(b /* do not add nothing to nothing */); 00417 low_level_matrix_unary_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition> 00418 (c, b); 00419 } 00420 } 00421 }; 00422 //! \brief c = a - b; for blas_double_complex entries 00423 template <unsigned BlockSideLength> 00424 struct low_level_matrix_binary_ass_op<blas_double_complex, BlockSideLength, false, false, typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction> 00425 { 00426 low_level_matrix_binary_ass_op(blas_double_complex * c, const blas_double_complex * a, const blas_double_complex * b, 00427 typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction = typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction()) 00428 { 00429 if (a) 00430 if (b) 00431 { 00432 low_level_matrix_unary_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition> 00433 (c, a); 00434 low_level_matrix_unary_ass_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction> 00435 (c, b); 00436 } 00437 else 00438 low_level_matrix_unary_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition> 00439 (c, a); 00440 else 00441 { 00442 assert(b /* do not add nothing to nothing */); 00443 low_level_matrix_unary_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction> 00444 (c, b); 00445 } 00446 } 00447 }; 00448 //! \brief c += a; for blas_double_complex entries 00449 template <unsigned BlockSideLength> 00450 struct low_level_matrix_unary_ass_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition> 00451 { 00452 low_level_matrix_unary_ass_op(blas_double_complex * c, const blas_double_complex * a, 00453 typename matrix_operations<blas_double_complex, BlockSideLength>::addition = typename matrix_operations<blas_double_complex, BlockSideLength>::addition()) 00454 { 00455 const blas_int size = BlockSideLength * BlockSideLength; 00456 const blas_int int_one = 1; 00457 const blas_double_complex one = 1.0; 00458 if (a) 00459 zaxpy_(&size, &one, a, &int_one, c, &int_one); 00460 } 00461 }; 00462 //! \brief c -= a; for blas_double_complex entries 00463 template <unsigned BlockSideLength> 00464 struct low_level_matrix_unary_ass_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction> 00465 { 00466 low_level_matrix_unary_ass_op(blas_double_complex * c, const blas_double_complex * a, 00467 typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction = typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction()) 00468 { 00469 const blas_int size = BlockSideLength * BlockSideLength; 00470 const blas_int int_one = 1; 00471 const blas_double_complex minusone = -1.0; 00472 if (a) 00473 zaxpy_(&size, &minusone, a, &int_one, c, &int_one); 00474 } 00475 }; 00476 //! \brief c = a; for blas_double_complex entries 00477 template <unsigned BlockSideLength> 00478 struct low_level_matrix_unary_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition> 00479 { 00480 low_level_matrix_unary_op(blas_double_complex * c, const blas_double_complex * a, 00481 typename matrix_operations<blas_double_complex, BlockSideLength>::addition = typename matrix_operations<blas_double_complex, BlockSideLength>::addition()) 00482 { 00483 const blas_int size = BlockSideLength * BlockSideLength; 00484 const blas_int int_one = 1; 00485 zcopy_(&size, a, &int_one, c, &int_one); 00486 } 00487 }; 00488 00489 //! \brief c = a + b; for blas_single_complex entries 00490 template <unsigned BlockSideLength> 00491 struct low_level_matrix_binary_ass_op<blas_single_complex, BlockSideLength, false, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition> 00492 { 00493 low_level_matrix_binary_ass_op(blas_single_complex * c, const blas_single_complex * a, const blas_single_complex * b, typename matrix_operations<blas_single_complex, BlockSideLength>::addition = typename matrix_operations<blas_single_complex, BlockSideLength>::addition()) 00494 { 00495 if (a) 00496 if (b) 00497 { 00498 low_level_matrix_unary_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition> 00499 (c, a); 00500 low_level_matrix_unary_ass_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition> 00501 (c, b); 00502 } 00503 else 00504 low_level_matrix_unary_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition> 00505 (c, a); 00506 else 00507 { 00508 assert(b /* do not add nothing to nothing */); 00509 low_level_matrix_unary_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition> 00510 (c, b); 00511 } 00512 } 00513 }; 00514 //! \brief c = a - b; for blas_single_complex entries 00515 template <unsigned BlockSideLength> 00516 struct low_level_matrix_binary_ass_op<blas_single_complex, BlockSideLength, false, false, typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction> 00517 { 00518 low_level_matrix_binary_ass_op(blas_single_complex * c, const blas_single_complex * a, const blas_single_complex * b, 00519 typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction = typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction()) 00520 { 00521 if (a) 00522 if (b) 00523 { 00524 low_level_matrix_unary_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition> 00525 (c, a); 00526 low_level_matrix_unary_ass_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction> 00527 (c, b); 00528 } 00529 else 00530 low_level_matrix_unary_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition> 00531 (c, a); 00532 else 00533 { 00534 assert(b /* do not add nothing to nothing */); 00535 low_level_matrix_unary_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction> 00536 (c, b); 00537 } 00538 } 00539 }; 00540 //! \brief c += a; for blas_single_complex entries 00541 template <unsigned BlockSideLength> 00542 struct low_level_matrix_unary_ass_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition> 00543 { 00544 low_level_matrix_unary_ass_op(blas_single_complex * c, const blas_single_complex * a, 00545 typename matrix_operations<blas_single_complex, BlockSideLength>::addition = typename matrix_operations<blas_single_complex, BlockSideLength>::addition()) 00546 { 00547 const blas_int size = BlockSideLength * BlockSideLength; 00548 const blas_int int_one = 1; 00549 const blas_single_complex one = 1.0; 00550 if (a) 00551 caxpy_(&size, &one, a, &int_one, c, &int_one); 00552 } 00553 }; 00554 //! \brief c -= a; for blas_single_complex entries 00555 template <unsigned BlockSideLength> 00556 struct low_level_matrix_unary_ass_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction> 00557 { 00558 low_level_matrix_unary_ass_op(blas_single_complex * c, const blas_single_complex * a, 00559 typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction = typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction()) 00560 { 00561 const blas_int size = BlockSideLength * BlockSideLength; 00562 const blas_int int_one = 1; 00563 const blas_single_complex minusone = -1.0; 00564 if (a) 00565 caxpy_(&size, &minusone, a, &int_one, c, &int_one); 00566 } 00567 }; 00568 //! \brief c = a; for blas_single_complex entries 00569 template <unsigned BlockSideLength> 00570 struct low_level_matrix_unary_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition> 00571 { 00572 low_level_matrix_unary_op(blas_single_complex * c, const blas_single_complex * a, 00573 typename matrix_operations<blas_single_complex, BlockSideLength>::addition = typename matrix_operations<blas_single_complex, BlockSideLength>::addition()) 00574 { 00575 const blas_int size = BlockSideLength * BlockSideLength; 00576 const blas_int int_one = 1; 00577 ccopy_(&size, a, &int_one, c, &int_one); 00578 } 00579 }; 00580 00581 // --- matrix-matrix multiplication --------------- 00582 00583 extern "C" void dgemm_(const char *transa, const char *transb, 00584 const blas_int *m, const blas_int *n, const blas_int *k, 00585 const double *alpha, const double *a, const blas_int *lda, 00586 const double *b, const blas_int *ldb, 00587 const double *beta, double *c, const blas_int *ldc); 00588 00589 extern "C" void sgemm_(const char *transa, const char *transb, 00590 const blas_int *m, const blas_int *n, const blas_int *k, 00591 const float *alpha, const float *a, const blas_int *lda, 00592 const float *b, const blas_int *ldb, 00593 const float *beta, float *c, const blas_int *ldc); 00594 00595 extern "C" void zgemm_(const char *transa, const char *transb, 00596 const blas_int *m, const blas_int *n, const blas_int *k, 00597 const blas_double_complex *alpha, const blas_double_complex *a, const blas_int *lda, 00598 const blas_double_complex *b, const blas_int *ldb, 00599 const blas_double_complex *beta, blas_double_complex *c, const blas_int *ldc); 00600 00601 extern "C" void cgemm_(const char *transa, const char *transb, 00602 const blas_int *m, const blas_int *n, const blas_int *k, 00603 const blas_single_complex *alpha, const blas_single_complex *a, const blas_int *lda, 00604 const blas_single_complex *b, const blas_int *ldb, 00605 const blas_single_complex *beta, blas_single_complex *c, const blas_int *ldc); 00606 00607 template <typename ValueType> 00608 void gemm_(const char *transa, const char *transb, 00609 const blas_int *m, const blas_int *n, const blas_int *k, 00610 const ValueType *alpha, const ValueType *a, const blas_int *lda, 00611 const ValueType *b, const blas_int *ldb, 00612 const ValueType *beta, ValueType *c, const blas_int *ldc); 00613 00614 //! \brief calculates c = alpha * a * b + beta * c 00615 //! \tparam ValueType type of elements 00616 //! \param n height of a and c 00617 //! \param l width of a and height of b 00618 //! \param m width of b and c 00619 //! \param a_in_col_major if a is stored in column-major rather than row-major 00620 //! \param b_in_col_major if b is stored in column-major rather than row-major 00621 //! \param c_in_col_major if c is stored in column-major rather than row-major 00622 template <typename ValueType> 00623 void gemm_wrapper(const blas_int n, const blas_int l, const blas_int m, 00624 const ValueType alpha, const bool a_in_col_major, const ValueType *a, 00625 const bool b_in_col_major, const ValueType *b, 00626 const ValueType beta, const bool c_in_col_major, ValueType *c) 00627 { 00628 const blas_int& stride_in_a = a_in_col_major ? n : l; 00629 const blas_int& stride_in_b = b_in_col_major ? l : m; 00630 const blas_int& stride_in_c = c_in_col_major ? n : m; 00631 const char transa = a_in_col_major xor c_in_col_major ? 'T' : 'N'; 00632 const char transb = b_in_col_major xor c_in_col_major ? 'T' : 'N'; 00633 if (c_in_col_major) 00634 // blas expects matrices in column-major unless specified via transa rsp. transb 00635 gemm_(&transa, &transb, &n, &m, &l, &alpha, a, &stride_in_a, b, &stride_in_b, &beta, c, &stride_in_c); 00636 else 00637 // blas expects matrices in column-major, so we calculate c^T = alpha * b^T * a^T + beta * c^T 00638 gemm_(&transb, &transa, &m, &n, &l, &alpha, b, &stride_in_b, a, &stride_in_a, &beta, c, &stride_in_c); 00639 } 00640 00641 template <> 00642 void gemm_(const char *transa, const char *transb, 00643 const blas_int *m, const blas_int *n, const blas_int *k, 00644 const double *alpha, const double *a, const blas_int *lda, 00645 const double *b, const blas_int *ldb, 00646 const double *beta, double *c, const blas_int *ldc) 00647 { dgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } 00648 00649 template <> 00650 void gemm_(const char *transa, const char *transb, 00651 const blas_int *m, const blas_int *n, const blas_int *k, 00652 const float *alpha, const float *a, const blas_int *lda, 00653 const float *b, const blas_int *ldb, 00654 const float *beta, float *c, const blas_int *ldc) 00655 { sgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } 00656 00657 template <> 00658 void gemm_(const char *transa, const char *transb, 00659 const blas_int *m, const blas_int *n, const blas_int *k, 00660 const blas_double_complex *alpha, const blas_double_complex *a, const blas_int *lda, 00661 const blas_double_complex *b, const blas_int *ldb, 00662 const blas_double_complex *beta, blas_double_complex *c, const blas_int *ldc) 00663 { zgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } 00664 00665 template <> 00666 void gemm_(const char *transa, const char *transb, 00667 const blas_int *m, const blas_int *n, const blas_int *k, 00668 const blas_single_complex *alpha, const blas_single_complex *a, const blas_int *lda, 00669 const blas_single_complex *b, const blas_int *ldb, 00670 const blas_single_complex *beta, blas_single_complex *c, const blas_int *ldc) 00671 { cgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } 00672 00673 //! \brief multiplies matrices A and B, adds result to C, for double entries 00674 template <unsigned BlockSideLength> 00675 struct low_level_matrix_multiply_and_add<double, BlockSideLength> 00676 { 00677 low_level_matrix_multiply_and_add(const double * a, bool a_in_col_major, 00678 const double * b, bool b_in_col_major, 00679 double * c, const bool c_in_col_major) 00680 { 00681 gemm_wrapper<double>(BlockSideLength, BlockSideLength, BlockSideLength, 00682 1.0, a_in_col_major, a, 00683 b_in_col_major, b, 00684 1.0, c_in_col_major, c); 00685 } 00686 }; 00687 00688 //! \brief multiplies matrices A and B, adds result to C, for float entries 00689 template <unsigned BlockSideLength> 00690 struct low_level_matrix_multiply_and_add<float, BlockSideLength> 00691 { 00692 low_level_matrix_multiply_and_add(const float * a, bool a_in_col_major, 00693 const float * b, bool b_in_col_major, 00694 float * c, const bool c_in_col_major) 00695 { 00696 gemm_wrapper<float>(BlockSideLength, BlockSideLength, BlockSideLength, 00697 1.0, a_in_col_major, a, 00698 b_in_col_major, b, 00699 1.0, c_in_col_major, c); 00700 } 00701 }; 00702 00703 //! \brief multiplies matrices A and B, adds result to C, for complex<float> entries 00704 template <unsigned BlockSideLength> 00705 struct low_level_matrix_multiply_and_add<blas_single_complex, BlockSideLength> 00706 { 00707 low_level_matrix_multiply_and_add(const blas_single_complex * a, bool a_in_col_major, 00708 const blas_single_complex * b, bool b_in_col_major, 00709 blas_single_complex * c, const bool c_in_col_major) 00710 { 00711 gemm_wrapper<blas_single_complex>(BlockSideLength, BlockSideLength, BlockSideLength, 00712 1.0, a_in_col_major, a, 00713 b_in_col_major, b, 00714 1.0, c_in_col_major, c); 00715 } 00716 }; 00717 00718 //! \brief multiplies matrices A and B, adds result to C, for complex<double> entries 00719 template <unsigned BlockSideLength> 00720 struct low_level_matrix_multiply_and_add<blas_double_complex, BlockSideLength> 00721 { 00722 low_level_matrix_multiply_and_add(const blas_double_complex * a, bool a_in_col_major, 00723 const blas_double_complex * b, bool b_in_col_major, 00724 blas_double_complex * c, const bool c_in_col_major) 00725 { 00726 gemm_wrapper<blas_double_complex>(BlockSideLength, BlockSideLength, BlockSideLength, 00727 1.0, a_in_col_major, a, 00728 b_in_col_major, b, 00729 1.0, c_in_col_major, c); 00730 } 00731 }; 00732 #endif 00733 00734 __STXXL_END_NAMESPACE 00735 00736 #endif /* STXXL_MATRIX_LOW_LEVEL_HEADER */