import numpy as np
from .layers import *
from .losses import *
from .metrics import *
from .optimizers import *
from .callbacks import *
from .preprocessing import *
from math import sqrt
from copy import deepcopy
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import pickle
import h5py
import sys
from datetime import datetime
import os

class Sequential:
    """
    A class for creating and training Feedforward Neural Networks (FNN) or Convolutional Neural Networks (CNN).

    Attributes
    ----------
    train_inputs : ndarray
        Input data for training the model.
    train_input_batch : ndarray
        A batch of training inputs.
    train_targets : ndarray
        Target data corresponding to training inputs.
    train_target_batch : ndarray
        A batch of training targets.
    train_outputs : ndarray
        Final output predictions during the training phase.
    train_output_batch : ndarray
        A batch of outputs during training.

    val_inputs : ndarray
        Input data for validating the model (optional).
    val_input_batch : ndarray
        A batch of validation inputs.
    val_targets : ndarray
        Target data corresponding to validation inputs (optional).
    val_target_batch : ndarray
        A batch of validation targets.
    val_outputs : ndarray
        Final output predictions during the validation phase (not computed until further improvements).
    val_output_batch : ndarray
        A batch of outputs during validation.

    predictions : ndarray
        Predictions generated by the model for a given input array.

    train_layers : list
        List of layers used during training.
    val_layers : list
        List of layers used for validation.

    train_loss : float
        Calculated loss for the training dataset.
    val_loss : float
        Calculated loss for the validation dataset.
    train_loss_history : list
        History of training loss values for each epoch.
    val_loss_history : list
        History of validation loss values for each epoch.

    train_accuracy : float
        Accuracy computed on the training dataset.
    val_accuracy : float
        Accuracy computed on the validation dataset.
    train_accuracy_history : list
        History of training accuracy values for each epoch.
    val_accuracy_history : list
        History of validation accuracy values for each epoch.

    cost_function : Loss
        Cost function used for training.

    runtime : float
        Total time taken to train the model.

    optimizer : Optimizer
        Optimizer instance used to update model parameters.
    validation : bool
        Indicates whether a validation dataset is used.

    Methods
    -------
    add(layer : layer, activation : Activation)
        Adds a new layer to the model.
    forward()
        Performs the forward pass for both training and validation layers.
    backward()
        Executes the backward pass to compute gradients for training layers.
    verbose(verbose : int, epoch : int, epochs : int, start_time : float)
        Displays training and validation metrics based on the specified verbosity level.
    compile(cost_function : Loss, optimizer : Optimizer)
        Compiles the model with the specified cost function and optimizer.
    fit(epochs : int, batch_size : int, verbose : int, callbacks : list)
        Trains the model using the provided cost function, optimizer, and other parameters.
    predict(X : ndarray)
        Generates predictions for the input data `X`.
    results()
        Visualizes loss and accuracy curves for both training and validation phases.
   save_params(name : str, extension : str)
        Saves the model's weights and biases to a file with the specified name and extension.
    load_params(path : str, parameters : dict)
        Loads model weights and biases from a file or dictionary.
    save_histories(name : str, extension : str)
        Saves the training and validation histories to a file with the specified name and extension.
    get_flatten_length()
        Computes the length of the flattened output from the `Flatten` layer.
    summary()
        Prints a summary of the model architecture.
    """

    def __init__(self, train_inputs, train_targets, val_inputs=None, val_targets=None):
        """
        Initializes the architecture with training and optional validation data.

        Parameters
        ----------
        train_inputs : ndarray
            The input data used to train the model.
        train_targets : ndarray
            The target labels or values corresponding to the `train_inputs`.
        val_inputs : ndarray, optional
            The input data used for validation. Default is 'None'.
        val_targets : ndarray, optional
            The target labels or values corresponding to the `val_inputs`. Default is 'None'.

        Raises
        ------
        TypeError
            If `train_inputs` and `train_targets` are not of type 'ndarray'.
            If `val_inputs` and `val_targets` are provided and not of type 'ndarray'.
        """
        self.train_inputs = train_inputs
        self.train_input_batch = None
        self.train_targets = train_targets
        self.train_target_batch = None
        self.train_outputs = None
        self.train_output_batch = None

        self.val_inputs = val_inputs
        self.val_input_batch = None
        self.val_targets = val_targets
        self.val_target_batch = None
        self.val_outputs = None
        self.val_output_batch = None

        self.batch_size = None
        self.val_batch_size = None

        self.predictions = None

        self.train_layers = []
        self.val_layers = []

        self.train_loss = 0.0
        self.val_loss = 0.0
        self.train_loss_history = []
        self.val_loss_history = []

        self.train_accuracy = 0.0
        self.val_accuracy = 0.0
        self.train_accuracy_history = []
        self.val_accuracy_history = []

        self.cost_function = None
        self.optimizer = None

        self.__is_trained__ = False
        self.__is_compiled__ = False
        self.runtime = 0.0
        self.validation = False

        if self.val_inputs is not None and self.val_targets is not None:
            self.validation = True

        if not isinstance(train_inputs, np.ndarray) or not isinstance(train_targets, np.ndarray):
            raise TypeError('`train_inputs` and `train_targets` must be of type ndarray.')
        if self.val_inputs is not None and self.val_targets is not None:
            if not isinstance(val_inputs, np.ndarray) or not isinstance(val_targets, np.ndarray):
                raise TypeError('`val_inputs` and `val_targets` must be of type ndarray.')

    def get_flatten_length(self):
        """
        Computes the length of the flattened output from the `Flatten` layer.

        This method iterates through the training layers to find the `Flatten` layer and computes the length of its output.

        Returns
        -------
        int
            The length of the flattened output.

        Raises
        ------
        ValueError
            If there is no `Flatten` layer in the training layers.
        """
        self.train_layers[0].inputs = self.train_inputs[0].reshape(1, *self.train_inputs[0].shape)
        for i in range(len(self.train_layers)):
            if isinstance(self.train_layers[i], Flatten):
                self.train_layers[i].forward()
                self.train_layers[0].inputs = self.train_inputs
                return self.train_layers[i].outputs.shape[1]
            elif i == len(self.train_layers) - 1:
                self.train_layers[0].inputs = self.train_inputs
                raise ValueError("There is no `Flatten` layer.")
            else:
                self.train_layers[i + 1].inputs = self.train_layers[i].forward()

    def add(self, layer, activation=None):
        """
        Adds a layer and its corresponding activation to the training layers.

        Parameters
        ----------
        layer : Dense or Convolution2D
            The layer to be added. Must be an instance of Dense or Convolution2D.
        activation : Activation, optional
            The activation layer to be added. Must be an instance of Activation. Default is 'None'.

        Raises
        ------
        TypeError
            If `layer` is not an instance of Dense or Convolution2D when `activation` is provided.
            If `activation` is not an instance of Activation when provided.
        """
        if not isinstance(layer, (Dense, Convolution2D)) and activation is not None:
            raise TypeError("`layer` must be of type `Dense` or `Convolution2D`.")
        if not isinstance(activation, Activation) and activation is not None:
            raise TypeError("`activation` must be of type `Activation`.")
        self.train_layers.append(layer)
        if activation is not None:
            self.train_layers.append(activation)

    def forward(self):
        """
        Performs the forward pass through the layers of the architecture for both
        training and validation.

        This method computes the output for both the training and validation data.
        During training, the input batch goes through the layers in sequence,
        and the output is stored in `train_output_batch`. If validation data is
        provided, the validation input batch goes through the validation layers
        and the output is stored in `val_output_batch`.

        Returns
        -------
        None
        """
        self.train_layers[0].inputs = self.train_input_batch
        if self.validation:
            self.val_layers[0].inputs = self.val_input_batch
        for i in range(len(self.train_layers)):
            if i + 1 == len(self.train_layers):
                self.train_output_batch = self.train_layers[i].forward()
            else:
                if isinstance(self.train_layers[i], Dropout):
                    self.train_layers[i].training = True
                self.train_layers[i + 1].inputs = self.train_layers[i].forward()
            if self.validation:
                if i + 1 == len(self.val_layers):
                    self.val_output_batch = self.val_layers[i].forward()
                else:
                    self.val_layers[i + 1].inputs = self.val_layers[i].forward()

    def predict(self, X):
        """
        Makes predictions for the given input `X` using the trained model.

        This method runs a forward pass with the provided input `X` and returns
        the predicted output based on the current state of the trained network.

        Be careful, you might encounter an error if the number of dimensions of `X`
        doesn't match the number of dimensions of `self.train_inputs`.

        Parameters
        ----------
        X : ndarray
            The input data for which predictions are to be made.

        Returns
        -------
        predictions : ndarray
            The predicted output for the given input `X`.

        Raises
        ------
        TypeError
            If `X` is not of type 'ndarray'.
        """
        if not isinstance(X, np.ndarray):
            raise TypeError("`X` must be of type 'ndarray'.")
        if X.ndim != self.train_inputs.ndim:
            raise TypeError("`X`must have the same number of dimensions as `self.train_inputs`.")
        self.train_layers[0].inputs = X
        for i in range(len(self.train_layers)):
            if isinstance(self.train_layers[i], Dropout):
                self.train_layers[i].training = False
            if i + 1 == len(self.train_layers):
                self.predictions = self.train_layers[i].forward()
                return self.predictions
            self.train_layers[i + 1].inputs = self.train_layers[i].forward()

    def backward(self):
        """
        Performs the backward pass through the layers to compute gradients.

        This method calculates the gradients of the loss function with respect to
        the parameters (weights and biases) of the network. The gradients are propagated backward through
        each layer.

        Returns
        -------
        None
        """
        self.dX = self.cost_function.derivative(self.train_target_batch, self.train_output_batch)
        for layer in reversed(self.train_layers):
            self.dX = layer.backward(self.dX)

    def verbose(self, verbose, epoch, epochs, start_time):
        """
        Prints training and validation metrics during training at specified intervals.

        This method provides feedback during training based on the verbosity level
        chosen. It can print the loss and accuracy for both the training and validation
        datasets, and also the training time. The method adjusts the print frequency
        based on the `verbose` parameter.

        Parameters
        ----------
        verbose : int
            The verbosity level (0, 1, or 2).
            0 - No output.
            1 - Print at the end of the training.
            2 - Print at the end of each epoch.
        epoch : int
            The current epoch number.
        epochs : int
            The total number of epochs.
        start_time : float
            The starting time of training to calculate the runtime.

        Raises
        ------
        ValueError
            If the `verbose` parameter is not 0, 1, or 2.
        TypeError
            If `verbose`, `epochs`, or `start_time` are not of the correct type.

        Returns
        -------
        None
        """
        if not isinstance(verbose, int) and verbose is not None:
            raise TypeError('`verbose` must be of type int or None.')
        if not isinstance(epochs, int):
            raise TypeError('`epochs` must be of type int.')
        if not isinstance(start_time, float):
            raise TypeError('`start_time` must be of type float.')

        if verbose == 2:
            if epoch + 1 < epochs:
                if self.validation:
                    print(f"[TRAINING METRICS] train_loss: {np.around(self.train_loss, 5)} · "
                          f"train_accuracy: {np.around(self.train_accuracy, 5)}\n" +
                          f"[VALIDATION METRICS] val_loss: {np.around(self.val_loss, 5)} · "
                          f"val_accuracy: {np.around(self.val_accuracy, 5)}\n\n")
                else:
                    print(f"[TRAINING METRICS] train_loss: {np.around(self.train_loss, 5)} | " +
                          f"train_accuracy: {np.around(self.train_accuracy, 5)}\n\n")
            elif epoch + 1 == epochs:
                if self.validation:
                    self.runtime = time.time() - start_time
                    string1 = f"| [TRAINING METRICS] train_loss: {np.around(self.train_loss, 5)} · " + \
                              f"train_accuracy: {np.around(self.train_accuracy, 5)} |"
                    string2 = f"| [VALIDATION METRICS] val_loss: {np.around(self.val_loss, 5)} · " + \
                              f"val_accuracy: {np.around(self.val_accuracy, 5)} |"
                    string1_length = len(string1)
                    string2_length = len(string2)
                    print("\n" + string1_length * "-" + "\n" + string1 + "\n" +
                          string1_length * "-" + "\n" +
                          string2 + (string1_length - string2_length - 1) * " " + "\n" +
                          string1_length * "-")
                    print(f"{round(self.runtime, 5)} seconds")
                else:
                    self.runtime = time.time() - start_time
                    string = f"| [TRAINING METRICS] train_loss: {np.around(self.train_loss, 5)} · " + \
                             f"train_accuracy: {np.around(self.train_accuracy, 5)} |"
                    string_length = len(string)
                    print("\n" + string_length * "-" + "\n" + string + "\n" + string_length * "-")
                    print(f"{round(self.runtime, 5)} seconds")
        elif verbose == 1:
            if epoch + 1 == epochs:
                if self.validation:
                    string1 = f"| [TRAINING METRICS] train_loss: {np.around(self.train_loss, 5)} · " + \
                              f"train_accuracy: {np.around(self.train_accuracy, 5)} |"
                    string2 = f"| [VALIDATION METRICS] val_loss: {np.around(self.val_loss, 5)} · " + \
                              f"val_accuracy: {np.around(self.val_accuracy, 5)} |"
                    string1_length = len(string1)
                    string2_length = len(string2)
                    print("\n" + string1_length * "-" + "\n" + string1 + "\n" + string1_length * "-" + "\n" +
                          string2 + (string1_length - string2_length - 1) * " " + "\n" + string1_length * "-")
                else:
                    string = f"| [TRAINING METRICS] train_loss: {np.around(self.train_loss, 5)} · " + \
                             f"train_accuracy: {np.around(self.train_accuracy, 5)} |"
                    string_length = len(string)
                    print("\n" + string_length * "-" + "\n" + string + "\n" + string_length * "-")
        elif verbose == 0 or verbose is None:
            return
        else:
            raise ValueError("`verbose` must be 0, 1, or 2.")

    def compile(self, cost_function, optimizer=SGD(learning_rate=0.01)):
        """
        Compiles the model with the specified cost function and optimizer.

        Parameters
        ----------
        cost_function : Loss
            The cost function to be used for training.
        optimizer : Optimizer, optional
            The optimizer to be used for updating the model parameters. Default is SGD with a learning rate of 0.01.

        Raises
        ------
        ValueError
            If `cost_function` is not an instance of Loss.
            If `optimizer` is not an instance of Optimizer.

        Returns
        -------
        None
        """
        if not isinstance(cost_function, Loss):
            raise ValueError("`cost_function` must be of type `Loss`.")
        if not isinstance(optimizer, Optimizer):
            raise ValueError("`optimizer` must be of type `Optimizer`.")

        self.__is_compiled__ = True
        self.optimizer = optimizer
        self.cost_function = cost_function

    def fit(self, epochs=1000, batch_size=None, verbose=1, callbacks=[]):
        """
        Trains the model using the provided cost function, optimizer, and other parameters.

        This method performs the training process for the neural network. It includes
        forward and backward passes, loss and accuracy computation and parameter updates
        using the chosen optimizer. During training, it also keeps track of the loss and
        accuracy for both the training and validation datasets (if validation data is provided).

        Parameters
        ----------
        epochs : int, optional
            The number of epochs for training. The default is 1000.
        batch_size : int, optional
            The number of samples per batch. If None, the entire dataset is used for each pass.
        verbose : int, optional
            The verbosity level for printing metrics during training. Default is 1.
        callbacks : list of Callback, optional
            A list of callback functions to extend training functionality. Each callback should be a callable
            object that implements the following methods:
            - `on_train_start(model)`: Called at the start of the training loop.
            - `on_epoch_start(model)`: Called at the start of each epoch.
            - `on_epoch_end(model)`: Called at the end of each epoch. If the callback is an instance
              of `LiveMetrics`, a `figure` parameter should be passed.
            - `on_train_end(model)`: Called at the end of the training loop.
            The default is an empty list.

        Raises
        ------
        ValueError
            - If the model has not been compiled before fitting.
            - If `batch_size` is not of type `int` and is not `None`.
            - If `epochs` is not of type `int`.
            - If `verbose` is not of type `int` and is not `None`.
        TypeError
            - If `callbacks` is not a list of Callback objects.

        Returns
        -------
        None
        """
        if not self.__is_compiled__:
            raise ValueError("The model must be compiled before fitting.")
        if not isinstance(batch_size, int) and batch_size is not None:
            raise ValueError("`batch_size` must be of type `int` or `None`.")
        if not isinstance(epochs, int):
            raise ValueError("`epochs` must be of type `int`.")
        if not isinstance(verbose, int) and verbose is not None:
            raise ValueError("`verbose` must be of type `int` or `None`.")
        if not isinstance(callbacks, list) or not all(isinstance(callback, Callback) for callback in callbacks):
            raise TypeError("`callbacks` must be a list of Callback objects.")

        self.__is_trained__ = True

        start_time = time.time()

        if self.validation:
            self.val_layers = deepcopy(self.train_layers)

        self.batch_size = batch_size

        if self.batch_size is None:
            steps = 1
        else:
            steps = self.train_inputs.shape[0] // self.batch_size
            if steps * self.batch_size < self.train_inputs.shape[0]:
                steps += 1
            if self.validation:
                self.val_batch_size = self.val_inputs.shape[0] // steps
                if self.val_batch_size <= 0:
                    raise ValueError(
                        f"Validation batch size must be at least 1. (currently {self.val_batch_size}). "
                        f"Ensure validation input size is greater than {steps} or increase training batch size to "
                        f"{self.train_inputs.shape[0] // self.val_inputs.shape[0]}."
                    )

        if epochs > 1000:
            update = 25
        elif 1000 >= epochs > 100:
            update = 10
        elif epochs <= 100:
            update = 1

        loss = 0.0
        acc = 0.0

        tqdm_epochs = False
        tqdm_steps = False

        if verbose == 1:
            tqdm_epochs = True
        elif verbose is not None and verbose != 0 and verbose != 1:
            tqdm_steps = True

        for callback in callbacks:
            if isinstance(callback, LiveMetrics):
                figure = plt.figure()
            callback.on_train_start(self)

        for epoch in (epoch_bar := tqdm(range(epochs), disable=not tqdm_epochs, file=sys.stdout)):
            epoch_bar.set_description(f"Epoch [{epoch + 1}/{epochs}]")
            epoch_bar.set_postfix({"loss": loss, "accuracy": acc})

            for callback in callbacks:
                callback.on_epoch_start(self)

            train_accumulated_loss = 0
            train_accumulated_accuracy = 0

            val_accumulated_loss = 0
            val_accumulated_accuracy = 0

            for step in (step_bar := tqdm(range(steps), disable=not tqdm_steps, file=sys.stdout)):
                step_bar.set_description(f"Epoch [{epoch + 1}/{epochs}]")
                step_bar.set_postfix({"loss": loss, "accuracy": acc})

                if self.batch_size is None:
                    self.train_input_batch = self.train_inputs
                    self.train_target_batch = self.train_targets

                    if self.validation:
                        self.val_input_batch = self.val_inputs
                        self.val_target_batch = self.val_targets
                else:
                    self.train_input_batch = self.train_inputs[step * self.batch_size:(step + 1) * self.batch_size]
                    self.train_target_batch = self.train_targets[step * self.batch_size:(step + 1) * self.batch_size]

                    if self.validation:
                        self.val_input_batch = self.val_inputs[
                                               step * self.val_batch_size:(step + 1) * self.val_batch_size]
                        self.val_target_batch = self.val_targets[
                                                step * self.val_batch_size:(step + 1) * self.val_batch_size]

                self.forward()
                self.backward()

                for i in range(len(self.train_layers)):
                    self.train_layers[i] = self.optimizer.update_params(self.train_layers[i])

                    if self.validation:
                        if isinstance(self.val_layers[i], Dense) or isinstance(self.val_layers[i], Convolution2D):
                            self.val_layers[i].weights = self.train_layers[i].weights
                            self.val_layers[i].biases = self.train_layers[i].biases

                loss = self.cost_function.loss(self.train_target_batch, self.train_output_batch)
                acc = accuracy(self.train_target_batch, self.train_output_batch)

                train_accumulated_loss += loss
                train_accumulated_accuracy += acc

                if self.validation:
                    val_accumulated_loss += self.cost_function.loss(self.val_target_batch, self.val_output_batch)
                    val_accumulated_accuracy += accuracy(self.val_target_batch, self.val_output_batch)

            self.train_loss = train_accumulated_loss / steps
            self.train_accuracy = train_accumulated_accuracy / steps

            if epoch % update == 0 or epoch == 1:
                self.train_loss_history.append(self.train_loss)
                self.train_accuracy_history.append(self.train_accuracy)

            if self.validation:
                self.val_loss = val_accumulated_loss / steps
                self.val_accuracy = val_accumulated_accuracy / steps

                if epoch % update == 0 or epoch == 1:
                    self.val_loss_history.append(self.val_loss)
                    self.val_accuracy_history.append(self.val_accuracy)

            self.verbose(verbose, epoch, epochs, start_time)

            for callback in callbacks:
                if isinstance(callback, LiveMetrics):
                    if epoch % update == 0 or epoch == 1:
                        callback.on_epoch_end(self, figure)
                else:
                    callback.on_epoch_end(self)

        if self.batch_size is not None:
            self.train_outputs = self.predict(self.train_inputs[:self.batch_size])
            for step in range(1, steps):
                self.train_outputs = np.concatenate((self.train_outputs, self.predict(
                    self.train_inputs[step * self.batch_size:(step + 1) * self.batch_size])), axis=0)
        else:
            self.train_outputs = self.predict(self.train_inputs)

        for callback in callbacks:
            callback.on_train_end(self)

    def results(self):
        """
        Plots the loss and accuracy evolution over epochs.

        This method generates plots to visualize the progress of the training
        process, showing how the loss and accuracy change for both training
        and validation datasets (if validation data is provided). This helps
        assess whether the model is learning and if it is overfitting.

        Raises
        ------
        RuntimeError
            If the model has not been trained yet.

        Returns
        -------
        None
        """
        if self.__is_trained__:
            figure, axs = plt.subplots(1, 2)

            axs[0].set_title("Loss Evolution")
            axs[0].set_xlabel("Epoch")
            axs[0].set_ylabel("Loss")
            axs[0].plot(self.train_loss_history, label="training dataset")
            if self.validation:
                axs[0].plot(self.val_loss_history, label="validation dataset")

            axs[1].set_title("Accuracy Evolution")
            axs[1].set_xlabel("Epoch")
            axs[1].set_ylabel("Accuracy")
            axs[1].plot(self.train_accuracy_history, label="training dataset")
            if self.validation:
                axs[1].plot(self.val_accuracy_history, label="validation dataset")

            axs[0].legend()
            axs[1].legend()

            plt.show()
        else:
            raise RuntimeError("The model has not been trained yet.")

    def save_params(self, name="parameters", extension="h5"):
        """
        Saves the trained model parameters (weights and biases) to a file.

        This method saves the weights and biases of all the layers of the trained model into a file with the specified name and extension.
        The parameters are saved with a timestamp appended to the base name.
        This allows the model to be restored later for further use or evaluation.

        Parameters
        ----------
        name : str, optional
            The base name of the file to save the parameters. Default is "parameters".
        extension : str, optional
            The file extension to use for saving the parameters. The default is "h5".

        Returns
        -------
        parameters : dict
            A dictionary containing the model's weights and biases. The dictionary includes the following keys:
            - "weights": List of weights for each layer.
            - "biases": List of biases for each layer.

        Raises
        ------
        RuntimeError
            If the model has not been trained yet.
        TypeError
            If `name` is not a string.
            If `extension` is not a string.
        ValueError
            If `extension` is not either 'h5' or 'pkl'.

        Notes
        -----
        The parameters are saved in a file with a timestamp appended to the base name. The file format can be either HDF5 (.h5) or pickle (.pkl).
        """
        if not isinstance(name, str):
            raise TypeError("`name` must be a string")
        if not isinstance(extension, str):
            raise TypeError("`extension` must be a string")
        if extension not in ("h5", "pkl"):
            raise ValueError("`extension` must be either 'h5' or 'pkl'")

        if self.__is_trained__:
            date_time = datetime.now().strftime("%m_%d_%Y-%H_%M_%S")
            parameters = {"weights": None, "biases": None}
            weights = []
            biases = []
            for layer in self.train_layers:
                if isinstance(layer, Dense) or isinstance(layer, Convolution2D):
                    weights.append(layer.weights)
                    biases.append(layer.biases)
                else:
                    weights.append(np.array(0.0))
                    biases.append(np.array(0.0))

            parameters["weights"] = weights
            parameters["biases"] = biases

            if extension == "pkl":
                with open(name + f"_{date_time}.pkl", 'wb') as f:
                    pickle.dump(parameters, f)

            elif extension == "h5":
                with h5py.File(name + f"_{date_time}.h5", 'w') as f:
                    weights_group = f.create_group("weights", track_order=True)
                    biases_group = f.create_group("biases", track_order=True)
                    for i in range(len(self.train_layers)):
                        weights_group.create_dataset(f"layer{i}", data=weights[i])
                        biases_group.create_dataset(f"layer{i}", data=biases[i])

            return parameters
        raise RuntimeError("The model has not been trained yet.")

    def load_params(self, path="parameters.h5", parameters=None):
        """
        Loads the model parameters (weights and biases) from a file or a given dictionary.

        This method loads the weights and biases from a file (either HDF5 or pickle format) or a dictionary
        and restores them to the corresponding layers in the model.
        This is useful for continuing training or evaluating a previously trained model.

        Parameters
        ----------
        path : str, optional
            The file path from which to load the parameters. Default is "parameters.h5".
        parameters : dict, optional
            A dictionary containing the model's weights and biases. Default is None.

        Raises
        ------
        TypeError
            If `path` is not a string.
            If `parameters` is not a dictionary.
        ValueError
            If the file extension is not either '.pkl' or '.h5'.

        Returns
        -------
        None

        Notes
        -----
        If `parameters` is provided, it will be used directly.
        Otherwise, the method will attempt to load the parameters from the file specified by `path`.
        The file format can be either HDF5 (.h5) or pickle (.pkl).
        """
        if not isinstance(path, str):
            raise TypeError("`path` must be a string")

        _, extension = os.path.splitext(path)
        if extension not in (".pkl", ".h5"):
            raise ValueError("`extension` must be either '.pkl' or '.h5'")

        if (type(parameters) == dict or parameters is None) and extension == ".pkl":
            if parameters is None:
                with open(path, 'rb') as f:
                    parameters = pickle.load(f)
            for i in range(len(self.train_layers)):
                if isinstance(self.train_layers[i], Dense) or isinstance(self.train_layers[i], Convolution2D):
                    self.train_layers[i].weights = parameters["weights"][i]
                    self.train_layers[i].biases = parameters["biases"][i]

        elif extension == ".h5" and parameters is None:
            with h5py.File(path, 'r') as f:
                weights = f["weights"]
                biases = f["biases"]
                for i in range(len(self.train_layers)):
                    if isinstance(self.train_layers[i], Dense) or isinstance(self.train_layers[i], Convolution2D):
                        self.train_layers[i].weights = weights[f"layer{i}"].astype(np.float64)[:]
                        self.train_layers[i].biases = biases[f"layer{i}"].astype(np.float64)[:]
        else:
            raise TypeError("`parameters` must be a dictionary")

    def save_histories(self, name="metrics_history", extension="h5"):
        """
        Saves the training and validation histories to a file.

        This method saves the training and validation histories (loss and accuracy) to a file with the specified name and extension.
        The histories are saved with a timestamp appended to the base name.

        Parameters
        ----------
        name : str, optional
            The base name of the file to save the histories. The default is "metrics_history".
        extension : str, optional
            The file extension to use for saving the histories. The default is "h5".

        Returns
        -------
        histories : dict
            A dictionary containing the training and validation histories. The dictionary includes the following keys:
            - "train_loss": List of training loss values.
            - "train_accuracy": List of training accuracy values.
            - "val_loss": List of validation loss values (if validation is enabled).
            - "val_accuracy": List of validation accuracy values (if validation is enabled).

        Raises
        ------
        TypeError
            If `name` is not a string.
            If `extension` is not a string.
        ValueError
            If `extension` is not either 'h5' or 'pkl'.

        Notes
        -----
        The histories are saved in a file with a timestamp appended to the base name. The file format can be either HDF5 (.h5) or pickle (.pkl).
        """
        if not isinstance(name, str):
            raise TypeError("`name` must be a string")
        if not isinstance(extension, str):
            raise TypeError("`extension` must be a string")
        if extension not in ("h5", "pkl"):
            raise ValueError("`extension` must be either 'h5' or 'pkl'")

        date_time = datetime.now().strftime("%m_%d_%Y-%H_%M_%S")
        histories = {
            "train_loss": self.train_loss_history,
            "train_accuracy": self.train_accuracy_history
        }
        if self.validation:
            histories["val_loss"] = self.val_loss_history
            histories["val_accuracy"] = self.val_accuracy_history

        if extension == "pkl":
            with open(name + f"_{date_time}.pkl", 'wb') as f:
                pickle.dump(histories, f)
        elif extension == "h5":
            with h5py.File(name + f"_{date_time}.h5", 'w') as f:
                f.create_dataset(f"train_loss", data=histories["train_loss"])
                f.create_dataset(f"train_accuracy", data=histories["train_accuracy"])
                if self.validation:
                    f.create_dataset(f"val_loss", data=histories["val_loss"])
                    f.create_dataset(f"val_accuracy", data=histories["val_accuracy"])

        return histories

    def summary(self):
        """
        Prints a summary of the model architecture.

        This method prints the shape of the output for each layer in the model.

        Returns
        -------
        None
        """
        self.predict(self.train_inputs[0].reshape(1, *self.train_inputs[0].shape))
        for i in range(len(self.train_layers)):
            print(f"{type(self.train_layers[i]).__name__}: {self.train_layers[i].outputs.shape}")
