#include "matrix_int.h"

int IntegerMatrix::alignment=1;

void IntegerMatrix::set_alignment(int n) {
  alignment=n;
}

IntegerMatrix::IntegerMatrix() {
  rows=cols=size=0;
  rows2=cols2=0;
  mat=NULL;
}

IntegerMatrix::IntegerMatrix(uint32_t rc) {
  rows=rc;
  cols=rc;
  rows2=rc;
  cols2=rc;
  size=rows2*cols2;
  malloc_aligned(u_mat,mat,INT,size*sizeof(INT),alignment);
}

IntegerMatrix::IntegerMatrix(uint32_t r, uint32_t c) {
  rows=r;
  cols=c;
  rows2=r;
  cols2=c;
  size=rows2*cols2;
  malloc_aligned(u_mat,mat,INT,size*sizeof(INT),alignment);
}

IntegerMatrix::IntegerMatrix(uint32_t r, uint32_t c, uint32_t r2, uint32_t c2) {
  rows=r;
  cols=c;
  rows2=r2;
  cols2=c2;
  size=rows2*cols2;
  malloc_aligned(u_mat,mat,INT,size*sizeof(INT),alignment);
}



IntegerMatrix::IntegerMatrix(IntegerMatrix& m) {
  if (mat!=NULL) delete [] mat;
  rows=m.rows;
  cols=m.cols;
  rows2=m.rows2;
  cols2=m.cols2;
  size=m.size;
  malloc_aligned(u_mat,mat,INT,size*sizeof(INT),alignment);
  memcpy(mat,m.mat,size*sizeof(INT));
}

IntegerMatrix::~IntegerMatrix() {
  free(u_mat);
}

void IntegerMatrix::fill(INT n) {
  /*
  for (uint32_t i=0;i<size;++i) {
    mat[i]=n;
  }
  */
  memset(mat,0,size*sizeof(INT));
  for (uint32_t i=0;i<rows;++i) {
    for (uint32_t j=0;j<cols;++j) {
      mat[i*cols2+j]=n;
    }
  }
}

double IntegerMatrix::checksum() {
  uint32_t i, j;
  double sum=0.0;

  for (i=0;i<rows;++i) {
    for (j=0;j<cols;++j) {
      sum+=(double)mat[cols2*i+j];
    }
  }
  return sum;
}



void IntegerMatrix::mul_basic(IntegerMatrix& a, IntegerMatrix& b) {
  assert(a.cols==b.rows);
  for (uint32_t i=0;i<a.rows;i++) {
    for (uint32_t j=0;j<b.cols;j++) {
      mat[cols2*i+j]=0;
      for (uint32_t k=0;k<a.cols;k++) {
	mat[cols2*i+j] += a.mat[i*a.cols2+k]*b.mat[k*b.cols2+j];
      }
    }
  }
}


void IntegerMatrix::mul_basic_inv_jk(IntegerMatrix& a, IntegerMatrix& b) {
  assert(a.cols==b.rows);
  for (uint32_t i=0;i<a.rows;i++) {
    for (uint32_t k=0;k<a.cols;k++) {
      for (uint32_t j=0;j<b.cols;j++) {
	mat[cols2*i+j] += a.mat[i*a.cols2+k]*b.mat[k*b.cols2+j];
      }
    }
  }
}

void IntegerMatrix::mul_block_32(IntegerMatrix& a, IntegerMatrix& b) {
  assert(a.cols==b.rows);

  const uint32_t NB = 32;

  uint32_t i,j,k,ii,jj,kk;
  uint32_t arows, acols, bcols;
  INT s00, s01, s10, s11;

  arows=a.rows2;
  acols=a.cols2;
  bcols=b.cols2;

  for (ii=0; ii<arows; ii+=NB) {
    for (jj=0; jj<bcols; jj+=NB) {
      for (i=ii; i<ii+NB; i++)
	for (j=jj; j<jj+NB; j++)
	  mat[i*bcols+j]=0;
      for (kk=0; kk<acols; kk+=NB)
	for (i=ii; i<ii+NB; i+=2)
	  for (j=jj; j<jj+NB; j+=2) {
	    s00=mat[i*bcols+j];
	    s01=mat[i*bcols+j+1];
	    s10=mat[(i+1)*bcols+j];
	    s11=mat[(i+1)*bcols+j+1];
	    for (k=kk; k<kk+NB; k++) {
	      s00+=a.mat[i*acols+k]     * b.mat[k*bcols+j];
	      s01+=a.mat[i*acols+k]     * b.mat[k*bcols+j+1];
	      s10+=a.mat[(i+1)*acols+k] * b.mat[k*bcols+j];
	      s11+=a.mat[(i+1)*acols+k] * b.mat[k*bcols+j+1];
	    }
	    mat[i*bcols+j]=s00;
	    mat[i*bcols+j+1]=s01;
	    mat[(i+1)*bcols+j]=s10;
	    mat[(i+1)*bcols+j+1]=s11;
	 }
    }
  }
}

