#include <neuralnet/matrix.h>

#include "test.h"
#include "test_util.h"

#include <assert.h>
#include <stdlib.h>

// static void PrintMatrix(const nnMatrix* matrix) {
//   assert(matrix);

//   for (int i = 0; i < matrix->rows; ++i) {
//     for (int j = 0; j < matrix->cols; ++j) {
//       printf("%f ", nnMatrixAt(matrix, i, j));
//     }
//     printf("\n");
//   }
// }

TEST_CASE(nnMatrixMake_1x1) {
  nnMatrix A = nnMatrixMake(1, 1);
  TEST_EQUAL(A.rows, 1);
  TEST_EQUAL(A.cols, 1);
}

TEST_CASE(nnMatrixMake_3x1) {
  nnMatrix A = nnMatrixMake(3, 1);
  TEST_EQUAL(A.rows, 3);
  TEST_EQUAL(A.cols, 1);
}

TEST_CASE(nnMatrixInit_3x1) {
  nnMatrix A = nnMatrixMake(3, 1);
  nnMatrixInit(&A, (R[]) { 1, 2, 3 });
  TEST_EQUAL(A.values[0], 1);
  TEST_EQUAL(A.values[1], 2);
  TEST_EQUAL(A.values[2], 3);
}

TEST_CASE(nnMatrixCopyCol_test) {
  nnMatrix A = nnMatrixMake(3, 2);
  nnMatrix B = nnMatrixMake(3, 1);

  nnMatrixInit(&A, (R[]) {
    1, 2,
    3, 4,
    5, 6,
  });

  nnMatrixCopyCol(&A, &B, 1, 0);

  TEST_EQUAL(nnMatrixAt(&B, 0, 0), 2);
  TEST_EQUAL(nnMatrixAt(&B, 1, 0), 4);
  TEST_EQUAL(nnMatrixAt(&B, 2, 0), 6);

  nnMatrixDel(&A);
  nnMatrixDel(&B);
}

TEST_CASE(nnMatrixMul_square_3x3) {
  nnMatrix A = nnMatrixMake(3, 3);
  nnMatrix B = nnMatrixMake(3, 3);
  nnMatrix O = nnMatrixMake(3, 3);

  nnMatrixInit(&A, (const R[]){
    1, 2, 3,
    4, 5, 6,
    7, 8, 9,
  });
  nnMatrixInit(&B, (const R[]){
    2, 4, 3,
    6, 8, 5,
    1, 7, 9,
  });
  nnMatrixMul(&A, &B, &O);

  const R expected[3][3] = {
    { 17, 41, 40 },
    { 44, 98, 91 },
    { 71, 155, 142 },
  };
  for (int i = 0; i < O.rows; ++i) {
    for (int j = 0; j < O.cols; ++j) {
      TEST_TRUE(double_eq(nnMatrixAt(&O, i, j), expected[i][j], EPS));
    }
  }

  nnMatrixDel(&A);
  nnMatrixDel(&B);
  nnMatrixDel(&O);
}

TEST_CASE(nnMatrixMul_non_square_2x3_3x1) {
  nnMatrix A = nnMatrixMake(2, 3);
  nnMatrix B = nnMatrixMake(3, 1);
  nnMatrix O = nnMatrixMake(2, 1);

  nnMatrixInit(&A, (const R[]){
    1, 2, 3,
    4, 5, 6,
  });
  nnMatrixInit(&B, (const R[]){
    2,
    6,
    1,
  });
  nnMatrixMul(&A, &B, &O);

  const R expected[2][1] = {
    { 17 },
    { 44 },
  };
  for (int i = 0; i < O.rows; ++i) {
    for (int j = 0; j < O.cols; ++j) {
      TEST_TRUE(double_eq(nnMatrixAt(&O, i, j), expected[i][j], EPS));
    }
  }

  nnMatrixDel(&A);
  nnMatrixDel(&B);
  nnMatrixDel(&O);
}

