#include "mempool.h"

#include "test.h"

#define NUM_BLOCKS 10

DEF_MEMPOOL(test_pool, int, NUM_BLOCKS);

static int count(test_pool* pool) {
  int count = 0;
  mempool_foreach(pool, n, { count++; });
  return count;
}

static int sum(test_pool* pool) {
  int sum = 0;
  mempool_foreach(pool, n, { sum += *n; });
  return sum;
}

// Create a pool.
TEST_CASE(mempool_create) {
  test_pool pool;
  mempool_make(&pool);
}

// Allocate all N blocks.
TEST_CASE(mempool_allocate_until_full) {
  test_pool pool;
  mempool_make(&pool);

  for (int i = 0; i < NUM_BLOCKS; ++i) {
    const int* block = mempool_alloc(&pool);
    TEST_TRUE(block != 0);
  }
}

// Allocate all N blocks, then free them.
TEST_CASE(mempool_fill_then_free) {
  test_pool pool;
  mempool_make(&pool);

  int* blocks[NUM_BLOCKS] = {0};
  for (int i = 0; i < NUM_BLOCKS; ++i) {
    blocks[i] = mempool_alloc(&pool);
    TEST_TRUE(blocks[i] != 0);
  }

  for (int i = 0; i < NUM_BLOCKS; ++i) {
    mempool_free(&pool, &blocks[i]);
    TEST_EQUAL(blocks[i], 0); // Pointer should be set to 0 on free.
  }

  TEST_EQUAL(count(&pool), 0);
}

// Attempt to allocate blocks past the maximum pool size.
// The pool should handle the failed allocations gracefully.
TEST_CASE(mempool_allocate_beyond_max_size) {
  test_pool pool;
  mempool_make(&pool);

  // Fully allocate the pool.
  for (int i = 0; i < NUM_BLOCKS; ++i) {
    TEST_TRUE(mempool_alloc(&pool) != 0);
  }

  // Past the end.
  for (int i = 0; i < NUM_BLOCKS; ++i) {
    TEST_EQUAL(mempool_alloc(&pool), 0);
  }
}

// Free blocks should always remain zeroed out.
// This tests the invariant right after creating the pool.
TEST_CASE(mempool_zero_free_blocks_after_creation) {
  test_pool pool;
  mempool_make(&pool);

  const int zero = 0;
  for (int i = 0; i < NUM_BLOCKS; ++i) {
    const int* block = (const int*)(pool.blocks) + i;
    TEST_EQUAL(memcmp(block, &zero, sizeof(int)), 0);
  }
}

// Free blocks should always remain zeroed out.
// This tests the invariant after freeing a block.
TEST_CASE(mempool_zero_free_block_after_free) {
  test_pool pool;
  mempool_make(&pool);

  int* val = mempool_alloc(&pool);
  TEST_TRUE(val != 0);
  *val = 177;

  int* old_val = val;
  mempool_free(&pool, &val); // val pointer is set to 0.
  TEST_EQUAL(*old_val, 0);   // Block is zeroed out after free.
}

// Traverse an empty pool.
TEST_CASE(mempool_traverse_empty) {
  test_pool pool;
  mempool_make(&pool);

  TEST_EQUAL(count(&pool), 0);
}

// Traverse a partially full pool.
TEST_CASE(mempool_traverse_partially_full) {
  const int N = NUM_BLOCKS / 2;

  test_pool pool;
  mempool_make(&pool);

  for (int i = 0; i < N; ++i) {
    int* val = mempool_alloc(&pool);
    TEST_TRUE(val != 0);
    *val = i + 1;
  }

  TEST_EQUAL(sum(&pool), N * (N + 1) / 2);
}

// Traverse a full pool.
TEST_CASE(mempool_traverse_full) {
  test_pool pool;
  mempool_make(&pool);

  for (int i = 0; i < NUM_BLOCKS; ++i) {
    int* val = mempool_alloc(&pool);
    TEST_TRUE(val != 0);
    *val = i + 1;
  }

  TEST_EQUAL(sum(&pool), NUM_BLOCKS * (NUM_BLOCKS + 1) / 2);
}

// Get the ith (allocated) block.
TEST_CASE(mempool_get_block) {
  test_pool pool;
  mempool_make(&pool);

  for (int i = 0; i < NUM_BLOCKS; ++i) {
    int* block = mempool_alloc(&pool);
    TEST_TRUE(block != 0);
    *block = i;
    TEST_EQUAL(mempool_get_block_index(&pool, block), (size_t)i);
  }

  for (int i = 0; i < NUM_BLOCKS; ++i) {
    TEST_EQUAL(*mempool_get_block(&pool, i), i);
  }
}

int main() { return 0; }