From 653e98e029a0d0f110b0ac599e50406060bb0f87 Mon Sep 17 00:00:00 2001
From: 3gg <3gg@shellblade.net>
Date: Sat, 16 Dec 2023 10:21:16 -0800
Subject: Decouple activations from linear layer.

---
 src/bin/mnist/src/main.c | 195 ++++++++++++++++++++++++++---------------------
 1 file changed, 108 insertions(+), 87 deletions(-)

(limited to 'src/bin')

diff --git a/src/bin/mnist/src/main.c b/src/bin/mnist/src/main.c
index 9aa3ce5..53e0197 100644
--- a/src/bin/mnist/src/main.c
+++ b/src/bin/mnist/src/main.c
@@ -29,32 +29,35 @@ static const double LABEL_UPPER_BOUND = 0.99;
 // Epsilon used to compare R values.
 static const double EPS = 1e-10;
 
-#define min(a,b) ((a) < (b) ? (a) : (b))
+#define min(a, b) ((a) < (b) ? (a) : (b))
 
 typedef struct ImageSet {
-  nnMatrix images;  // Images flattened into row vectors of the matrix.
-  nnMatrix labels;  // One-hot-encoded labels.
-  int count;        // Number of images and labels.
-  int rows;         // Rows in an image.
-  int cols;         // Columns in an image.
+  nnMatrix images; // Images flattened into row vectors of the matrix.
+  nnMatrix labels; // One-hot-encoded labels.
+  int      count;  // Number of images and labels.
+  int      rows;   // Rows in an image.
+  int      cols;   // Columns in an image.
 } ImageSet;
 
 static void usage(const char* argv0) {
-  fprintf(stderr, "Usage: %s <path to mnist files directory> [num images]\n", argv0);
+  fprintf(
+      stderr, "Usage: %s <path to mnist files directory> [num images]\n",
+      argv0);
   fprintf(stderr, "\n");
-  fprintf(stderr, "  Use -1 for [num images] to use all the images in the data set\n");
+  fprintf(
+      stderr,
+      "  Use -1 for [num images] to use all the images in the data set\n");
 }
 
-static bool R_eq(R a, R b) {
-  return fabs(a-b) <= EPS;
-}
+static bool R_eq(R a, R b) { return fabs(a - b) <= EPS; }
 
