http://stxxl.sourceforge.net
<R-Steffen@gmx.de>
http://www.boost.org/LICENSE_1_0.txt
#ifndef STXXL_MATRIX_HEADER
#define STXXL_MATRIX_HEADER
#include <stxxl/bits/containers/vector.h>
#include <stxxl/bits/common/shared_object.h>
#include <stxxl/bits/mng/block_scheduler.h>
#include <stxxl/bits/containers/matrix_arithmetic.h>
__STXXL_BEGIN_NAMESPACE
template <typename ValueType, unsigned BlockSideLength>
class matrix;
template <typename ValueType>
class column_vector : public vector<ValueType>
{
public:
typedef vector<ValueType> vector_type;
typedef typename vector_type::size_type size_type;
using vector_type::size;
column_vector(size_type n = 0)
: vector_type(n) {}
column_vector operator + (const column_vector & right) const
{
assert(size() == right.size());
column_vector res(size());
for (size_type i = 0; i < size(); ++i)
res[i] = (*this)[i] + right[i];
return res;
}
column_vector operator - (const column_vector & right) const
{
assert(size() == right.size());
column_vector res(size());
for (size_type i = 0; i < size(); ++i)
res[i] = (*this)[i] - right[i];
return res;
}
column_vector operator * (const ValueType scalar) const
{
column_vector res(size());
for (size_type i = 0; i < size(); ++i)
res[i] = (*this)[i] * scalar;
return res;
}
column_vector & operator += (const column_vector & right)
{
assert(size() == right.size());
for (size_type i = 0; i < size(); ++i)
(*this)[i] += right[i];
return *this;
}
column_vector & operator -= (const column_vector & right)
{
assert(size() == right.size());
for (size_type i = 0; i < size(); ++i)
(*this)[i] -= right[i];
return *this;
}
column_vector & operator *= (const ValueType scalar)
{
for (size_type i = 0; i < size(); ++i)
(*this)[i] *= scalar;
return *this;
}
void set_zero()
{
for (typename vector_type::iterator it = vector_type::begin(); it != vector_type::end(); ++it)
*it = 0;
}
};
template <typename ValueType>
class row_vector : public vector<ValueType>
{
public:
typedef vector<ValueType> vector_type;
typedef typename vector_type::size_type size_type;
using vector_type::size;
row_vector(size_type n = 0)
: vector_type(n) {}
row_vector operator + (const row_vector & right) const
{
assert(size() == right.size());
row_vector res(size());
for (size_type i = 0; i < size(); ++i)
res[i] = (*this)[i] + right[i];
return res;
}
row_vector operator - (const row_vector & right) const
{
assert(size() == right.size());
row_vector res(size());
for (size_type i = 0; i < size(); ++i)
res[i] = (*this)[i] - right[i];
return res;
}
row_vector operator * (const ValueType scalar) const
{
row_vector res(size());
for (size_type i = 0; i < size(); ++i)
res[i] = (*this)[i] * scalar;
return res;
}
template <unsigned BlockSideLength>
row_vector operator * (const matrix<ValueType, BlockSideLength> & right) const
{ return right.multiply_from_left(*this); }
ValueType operator * (const column_vector<ValueType> & right) const
{
ValueType res = 0;
for (size_type i = 0; i < size(); ++i)
res += (*this)[i] * right[i];
return res;
}
row_vector & operator += (const row_vector & right)
{
assert(size() == right.size());
for (size_type i = 0; i < size(); ++i)
(*this)[i] += right[i];
return *this;
}
row_vector & operator -= (const row_vector & right)
{
assert(size() == right.size());
for (size_type i = 0; i < size(); ++i)
(*this)[i] -= right[i];
return *this;
}
row_vector & operator *= (const ValueType scalar)
{
for (size_type i = 0; i < size(); ++i)
(*this)[i] *= scalar;
return *this;
}
void set_zero()
{
for (typename vector_type::iterator it = vector_type::begin(); it != vector_type::end(); ++it)
*it = 0;
}
};
template <typename ValueType, unsigned BlockSideLength>
class matrix_swappable_block : public swappable_block<ValueType, BlockSideLength * BlockSideLength>
{
public:
typedef typename swappable_block<ValueType, BlockSideLength * BlockSideLength>::internal_block_type internal_block_type;
using swappable_block<ValueType, BlockSideLength * BlockSideLength>::get_internal_block;
void fill_default()
{
internal_block_type & data = get_internal_block();
#if STXXL_PARALLEL
#pragma omp parallel for
#endif
for (int_type row = 0; row < int_type(BlockSideLength); ++row)
for (int_type col = 0; col < int_type(BlockSideLength); ++col)
data[row * BlockSideLength + col] = 0;
}
};
template <typename ValueType, unsigned BlockSideLength>
class swappable_block_matrix : public shared_object
{
public:
typedef int_type size_type;
typedef int_type elem_size_type;
typedef block_scheduler< matrix_swappable_block<ValueType, BlockSideLength> > block_scheduler_type;
typedef typename block_scheduler_type::swappable_block_identifier_type swappable_block_identifier_type;
typedef std::vector<swappable_block_identifier_type> blocks_type;
typedef matrix_operations<ValueType, BlockSideLength> Ops;
block_scheduler_type & bs;
private:
swappable_block_matrix & operator = (const swappable_block_matrix & other);
protected:
size_type height,
width,
height_from_supermatrix,
width_from_supermatrix;
blocks_type blocks;
bool elements_in_blocks_transposed;
swappable_block_identifier_type & bl(const size_type row, const size_type col)
{ return blocks[row * width + col]; }
public:
swappable_block_matrix(block_scheduler_type & bs, const size_type height_in_blocks, const size_type width_in_blocks, const bool transposed = false)
: bs(bs),
height(height_in_blocks),
width(width_in_blocks),
height_from_supermatrix(0),
width_from_supermatrix(0),
blocks(height * width),
elements_in_blocks_transposed(transposed)
{
for (size_type row = 0; row < height; ++row)
for (size_type col = 0; col < width; ++col)
bl(row, col) = bs.allocate_swappable_block();
}
swappable_block_matrix(const swappable_block_matrix & supermatrix,
const size_type height_in_blocks, const size_type width_in_blocks,
const size_type from_row_in_blocks, const size_type from_col_in_blocks)
: bs(supermatrix.bs),
height(height_in_blocks),
width(width_in_blocks),
height_from_supermatrix(std::min(supermatrix.height - from_row_in_blocks, height)),
width_from_supermatrix(std::min(supermatrix.width - from_col_in_blocks, width)),
blocks(height * width),
elements_in_blocks_transposed(supermatrix.elements_in_blocks_transposed)
{
for (size_type row = 0; row < height_from_supermatrix; ++row)
{
for (size_type col = 0; col < width_from_supermatrix; ++col)
bl(row, col) = supermatrix.block(row + from_row_in_blocks, col + from_col_in_blocks);
for (size_type col = width_from_supermatrix; col < width; ++col)
bl(row, col) = bs.allocate_swappable_block();
}
for (size_type row = height_from_supermatrix; row < height; ++row)
for (size_type col = 0; col < width; ++col)
bl(row, col) = bs.allocate_swappable_block();
}
swappable_block_matrix(const swappable_block_matrix & ul, const swappable_block_matrix & ur,
const swappable_block_matrix & dl, const swappable_block_matrix & dr)
: bs(ul.bs),
height(ul.height + dl.height),
width(ul.width + ur.width),
height_from_supermatrix(height),
width_from_supermatrix(width),
blocks(height * width),
elements_in_blocks_transposed(ul.elements_in_blocks_transposed)
{
for (size_type row = 0; row < ul.height; ++row)
{
for (size_type col = 0; col < ul.width; ++col)
bl(row, col) = ul.block(row, col);
for (size_type col = ul.width; col < width; ++col)
bl(row, col) = ur.block(row, col - ul.width);
}
for (size_type row = ul.height; row < height; ++row)
{
for (size_type col = 0; col < ul.width; ++col)
bl(row, col) = dl.block(row - ul.height, col);
for (size_type col = ul.width; col < width; ++col)
bl(row, col) = dr.block(row - ul.height, col - ul.width);
}
}
swappable_block_matrix(const swappable_block_matrix & other)
: shared_object(other),
bs(other.bs),
height(other.height),
width(other.width),
height_from_supermatrix(0),
width_from_supermatrix(0),
blocks(height * width),
elements_in_blocks_transposed(false)
{
for (size_type row = 0; row < height; ++row)
for (size_type col = 0; col < width; ++col)
bl(row, col) = bs.allocate_swappable_block();
Ops::element_op(*this, other, typename Ops::addition());
}
~swappable_block_matrix()
{
for (size_type row = 0; row < height_from_supermatrix; ++row)
{
for (size_type col = width_from_supermatrix; col < width; ++col)
bs.free_swappable_block(bl(row, col));
}
for (size_type row = height_from_supermatrix; row < height; ++row)
for (size_type col = 0; col < width; ++col)
bs.free_swappable_block(bl(row, col));
}
static size_type block_index_from_elem(elem_size_type index)
{ return index / BlockSideLength; }
static int_type elem_index_in_block_from_elem(elem_size_type index)
{ return index % BlockSideLength; }
int_type elem_index_in_block_from_elem(elem_size_type row, elem_size_type col) const
{
return (is_transposed())
? row % BlockSideLength + col % BlockSideLength * BlockSideLength
: row % BlockSideLength * BlockSideLength + col % BlockSideLength;
}
swappable_block_identifier_type block(const size_type row, const size_type col) const
{ return blocks[row * width + col]; }
swappable_block_identifier_type operator () (const size_type row, const size_type col) const
{ return block(row, col); }
const size_type & get_height() const
{ return height; }
const size_type & get_width() const
{ return width; }
const bool & is_transposed() const
{ return elements_in_blocks_transposed; }
void transpose()
{
blocks_type bl(blocks.size());
for (size_type row = 1; row < height; ++row)
for (size_type col = 0; col < row; ++col)
bl[col * height + row] = bl(row,col);
bl.swap(blocks);
std::swap(height, width);
std::swap(height_from_supermatrix, width_from_supermatrix);
elements_in_blocks_transposed = ! elements_in_blocks_transposed;
}
void set_zero()
{
for (typename blocks_type::iterator it = blocks.begin(); it != blocks.end(); ++it)
bs.deinitialize(*it);
}
};
template <typename ValueType, unsigned BlockSideLength>
class matrix_iterator
{
protected:
typedef matrix<ValueType, BlockSideLength> matrix_type;
typedef typename matrix_type::swappable_block_matrix_type swappable_block_matrix_type;
typedef typename matrix_type::block_scheduler_type block_scheduler_type;
typedef typename block_scheduler_type::internal_block_type internal_block_type;
typedef typename matrix_type::elem_size_type elem_size_type;
typedef typename matrix_type::block_size_type block_size_type;
template <typename VT, unsigned BSL> friend class matrix;
template <typename VT, unsigned BSL> friend class const_matrix_iterator;
matrix_type * m;
elem_size_type current_row,
current_col;
block_size_type current_block_row,
current_block_col;
internal_block_type * current_iblock;
void acquire_current_iblock()
{
if (! current_iblock)
current_iblock = & m->data->bs.acquire(m->data->block(current_block_row, current_block_col));
}
void release_current_iblock()
{
if (current_iblock)
{
m->data->bs.release(m->data->block(current_block_row, current_block_col), true);
current_iblock = 0;
}
}
matrix_iterator(matrix_type & matrix, const elem_size_type start_row, const elem_size_type start_col)
: m(&matrix),
current_row(start_row),
current_col(start_col),
current_block_row(m->data->block_index_from_elem(start_row)),
current_block_col(m->data->block_index_from_elem(start_col)),
current_iblock(0) {}
matrix_iterator(matrix_type & matrix)
: m(&matrix),
current_row(-1),
current_col(-1),
current_block_row(-1),
current_block_col(-1),
current_iblock(0) {}
void set_empty()
{
release_current_iblock();
current_row = -1;
current_col = -1;
current_block_row = -1;
current_block_col = -1;
}
public:
matrix_iterator(const matrix_iterator & other)
: m(other.m),
current_row(other.current_row),
current_col(other.current_col),
current_block_row(other.current_block_row),
current_block_col(other.current_block_col),
current_iblock(0)
{
if (other.current_iblock)
acquire_current_iblock();
}
matrix_iterator & operator = (const matrix_iterator & other)
{
set_pos(other.current_row, other.current_col);
m = other.m;
if (other.current_iblock)
acquire_current_iblock();
return *this;
}
~matrix_iterator()
{ release_current_iblock(); }
void set_row(const elem_size_type new_row)
{
const block_size_type new_block_row = m->data->block_index_from_elem(new_row);
if (new_block_row != current_block_row)
{
release_current_iblock();
current_block_row = new_block_row;
}
current_row = new_row;
}
void set_col(const elem_size_type new_col)
{
const block_size_type new_block_col = m->data->block_index_from_elem(new_col);
if (new_block_col != current_block_col)
{
release_current_iblock();
current_block_col = new_block_col;
}
current_col = new_col;
}
void set_pos(const elem_size_type new_row, const elem_size_type new_col)
{
const block_size_type new_block_row = m->data->block_index_from_elem(new_row),
new_block_col = m->data->block_index_from_elem(new_col);
if (new_block_col != current_block_col || new_block_row != current_block_row)
{
release_current_iblock();
current_block_row = new_block_row;
current_block_col = new_block_col;
}
current_row = new_row;
current_col = new_col;
}
void set_pos(const std::pair<elem_size_type, elem_size_type> new_pos)
{ set_pos(new_pos.first, new_pos.second); }
const elem_size_type & get_row() const
{ return current_row; }
const elem_size_type & get_col() const
{ return current_col; }
std::pair<elem_size_type, elem_size_type> get_pos() const
{ return std::make_pair(current_row, current_col); }
bool empty() const
{ return current_row == -1 && current_col == -1; }
operator bool () const
{ return ! empty(); }
bool operator == (const matrix_iterator & other) const
{
return current_row == other.current_row && current_col == other.current_col && m == other.m;
}
ValueType & operator * ()
{
acquire_current_iblock();
return (*current_iblock)[m->data->elem_index_in_block_from_elem(current_row, current_col)];
}
};
template <typename ValueType, unsigned BlockSideLength>
class matrix_row_major_iterator : public matrix_iterator<ValueType, BlockSideLength>
{
protected:
typedef matrix_iterator<ValueType, BlockSideLength> matrix_iterator_type;
typedef typename matrix_iterator_type::matrix_type matrix_type;
typedef typename matrix_iterator_type::elem_size_type elem_size_type;
template <typename VT, unsigned BSL> friend class matrix;
using matrix_iterator_type::m;
using matrix_iterator_type::set_empty;
matrix_row_major_iterator(matrix_type & matrix, const elem_size_type start_row, const elem_size_type start_col)
: matrix_iterator_type(matrix, start_row, start_col) {}
matrix_row_major_iterator(matrix_type & matrix)
: matrix_iterator_type(matrix) {}
public:
matrix_row_major_iterator(const matrix_iterator_type & matrix_iterator)
: matrix_iterator_type(matrix_iterator) {}
matrix_row_major_iterator & operator ++ ()
{
if (get_col() + 1 < m->get_width())
set_col(get_col() + 1);
else if (get_row() + 1 < m->get_height())
set_pos(get_row() + 1, 0);
else
set_empty();
return *this;
}
matrix_row_major_iterator & operator -- ()
{
if (get_col() - 1 >= 0)
set_col(get_col() - 1);
else if (get_row() - 1 >= 0)
set_pos(get_row() - 1, m.get_width() - 1);
else
set_empty();
return *this;
}
using matrix_iterator_type::get_row;
using matrix_iterator_type::get_col;
using matrix_iterator_type::set_col;
using matrix_iterator_type::set_pos;
};
template <typename ValueType, unsigned BlockSideLength>
class matrix_col_major_iterator : public matrix_iterator<ValueType, BlockSideLength>
{
protected:
typedef matrix_iterator<ValueType, BlockSideLength> matrix_iterator_type;
typedef typename matrix_iterator_type::matrix_type matrix_type;
typedef typename matrix_iterator_type::elem_size_type elem_size_type;
template <typename VT, unsigned BSL> friend class matrix;
using matrix_iterator_type::m;
using matrix_iterator_type::set_empty;
matrix_col_major_iterator(matrix_type & matrix, const elem_size_type start_row, const elem_size_type start_col)
: matrix_iterator_type(matrix, start_row, start_col) {}
matrix_col_major_iterator(matrix_type & matrix)
: matrix_iterator_type(matrix) {}
public:
matrix_col_major_iterator(const matrix_iterator_type & matrix_iterator)
: matrix_iterator_type(matrix_iterator) {}
matrix_col_major_iterator & operator ++ ()
{
if (get_row() + 1 < m->get_height())
set_row(get_row() + 1);
else if (get_col() + 1 < m->get_width())
set_pos(0, get_col() + 1);
else
set_empty();
return *this;
}
matrix_col_major_iterator & operator -- ()
{
if (get_row() - 1 >= 0)
set_row(get_row() - 1);
else if (get_col() - 1 >= 0)
set_pos(m->get_height() - 1, get_col() - 1);
else
set_empty();
return *this;
}
using matrix_iterator_type::get_row;
using matrix_iterator_type::get_col;
using matrix_iterator_type::set_row;
using matrix_iterator_type::set_pos;
};
template <typename ValueType, unsigned BlockSideLength>
class const_matrix_iterator
{
protected:
typedef matrix<ValueType, BlockSideLength> matrix_type;
typedef typename matrix_type::swappable_block_matrix_type swappable_block_matrix_type;
typedef typename matrix_type::block_scheduler_type block_scheduler_type;
typedef typename block_scheduler_type::internal_block_type internal_block_type;
typedef typename matrix_type::elem_size_type elem_size_type;
typedef typename matrix_type::block_size_type block_size_type;
template <typename VT, unsigned BSL> friend class matrix;
const matrix_type * m;
elem_size_type current_row,
current_col;
block_size_type current_block_row,
current_block_col;
internal_block_type * current_iblock;
void acquire_current_iblock()
{
if (! current_iblock)
current_iblock = & m->data->bs.acquire(m->data->block(current_block_row, current_block_col));
}
void release_current_iblock()
{
if (current_iblock)
{
m->data->bs.release(m->data->block(current_block_row, current_block_col), false);
current_iblock = 0;
}
}
const_matrix_iterator(const matrix_type & matrix, const elem_size_type start_row, const elem_size_type start_col)
: m(&matrix),
current_row(start_row),
current_col(start_col),
current_block_row(m->data->block_index_from_elem(start_row)),
current_block_col(m->data->block_index_from_elem(start_col)),
current_iblock(0) {}
const_matrix_iterator(const matrix_type & matrix)
: m(&matrix),
current_row(-1),
current_col(-1),
current_block_row(-1),
current_block_col(-1),
current_iblock(0) {}
void set_empty()
{
release_current_iblock();
current_row = -1;
current_col = -1;
current_block_row = -1;
current_block_col = -1;
}
public:
const_matrix_iterator(const matrix_iterator<ValueType, BlockSideLength> & other)
: m(other.m),
current_row(other.current_row),
current_col(other.current_col),
current_block_row(other.current_block_row),
current_block_col(other.current_block_col),
current_iblock(0)
{
if (other.current_iblock)
acquire_current_iblock();
}
const_matrix_iterator(const const_matrix_iterator & other)
: m(other.m),
current_row(other.current_row),
current_col(other.current_col),
current_block_row(other.current_block_row),
current_block_col(other.current_block_col),
current_iblock(0)
{
if (other.current_iblock)
acquire_current_iblock();
}
const_matrix_iterator & operator = (const const_matrix_iterator & other)
{
set_pos(other.current_row, other.current_col);
m = other.m;
if (other.current_iblock)
acquire_current_iblock();
return *this;
}
~const_matrix_iterator()
{ release_current_iblock(); }
void set_row(const elem_size_type new_row)
{
const block_size_type new_block_row = m->data->block_index_from_elem(new_row);
if (new_block_row != current_block_row)
{
release_current_iblock();
current_block_row = new_block_row;
}
current_row = new_row;
}
void set_col(const elem_size_type new_col)
{
const block_size_type new_block_col = m->data->block_index_from_elem(new_col);
if (new_block_col != current_block_col)
{
release_current_iblock();
current_block_col = new_block_col;
}
current_col = new_col;
}
void set_pos(const elem_size_type new_row, const elem_size_type new_col)
{
const block_size_type new_block_row = m->data->block_index_from_elem(new_row),
new_block_col = m->data->block_index_from_elem(new_col);
if (new_block_col != current_block_col || new_block_row != current_block_row)
{
release_current_iblock();
current_block_row = new_block_row;
current_block_col = new_block_col;
}
current_row = new_row;
current_col = new_col;
}
void set_pos(const std::pair<elem_size_type, elem_size_type> new_pos)
{ set_pos(new_pos.first, new_pos.second); }
const elem_size_type & get_row() const
{ return current_row; }
const elem_size_type & get_col() const
{ return current_col; }
std::pair<elem_size_type, elem_size_type> get_pos() const
{ return std::make_pair(current_row, current_col); }
bool empty() const
{ return current_row == -1 && current_col == -1; }
operator bool () const
{ return ! empty(); }
bool operator == (const const_matrix_iterator & other) const
{
return current_row == other.current_row && current_col == other.current_col && m == other.m;
}
const ValueType & operator * ()
{
acquire_current_iblock();
return (*current_iblock)[m->data->elem_index_in_block_from_elem(current_row, current_col)];
}
};
template <typename ValueType, unsigned BlockSideLength>
class const_matrix_row_major_iterator : public const_matrix_iterator<ValueType, BlockSideLength>
{
protected:
typedef const_matrix_iterator<ValueType, BlockSideLength> const_matrix_iterator_type;
typedef typename const_matrix_iterator_type::matrix_type matrix_type;
typedef typename const_matrix_iterator_type::elem_size_type elem_size_type;
template <typename VT, unsigned BSL> friend class matrix;
using const_matrix_iterator_type::m;
using const_matrix_iterator_type::set_empty;
const_matrix_row_major_iterator(const matrix_type & matrix, const elem_size_type start_row, const elem_size_type start_col)
: const_matrix_iterator_type(matrix, start_row, start_col) {}
const_matrix_row_major_iterator(const matrix_type & matrix)
: const_matrix_iterator_type(matrix) {}
public:
const_matrix_row_major_iterator(const const_matrix_row_major_iterator & matrix_iterator)
: const_matrix_iterator_type(matrix_iterator) {}
const_matrix_row_major_iterator(const const_matrix_iterator_type & matrix_iterator)
: const_matrix_iterator_type(matrix_iterator) {}
const_matrix_row_major_iterator & operator ++ ()
{
if (get_col() + 1 < m->get_width())
set_col(get_col() + 1);
else if (get_row() + 1 < m->get_height())
set_pos(get_row() + 1, 0);
else
set_empty();
return *this;
}
const_matrix_row_major_iterator & operator -- ()
{
if (get_col() - 1 >= 0)
set_col(get_col() - 1);
else if (get_row() - 1 >= 0)
set_pos(get_row() - 1, m.get_width() - 1);
else
set_empty();
return *this;
}
using const_matrix_iterator_type::get_row;
using const_matrix_iterator_type::get_col;
using const_matrix_iterator_type::set_col;
using const_matrix_iterator_type::set_pos;
};
template <typename ValueType, unsigned BlockSideLength>
class const_matrix_col_major_iterator : public const_matrix_iterator<ValueType, BlockSideLength>
{
protected:
typedef const_matrix_iterator<ValueType, BlockSideLength> const_matrix_iterator_type;
typedef typename const_matrix_iterator_type::matrix_type matrix_type;
typedef typename const_matrix_iterator_type::elem_size_type elem_size_type;
template <typename VT, unsigned BSL> friend class matrix;
using const_matrix_iterator_type::m;
using const_matrix_iterator_type::set_empty;
const_matrix_col_major_iterator(const matrix_type & matrix, const elem_size_type start_row, const elem_size_type start_col)
: const_matrix_iterator_type(matrix, start_row, start_col) {}
const_matrix_col_major_iterator(const matrix_type & matrix)
: const_matrix_iterator_type(matrix) {}
public:
const_matrix_col_major_iterator(const matrix_iterator<ValueType, BlockSideLength> & matrix_iterator)
: const_matrix_iterator_type(matrix_iterator) {}
const_matrix_col_major_iterator(const const_matrix_iterator_type & matrix_iterator)
: const_matrix_iterator_type(matrix_iterator) {}
const_matrix_col_major_iterator & operator ++ ()
{
if (get_row() + 1 < m->get_height())
set_row(get_row() + 1);
else if (get_col() + 1 < m->get_width())
set_pos(0, get_col() + 1);
else
set_empty();
return *this;
}
const_matrix_col_major_iterator & operator -- ()
{
if (get_row() - 1 >= 0)
set_row(get_row() - 1);
else if (get_col() - 1 >= 0)
set_pos(m->get_height() - 1, get_col() - 1);
else
set_empty();
return *this;
}
using const_matrix_iterator_type::get_row;
using const_matrix_iterator_type::get_col;
using const_matrix_iterator_type::set_row;
using const_matrix_iterator_type::set_pos;
};
template <typename ValueType, unsigned BlockSideLength>
class matrix
{
protected:
typedef matrix<ValueType, BlockSideLength> matrix_type;
typedef swappable_block_matrix<ValueType, BlockSideLength> swappable_block_matrix_type;
typedef shared_object_pointer<swappable_block_matrix_type> swappable_block_matrix_pointer_type;
typedef typename swappable_block_matrix_type::block_scheduler_type block_scheduler_type;
typedef typename swappable_block_matrix_type::size_type block_size_type;
typedef typename swappable_block_matrix_type::elem_size_type elem_size_type;
typedef matrix_operations<ValueType, BlockSideLength> Ops;
typedef matrix_swappable_block<ValueType, BlockSideLength> swappable_block_type;
public:
typedef matrix_iterator<ValueType, BlockSideLength> iterator;
typedef const_matrix_iterator<ValueType, BlockSideLength> const_iterator;
typedef matrix_row_major_iterator<ValueType, BlockSideLength> row_major_iterator;
typedef matrix_col_major_iterator<ValueType, BlockSideLength> col_major_iterator;
typedef const_matrix_row_major_iterator<ValueType, BlockSideLength> const_row_major_iterator;
typedef const_matrix_col_major_iterator<ValueType, BlockSideLength> const_col_major_iterator;
typedef column_vector<ValueType> column_vector_type;
typedef row_vector<ValueType> row_vector_type;
protected:
template <typename VT, unsigned BSL> friend class matrix_iterator;
template <typename VT, unsigned BSL> friend class const_matrix_iterator;
elem_size_type height,
width;
swappable_block_matrix_pointer_type data;
public:
matrix(block_scheduler_type & bs, const elem_size_type height, const elem_size_type width)
: height(height),
width(width),
data(new swappable_block_matrix_type
(bs, div_ceil(height, BlockSideLength), div_ceil(width, BlockSideLength)))
{}
matrix(block_scheduler_type & bs, const column_vector_type & left, const row_vector_type & right)
: height(left.size()),
width(right.size()),
data(new swappable_block_matrix_type
(bs, div_ceil(height, BlockSideLength), div_ceil(width, BlockSideLength)))
{ Ops::recursive_matrix_from_vectors(*data, left, right); }
~matrix() {}
const elem_size_type & get_height() const
{ return height; }
const elem_size_type & get_width() const
{ return width; }
iterator begin()
{
data.unify();
return iterator(*this, 0, 0);
}
const_iterator begin() const
{ return const_iterator(*this, 0, 0); }
const_iterator cbegin() const
{ return const_iterator(*this, 0, 0); }
iterator end()
{
data.unify();
return iterator(*this);
}
const_iterator end() const
{ return const_iterator(*this); }
const_iterator cend() const
{ return const_iterator(*this); }
const_iterator operator () (const elem_size_type row, const elem_size_type col) const
{ return const_iterator(*this, row, col); }
iterator operator () (const elem_size_type row, const elem_size_type col)
{
data.unify();
return iterator(*this, row, col);
}
void transpose()
{
data.unify();
data->transpose();
std::swap(height, width);
}
void set_zero()
{
if (data.unique())
data->set_zero();
else
data = new swappable_block_matrix_type
(data->bs, div_ceil(height, BlockSideLength), div_ceil(width, BlockSideLength));
}
matrix_type operator + (const matrix_type & right) const
{
assert(height == right.height && width == right.width);
matrix_type res(data->bs, height, width);
Ops::element_op(*res.data, *data, *right.data, typename Ops::addition());
return res;
}
matrix_type operator - (const matrix_type & right) const
{
assert(height == right.height && width == right.width);
matrix_type res(data->bs, height, width);
Ops::element_op(*res.data, *data, *right.data, typename Ops::subtraction());
return res;
}
matrix_type operator * (const matrix_type & right) const
{ return multiply(right); }
matrix_type operator * (const ValueType scalar) const
{
matrix_type res(data->bs, height, width);
Ops::element_op(*res.data, *data, typename Ops::scalar_multiplication(scalar));
return res;
}
matrix_type & operator += (const matrix_type & right)
{
assert(height == right.height && width == right.width);
data.unify();
Ops::element_op(*data, *right.data, typename Ops::addition());
return *this;
}
matrix_type & operator -= (const matrix_type & right)
{
assert(height == right.height && width == right.width);
data.unify();
Ops::element_op(*data, *right.data, typename Ops::subtraction());
return *this;
}
matrix_type & operator *= (const matrix_type & right)
{ return *this = operator * (right); }
matrix_type & operator *= (const ValueType scalar)
{
data.unify();
Ops::element_op(*data, typename Ops::scalar_multiplication(scalar));
return *this;
}
column_vector_type operator * (const column_vector_type & right) const
{
assert(elem_size_type(right.size()) == width);
column_vector_type res(height);
res.set_zero();
Ops::recursive_matrix_col_vector_multiply_and_add(*data, right, res);
return res;
}
row_vector_type multiply_from_left(const row_vector_type & left) const
{
assert(elem_size_type(left.size()) == height);
row_vector_type res(width);
res.set_zero();
Ops::recursive_matrix_row_vector_multiply_and_add(left, *data, res);
return res;
}
matrix_type multiply (const matrix_type & right, const int_type multiplication_algorithm = 1, const int_type scheduling_algorithm = 2) const
{
assert(width == right.height);
assert(& data->bs == & right.data->bs);
matrix_type res(data->bs, height, right.width);
if (scheduling_algorithm > 0)
{
delete data->bs.switch_algorithm_to(new
block_scheduler_algorithm_simulation<swappable_block_type>(data->bs));
switch (multiplication_algorithm)
{
case 0:
Ops::naive_multiply_and_add(*data, *right.data, *res.data);
break;
case 1:
Ops::recursive_multiply_and_add(*data, *right.data, *res.data);
break;
case 2:
Ops::strassen_winograd_multiply_and_add(*data, *right.data, *res.data);
break;
case 3:
Ops::multi_level_strassen_winograd_multiply_and_add(*data, *right.data, *res.data);
break;
case 4:
Ops::strassen_winograd_multiply(*data, *right.data, *res.data);
break;
case 5:
Ops::strassen_winograd_multiply_and_add_interleaved(*data, *right.data, *res.data);
break;
case 6:
Ops::multi_level_strassen_winograd_multiply_and_add_block_grained(*data, *right.data, *res.data);
break;
default:
STXXL_ERRMSG("invalid multiplication-algorithm number");
break;
}
}
switch (scheduling_algorithm)
{
case 0:
delete data->bs.switch_algorithm_to(new
block_scheduler_algorithm_online_lru<swappable_block_type>(data->bs));
break;
case 1:
delete data->bs.switch_algorithm_to(new
block_scheduler_algorithm_offline_lfd<swappable_block_type>(data->bs));
break;
case 2:
delete data->bs.switch_algorithm_to(new
block_scheduler_algorithm_offline_lru_prefetching<swappable_block_type>(data->bs));
break;
default:
STXXL_ERRMSG("invalid scheduling-algorithm number");
}
switch (multiplication_algorithm)
{
case 0:
Ops::naive_multiply_and_add(*data, *right.data, *res.data);
break;
case 1:
Ops::recursive_multiply_and_add(*data, *right.data, *res.data);
break;
case 2:
Ops::strassen_winograd_multiply_and_add(*data, *right.data, *res.data);
break;
case 3:
Ops::multi_level_strassen_winograd_multiply_and_add(*data, *right.data, *res.data);
break;
case 4:
Ops::strassen_winograd_multiply(*data, *right.data, *res.data);
break;
case 5:
Ops::strassen_winograd_multiply_and_add_interleaved(*data, *right.data, *res.data);
break;
case 6:
Ops::multi_level_strassen_winograd_multiply_and_add_block_grained(*data, *right.data, *res.data);
break;
default:
STXXL_ERRMSG("invalid multiplication-algorithm number");
break;
}
delete data->bs.switch_algorithm_to(new
block_scheduler_algorithm_online_lru<swappable_block_type>(data->bs));
return res;
}
matrix_type multiply_internal (const matrix_type & right, const int_type scheduling_algorithm = 2) const
{
assert(width == right.height);
assert(& data->bs == & right.data->bs);
matrix_type res(data->bs, height, right.width);
if (scheduling_algorithm > 0)
{
delete data->bs.switch_algorithm_to(new
block_scheduler_algorithm_simulation<swappable_block_type>(data->bs));
multiply_internal(right, res);
}
switch (scheduling_algorithm)
{
case 0:
delete data->bs.switch_algorithm_to(new
block_scheduler_algorithm_online_lru<swappable_block_type>(data->bs));
break;
case 1:
delete data->bs.switch_algorithm_to(new
block_scheduler_algorithm_offline_lfd<swappable_block_type>(data->bs));
break;
case 2:
delete data->bs.switch_algorithm_to(new
block_scheduler_algorithm_offline_lru_prefetching<swappable_block_type>(data->bs));
break;
default:
STXXL_ERRMSG("invalid scheduling-algorithm number");
}
multiply_internal(right, res);
delete data->bs.switch_algorithm_to(new
block_scheduler_algorithm_online_lru<swappable_block_type>(data->bs));
return res;
}
protected:
void multiply_internal(const matrix_type & right, matrix_type & res) const
{
ValueType * A = new ValueType[height * width];
ValueType * B = new ValueType[right.height * right.width];
ValueType * C = new ValueType[res.height * res.width];
ValueType * vit;
vit = A;
for(const_row_major_iterator mit = cbegin(); mit != cend(); ++mit, ++vit)
*vit = *mit;
vit = B;
for(const_row_major_iterator mit = right.cbegin(); mit != right.cend(); ++mit, ++vit)
*vit = *mit;
if (! res.data->bs.is_simulating())
#if STXXL_BLAS
gemm_wrapper(height, width, res.width,
ValueType(1), false, A,
false, B,
ValueType(0), false, C);
#else
assert(false );
#endif
vit = C;
for(row_major_iterator mit = res.begin(); mit != res.end(); ++mit, ++vit)
*mit = *vit;
delete A;
delete B;
delete C;
}
};
__STXXL_END_NAMESPACE
#endif