from .core import Layer


class Observer(object):
    """Container attached to a ``Model`` that allows to retrieve output spikes and
    potentials for a given layer.

    """

    def __init__(self, model, layer):
        """ Creates an ``Observer``.
        Args:
            model (:obj:`Model`): the model attached to the ``Observer``.
            layer (:obj:`Layer`): a ``Layer`` object that you want to observe.

        """
        self._model = model
        self._layer = layer

        def observer(layer, source_id, spikes, potentials):
            try:
                # Note that spikes will be None if there is no activation
                self._spikes[source_id] = spikes
                # Potentials will be None for InputData layers
                if potentials is None:
                    self._potentials[source_id] = None
                else:
                    # Convert CHW potentials Col Major Dense
                    # to a WHC Row Major numpy array
                    self._potentials[source_id] = potentials.to_numpy()
            except Exception as e:
                # We swallow any python exception because otherwise it would
                # crash the calling library
                print("Exception in observer callback: " + str(e))

        if isinstance(layer, Layer):
            # Verify the layer belongs to the model
            layer_obj = model.get_layer(layer.name)
            if layer_obj != layer:
                layer_obj = None
        else:
            # If we were passed a string or a layer id, get a layer object
            layer_obj = model.get_layer(layer)
        if layer_obj is None:
            raise ValueError(f"No layer '{layer}' in the model")
        self._id = model.backend.register_observer(layer_obj, observer)
        self._spikes = {}
        self._potentials = {}

    def __del__(self):
        self._model.backend.unregister_observer(self._id)

    def __repr__(self):
        data = "<akida.Observer, id=" + str(self._id) + ",\n"
        data += "model=" + str(self._model) + ",\n"
        data += "layer=" + str(self._layer) + ">"
        return data

    @property
    def spikes(self):
        """Get generated spikes.

        Returns a dictionary of spikes generated by the attached layer indexed
        by their source id.

        Returns:
            a dictionary of :obj:`Sparse` objects of shape (w, h, c).

        """
        return self._spikes

    @property
    def potentials(self):
        """Get generated potentials.

        Returns a dictionary of potentials generated by the attached layer

        Returns:
            a dictionary of `numpy.ndarray` objects of shape (w, h, c).

        """
        return self._potentials

    def clear(self):
        """Clear spikes and potentials lists."""
        self._spikes = {}
        self._potentials = {}