TEST_CASE(nnMatrixMulAdd_test) {
  nnMatrix A = nnMatrixMake(2, 3);
  nnMatrix B = nnMatrixMake(2, 3);
  nnMatrix O = nnMatrixMake(2, 3);
  const R scale = 2;

  nnMatrixInit(&A, (const R[]){
    1, 2, 3,
    4, 5, 6,
  });
  nnMatrixInit(&B, (const R[]){
    2, 3, 1,
    7, 4, 3
  });
  nnMatrixMulAdd(&A, &B, scale, &O);  // O = A + B * scale

  const R expected[2][3] = {
    { 5, 8, 5 },
    { 18, 13, 12 },
  };
  for (int i = 0; i < O.rows; ++i) {
    for (int j = 0; j < O.cols; ++j) {
      TEST_TRUE(double_eq(nnMatrixAt(&O, i, j), expected[i][j], EPS));
    }
  }

  nnMatrixDel(&A);
  nnMatrixDel(&B);
  nnMatrixDel(&O);
}

TEST_CASE(nnMatrixMulSub_test) {
  nnMatrix A = nnMatrixMake(2, 3);
  nnMatrix B = nnMatrixMake(2, 3);
  nnMatrix O = nnMatrixMake(2, 3);
  const R scale = 2;

  nnMatrixInit(&A, (const R[]){
    1, 2, 3,
    4, 5, 6,
  });
  nnMatrixInit(&B, (const R[]){
    2, 3, 1,
    7, 4, 3
  });
  nnMatrixMulSub(&A, &B, scale, &O);  // O = A - B * scale

  const R expected[2][3] = {
    { -3, -4, 1 },
    { -10, -3, 0 },
  };
  for (int i = 0; i < O.rows; ++i) {
    for (int j = 0; j < O.cols; ++j) {
      TEST_TRUE(double_eq(nnMatrixAt(&O, i, j), expected[i][j], EPS));
    }
  }

  nnMatrixDel(&A);
  nnMatrixDel(&B);
  nnMatrixDel(&O);
}

TEST_CASE(nnMatrixMulPairs_2x3) {
  nnMatrix A = nnMatrixMake(2, 3);
  nnMatrix B = nnMatrixMake(2, 3);
  nnMatrix O = nnMatrixMake(2, 3);

  nnMatrixInit(&A, (const R[]){
    1, 2, 3,
    4, 5, 6,
  });
  nnMatrixInit(&B, (const R[]){
    2, 3, 1,
    7, 4, 3
  });
  nnMatrixMulPairs(&A, &B, &O);

  const R expected[2][3] = {
    { 2, 6, 3 },
    { 28, 20, 18 },
  };
  for (int i = 0; i < O.rows; ++i) {
    for (int j = 0; j < O.cols; ++j) {
      TEST_TRUE(double_eq(nnMatrixAt(&O, i, j), expected[i][j], EPS));
    }
  }

  nnMatrixDel(&A);
  nnMatrixDel(&B);
  nnMatrixDel(&O);
}

TEST_CASE(nnMatrixAdd_square_2x2) {
  nnMatrix A = nnMatrixMake(2, 2);
  nnMatrix B = nnMatrixMake(2, 2);
  nnMatrix C = nnMatrixMake(2, 2);

  nnMatrixInit(&A, (R[]) {
    1, 2,
    3, 4,
  });
  nnMatrixInit(&B, (R[]) {
    2, 1,
    5, 3,
  });

  nnMatrixAdd(&A, &B, &C);

  TEST_TRUE(double_eq(nnMatrixAt(&C, 0, 0), 3, EPS));
  TEST_TRUE(double_eq(nnMatrixAt(&C, 0, 1), 3, EPS));
  TEST_TRUE(double_eq(nnMatrixAt(&C, 1, 0), 8, EPS));
  TEST_TRUE(double_eq(nnMatrixAt(&C, 1, 1), 7, EPS));

  nnMatrixDel(&A);
  nnMatrixDel(&B);
  nnMatrixDel(&C);
}

