"""Base environment with some basic funcitons"""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/21_envs_inventory/00_inventory_utils.ipynb.

# %% auto 0
__all__ = ['OrderPipeline']

# %% ../../../nbs/21_envs_inventory/00_inventory_utils.ipynb 3
import logging


from abc import ABC, abstractmethod
from typing import Union, Tuple, List, Literal

from ..base import BaseEnvironment
from ...utils import Parameter, MDPInfo
from ...dataloaders.base import BaseDataLoader
from ...loss_functions import pinball_loss
from ...utils import set_param, Parameter


import gymnasium as gym

import numpy as np
import time

# %% ../../../nbs/21_envs_inventory/00_inventory_utils.ipynb 4
class OrderPipeline():
   
    """
    Class to handle the order pipeline in the inventory environments. It is used to keep track of the orders
    that are placed. It can account for fixed and variable lead times.
    
    """

    def __init__(self, 

        num_units: int,  # number of units (SKUs)
        lead_time_mean: Parameter | np.ndarray | List | int | float,  # mean lead time
        lead_time_stochasticity: Literal["fixed", "gamma", "normal_absolute", "normal_relative"] = "fixed", # "fixed", "gamma", "normal_absolute", "normal_relative"
        lead_time_variance: Parameter | np.ndarray | List | int | float | None = None,  # variance of the lead time
        max_lead_time: list[object] | None = None,  # maximum lead time in case of stochastic lead times
        min_lead_time: list[object] | None = 1,  # minimum lead time in case of stochastic lead times

        ) -> None:

        self.set_param('num_units', num_units, shape=(1,), new=True)
        self.set_param('lead_time_mean', lead_time_mean, shape=(self.num_units[0],), new=True)
        self.set_param('lead_time_variance', lead_time_variance, shape=(self.num_units[0],), new=True)
        self.lead_time_stochasticity = lead_time_stochasticity
        self.check_stochasticity(max_lead_time)

        if max_lead_time is None:
            self.set_param('max_lead_time', lead_time_mean, shape=(self.num_units[0],), new=True)
        else:
            self.set_param('max_lead_time', max_lead_time, shape=(self.num_units[0],), new=True)
        self.set_param('min_lead_time', min_lead_time, shape=(self.num_units[0],), new=True)

        if self.max_lead_time is None:
            self.max_lead_time = self.lead_time_mean
        if self.min_lead_time is None:
            self.min_lead_time = 1

        self.check_max_min_mean_lt()
  
        self.pipeline = np.zeros((np.max(self.max_lead_time), num_units))
        self.lead_time_realized = np.zeros((np.max(self.max_lead_time), num_units))

    def get_pipeline(self) -> np.ndarray:
        """ Get the current pipeline """

        return self.pipeline

    def reset(self) -> None:
        """ Reset the pipeline """

        self.pipeline = np.zeros((np.max(self.max_lead_time), self.num_units[0]))
        self.lead_time_realized = np.zeros((np.max(self.max_lead_time), self.num_units[0]))


    def step(self, 
        orders: np.ndarray,
        ) -> np.ndarray:
        
        """ Add orders to the pipeline and return the orders that are arriving """


        # print("beginnig pipeline:")
        # print(self.pipeline)
        # print("beginnig lead_time_realized:")
        # print(self.lead_time_realized)


        orders_arriving = self.get_orders_arriving().copy()
        lead_times = self.draw_lead_times().copy()

        # print("orders_arriving:")
        # print(orders_arriving)
        # print("new orders:")
        # print(orders)
        # print("new lead_times:")
        # print(lead_times)


        self.pipeline = np.roll(self.pipeline, -1, axis=0)
        self.lead_time_realized = np.roll(self.lead_time_realized, -1, axis=0)
        self.pipeline[-1, :] = 0
        self.lead_time_realized[-1, :] = 0
        
        self.pipeline[-1, :] = orders.copy()
        self.lead_time_realized[-1, :] = lead_times
        self.lead_time_realized -= 1
        self.lead_time_realized = np.clip(self.lead_time_realized, 0, None)

        # print("ending pipeline:")
        # print(self.pipeline)
        # print("ending lead_time_realized:")
        # print(self.lead_time_realized)

        return orders_arriving

    def get_orders_arriving(self) -> np.ndarray:

        """ Get the orders that are arriving in the current period """

        orders_arriving = np.zeros(self.num_units[0])
        for i in range(self.num_units[0]):
            # check along the pipeline where the lead time is 0

            arriving_indices = np.where(self.lead_time_realized[:, i] == 0)[0]

            orders_arriving[i] = np.sum(self.pipeline[arriving_indices, i])

            self.pipeline[arriving_indices, i] = 0

        return orders_arriving

    def draw_lead_times(self) -> np.ndarray:
        """ Draw lead times for the orders """

        if self.lead_time_stochasticity == "fixed":
            lead_times = self.lead_time_mean
        elif self.lead_time_stochasticity == "gamma":
            lead_times = np.random.gamma(self.lead_time_mean, 1, self.num_units[0])
        elif self.lead_time_stochasticity == "normal_absolute":
            lead_times = np.random.normal(self.lead_time_mean, self.lead_time_variance, self.num_units[0])
        elif self.lead_time_stochasticity == "normal_relative":
            lead_times = np.random.normal(self.lead_time_mean, self.lead_time_mean * self.lead_time_variance, self.num_units[0])
        else:
            raise ValueError("Invalid lead time stochasticity")

        lead_times = np.clip(lead_times, self.min_lead_time, self.max_lead_time)
        lead_times = np.round(lead_times).astype(int)

        return lead_times

    def check_stochasticity(self, max_lead_time):
        """ Check that params for stochastic lead times are set correctly """

        # lead time mean to be set in any case (it will be the determinstic lead time if lead_time_stochasticity is fixed)
        if self.lead_time_mean is None:
            raise ValueError("Lead time mean is not set")
        if self.lead_time_stochasticity == "fixed" or self.lead_time_stochasticity == "gamma":
            if self.lead_time_variance is not None:
                raise ValueError("Lead time variance must be None for fixed lead times (no variance) or gamma lead times (variance is set by the gamma distribution)")
        elif self.lead_time_stochasticity == "normal_absolute" or self.lead_time_stochasticity == "normal_relative":
            if self.lead_time_variance is None:
                raise ValueError("Lead time variance must be set for normal lead times")
        else:
            raise ValueError("Invalid lead time stochasticity")
        
        if self.lead_time_stochasticity != "fixed":
            if max_lead_time is None:
                raise ValueError("Max lead time must be set for stochastic lead times")

    def check_max_min_mean_lt(self):
        if np.any(self.max_lead_time < self.lead_time_mean):
            raise ValueError("Max lead time must be greater than or equal to the lead time mean")
        if np.any(self.min_lead_time < 1):
            raise ValueError("Min lead time must be at least 1")
        if np.any(self.min_lead_time > self.lead_time_mean):
            raise ValueError("Min lead time must be less than or equal to the lead time mean")
        if np.any(self.max_lead_time < self.min_lead_time):
            raise ValueError("Max lead time must be greater than or equal to the min lead time")
        if np.any(self.max_lead_time < 1):
            raise ValueError("Max lead time must be at least 1")
        if np.any(self.lead_time_mean < 1):
            raise ValueError("Lead time mean must be at least 1")
            
        
    def set_param(self,
                        name: str, # name of the parameter (will become the attribute name)
                        input: Parameter | int | float | np.ndarray | List, # input value of the parameter
                        shape: tuple = (1,), # shape of the parameter
                        new: bool = False # whether to create a new parameter or update an existing one
                        ) -> None: #
        
        """
        Set a parameter for the environment. It converts scalar values to numpy arrays and ensures that
        environment parameters are either of the Parameter class of Numpy arrays. If new is set to True, 
        the function will create a new parameter or update an existing one otherwise. If new is set to
        False, the function will raise an error if the parameter does not exist.
        """

        set_param(self, name, input, shape, new)

    @property
    def shape(self) -> Tuple:
        """ Get the shape of the pipeline """

        return self.pipeline.shape

