From 411f66a2540fa17c736116d865e0ceb0cfe5623b Mon Sep 17 00:00:00 2001
From: jeanne <jeanne@localhost.localdomain>
Date: Wed, 11 May 2022 09:54:38 -0700
Subject: Initial commit.

---
 src/lib/test/train_linear_perceptron_test.c | 62 +++++++++++++++++++++++++++++
 1 file changed, 62 insertions(+)
 create mode 100644 src/lib/test/train_linear_perceptron_test.c

(limited to 'src/lib/test/train_linear_perceptron_test.c')

diff --git a/src/lib/test/train_linear_perceptron_test.c b/src/lib/test/train_linear_perceptron_test.c
new file mode 100644
index 0000000..2b1336d
--- /dev/null
+++ b/src/lib/test/train_linear_perceptron_test.c
@@ -0,0 +1,62 @@
+#include <neuralnet/train.h>
+
+#include <neuralnet/matrix.h>
+#include <neuralnet/neuralnet.h>
+#include "activation.h"
+#include "neuralnet_impl.h"
+
+#include "test.h"
+#include "test_util.h"
+
+#include <assert.h>
+
+TEST_CASE(neuralnet_train_linear_perceptron_test) {
+  const int num_layers = 1;
+  const int layer_sizes[] = { 1, 1 };
+  const nnActivation layer_activations[] = { nnIdentity };
+
+  nnNeuralNetwork* net = nnMakeNet(num_layers, layer_sizes, layer_activations);
+  assert(net);
+
+  // Train.
+
+  // Try to learn the Y=X line.
+  #define N 2
+  const R inputs[N]  = { 0., 1. };
+  const R targets[N] = { 0., 1. };
+
+  nnMatrix inputs_matrix  = nnMatrixMake(N, 1);
+  nnMatrix targets_matrix = nnMatrixMake(N, 1);
+  nnMatrixInit(&inputs_matrix, inputs);
+  nnMatrixInit(&targets_matrix, targets);
+
+  nnTrainingParams params = {
+    .learning_rate = 0.7,
+    .max_iterations = 10,
+    .seed = 0,
+    .weight_init = nnWeightInit01,
+    .debug = false,
+  };
+
+  nnTrain(net, &inputs_matrix, &targets_matrix, &params);
+
+  const R weight = nnMatrixAt(&net->weights[0], 0, 0);
+  const R expected_weight = 1.0;
+  printf("\nTrained network weight: %f, Expected: %f\n", weight, expected_weight);
+  TEST_TRUE(double_eq(weight, expected_weight, WEIGHT_EPS));
+
+  // Test.
+
+  nnQueryObject* query = nnMakeQueryObject(net, /*num_inputs=*/1);
+
+  const R test_input[] = { 2.3 };
+  R test_output[1];
+  nnQueryArray(net, query, test_input, test_output);
+
+  const R expected_output = test_input[0];
+  printf("Output: %f, Expected: %f\n", test_output[0], expected_output);
+  TEST_TRUE(double_eq(test_output[0], expected_output, OUTPUT_EPS));
+
+  nnDeleteQueryObject(&query);
+  nnDeleteNet(&net);
+}
-- 
cgit v1.2.3