TEST_CASE(nnMatrixSub_square_2x2) {
  nnMatrix A = nnMatrixMake(2, 2);
  nnMatrix B = nnMatrixMake(2, 2);
  nnMatrix C = nnMatrixMake(2, 2);

  nnMatrixInit(&A, (R[]) {
    1, 2,
    3, 4,
  });
  nnMatrixInit(&B, (R[]) {
    2, 1,
    5, 3,
  });

  nnMatrixSub(&A, &B, &C);

  TEST_TRUE(double_eq(nnMatrixAt(&C, 0, 0), -1, EPS));
  TEST_TRUE(double_eq(nnMatrixAt(&C, 0, 1), +1, EPS));
  TEST_TRUE(double_eq(nnMatrixAt(&C, 1, 0), -2, EPS));
  TEST_TRUE(double_eq(nnMatrixAt(&C, 1, 1), +1, EPS));

  nnMatrixDel(&A);
  nnMatrixDel(&B);
  nnMatrixDel(&C);
}

TEST_CASE(nnMatrixAddRow_test) {
  nnMatrix A = nnMatrixMake(2, 3);
  nnMatrix B = nnMatrixMake(1, 3);
  nnMatrix C = nnMatrixMake(2, 3);

  nnMatrixInit(&A, (R[]) {
    1, 2, 3,
    4, 5, 6,
  });
  nnMatrixInit(&B, (R[]) {
    2, 1, 3,
  });

  nnMatrixAddRow(&A, &B, &C);

  TEST_TRUE(double_eq(nnMatrixAt(&C, 0, 0), 3, EPS));
  TEST_TRUE(double_eq(nnMatrixAt(&C, 0, 1), 3, EPS));
  TEST_TRUE(double_eq(nnMatrixAt(&C, 0, 2), 6, EPS));
  TEST_TRUE(double_eq(nnMatrixAt(&C, 1, 0), 6, EPS));
  TEST_TRUE(double_eq(nnMatrixAt(&C, 1, 1), 6, EPS));
  TEST_TRUE(double_eq(nnMatrixAt(&C, 1, 2), 9, EPS));

  nnMatrixDel(&A);
  nnMatrixDel(&B);
  nnMatrixDel(&C);
}

TEST_CASE(nnMatrixTranspose_square_2x2) {
  nnMatrix A = nnMatrixMake(2, 2);
  nnMatrix B = nnMatrixMake(2, 2);

  nnMatrixInit(&A, (R[]) {
    1, 2,
    3, 4
  });

  nnMatrixTranspose(&A, &B);
  TEST_TRUE(double_eq(nnMatrixAt(&B, 0, 0), 1, EPS));
  TEST_TRUE(double_eq(nnMatrixAt(&B, 0, 1), 3, EPS));
  TEST_TRUE(double_eq(nnMatrixAt(&B, 1, 0), 2, EPS));
  TEST_TRUE(double_eq(nnMatrixAt(&B, 1, 1), 4, EPS));

  nnMatrixDel(&A);
  nnMatrixDel(&B);
}

TEST_CASE(nnMatrixTranspose_non_square_2x1) {
  nnMatrix A = nnMatrixMake(2, 1);
  nnMatrix B = nnMatrixMake(1, 2);

  nnMatrixInit(&A, (R[]) {
    1,
    3,
  });

  nnMatrixTranspose(&A, &B);
  TEST_TRUE(double_eq(nnMatrixAt(&B, 0, 0), 1, EPS));
  TEST_TRUE(double_eq(nnMatrixAt(&B, 0, 1), 3, EPS));

  nnMatrixDel(&A);
  nnMatrixDel(&B);
}

TEST_CASE(nnMatrixGt_test) {
  nnMatrix A = nnMatrixMake(2, 3);
  nnMatrix B = nnMatrixMake(2, 3);

  nnMatrixInit(&A, (R[]) {
    -3, 2, 0,
    4, -1, 5
  });

  nnMatrixGt(&A, 0, &B);
  TEST_TRUE(double_eq(nnMatrixAt(&B, 0, 0), 0, EPS));
  TEST_TRUE(double_eq(nnMatrixAt(&B, 0, 1), 1, EPS));
  TEST_TRUE(double_eq(nnMatrixAt(&B, 0, 2), 0, EPS));
  TEST_TRUE(double_eq(nnMatrixAt(&B, 1, 0), 1, EPS));
  TEST_TRUE(double_eq(nnMatrixAt(&B, 1, 1), 0, EPS));
  TEST_TRUE(double_eq(nnMatrixAt(&B, 1, 2), 1, EPS));

  nnMatrixDel(&A);
  nnMatrixDel(&B);
}