Stxxl  1.4.0
include/stxxl/bits/containers/matrix_low_level.h
Go to the documentation of this file.
00001 /***************************************************************************
00002  *  include/stxxl/bits/containers/matrix_low_level.h
00003  *
00004  *  Part of the STXXL. See http://stxxl.sourceforge.net
00005  *
00006  *  Copyright (C) 2010-2011 Raoul Steffen <R-Steffen@gmx.de>
00007  *
00008  *  Distributed under the Boost Software License, Version 1.0.
00009  *  (See accompanying file LICENSE_1_0.txt or copy at
00010  *  http://www.boost.org/LICENSE_1_0.txt)
00011  **************************************************************************/
00012 
00013 #ifndef STXXL_MATRIX_LOW_LEVEL_HEADER
00014 #define STXXL_MATRIX_LOW_LEVEL_HEADER
00015 
00016 #ifndef STXXL_BLAS
00017 #define STXXL_BLAS 0
00018 #endif
00019 
00020 #include <complex>
00021 
00022 #include <stxxl/bits/common/types.h>
00023 #include <stxxl/bits/parallel.h>
00024 
00025 __STXXL_BEGIN_NAMESPACE
00026 
00027 //forward declaration
00028 template <typename ValueType, unsigned BlockSideLength>
00029 struct matrix_operations;
00030 
00031 //generic declaration
00032 template <unsigned BlockSideLength, bool transposed>
00033 struct switch_major_index;
00034 
00035 //row-major specialization
00036 template <unsigned BlockSideLength>
00037 struct switch_major_index<BlockSideLength, false>
00038 {
00039     inline switch_major_index(const int_type row, const int_type col) : i(row * BlockSideLength + col) {}
00040     inline operator int_type & () { return i; }
00041 private:
00042     int_type i;
00043 };
00044 
00045 //column-major specialization
00046 template <unsigned BlockSideLength>
00047 struct switch_major_index<BlockSideLength, true>
00048 {
00049     inline switch_major_index(const int_type row, const int_type col) : i(row + col * BlockSideLength) {}
00050     inline operator int_type & () { return i; }
00051 private:
00052     int_type i;
00053 };
00054 
00055 //! \brief c = a <op> b; for arbitrary entries
00056 template <typename ValueType, unsigned BlockSideLength, bool a_transposed, bool b_transposed, class Op>
00057 struct low_level_matrix_binary_ass_op
00058 {
00059     low_level_matrix_binary_ass_op(ValueType * c, const ValueType * a, const ValueType * b, Op op = Op())
00060     {
00061         if (a)
00062             if (b)
00063                 #if STXXL_PARALLEL
00064                 #pragma omp parallel for
00065                 #endif
00066                 for (int_type row = 0; row < int_type(BlockSideLength); ++row)
00067                     for (int_type col = 0; col < int_type(BlockSideLength); ++col)
00068                         op(c[switch_major_index<BlockSideLength, false>(row, col)],
00069                                 a[switch_major_index<BlockSideLength, a_transposed>(row, col)],
00070                                 b[switch_major_index<BlockSideLength, b_transposed>(row, col)]);
00071             else
00072                 #if STXXL_PARALLEL
00073                 #pragma omp parallel for
00074                 #endif
00075                 for (int_type row = 0; row < int_type(BlockSideLength); ++row)
00076                     for (int_type col = 0; col < int_type(BlockSideLength); ++col)
00077                         op(c[switch_major_index<BlockSideLength, false>(row, col)],
00078                                 a[switch_major_index<BlockSideLength, a_transposed>(row, col)], 0);
00079         else
00080         {
00081             assert(b /* do not add nothing to nothing */);
00082             #if STXXL_PARALLEL
00083             #pragma omp parallel for
00084             #endif
00085             for (int_type row = 0; row < int_type(BlockSideLength); ++row)
00086                 for (int_type col = 0; col < int_type(BlockSideLength); ++col)
00087                     op(c[switch_major_index<BlockSideLength, false>(row, col)],
00088                                 0, b[switch_major_index<BlockSideLength, b_transposed>(row, col)]);
00089         }
00090     }
00091 };
00092 
00093 //! \brief c <op>= a; for arbitrary entries
00094 template <typename ValueType, unsigned BlockSideLength, bool a_transposed, class Op>
00095 struct low_level_matrix_unary_ass_op
00096 {
00097     low_level_matrix_unary_ass_op(ValueType * c, const ValueType * a, Op op = Op())
00098     {
00099         if (a)
00100             #if STXXL_PARALLEL
00101             #pragma omp parallel for
00102             #endif
00103             for (int_type row = 0; row < int_type(BlockSideLength); ++row)
00104                 for (int_type col = 0; col < int_type(BlockSideLength); ++col)
00105                     op(c[switch_major_index<BlockSideLength, false>(row, col)],
00106                             a[switch_major_index<BlockSideLength, a_transposed>(row, col)]);
00107     }
00108 };
00109 
00110 //! \brief c = <op>a; for arbitrary entries
00111 template <typename ValueType, unsigned BlockSideLength, bool a_transposed, class Op>
00112 struct low_level_matrix_unary_op
00113 {
00114     low_level_matrix_unary_op(ValueType * c, const ValueType * a, Op op = Op())
00115     {
00116         assert(a);
00117         #if STXXL_PARALLEL
00118         #pragma omp parallel for
00119         #endif
00120         for (int_type row = 0; row < int_type(BlockSideLength); ++row)
00121             for (int_type col = 0; col < int_type(BlockSideLength); ++col)
00122                 c[switch_major_index<BlockSideLength, false>(row, col)] =
00123                         op(a[switch_major_index<BlockSideLength, a_transposed>(row, col)]);
00124     }
00125 };
00126 
00127 //! \brief multiplies matrices A and B, adds result to C, for arbitrary entries
00128 //! param pointer to blocks of A,B,C; elements in blocks have to be in row-major
00129 /* designated usage as:
00130  * void
00131  * low_level_matrix_multiply_and_add(const double * a, bool a_in_col_major,
00132                                      const double * b, bool b_in_col_major,
00133                                      double * c, const bool c_in_col_major)  */
00134 template <typename ValueType, unsigned BlockSideLength>
00135 struct low_level_matrix_multiply_and_add
00136 {
00137     low_level_matrix_multiply_and_add(const ValueType * a, bool a_in_col_major,
00138                                       const ValueType * b, bool b_in_col_major,
00139                                       ValueType * c, const bool c_in_col_major)
00140     {
00141         if (c_in_col_major)
00142         {
00143             std::swap(a,b);
00144             bool a_cm = ! b_in_col_major;
00145             b_in_col_major = ! a_in_col_major;
00146             a_in_col_major = a_cm;
00147         }
00148         if (! a_in_col_major)
00149         {
00150             if (! b_in_col_major)
00151             {   // => both row-major
00152                 #if STXXL_PARALLEL
00153                 #pragma omp parallel for
00154                 #endif
00155                 for (int_type i = 0; i < int_type(BlockSideLength); ++i)    //OpenMP does not like unsigned iteration variables
00156                   for (unsigned_type k = 0; k < BlockSideLength; ++k)
00157                       for (unsigned_type j = 0; j < BlockSideLength; ++j)
00158                           c[i * BlockSideLength + j] += a[i * BlockSideLength + k] * b[k * BlockSideLength + j];
00159             }
00160             else
00161             {   // => a row-major, b col-major
00162                 #if STXXL_PARALLEL
00163                 #pragma omp parallel for
00164                 #endif
00165                 for (int_type i = 0; i < int_type(BlockSideLength); ++i)    //OpenMP does not like unsigned iteration variables
00166                     for (unsigned_type j = 0; j < BlockSideLength; ++j)
00167                         for (unsigned_type k = 0; k < BlockSideLength; ++k)
00168                           c[i * BlockSideLength + j] += a[i * BlockSideLength + k] * b[k + j * BlockSideLength];
00169             }
00170         }
00171         else
00172         {
00173             if (! b_in_col_major)
00174             {   // => a col-major, b row-major
00175                 #if STXXL_PARALLEL
00176                 #pragma omp parallel for
00177                 #endif
00178                 for (int_type i = 0; i < int_type(BlockSideLength); ++i)    //OpenMP does not like unsigned iteration variables
00179                   for (unsigned_type k = 0; k < BlockSideLength; ++k)
00180                       for (unsigned_type j = 0; j < BlockSideLength; ++j)
00181                           c[i * BlockSideLength + j] += a[i + k * BlockSideLength] * b[k * BlockSideLength + j];
00182             }
00183             else
00184             {   // => both col-major
00185                 #if STXXL_PARALLEL
00186                 #pragma omp parallel for
00187                 #endif
00188                 for (int_type i = 0; i < int_type(BlockSideLength); ++i)    //OpenMP does not like unsigned iteration variables
00189                   for (unsigned_type k = 0; k < BlockSideLength; ++k)
00190                       for (unsigned_type j = 0; j < BlockSideLength; ++j)
00191                           c[i * BlockSideLength + j] += a[i + k * BlockSideLength] * b[k + j * BlockSideLength];
00192             }
00193         }
00194     }
00195 };
00196 
00197 #if STXXL_BLAS
00198 typedef int_type blas_int;
00199 typedef std::complex<double> blas_double_complex;
00200 typedef std::complex<float> blas_single_complex;
00201 
00202 // --- vector add (used as matrix-add) -----------------
00203 
00204 extern "C" void daxpy_(const blas_int *n, const double              *alpha, const double              *x, const blas_int *incx, double              *y, const blas_int *incy);
00205 extern "C" void saxpy_(const blas_int *n, const float               *alpha, const float               *x, const blas_int *incx, float               *y, const blas_int *incy);
00206 extern "C" void zaxpy_(const blas_int *n, const blas_double_complex *alpha, const blas_double_complex *x, const blas_int *incx, blas_double_complex *y, const blas_int *incy);
00207 extern "C" void caxpy_(const blas_int *n, const blas_single_complex *alpha, const blas_single_complex *x, const blas_int *incx, blas_single_complex *y, const blas_int *incy);
00208 extern "C" void dcopy_(const blas_int *n, const double              *x, const blas_int *incx, double              *y, const blas_int *incy);
00209 extern "C" void scopy_(const blas_int *n, const float               *x, const blas_int *incx, float               *y, const blas_int *incy);
00210 extern "C" void zcopy_(const blas_int *n, const blas_double_complex *x, const blas_int *incx, blas_double_complex *y, const blas_int *incy);
00211 extern "C" void ccopy_(const blas_int *n, const blas_single_complex *x, const blas_int *incx, blas_single_complex *y, const blas_int *incy);
00212 
00213 //! \brief c = a + b; for double entries
00214 template <unsigned BlockSideLength>
00215 struct low_level_matrix_binary_ass_op<double, BlockSideLength, false, false, typename matrix_operations<double, BlockSideLength>::addition>
00216 {
00217     low_level_matrix_binary_ass_op(double * c, const double * a, const double * b, typename matrix_operations<double, BlockSideLength>::addition = typename matrix_operations<double, BlockSideLength>::addition())
00218     {
00219         if (a)
00220             if (b)
00221             {
00222                 low_level_matrix_unary_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition>
00223                         (c, a);
00224                 low_level_matrix_unary_ass_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition>
00225                         (c, b);
00226             }
00227             else
00228                 low_level_matrix_unary_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition>
00229                         (c, a);
00230         else
00231         {
00232             assert(b /* do not add nothing to nothing */);
00233             low_level_matrix_unary_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition>
00234                     (c, b);
00235         }
00236     }
00237 };
00238 //! \brief c = a - b; for double entries
00239 template <unsigned BlockSideLength>
00240 struct low_level_matrix_binary_ass_op<double, BlockSideLength, false, false, typename matrix_operations<double, BlockSideLength>::subtraction>
00241 {
00242     low_level_matrix_binary_ass_op(double * c, const double * a, const double * b,
00243             typename matrix_operations<double, BlockSideLength>::subtraction = typename matrix_operations<double, BlockSideLength>::subtraction())
00244     {
00245         if (a)
00246             if (b)
00247             {
00248                 low_level_matrix_unary_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition>
00249                         (c, a);
00250                 low_level_matrix_unary_ass_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::subtraction>
00251                         (c, b);
00252             }
00253             else
00254                 low_level_matrix_unary_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition>
00255                         (c, a);
00256         else
00257         {
00258             assert(b /* do not add nothing to nothing */);
00259             low_level_matrix_unary_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::subtraction>
00260                     (c, b);
00261         }
00262     }
00263 };
00264 //! \brief c += a; for double entries
00265 template <unsigned BlockSideLength>
00266 struct low_level_matrix_unary_ass_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition>
00267 {
00268     low_level_matrix_unary_ass_op(double * c, const double * a,
00269             typename matrix_operations<double, BlockSideLength>::addition = typename matrix_operations<double, BlockSideLength>::addition())
00270     {
00271         const blas_int size = BlockSideLength * BlockSideLength;
00272         const blas_int int_one = 1;
00273         const double one = 1.0;
00274         if (a)
00275             daxpy_(&size, &one, a, &int_one, c, &int_one);
00276     }
00277 };
00278 //! \brief c -= a; for double entries
00279 template <unsigned BlockSideLength>
00280 struct low_level_matrix_unary_ass_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::subtraction>
00281 {
00282     low_level_matrix_unary_ass_op(double * c, const double * a,
00283             typename matrix_operations<double, BlockSideLength>::subtraction = typename matrix_operations<double, BlockSideLength>::subtraction())
00284     {
00285         const blas_int size = BlockSideLength * BlockSideLength;
00286         const blas_int int_one = 1;
00287         const double minusone = -1.0;
00288         if (a)
00289             daxpy_(&size, &minusone, a, &int_one, c, &int_one);
00290     }
00291 };
00292 //! \brief c = a; for double entries
00293 template <unsigned BlockSideLength>
00294 struct low_level_matrix_unary_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition>
00295 {
00296     low_level_matrix_unary_op(double * c, const double * a,
00297             typename matrix_operations<double, BlockSideLength>::addition = typename matrix_operations<double, BlockSideLength>::addition())
00298     {
00299         const blas_int size = BlockSideLength * BlockSideLength;
00300         const blas_int int_one = 1;
00301         dcopy_(&size, a, &int_one, c, &int_one);
00302     }
00303 };
00304 
00305 //! \brief c = a + b; for float entries
00306 template <unsigned BlockSideLength>
00307 struct low_level_matrix_binary_ass_op<float, BlockSideLength, false, false, typename matrix_operations<float, BlockSideLength>::addition>
00308 {
00309     low_level_matrix_binary_ass_op(float * c, const float * a, const float * b, typename matrix_operations<float, BlockSideLength>::addition = typename matrix_operations<float, BlockSideLength>::addition())
00310     {
00311         if (a)
00312             if (b)
00313             {
00314                 low_level_matrix_unary_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition>
00315                         (c, a);
00316                 low_level_matrix_unary_ass_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition>
00317                         (c, b);
00318             }
00319             else
00320                 low_level_matrix_unary_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition>
00321                         (c, a);
00322         else
00323         {
00324             assert(b /* do not add nothing to nothing */);
00325             low_level_matrix_unary_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition>
00326                     (c, b);
00327         }
00328     }
00329 };
00330 //! \brief c = a - b; for float entries
00331 template <unsigned BlockSideLength>
00332 struct low_level_matrix_binary_ass_op<float, BlockSideLength, false, false, typename matrix_operations<float, BlockSideLength>::subtraction>
00333 {
00334     low_level_matrix_binary_ass_op(float * c, const float * a, const float * b,
00335             typename matrix_operations<float, BlockSideLength>::subtraction = typename matrix_operations<float, BlockSideLength>::subtraction())
00336     {
00337         if (a)
00338             if (b)
00339             {
00340                 low_level_matrix_unary_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition>
00341                         (c, a);
00342                 low_level_matrix_unary_ass_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::subtraction>
00343                         (c, b);
00344             }
00345             else
00346                 low_level_matrix_unary_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition>
00347                         (c, a);
00348         else
00349         {
00350             assert(b /* do not add nothing to nothing */);
00351             low_level_matrix_unary_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::subtraction>
00352                     (c, b);
00353         }
00354     }
00355 };
00356 //! \brief c += a; for float entries
00357 template <unsigned BlockSideLength>
00358 struct low_level_matrix_unary_ass_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition>
00359 {
00360     low_level_matrix_unary_ass_op(float * c, const float * a,
00361             typename matrix_operations<float, BlockSideLength>::addition = typename matrix_operations<float, BlockSideLength>::addition())
00362     {
00363         const blas_int size = BlockSideLength * BlockSideLength;
00364         const blas_int int_one = 1;
00365         const float one = 1.0;
00366         if (a)
00367             saxpy_(&size, &one, a, &int_one, c, &int_one);
00368     }
00369 };
00370 //! \brief c -= a; for float entries
00371 template <unsigned BlockSideLength>
00372 struct low_level_matrix_unary_ass_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::subtraction>
00373 {
00374     low_level_matrix_unary_ass_op(float * c, const float * a,
00375             typename matrix_operations<float, BlockSideLength>::subtraction = typename matrix_operations<float, BlockSideLength>::subtraction())
00376     {
00377         const blas_int size = BlockSideLength * BlockSideLength;
00378         const blas_int int_one = 1;
00379         const float minusone = -1.0;
00380         if (a)
00381             saxpy_(&size, &minusone, a, &int_one, c, &int_one);
00382     }
00383 };
00384 //! \brief c = a; for float entries
00385 template <unsigned BlockSideLength>
00386 struct low_level_matrix_unary_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition>
00387 {
00388     low_level_matrix_unary_op(float * c, const float * a,
00389             typename matrix_operations<float, BlockSideLength>::addition = typename matrix_operations<float, BlockSideLength>::addition())
00390     {
00391         const blas_int size = BlockSideLength * BlockSideLength;
00392         const blas_int int_one = 1;
00393         scopy_(&size, a, &int_one, c, &int_one);
00394     }
00395 };
00396 
00397 //! \brief c = a + b; for blas_double_complex entries
00398 template <unsigned BlockSideLength>
00399 struct low_level_matrix_binary_ass_op<blas_double_complex, BlockSideLength, false, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition>
00400 {
00401     low_level_matrix_binary_ass_op(blas_double_complex * c, const blas_double_complex * a, const blas_double_complex * b, typename matrix_operations<blas_double_complex, BlockSideLength>::addition = typename matrix_operations<blas_double_complex, BlockSideLength>::addition())
00402     {
00403         if (a)
00404             if (b)
00405             {
00406                 low_level_matrix_unary_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition>
00407                         (c, a);
00408                 low_level_matrix_unary_ass_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition>
00409                         (c, b);
00410             }
00411             else
00412                 low_level_matrix_unary_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition>
00413                         (c, a);
00414         else
00415         {
00416             assert(b /* do not add nothing to nothing */);
00417             low_level_matrix_unary_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition>
00418                     (c, b);
00419         }
00420     }
00421 };
00422 //! \brief c = a - b; for blas_double_complex entries
00423 template <unsigned BlockSideLength>
00424 struct low_level_matrix_binary_ass_op<blas_double_complex, BlockSideLength, false, false, typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction>
00425 {
00426     low_level_matrix_binary_ass_op(blas_double_complex * c, const blas_double_complex * a, const blas_double_complex * b,
00427             typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction = typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction())
00428     {
00429         if (a)
00430             if (b)
00431             {
00432                 low_level_matrix_unary_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition>
00433                         (c, a);
00434                 low_level_matrix_unary_ass_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction>
00435                         (c, b);
00436             }
00437             else
00438                 low_level_matrix_unary_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition>
00439                         (c, a);
00440         else
00441         {
00442             assert(b /* do not add nothing to nothing */);
00443             low_level_matrix_unary_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction>
00444                     (c, b);
00445         }
00446     }
00447 };
00448 //! \brief c += a; for blas_double_complex entries
00449 template <unsigned BlockSideLength>
00450 struct low_level_matrix_unary_ass_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition>
00451 {
00452     low_level_matrix_unary_ass_op(blas_double_complex * c, const blas_double_complex * a,
00453             typename matrix_operations<blas_double_complex, BlockSideLength>::addition = typename matrix_operations<blas_double_complex, BlockSideLength>::addition())
00454     {
00455         const blas_int size = BlockSideLength * BlockSideLength;
00456         const blas_int int_one = 1;
00457         const blas_double_complex one = 1.0;
00458         if (a)
00459             zaxpy_(&size, &one, a, &int_one, c, &int_one);
00460     }
00461 };
00462 //! \brief c -= a; for blas_double_complex entries
00463 template <unsigned BlockSideLength>
00464 struct low_level_matrix_unary_ass_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction>
00465 {
00466     low_level_matrix_unary_ass_op(blas_double_complex * c, const blas_double_complex * a,
00467             typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction = typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction())
00468     {
00469         const blas_int size = BlockSideLength * BlockSideLength;
00470         const blas_int int_one = 1;
00471         const blas_double_complex minusone = -1.0;
00472         if (a)
00473             zaxpy_(&size, &minusone, a, &int_one, c, &int_one);
00474     }
00475 };
00476 //! \brief c = a; for blas_double_complex entries
00477 template <unsigned BlockSideLength>
00478 struct low_level_matrix_unary_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition>
00479 {
00480     low_level_matrix_unary_op(blas_double_complex * c, const blas_double_complex * a,
00481             typename matrix_operations<blas_double_complex, BlockSideLength>::addition = typename matrix_operations<blas_double_complex, BlockSideLength>::addition())
00482     {
00483         const blas_int size = BlockSideLength * BlockSideLength;
00484         const blas_int int_one = 1;
00485         zcopy_(&size, a, &int_one, c, &int_one);
00486     }
00487 };
00488 
00489 //! \brief c = a + b; for blas_single_complex entries
00490 template <unsigned BlockSideLength>
00491 struct low_level_matrix_binary_ass_op<blas_single_complex, BlockSideLength, false, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition>
00492 {
00493     low_level_matrix_binary_ass_op(blas_single_complex * c, const blas_single_complex * a, const blas_single_complex * b, typename matrix_operations<blas_single_complex, BlockSideLength>::addition = typename matrix_operations<blas_single_complex, BlockSideLength>::addition())
00494     {
00495         if (a)
00496             if (b)
00497             {
00498                 low_level_matrix_unary_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition>
00499                         (c, a);
00500                 low_level_matrix_unary_ass_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition>
00501                         (c, b);
00502             }
00503             else
00504                 low_level_matrix_unary_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition>
00505                         (c, a);
00506         else
00507         {
00508             assert(b /* do not add nothing to nothing */);
00509             low_level_matrix_unary_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition>
00510                     (c, b);
00511         }
00512     }
00513 };
00514 //! \brief c = a - b; for blas_single_complex entries
00515 template <unsigned BlockSideLength>
00516 struct low_level_matrix_binary_ass_op<blas_single_complex, BlockSideLength, false, false, typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction>
00517 {
00518     low_level_matrix_binary_ass_op(blas_single_complex * c, const blas_single_complex * a, const blas_single_complex * b,
00519             typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction = typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction())
00520     {
00521         if (a)
00522             if (b)
00523             {
00524                 low_level_matrix_unary_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition>
00525                         (c, a);
00526                 low_level_matrix_unary_ass_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction>
00527                         (c, b);
00528             }
00529             else
00530                 low_level_matrix_unary_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition>
00531                         (c, a);
00532         else
00533         {
00534             assert(b /* do not add nothing to nothing */);
00535             low_level_matrix_unary_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction>
00536                     (c, b);
00537         }
00538     }
00539 };
00540 //! \brief c += a; for blas_single_complex entries
00541 template <unsigned BlockSideLength>
00542 struct low_level_matrix_unary_ass_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition>
00543 {
00544     low_level_matrix_unary_ass_op(blas_single_complex * c, const blas_single_complex * a,
00545             typename matrix_operations<blas_single_complex, BlockSideLength>::addition = typename matrix_operations<blas_single_complex, BlockSideLength>::addition())
00546     {
00547         const blas_int size = BlockSideLength * BlockSideLength;
00548         const blas_int int_one = 1;
00549         const blas_single_complex one = 1.0;
00550         if (a)
00551             caxpy_(&size, &one, a, &int_one, c, &int_one);
00552     }
00553 };
00554 //! \brief c -= a; for blas_single_complex entries
00555 template <unsigned BlockSideLength>
00556 struct low_level_matrix_unary_ass_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction>
00557 {
00558     low_level_matrix_unary_ass_op(blas_single_complex * c, const blas_single_complex * a,
00559             typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction = typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction())
00560     {
00561         const blas_int size = BlockSideLength * BlockSideLength;
00562         const blas_int int_one = 1;
00563         const blas_single_complex minusone = -1.0;
00564         if (a)
00565             caxpy_(&size, &minusone, a, &int_one, c, &int_one);
00566     }
00567 };
00568 //! \brief c = a; for blas_single_complex entries
00569 template <unsigned BlockSideLength>
00570 struct low_level_matrix_unary_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition>
00571 {
00572     low_level_matrix_unary_op(blas_single_complex * c, const blas_single_complex * a,
00573             typename matrix_operations<blas_single_complex, BlockSideLength>::addition = typename matrix_operations<blas_single_complex, BlockSideLength>::addition())
00574     {
00575         const blas_int size = BlockSideLength * BlockSideLength;
00576         const blas_int int_one = 1;
00577         ccopy_(&size, a, &int_one, c, &int_one);
00578     }
00579 };
00580 
00581 // --- matrix-matrix multiplication ---------------
00582 
00583 extern "C" void dgemm_(const char *transa, const char *transb,
00584         const blas_int *m, const blas_int *n, const blas_int *k,
00585         const double *alpha, const double *a, const blas_int *lda,
00586         const double *b, const blas_int *ldb,
00587         const double *beta, double *c, const blas_int *ldc);
00588 
00589 extern "C" void sgemm_(const char *transa, const char *transb,
00590         const blas_int *m, const blas_int *n, const blas_int *k,
00591         const float *alpha, const float *a, const blas_int *lda,
00592         const float *b, const blas_int *ldb,
00593         const float *beta, float *c, const blas_int *ldc);
00594 
00595 extern "C" void zgemm_(const char *transa, const char *transb,
00596         const blas_int *m, const blas_int *n, const blas_int *k,
00597         const blas_double_complex *alpha, const blas_double_complex *a, const blas_int *lda,
00598         const blas_double_complex *b, const blas_int *ldb,
00599         const blas_double_complex *beta, blas_double_complex *c, const blas_int *ldc);
00600 
00601 extern "C" void cgemm_(const char *transa, const char *transb,
00602         const blas_int *m, const blas_int *n, const blas_int *k,
00603         const blas_single_complex *alpha, const blas_single_complex *a, const blas_int *lda,
00604         const blas_single_complex *b, const blas_int *ldb,
00605         const blas_single_complex *beta, blas_single_complex *c, const blas_int *ldc);
00606 
00607 template <typename ValueType>
00608 void gemm_(const char *transa, const char *transb,
00609         const blas_int *m, const blas_int *n, const blas_int *k,
00610         const ValueType *alpha, const ValueType *a, const blas_int *lda,
00611         const ValueType *b, const blas_int *ldb,
00612         const ValueType *beta, ValueType *c, const blas_int *ldc);
00613 
00614 //! \brief calculates c = alpha * a * b + beta * c
00615 //! \tparam ValueType type of elements
00616 //! \param n height of a and c
00617 //! \param l width of a and height of b
00618 //! \param m width of b and c
00619 //! \param a_in_col_major if a is stored in column-major rather than row-major
00620 //! \param b_in_col_major if b is stored in column-major rather than row-major
00621 //! \param c_in_col_major if c is stored in column-major rather than row-major
00622 template <typename ValueType>
00623 void gemm_wrapper(const blas_int n, const blas_int l, const blas_int m,
00624         const ValueType alpha, const bool a_in_col_major, const ValueType *a,
00625                                const bool b_in_col_major, const ValueType *b,
00626         const ValueType beta,  const bool c_in_col_major,       ValueType *c)
00627 {
00628     const blas_int& stride_in_a = a_in_col_major ? n : l;
00629     const blas_int& stride_in_b = b_in_col_major ? l : m;
00630     const blas_int& stride_in_c = c_in_col_major ? n : m;
00631     const char transa = a_in_col_major xor c_in_col_major ? 'T' : 'N';
00632     const char transb = b_in_col_major xor c_in_col_major ? 'T' : 'N';
00633     if (c_in_col_major)
00634         // blas expects matrices in column-major unless specified via transa rsp. transb
00635         gemm_(&transa, &transb, &n, &m, &l, &alpha, a, &stride_in_a, b, &stride_in_b, &beta, c, &stride_in_c);
00636     else
00637         // blas expects matrices in column-major, so we calculate c^T = alpha * b^T * a^T + beta * c^T
00638         gemm_(&transb, &transa, &m, &n, &l, &alpha, b, &stride_in_b, a, &stride_in_a, &beta, c, &stride_in_c);
00639 }
00640 
00641 template <>
00642 void gemm_(const char *transa, const char *transb,
00643         const blas_int *m, const blas_int *n, const blas_int *k,
00644         const double *alpha, const double *a, const blas_int *lda,
00645         const double *b, const blas_int *ldb,
00646         const double *beta, double *c, const blas_int *ldc)
00647 { dgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); }
00648 
00649 template <>
00650 void gemm_(const char *transa, const char *transb,
00651         const blas_int *m, const blas_int *n, const blas_int *k,
00652         const float *alpha, const float *a, const blas_int *lda,
00653         const float *b, const blas_int *ldb,
00654         const float *beta, float *c, const blas_int *ldc)
00655 { sgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); }
00656 
00657 template <>
00658 void gemm_(const char *transa, const char *transb,
00659         const blas_int *m, const blas_int *n, const blas_int *k,
00660         const blas_double_complex *alpha, const blas_double_complex *a, const blas_int *lda,
00661         const blas_double_complex *b, const blas_int *ldb,
00662         const blas_double_complex *beta, blas_double_complex *c, const blas_int *ldc)
00663 { zgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); }
00664 
00665 template <>
00666 void gemm_(const char *transa, const char *transb,
00667         const blas_int *m, const blas_int *n, const blas_int *k,
00668         const blas_single_complex *alpha, const blas_single_complex *a, const blas_int *lda,
00669         const blas_single_complex *b, const blas_int *ldb,
00670         const blas_single_complex *beta, blas_single_complex *c, const blas_int *ldc)
00671 { cgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); }
00672 
00673 //! \brief multiplies matrices A and B, adds result to C, for double entries
00674 template <unsigned BlockSideLength>
00675 struct low_level_matrix_multiply_and_add<double, BlockSideLength>
00676 {
00677     low_level_matrix_multiply_and_add(const double * a, bool a_in_col_major,
00678                                       const double * b, bool b_in_col_major,
00679                                       double * c, const bool c_in_col_major)
00680     {
00681         gemm_wrapper<double>(BlockSideLength, BlockSideLength, BlockSideLength,
00682                 1.0, a_in_col_major, a,
00683                      b_in_col_major, b,
00684                 1.0, c_in_col_major, c);
00685     }
00686 };
00687 
00688 //! \brief multiplies matrices A and B, adds result to C, for float entries
00689 template <unsigned BlockSideLength>
00690 struct low_level_matrix_multiply_and_add<float, BlockSideLength>
00691 {
00692     low_level_matrix_multiply_and_add(const float * a, bool a_in_col_major,
00693                                       const float * b, bool b_in_col_major,
00694                                       float * c, const bool c_in_col_major)
00695     {
00696         gemm_wrapper<float>(BlockSideLength, BlockSideLength, BlockSideLength,
00697                 1.0, a_in_col_major, a,
00698                      b_in_col_major, b,
00699                 1.0, c_in_col_major, c);
00700     }
00701 };
00702 
00703 //! \brief multiplies matrices A and B, adds result to C, for complex<float> entries
00704 template <unsigned BlockSideLength>
00705 struct low_level_matrix_multiply_and_add<blas_single_complex, BlockSideLength>
00706 {
00707     low_level_matrix_multiply_and_add(const blas_single_complex * a, bool a_in_col_major,
00708                                       const blas_single_complex * b, bool b_in_col_major,
00709                                       blas_single_complex * c, const bool c_in_col_major)
00710     {
00711         gemm_wrapper<blas_single_complex>(BlockSideLength, BlockSideLength, BlockSideLength,
00712                 1.0, a_in_col_major, a,
00713                      b_in_col_major, b,
00714                 1.0, c_in_col_major, c);
00715     }
00716 };
00717 
00718 //! \brief multiplies matrices A and B, adds result to C, for complex<double> entries
00719 template <unsigned BlockSideLength>
00720 struct low_level_matrix_multiply_and_add<blas_double_complex, BlockSideLength>
00721 {
00722     low_level_matrix_multiply_and_add(const blas_double_complex * a, bool a_in_col_major,
00723                                       const blas_double_complex * b, bool b_in_col_major,
00724                                       blas_double_complex * c, const bool c_in_col_major)
00725     {
00726         gemm_wrapper<blas_double_complex>(BlockSideLength, BlockSideLength, BlockSideLength,
00727                 1.0, a_in_col_major, a,
00728                      b_in_col_major, b,
00729                 1.0, c_in_col_major, c);
00730     }
00731 };
00732 #endif
00733 
00734 __STXXL_END_NAMESPACE
00735 
00736 #endif /* STXXL_MATRIX_LOW_LEVEL_HEADER */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines