import json
import numpy as np
import re

from .core import (Model, Layer, LayerType, LayerParams, AkidaUnsupervised, __version__,
                   evaluate_bitwidth)
from .statistics import Statistics


def model_str(self):
    data = "akida.Model, layer_count=" + str(self.get_layer_count())
    data += ", sequence_count=" + str(len(self.sequences))
    out_dims = self.output_shape if self.get_layer_count() else []
    data += ", output_shape=" + str(out_dims)
    return data


def model_repr(self):
    out_dims = self.output_shape if self.get_layer_count() else []
    data = "<akida.Model, layer_count=" + str(self.get_layer_count())
    data += ", output_shape=" + str(out_dims)
    data += ", sequences=" + repr(self.sequences) + ">"
    return data


def model_to_dict(self):
    """Provide a dict representation of the Model

    Returns:
        dict: a Model dictionary.
    """
    learning = None
    if self.learning:
        learning = {name: getattr(self.learning, name) for name in dir(self.learning)}
    return {
        "Akida version": __version__,
        "IP version": self.ip_version.name,
        "layers": [layer.to_dict() for layer in self.layers],
        "learning": learning,
        "input_shape": self.input_shape,
        "output_shape": self.output_shape
    }


def model_from_dict(model_dict):
    """Instantiate a Model from a dict representation

    Args:
        model_dict(dict): a Model dictionary.

    Returns:
        :obj:`Model`: a Model.
    """
    # Check major and minor version
    model_version = model_dict["Akida version"]
    major, minor, _ = __version__.split('.')
    model_major, model_minor, _ = model_version.split('.')
    if major != model_major or minor != model_minor:
        raise ValueError(f"Serialized model was generated by version {model_version}, which"
                         f" is incompatible with current version {__version__}")
    # Instantiate an empty Model
    model = Model()
    # Add layers
    layers = model_dict["layers"]
    for layer_dict in layers:
        layer_name = layer_dict["name"]
        # Extract layer parameters
        layer_params = layer_dict["parameters"]
        if layer_params is None:
            raise ValueError("Cannot deserialize a Layer without parameters")
        layer_params = layer_params.copy()
        # Evaluate the Layer type from its serialized name
        layer_type = getattr(LayerType, layer_params["layer_type"])
        # Remove layer_type
        layer_params.pop("layer_type")
        # Instantiate layer
        layer = Layer(LayerParams(layer_type, layer_params), layer_name)
        # Evaluate the layer inbounds
        inbounds = [model.get_layer(ib) for ib in layer_dict["inbounds"]]
        # Add it to the model
        model.add(layer, inbounds)
    # If needed compile it to set learning parameters and variables
    if model_dict["learning"] is not None:
        learning = model_dict["learning"].copy()
        # Some parameters must be integer
        for name in ["num_weights", "num_classes"]:
            learning[name] = int(learning[name])
        model.compile(AkidaUnsupervised(**learning))
    # Now that the Model is fully initialized, load layer variables
    for layer_dict in layers:
        layer_name = layer_dict["name"]
        # Get corresponding Layer in model
        layer = model.get_layer(layer_name)
        # Iterate over variables
        variables_dict = layer_dict["variables"]
        for name in variables_dict:
            variable_dict = variables_dict[name]
            dtype = variable_dict["dtype"]
            variable = np.array(variable_dict["data"]).astype(dtype)
            bitwidth = variable_dict["bitwidth"]
            actual_bitwidth = evaluate_bitwidth(variable)
            if bitwidth < actual_bitwidth:
                raise ValueError(f"The specified bitwidth ({bitwidth}) must be higher or equal to \
                                 the actual bitwidth ({actual_bitwidth}) of {name} \
                                 from the layer '{layer_name}'.")
            layer.variables[name] = variable
    return model


def model_to_json(self):
    """Provide a JSON representation of the Model

    Returns:
        str: a JSON-formatted string corresponding to a Model.
    """
    # Pretty-print serialized model
    model_str = json.dumps(self.to_dict(), indent=2)

    # Remove spurious line jumps in serialized arrays of numbers
    def align_arrays(m):
        # Just return the extracted pattern
        return m.group(1)
    # Look for lines starting with whitespaces and:
    # - a signed integer or float number with an optional opening square bracket,
    # - stand-alone opening or closing square brackets
    return re.sub(r'\n\s+(\[?-?[\d\.]+,?|\]|\[)', align_arrays, model_str)


def model_from_json(model_str):
    """Instantiate a Model from a JSON representation

    Args:
        model_str(str): a JSON-formatted string corresponding to a Model.

    Returns:
        :obj:`Model`: a Model.
    """
    return Model.from_dict(json.loads(model_str))


@property
def statistics(self):
    """Get statistics by sequence for this model.

    Returns:
        a dictionary of :obj:`SequenceStatistics` indexed by name.

    """
    return Statistics(model=self)


def summary(self):
    """Prints a string summary of the model.

    This method prints a summary of the model with details for every layer,
    grouped by sequences:

    - name and type in the first column
    - output shape
    - kernel shape

    If there is any layer with unsupervised learning enabled, it will list
    them, with these details:

    - name of layer
    - number of incoming connections
    - number of weights per neuron

    """

    def _model_summary(model):
        # prepare headers
        headers = ['Input shape', 'Output shape', 'Sequences', 'Layers']
        # prepare an empty table
        table = [headers]

        if not model.layers:
            row = [
                'N/A',
                'N/A',
                str(len(model.sequences)),
                str(len(model.layers))
            ]
        else:
            row = [
                str(model.input_shape),
                str(model.output_shape),
                str(len(model.sequences)),
                str(len(model.layers))
            ]
        # add the number of NPs if the model is mapped
        has_program = np.any([s.program is not None for s in self.sequences])
        if has_program:
            headers.append('NPs')
            nb_nps = 0
            for sequence in model.sequences:
                for current_pass in sequence.passes:
                    for layer in current_pass.layers:
                        if layer.parameters.layer_type != LayerType.InputConvolutional:
                            nb_nps += 0 if layer.mapping is None else len(layer.mapping.nps)
            row.append(nb_nps)

        # prepare an empty table
        table = [headers]
        table.append(row)
        print_table(table, "Model Summary")

    def _get_backend_info(sequence):
        backend = str(sequence.backend).split('.')[-1]
        sequence_info = sequence.name + " (" + backend + ")"
        if sequence.program is not None:
            sequence_info += " - size: " + str(len(sequence.program)) + " bytes"
        return sequence_info

    def _layers_summary(sequences):
        # Prepare headers
        headers = ['Layer (type)', 'Output shape', 'Kernel shape']
        has_program = np.any([s.program is not None for s in self.sequences])
        if has_program:
            headers.append('NPs')
        # prepare an empty table
        table = [headers]
        new_splits = []
        has_multi_pass = len(sequences[0].passes) > 1
        nb_pass = 0
        for s in sequences:
            info = _get_backend_info(s)
            new_splits.append(info)
            for p in s.passes:
                if has_multi_pass:
                    nb_pass += 1
                    if nb_pass > 1:
                        new_splits.append(f"pass {nb_pass}")
                for layer in p.layers:
                    nps = None if layer.mapping is None else layer.mapping.nps
                    # layer name (type)
                    layer_type = layer.parameters.layer_type
                    # kernel shape
                    if "weights" in layer.get_variable_names():
                        kernel_shape = layer.get_variable("weights").shape
                    elif layer_type == LayerType.StatefulRecurrent:
                        kernel_shape = str(layer.get_variable("in_proj").shape)
                        kernel_shape += "  "
                        kernel_shape += str(layer.get_variable("A_real").shape)
                        kernel_shape += "  "
                        kernel_shape += str(layer.get_variable("A_imag").shape)
                        kernel_shape += "  "
                        kernel_shape += str(layer.get_variable("out_proj").shape)
                    else:
                        kernel_shape = "N/A"
                    # Prepare row and add it
                    row = [str(layer), str(layer.output_dims), str(kernel_shape)]
                    if has_program:
                        if layer_type == LayerType.InputConvolutional or nps is None:
                            row.append('N/A')
                        else:
                            row.append(len(nps))
                    table.append(row)
                    if len(table) - 1 > len(new_splits):
                        new_splits.append(False)
                    if layer_type == LayerType.SeparableConvolutional:
                        # Add pointwise weights on a second line
                        kernel_pw_shape = layer.get_variable("weights_pw").shape
                        row = ['', '', kernel_pw_shape]
                        if has_program:
                            row.append('')
                        table.append(row)
                        new_splits.append(False)
        print_table(table, None, new_splits)

    def _learning_summary(model):
        layer = model.layers[-1]
        # Prepare headers
        headers = ["Learning Layer", "# Input Conn.", "# Weights"]
        table = [headers]
        name = layer.name
        # Input connections is the product of input dims
        input_connections = np.prod(layer.input_dims)
        # Num non zero weights per neuron (counted on first neuron)
        weights = layer.get_variable("weights")
        incoming_conn = np.count_nonzero(weights[:, :, :, 0])
        # Prepare row and add it
        row = [name, str(input_connections), incoming_conn]
        table.append(row)
        print()
        print_table(table, "Learning Summary")

    # Print first the general Model summary
    _model_summary(self)
    # Print sequences summary
    if self.sequences:
        print()
        _layers_summary(self.sequences)
    # Print learning summary if we have more than one input layer
    if len(self.layers) > 1 and self.learning:
        # Only the last layer of a model can learn
        print()
        _learning_summary(self)


def print_table(table, title, new_splits=None):
    # Convert to np.array
    to_str = np.vectorize(str, otypes=['O'])
    table = to_str(table)
    # get column lengths
    str_len_f = np.vectorize(lambda cell: len(str(cell)))
    str_lens = np.amax(str_len_f(table), 0)
    line_len = np.sum(str_lens)
    # Prepare format rows
    size_formats = np.vectorize(lambda cell: f"{{:{cell}.{cell}}}")
    format_strings = size_formats(str_lens)
    format_row = "  ".join(format_strings)
    # Generate separators
    separator_len = line_len + 2 * len(table[0])
    separator = "_" * separator_len
    double_separator = "=" * separator_len

    # Print header
    center_format = f"{{:^{separator_len}}}"
    if title is not None:
        print(center_format.format(title))
    print(separator)
    print(format_row.format(*table[0]))

    rows = table[1:, :]
    if new_splits is None:
        new_splits = [False] * len(rows)
    assert len(rows) == len(new_splits)
    if not any(new_splits):
        print(double_separator)
    # Print body
    for row, new_split in zip(rows, new_splits):
        if new_split:
            # Display a line break only for sequences
            if "pass" not in new_split:
                print()
            # Compute the number of char on each side of the text
            space_len = max((separator_len - len(new_split)) / 2., 1.)
            space_left = "=" * int(np.ceil(space_len - 1))
            space_right = "=" * int(np.floor(space_len - 1))
            print(space_left, new_split, space_right)
            # Display a line break only for sequences
            if "pass" not in new_split:
                print()
        # Don't use a separator line on first row
        elif row[0] != rows[0][0]:
            print(separator)
        print(format_row.format(*row))
    print(separator)


def predict_classes(self, inputs, num_classes=0, batch_size=0):
    """Predicts the class labels for the specified inputs.

    Args:
        inputs (:obj:`numpy.ndarray`): a (n, x, y, c) uint8 tensor
        num_classes (int, optional): the number of output classes
        batch_size (int, optional): maximum number of inputs that should be
            processed at a time

    Returns:
        :obj:`numpy.ndarray`: an array of class labels

    """

    outputs = self.predict(inputs, batch_size)
    classes = np.argmax(outputs, axis=-1).flatten()
    num_neurons = outputs.shape[-1]
    if num_classes != 0 and num_classes != num_neurons:
        neurons_per_class = num_neurons // num_classes
        classes = classes // neurons_per_class
    return classes
