#ifndef INTEGER_MATRIX_H
#define INTEGER_MATRIX_H

#include "common.h"
#include <cassert>
#include <stdint.h>
#include <iostream>
#include <algorithm>
using namespace std;

#define INT int

class IntegerMatrix {
 public:
  IntegerMatrix();
  IntegerMatrix(uint32_t rc);
  IntegerMatrix(uint32_t r, uint32_t c);
  IntegerMatrix(uint32_t r, uint32_t c, uint32_t r2, uint32_t c2);
  IntegerMatrix(IntegerMatrix& m);
  ~IntegerMatrix();

  uint32_t get_rows() { return rows; }
  uint32_t get_cols() { return cols; }
  uint32_t get_rows2() { return rows2; }
  uint32_t get_cols2() { return cols2; }
  uint32_t get_size() { return size; }

  INT *get_content() { return mat; }

  double checksum();

  void fill(INT n);

  /**
   * basic implementation - no optimization
   */
  void mul_basic(IntegerMatrix& a, IntegerMatrix& b);

  /**
   * basic implementation
   * with j/k loop inversion
   */
  void mul_basic_inv_jk(IntegerMatrix& a, IntegerMatrix& b);

  /**
   * multiply with blocs
   */
  void mul_block_32(IntegerMatrix& a, IntegerMatrix& b);

  /**
   * multiply with blocs
   */
  void mul_block_48(IntegerMatrix& a, IntegerMatrix& b);

  /**
   * multiply with blocs
   */
  void mul_block_64(IntegerMatrix& a, IntegerMatrix& b);

  static void set_alignment(int n);

  friend bool operator==(IntegerMatrix& m, IntegerMatrix& n);

  void display(ostream& out);
  void display_info(ostream& out);

  friend ostream& operator<<(ostream& out, IntegerMatrix& m);

 protected:
  uint32_t rows;  // number of rows
  uint32_t cols;  // number of columns
  uint32_t rows2; // corrected number of rows
  uint32_t cols2; // corrected number f columns
  uint32_t size;  // = rows2 * cols2
  INT *mat; // aligned
  INT *u_mat; // unaligned

  static int alignment;

};

#endif