void IntegerMatrix::mul_block_48(IntegerMatrix& a, IntegerMatrix& b) {
  assert(a.cols==b.rows);

  const uint32_t NB = 48;

  uint32_t i,j,k,ii,jj,kk;
  uint32_t arows, acols, bcols;
  INT s00, s01, s10, s11;

  arows=a.rows2;
  acols=a.cols2;
  bcols=b.cols2;

  for (ii=0; ii<arows; ii+=NB) {
    for (jj=0; jj<bcols; jj+=NB) {
      for (i=ii; i<ii+NB; i++)
	for (j=jj; j<jj+NB; j++)
	  mat[i*bcols+j]=0;
      for (kk=0; kk<acols; kk+=NB)
	for (i=ii; i<ii+NB; i+=2)
	  for (j=jj; j<jj+NB; j+=2) {
	    s00=mat[i*bcols+j];
	    s01=mat[i*bcols+j+1];
	    s10=mat[(i+1)*bcols+j];
	    s11=mat[(i+1)*bcols+j+1];
	    for (k=kk; k<kk+NB; k++) {
	      s00+=a.mat[i*acols+k]     * b.mat[k*bcols+j];
	      s01+=a.mat[i*acols+k]     * b.mat[k*bcols+j+1];
	      s10+=a.mat[(i+1)*acols+k] * b.mat[k*bcols+j];
	      s11+=a.mat[(i+1)*acols+k] * b.mat[k*bcols+j+1];
	    }
	    mat[i*bcols+j]=s00;
	    mat[i*bcols+j+1]=s01;
	    mat[(i+1)*bcols+j]=s10;
	    mat[(i+1)*bcols+j+1]=s11;
	 }
    }
  }
}


void IntegerMatrix::mul_block_64(IntegerMatrix& a, IntegerMatrix& b) {
  assert(a.cols==b.rows);

  const uint32_t NB = 64;

  uint32_t i,j,k,ii,jj,kk;
  uint32_t arows, acols, bcols;
  INT s00, s01, s10, s11;

  arows=a.rows2;
  acols=a.cols2;
  bcols=b.cols2;

  for (ii=0; ii<arows; ii+=NB) {
    for (jj=0; jj<bcols; jj+=NB) {
      for (i=ii; i<ii+NB; i++)
	for (j=jj; j<jj+NB; j++)
	  mat[i*bcols+j]=0;
      for (kk=0; kk<acols; kk+=NB)
	for (i=ii; i<ii+NB; i+=2)
	  for (j=jj; j<jj+NB; j+=2) {
	    s00=mat[i*bcols+j];
	    s01=mat[i*bcols+j+1];
	    s10=mat[(i+1)*bcols+j];
	    s11=mat[(i+1)*bcols+j+1];
	    for (k=kk; k<kk+NB; k++) {
	      s00+=a.mat[i*acols+k]     * b.mat[k*bcols+j];
	      s01+=a.mat[i*acols+k]     * b.mat[k*bcols+j+1];
	      s10+=a.mat[(i+1)*acols+k] * b.mat[k*bcols+j];
	      s11+=a.mat[(i+1)*acols+k] * b.mat[k*bcols+j+1];
	    }
	    mat[i*bcols+j]=s00;
	    mat[i*bcols+j+1]=s01;
	    mat[(i+1)*bcols+j]=s10;
	    mat[(i+1)*bcols+j+1]=s11;
	 }
    }
  }
}


void IntegerMatrix::display(ostream& out) {
  out << hex << mat << dec << " (alignment=" << alignment << ") ";
  out << "rows=" << rows << ", cols=" << cols;
  out << " (rows2=" << rows2 << ", cols2=" << cols2 << ")" << endl;
  for (uint32_t i=0;i<rows;++i) {
    for (uint32_t j=0;j<cols;++j) {
      out.width(6);
      out << mat[i*cols2+j] << " ";
    }
    out << endl;
  }
}


void IntegerMatrix::display_info(ostream& out) {
  out << hex << mat << dec << " (alignment=" << alignment << ") ";
  out << "rows=" << rows << ", cols=" << cols;
  out << " (rows2=" << rows2 << ", cols2=" << cols2 << ")" << endl;
}



ostream& operator<<(ostream& out, IntegerMatrix& m) {
  m.display(out);
  return out;
}

bool operator==(IntegerMatrix& m, IntegerMatrix& n) {

  if ((m.rows!=n.rows)||(m.cols!=n.cols)) return false;
  for (uint32_t i=0;i<m.rows;++i) {
    for (uint32_t j=0;j<m.cols;++j) {
      if (m.mat[i*m.cols2+j]!=n.mat[i*n.cols2+j]) {
	cout << "error, i=" << i << ", j=" << j;
	cout << ", val1=" << m.mat[i*m.cols2+j];
	cout << ", val2=" << n.mat[i*n.cols2+j] << endl;
	return false;
      }
    }
  }
  return true;
}