#include "mem.h"

#include <cassert.h>

#include <stdlib.h>
#include <string.h>

bool mem_make_(
    Memory* mem, Chunk* chunks, void* blocks, size_t num_blocks,
    size_t block_size_bytes) {
  assert(mem);
  assert((chunks && blocks) || (!chunks && !blocks));
  assert(num_blocks >= 1);

  mem->block_size_bytes = block_size_bytes;
  mem->num_blocks       = num_blocks;
  mem->next_free_chunk  = 0;
  mem->trap             = true;

  // Allocate chunks and blocks if necessary and zero them out.
  if (!chunks) {
    chunks       = calloc(num_blocks, sizeof(Chunk));
    blocks       = calloc(num_blocks, block_size_bytes);
    mem->dynamic = true;
    if (!chunks || !blocks) {
      return false;
    }
  } else {
    memset(blocks, 0, num_blocks * block_size_bytes);
    memset(chunks, 0, num_blocks * sizeof(Chunk));
    mem->dynamic = false;
  }
  mem->chunks = chunks;
  mem->blocks = blocks;

  // Initialize the head as one large free chunk.
  Chunk* head      = &mem->chunks[0];
  head->num_blocks = num_blocks;

  return true;
}

void mem_del_(Memory* mem) {
  assert(mem);
  if (mem->dynamic) {
    if (mem->chunks) {
      free(mem->chunks);
      mem->chunks = 0;
    }
    if (mem->blocks) {
      free(mem->blocks);
      mem->blocks = 0;
    }
  }
}

void mem_clear_(Memory* mem) {
  assert(mem);
  mem->next_free_chunk = 0;
  memset(mem->blocks, 0, mem->num_blocks * mem->block_size_bytes);
  memset(mem->chunks, 0, mem->num_blocks * sizeof(Chunk));

  // Initialize the head as one large free chunk.
  Chunk* head      = &mem->chunks[0];
  head->num_blocks = mem->num_blocks;
}

void* mem_alloc_(Memory* mem, size_t num_blocks) {
  assert(mem);
  assert(num_blocks >= 1);

  // Search for the first free chunk that can accommodate num_blocks.
  const size_t start     = mem->next_free_chunk;
  size_t       chunk_idx = start;
  bool         found     = false;
  do {
    Chunk* chunk = &mem->chunks[chunk_idx];
    if (!chunk->used) {
      if (chunk->num_blocks > num_blocks) {
        // Carve out a smaller chunk when the found chunk is larger than
        // requested.
        // [prev] <--> [chunk] <--> [new next] <--> [next]
        const size_t new_next_idx = chunk_idx + num_blocks;
        Chunk*       new_next     = &mem->chunks[new_next_idx];
        if (chunk->next) {
          mem->chunks[chunk->next].prev = new_next_idx;
        }
        new_next->prev = chunk_idx;
        new_next->next = chunk->next;
        chunk->next    = new_next_idx;

        new_next->num_blocks = chunk->num_blocks - num_blocks;
        chunk->num_blocks    = num_blocks;

        chunk->used = true;
        found       = true;
        break;
      } else if (chunk->num_blocks == num_blocks) {
        chunk->used = true;
        found       = true;
        break;
      }
    }
    chunk_idx = chunk->next; // Last chunk points back to 0, which is always the
                             // start of some chunk. 'next' and 'prev' are
                             // always valid pointers.
  } while (chunk_idx != start);

  if (found) {
    mem->next_free_chunk = mem->chunks[chunk_idx].next;
    return &mem->blocks[chunk_idx * mem->block_size_bytes];
  } else {
    if (mem->trap) {
      FAIL("Memory allocation failed, increase the allocator's capacity or "
           "avoid fragmentation.");
    }
    return 0; // Large-enough free chunk not found.
  }
}

// The given pointer is a pointer to this first block of the chunk.
void mem_free_(Memory* mem, void** chunk_ptr) {
  assert(mem);
  assert(chunk_ptr);

  const size_t chunk_idx =
      ((uint8_t*)*chunk_ptr - mem->blocks) / mem->block_size_bytes;
  assert(chunk_idx < mem->num_blocks);
  Chunk* chunk = &mem->chunks[chunk_idx];

  // Disallow double-frees.
  assert(chunk->used);

  // Zero out the chunk so that we don't get stray values the next time it is
  // allocated.
  memset(&mem->blocks[chunk_idx], 0, chunk->num_blocks * mem->block_size_bytes);

  // Free the chunk. If it is contiguous with other free chunks, then merge.
  // We only need to look at the chunk's immediate neighbours because no two
  // free chunks are left contiguous after merging.
  chunk->used = false;
  if (chunk->next) {
    Chunk* next = &mem->chunks[chunk->next];
    if (!next->used) {
      // Pre:  [chunk] <--> [next] <--> [next next]
      // Post: [  chunk + next   ] <--> [next next]
      chunk->num_blocks += mem->chunks[chunk->next].num_blocks;
      chunk->next = next->next;
      if (next->next) {
        Chunk* next_next = &mem->chunks[next->next];
        next_next->prev  = chunk_idx;
      }
      next->prev = next->next = next->num_blocks = 0;
    }
  }
  if (chunk->prev) {
    Chunk* prev = &mem->chunks[chunk->prev];
    if (!prev->used) {
      // Pre:  [prev] <--> [chunk] <--> [next]
      // Post: [  prev + chunk   ] <--> [next]
      prev->num_blocks += chunk->num_blocks;
      prev->next = chunk->next;
      if (chunk->next) {
        Chunk* next = &mem->chunks[chunk->next];
        next->prev  = chunk->prev;
      }
      chunk->prev = chunk->next = chunk->num_blocks = 0;
    }
  }

  *chunk_ptr = 0;
}

// The handle is the chunk's index. We don't call it an index in the public API
// because from the user's perspective, two chunks allocated back-to-back need
// not be +1 away (the offset depends on how large the first chunk is).
void* mem_get_chunk_(const Memory* mem, size_t chunk_handle) {
  assert(mem);
  assert(chunk_handle < mem->num_blocks);
  assert(mem->chunks[chunk_handle].used);
  return &mem->blocks[chunk_handle * mem->block_size_bytes];
}

// The given chunk pointer is a pointer to the blocks array.
size_t mem_get_chunk_handle_(const Memory* mem, const void* chunk) {
  assert(mem);
  const size_t block_byte_index = (const uint8_t*)chunk - mem->blocks;
  assert(block_byte_index % mem->block_size_bytes == 0);
  return block_byte_index / mem->block_size_bytes;
}

size_t mem_capacity_(const Memory* mem) {
  assert(mem);
  return mem->num_blocks * mem->block_size_bytes;
}

void mem_enable_traps_(Memory* mem, bool enable) {
  assert(mem);
  mem->trap = enable;
}