#include <cstddef>
#include <cstdint>

#include "akida/program_memory_info.h"
#include "engine/dma.h"
#include "infra/int_ops.h"

#include "akida/program_info.h"
#include "dma_desc_format.h"

namespace akida {

size_t input_memory_required(const ProgramInfo& program_info) {
  const auto* dimensions = program_info.input_dims();
  const size_t max_num_elements = dimensions[0] * dimensions[1] * dimensions[2];
  // With dense inputs, each element uses 1 byte, so it is just the number of
  // elements. With sparse inputs, each element uses dma::kSparseEventByteSize
  return program_info.input_is_dense()
             ? max_num_elements
             : max_num_elements * dma::kSparseEventByteSize;
}

size_t output_memory_required(const ProgramInfo& program_info) {
  const auto dimensions = program_info.output_dims();
  const auto nb_items = dimensions[0] * dimensions[1] * dimensions[2];

  const auto dense_output = program_info.output_is_dense();

  uint32_t item_size;
  if (dense_output) {
    if (program_info.activation_enabled()) {
      // dense activations use 1 byte per item
      item_size = 1;
    } else {
      // dense potentials use 32 bits per item
      item_size = sizeof(uint32_t);
    }
  } else {
    // other format are sparse, they use kSparseEventByteSize per item
    item_size = dma::kSparseEventByteSize;
  }

  // Output requires a DMA header, and must be aligned
  return align_up(item_size * nb_items + dma::kOutputHeaderByteSize,
                  dma::kAlignment);
}

size_t input_descriptor_memory_required(const ProgramInfo& program_info) {
  // dense inputs use HRC dma, sparse use AE (event) dma
  return program_info.input_is_dense() ? dma::hrc::DESC_BYTE_SIZE
                                       : dma::event::DESC_BYTE_SIZE;
}

size_t program_descriptors_memory_required(const ProgramInfo& program_info) {
  return (program_info.number_of_program_descriptors_required() +
          program_info.number_of_extra_program_descriptors_required()) *
         dma::config::DESC_BYTE_SIZE;
}

size_t program_data_memory_required(const ProgramInfo& program_info) {
  return program_info.program_data_required_memory();
}

size_t extra_program_memory_required(const ProgramInfo& program_info) {
  size_t result = program_info.fnp2_required_memory();
  if (program_info.number_of_passes() > 1) {
    // extra data are only required for multi pass: 1 word for dummy descriptor
    // output, 1 extra input descriptor generated by HW and 1 word for this
    // descriptor output There is an extra word used for dummy descriptors
    result += sizeof(dma::w32) +
              (program_info.input_is_dense() ? dma::hrc::DESC_BYTE_SIZE
                                             : dma::event::DESC_BYTE_SIZE) +
              sizeof(dma::w32);
  }
  return result;
}

}  // namespace akida