-static void PrintImage(const nnMatrix* images, int rows, int cols, int image_index) {
+static void PrintImage(
+    const nnMatrix* images, int rows, int cols, int image_index) {
   assert(images);
   assert((0 <= image_index) && (image_index < images->rows));
 
   // Top line.
-  for (int j = 0; j < cols/2; ++j) {
+  for (int j = 0; j < cols / 2; ++j) {
     printf(" -");
   }
   printf("\n");
@@ -68,8 +71,7 @@ static void PrintImage(const nnMatrix* images, int rows, int cols, int image_ind
         printf("#");
       } else if (*value > 0.5) {
         printf("*");
-      }
-      else if (*value > PIXEL_LOWER_BOUND) {
+      } else if (*value > PIXEL_LOWER_BOUND) {
         printf(":");
       } else if (*value == 0.0) {
         // Values should not be exactly 0, otherwise they cancel out weights
@@ -84,7 +86,7 @@ static void PrintImage(const nnMatrix* images, int rows, int cols, int image_ind
   }
 
   // Bottom line.
-  for (int j = 0; j < cols/2; ++j) {
+  for (int j = 0; j < cols / 2; ++j) {
     printf(" -");
   }
   printf("\n");
@@ -96,7 +98,7 @@ static void PrintLabel(const nnMatrix* labels, int label_index) {
 
   // Compute the label from the one-hot encoding.
   const R* value = nnMatrixRow(labels, label_index);
-  int label = -1;
+  int      label = -1;
   for (int i = 0; i < 10; ++i) {
     if (R_eq(*value++, LABEL_UPPER_BOUND)) {
       label = i;
@@ -113,13 +115,12 @@ static void PrintLabel(const nnMatrix* labels, int label_index) {
   printf(")\n");
 }
 
-static R lerp(R a, R b, R t) {
-  return a + t*(b-a);
-}
+static R lerp(R a, R b, R t) { return a + t * (b - a); }
 
 /// Rescales a pixel from [0,255] to [PIXEL_LOWER_BOUND, 1.0].
 static R FormatPixel(uint8_t pixel) {
-  const R value = (R)(pixel) / 255.0 * (1.0 - PIXEL_LOWER_BOUND) + PIXEL_LOWER_BOUND;
+  const R value =
+      (R)(pixel) / 255.0 * (1.0 - PIXEL_LOWER_BOUND) + PIXEL_LOWER_BOUND;
   assert(value >= PIXEL_LOWER_BOUND);
   assert(value <= 1.0);
   return value;
@@ -152,7 +153,8 @@ static void ImageToMatrix(
   }
 }
 
-static bool ReadImages(gzFile images_file, int max_num_images, ImageSet* image_set) {
+static bool ReadImages(
+    gzFile images_file, int max_num_images, ImageSet* image_set) {
   assert(images_file != Z_NULL);
   assert(image_set);
 
@@ -161,36 +163,41 @@ static bool ReadImages(gzFile images_file, int max_num_images, ImageSet* image_s
   uint8_t* pixels = 0;
 
   int32_t magic, total_images, rows, cols;
-  if ( (gzread(images_file, (char*)&magic, sizeof(int32_t)) != sizeof(int32_t)) ||
-       (gzread(images_file, (char*)&total_images, sizeof(int32_t)) != sizeof(int32_t)) ||
-       (gzread(images_file, (char*)&rows, sizeof(int32_t)) != sizeof(int32_t)) ||
-       (gzread(images_file, (char*)&cols, sizeof(int32_t)) != sizeof(int32_t)) ) {
+  if ((gzread(images_file, (char*)&magic, sizeof(int32_t)) !=
+       sizeof(int32_t)) ||
+      (gzread(images_file, (char*)&total_images, sizeof(int32_t)) !=
+       sizeof(int32_t)) ||
+      (gzread(images_file, (char*)&rows, sizeof(int32_t)) != sizeof(int32_t)) ||
+      (gzread(images_file, (char*)&cols, sizeof(int32_t)) != sizeof(int32_t))) {
     fprintf(stderr, "Failed to read header\n");
     goto cleanup;
   }
 
-  magic = ReverseEndian32(magic);
+  magic        = ReverseEndian32(magic);
   total_images = ReverseEndian32(total_images);
-  rows = ReverseEndian32(rows);
-  cols = ReverseEndian32(cols);
+  rows         = ReverseEndian32(rows);
+  cols         = ReverseEndian32(cols);
 
   if (magic != IMAGE_FILE_MAGIC) {
-    fprintf(stderr, "Magic number mismatch. Got %x, expected: %x\n",
-      magic, IMAGE_FILE_MAGIC);
+    fprintf(
+        stderr, "Magic number mismatch. Got %x, expected: %x\n", magic,
+        IMAGE_FILE_MAGIC);
     goto cleanup;
   }
 
-  printf("Magic: %.8x\nTotal images: %d\nRows: %d\nCols: %d\n",
-    magic, total_images, rows, cols);
+  printf(
+      "Magic: %.8x\nTotal images: %d\nRows: %d\nCols: %d\n", magic,
+      total_images, rows, cols);
 
-  total_images = max_num_images >= 0 ? min(total_images, max_num_images) : total_images;
+  total_images =
+      max_num_images >= 0 ? min(total_images, max_num_images) : total_images;
 
   // Images are flattened into single row vectors.
   const int num_pixels = rows * cols;
-  image_set->images = nnMatrixMake(total_images, num_pixels);
-  image_set->count = total_images;
-  image_set->rows = rows;
-  image_set->cols = cols;
+  image_set->images    = nnMatrixMake(total_images, num_pixels);
+  image_set->count     = total_images;
+  image_set->rows      = rows;
+  image_set->cols      = cols;
 
   pixels = calloc(1, num_pixels);
   if (!pixels) {
@@ -219,30 +226,31 @@ cleanup:
   return success;
 }
 
-static void OneHotEncode(const uint8_t* labels_bytes, int num_labels, nnMatrix* labels) {
+static void OneHotEncode(
+    const uint8_t* labels_bytes, int num_labels, nnMatrix* labels) {
   assert(labels_bytes);
   assert(labels);
   assert(labels->rows == num_labels);
   assert(labels->cols == 10);
 
   static const R one_hot[10][10] = {
-    { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
-    { 0, 1, 0, 0, 0, 0, 0, 0, 0, 0 },
-    { 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 },
-    { 0, 0, 0, 1, 0, 0, 0, 0, 0, 0 },
-    { 0, 0, 0, 0, 1, 0, 0, 0, 0, 0 },
-    { 0, 0, 0, 0, 0, 1, 0, 0, 0, 0 },
-    { 0, 0, 0, 0, 0, 0, 1, 0, 0, 0 },
-    { 0, 0, 0, 0, 0, 0, 0, 1, 0, 0 },
-    { 0, 0, 0, 0, 0, 0, 0, 0, 1, 0 },
-    { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 },
+      {1, 0, 0, 0, 0, 0, 0, 0, 0, 0},
+      {0, 1, 0, 0, 0, 0, 0, 0, 0, 0},
+      {0, 0, 1, 0, 0, 0, 0, 0, 0, 0},
+      {0, 0, 0, 1, 0, 0, 0, 0, 0, 0},
+      {0, 0, 0, 0, 1, 0, 0, 0, 0, 0},
+      {0, 0, 0, 0, 0, 1, 0, 0, 0, 0},
+      {0, 0, 0, 0, 0, 0, 1, 0, 0, 0},
+      {0, 0, 0, 0, 0, 0, 0, 1, 0, 0},
+      {0, 0, 0, 0, 0, 0, 0, 0, 1, 0},
+      {0, 0, 0, 0, 0, 0, 0, 0, 0, 1},
   };
 
   R* value = labels->values;
 
   for (int i = 0; i < num_labels; ++i) {
-    const uint8_t label = labels_bytes[i];
-    const R* one_hot_value = one_hot[label];
+    const uint8_t label         = labels_bytes[i];
+    const R*      one_hot_value = one_hot[label];
 
     for (int j = 0; j < 10; ++j) {
       *value++ = FormatLabel(*one_hot_value++);
@@ -255,13 +263,13 @@ static int OneHotDecode(const nnMatrix* label_matrix) {
   assert(label_matrix->cols == 10);
   assert(label_matrix->rows == 1);
 
-  R max_value = 0;
-  int pos_max = 0;
+  R   max_value = 0;
+  int pos_max   = 0;
   for (int i = 0; i < 10; ++i) {
     const R value = nnMatrixAt(label_matrix, 0, i);
     if (value > max_value) {
       max_value = value;
-      pos_max = i;
+      pos_max   = i;
     }
   }
   assert(pos_max >= 0);
@@ -269,7 +277,8 @@ static int OneHotDecode(const nnMatrix* label_matrix) {
   return pos_max;
 }
 
-static bool ReadLabels(gzFile labels_file, int max_num_labels, ImageSet* image_set) {
+static bool ReadLabels(
+    gzFile labels_file, int max_num_labels, ImageSet* image_set) {
   assert(labels_file != Z_NULL);
   assert(image_set != 0);
 
@@ -278,24 +287,28 @@ static bool ReadLabels(gzFile labels_file, int max_num_labels, ImageSet* image_s
   uint8_t* labels = 0;
 
   int32_t magic, total_labels;
-  if ( (gzread(labels_file, (char*)&magic, sizeof(int32_t)) != sizeof(int32_t)) ||
-       (gzread(labels_file, (char*)&total_labels, sizeof(int32_t)) != sizeof(int32_t)) ) {
+  if ((gzread(labels_file, (char*)&magic, sizeof(int32_t)) !=
+       sizeof(int32_t)) ||
+      (gzread(labels_file, (char*)&total_labels, sizeof(int32_t)) !=
+       sizeof(int32_t))) {
     fprintf(stderr, "Failed to read header\n");
     goto cleanup;
   }
 
-  magic = ReverseEndian32(magic);
+  magic        = ReverseEndian32(magic);
   total_labels = ReverseEndian32(total_labels);
 
   if (magic != LABEL_FILE_MAGIC) {
-    fprintf(stderr, "Magic number mismatch. Got %x, expected: %x\n",
-      magic, LABEL_FILE_MAGIC);
+    fprintf(
+        stderr, "Magic number mismatch. Got %x, expected: %x\n", magic,
+        LABEL_FILE_MAGIC);
     goto cleanup;
   }
 
   printf("Magic: %.8x\nTotal labels: %d\n", magic, total_labels);
 
-  total_labels = max_num_labels >= 0 ? min(total_labels, max_num_labels) : total_labels;
+  total_labels =
+      max_num_labels >= 0 ? min(total_labels, max_num_labels) : total_labels;
 
   assert(image_set->count == total_labels);
 
@@ -308,7 +321,8 @@ static bool ReadLabels(gzFile labels_file, int max_num_labels, ImageSet* image_s
     goto cleanup;
   }
 
-  if (gzread(labels_file, labels, total_labels * sizeof(uint8_t)) != total_labels) {
+  if (gzread(labels_file, labels, total_labels * sizeof(uint8_t)) !=
+      total_labels) {
     fprintf(stderr, "Failed to read labels\n");
     goto cleanup;
   }
@@ -335,17 +349,17 @@ int main(int argc, const char** argv) {
 
   bool success = false;
 
-  gzFile train_images_file = Z_NULL;
-  gzFile train_labels_file = Z_NULL;
-  gzFile test_images_file  = Z_NULL;
-  gzFile test_labels_file  = Z_NULL;
-  ImageSet train_set = { 0 };
-  ImageSet test_set  = { 0 };
-  nnNeuralNetwork* net = 0;
-  nnQueryObject* query = 0;
+  gzFile           train_images_file = Z_NULL;
+  gzFile           train_labels_file = Z_NULL;
+  gzFile           test_images_file  = Z_NULL;
+  gzFile           test_labels_file  = Z_NULL;
+  ImageSet         train_set         = {0};
+  ImageSet         test_set          = {0};
+  nnNeuralNetwork* net               = 0;
+  nnQueryObject*   query             = 0;
 
   const char* mnist_files_dir = argv[1];
-  const int max_num_images = argc > 2 ? atoi(argv[2]) : -1;
+  const int   max_num_images  = argc > 2 ? atoi(argv[2]) : -1;
 
   char train_labels_path[PATH_MAX];
   char train_images_path[PATH_MAX];
@@ -353,12 +367,12 @@ int main(int argc, const char** argv) {
   char test_images_path[PATH_MAX];
   strlcpy(train_labels_path, mnist_files_dir, PATH_MAX);
   strlcpy(train_images_path, mnist_files_dir, PATH_MAX);
-  strlcpy(test_labels_path,  mnist_files_dir, PATH_MAX);
-  strlcpy(test_images_path,  mnist_files_dir, PATH_MAX);
+  strlcpy(test_labels_path, mnist_files_dir, PATH_MAX);
+  strlcpy(test_images_path, mnist_files_dir, PATH_MAX);
   strlcat(train_labels_path, "/train-labels-idx1-ubyte.gz", PATH_MAX);
   strlcat(train_images_path, "/train-images-idx3-ubyte.gz", PATH_MAX);
-  strlcat(test_labels_path,  "/t10k-labels-idx1-ubyte.gz",  PATH_MAX);
-  strlcat(test_images_path,  "/t10k-images-idx3-ubyte.gz",  PATH_MAX);
+  strlcat(test_labels_path, "/t10k-labels-idx1-ubyte.gz", PATH_MAX);
+  strlcat(test_images_path, "/t10k-images-idx3-ubyte.gz", PATH_MAX);
 
   train_images_file = gzopen(train_images_path, "r");
   if (train_images_file == Z_NULL) {
@@ -406,11 +420,18 @@ int main(int argc, const char** argv) {
   }
 
   // Network definition.
-  const int image_size_pixels = train_set.rows * train_set.cols;
-  const int num_layers = 2;
-  const int layer_sizes[3] = { image_size_pixels, 100, 10 };
-  const nnActivation layer_activations[2] = { nnSigmoid, nnSigmoid };
-  if (!(net = nnMakeNet(num_layers, layer_sizes, layer_activations))) {
+  const int     image_size_pixels = train_set.rows * train_set.cols;
+  const int     num_layers        = 4;
+  const int     hidden_size       = 100;
+  const nnLayer layers[4]         = {
+      {.type   = nnLinear,
+       .linear = {.input_size = image_size_pixels, .output_size = hidden_size}},
+      {.type = nnSigmoid},
+      {.type   = nnLinear,
+       .linear = {.input_size = hidden_size, .output_size = 10}},
+      {.type = nnSigmoid}
+  };
+  if (!(net = nnMakeNet(layers, num_layers, image_size_pixels))) {
     fprintf(stderr, "Failed to create neural network\n");
     goto cleanup;
   }
@@ -418,17 +439,17 @@ int main(int argc, const char** argv) {
   // Train.
   printf("Training with up to %d images from the data set\n\n", max_num_images);
   const nnTrainingParams training_params = {
-    .learning_rate = 0.1,
-    .max_iterations = TRAIN_ITERATIONS,
-    .seed = 0,
-    .weight_init = nnWeightInitNormal,
-    .debug = true,
+      .learning_rate  = 0.1,
+      .max_iterations = TRAIN_ITERATIONS,
+      .seed           = 0,
+      .weight_init    = nnWeightInitNormal,
+      .debug          = true,
   };
   nnTrain(net, &train_set.images, &train_set.labels, &training_params);
 
   // Test.
   int hits = 0;
-  query = nnMakeQueryObject(net, /*num_inputs=*/1);
+  query    = nnMakeQueryObject(net, /*num_inputs=*/1);
   for (int i = 0; i < test_set.count; ++i) {
     const nnMatrix test_image = nnMatrixBorrowRows(&test_set.images, i, 1);
     const nnMatrix test_label = nnMatrixBorrowRows(&test_set.labels, i, 1);
@@ -444,7 +465,7 @@ int main(int argc, const char** argv) {
   }
   const R hit_ratio = (R)hits / (R)test_set.count;
   printf("Test images: %d\n", test_set.count);
-  printf("Hits: %d/%d (%.3f%%)\n", hits, test_set.count, hit_ratio*100);
+  printf("Hits: %d/%d (%.3f%%)\n", hits, test_set.count, hit_ratio * 100);
 
   success = true;
 
-- 
cgit v1.2.3