#include <cstddef>
#include <cstdint>

#include "akida/program_memory_info.h"
#include "engine/dma.h"
#include "infra/int_ops.h"

#include "dma_desc_format.h"
#include "program_play.h"

namespace akida {

size_t input_memory_required(const uint8_t* program) {
  if (program == nullptr) {
    return 0;
  }

  const auto* dimensions = program::input_dims(program);
  const size_t max_num_elements = dimensions[0] * dimensions[1] * dimensions[2];
  // With dense inputs, each element uses 1 byte, so it is just the number of
  // elements. With sparse inputs, each element uses dma::kSparseEventByteSize
  return program::input_is_dense(program)
             ? max_num_elements
             : max_num_elements * dma::kSparseEventByteSize;
}

size_t output_memory_required(const uint8_t* program) {
  if (program == nullptr) {
    return 0;
  }

  const auto dimensions = program::output_dims(program);
  const auto nb_items = dimensions[0] * dimensions[1] * dimensions[2];

  const auto format = program::output_format(program);

  uint32_t item_size;
  if (format == dma::OutputFormat::DenseActivations) {
    // dense activations use 1 byte per item
    item_size = 1;
  } else if (format == dma::OutputFormat::DensePotentials) {
    // dense potentials use 32 bits per item
    item_size = sizeof(uint32_t);
  } else {
    // other format are sparse, they use kSparseEventByteSize per item
    item_size = dma::kSparseEventByteSize;
  }

  // Output requires a DMA header, and must be aligned
  return align_up(item_size * nb_items + dma::kOutputHeaderByteSize,
                  dma::kAlignment);
}

size_t input_descriptor_memory_required(const uint8_t* program) {
  // dense inputs use HRC dma, sparse use AE (event) dma
  return program::input_is_dense(program) ? dma::hrc::DESC_BYTE_SIZE
                                          : dma::event::DESC_BYTE_SIZE;
}

size_t program_descriptors_memory_required(const uint8_t* program) {
  // FIXME: Currently we always allocate "standard" config dma descriptors even
  // if they are not used by multipass. When multipass descriptors will use the
  // base descriptor address of config dma, we should remove this
  size_t result = dma::kMinNbDescriptors * dma::config::DESC_BYTE_SIZE;

  if (program::is_multi_pass(program)) {
    result += program::multi_pass_descriptors_required_memory(program);
  }
  return result;
}

size_t program_data_memory_required(const uint8_t* program) {
  return program::program_data_required_memory(program);
}

size_t extra_program_memory_required(const uint8_t* program) {
  size_t result = program::fnp2_tracks_byte_size(program);
  if (program::is_multi_pass(program)) {
    // extra data are only required for multi pass: 1 word for dummy descriptor
    // output, 1 extra input descriptor generated by HW and 1 word for this
    // descriptor output There is an extra word used for dummy descriptors
    result += sizeof(dma::w32) +
              (program::input_is_dense(program) ? dma::hrc::DESC_BYTE_SIZE
                                                : dma::event::DESC_BYTE_SIZE) +
              sizeof(dma::w32);
  }
  return result;
}

}  // namespace akida
