"""Infrastructure layer for representing agent observations

Maintains a synchronized + serialized representation of agent observations in
flat tensors. This allows for fast observation processing as a set of tensor
slices instead of a lengthy traversal over hundreds of game properties.

Synchronization bugs are notoriously difficult to track down: make sure
to follow the correct instantiation protocol, e.g. as used for defining
agent/tile observations, when adding new types observations to the code"""
import copy
from collections import defaultdict

import marlben
import numpy as np


class DataType:
    CONTINUOUS = np.float32
    DISCRETE = np.int32


class Index:
    """Lookup index of attribute names"""

    def __init__(self, prealloc):
        self.free = {idx for idx in range(1, prealloc)}
        self.index = {}
        self.back = {}

    def full(self):
        return len(self.free) == 0

    def remove(self, key):
        row = self.index[key]
        del self.index[key]
        del self.back[row]

        self.free.add(row)
        return row

    def update(self, key):
        if key in self.index:
            row = self.index[key]
        else:
            row = self.free.pop()
            self.index[key] = row
            self.back[row] = key

        return row

    def get(self, key):
        return self.index[key]

    def teg(self, row):
        return self.back[row]

    def expand(self, cur, nxt):
        self.free.update({idx for idx in range(cur, nxt)})


class ContinuousTable:
    '''Flat tensor representation for a set of continuous attributes'''

    def __init__(self, config, obj, prealloc, dtype=DataType.CONTINUOUS):
        self.config = config
        self.dtype = dtype
        self.cols = {}
        self.nCols = 0

        for (attribute,), attr in obj:
            self.initAttr(attribute, attr)

        self.data = self.initData(prealloc, self.nCols)

    def initAttr(self, key, attr):
        if attr.CONTINUOUS:
            self.cols[key] = self.nCols
            self.nCols += 1

    def initData(self, nRows, nCols):
        return np.zeros((nRows, nCols), dtype=self.dtype)

    def update(self, row, attr, val):
        col = self.cols[attr]
        self.data[row, col] = val

    def expand(self, cur, nxt):
        data = self.initData(nxt, self.nCols)
        data[:cur] = self.data

        self.data = data
        self.nRows = nxt

    def get(self, rows, pad=None):
        data = self.data[rows]
        data[rows == 0] = 0

        if pad is not None:
            data = np.pad(data, ((0, pad - len(data)), (0, 0)))

        return data


class DiscreteTable(ContinuousTable):
    '''Flat tensor representation for a set of discrete attributes'''

    def __init__(self, config, obj, prealloc, dtype=DataType.DISCRETE):
        self.discrete, self.cumsum = {}, 0
        super().__init__(config, obj, prealloc, dtype)

    def initAttr(self, key, attr):
        if not attr.DISCRETE:
            return

        self.cols[key] = self.nCols

        # Flat index
        attr = attr(None, None, 0, config=self.config)
        self.discrete[key] = self.cumsum

        self.cumsum += attr.max - attr.min + 1
        self.nCols += 1

    def update(self, row, attr, val):
        col = self.cols[attr]
        self.data[row, col] = val + self.discrete[attr]


class Grid:
    '''Flat representation of tile/agent positions'''

    def __init__(self, R, C):
        self.data = np.zeros((R, C), dtype=np.int32)

    def zero(self, pos):
        r, c = pos
        self.data[r, c] = 0

    def set(self, pos, val):
        r, c = pos
        self.data[r, c] = val

    def move(self, pos, nxt, row):
        self.zero(pos)
        self.set(nxt, row)

    def window(self, rStart, rEnd, cStart, cEnd):
        crop = self.data[rStart:rEnd, cStart:cEnd].ravel()
        return list(filter(lambda x: x != 0, crop))


class GridTables:
    '''Combines a Grid + Index + Continuous and Discrete tables

   Together, these data structures provide a robust and efficient
   flat tensor representation of an entire class of observations,
   such as agents or tiles'''

    def __init__(self, config, obj, pad, prealloc=1000, expansion=2):
        self.grid = Grid(config.TERRAIN_SIZE, config.TERRAIN_SIZE)
        self.continuous = ContinuousTable(config, obj, prealloc)
        self.discrete = DiscreteTable(config, obj, prealloc)
        self.index = Index(prealloc)

        self.nRows = prealloc
        self.expansion = expansion
        self.radius = config.NSTIM
        self.pad = pad

    def get(self, ent, radius=None, entity=False):
        if radius is None:
            radius = self.radius

        r, c = ent.pos
        cent = self.grid.data[r, c]
        assert cent != 0

        rows = self.grid.window(
            r - radius, r + radius + 1,
            c - radius, c + radius + 1)

        # Self entity first
        if entity:
            rows.remove(cent)
            rows.insert(0, cent)

        values = {'Continuous': self.continuous.get(rows, self.pad),
                  'Discrete': self.discrete.get(rows, self.pad)}

        if entity:
            ents = [self.index.teg(e) for e in rows]
            assert ents[0] == ent.entID
            return values, ents

        return values

    def update(self, obj, val):
        key, attr = obj.key, obj.attr
        if self.index.full():
            cur = self.nRows
            self.nRows = cur * self.expansion

            self.index.expand(cur, self.nRows)
            self.continuous.expand(cur, self.nRows)
            self.discrete.expand(cur, self.nRows)

        row = self.index.update(key)
        if obj.DISCRETE:
            self.discrete.update(row, attr, val - obj.min)
        if obj.CONTINUOUS:
            self.continuous.update(row, attr, val)

    def move(self, key, pos, nxt):
        row = self.index.get(key)
        self.grid.move(pos, nxt, row)

    def init(self, key, pos):
        row = self.index.get(key)
        self.grid.set(pos, row)

    def remove(self, key, pos):
        self.index.remove(key)
        self.grid.zero(pos)


class Dataframe:
    '''Infrastructure wrapper class'''

    def __init__(self, config):
        self.config, self.data = config, defaultdict(dict)
        for (objKey,), obj in marlben.Serialized:
            self.data[objKey] = GridTables(config, obj, pad=obj.N(config))

    def update(self, node, val):
        self.data[node.obj].update(node, val)

    def remove(self, obj, key, pos):
        self.data[obj.__name__].remove(key, pos)

    def init(self, obj, key, pos):
        self.data[obj.__name__].init(key, pos)

    def move(self, obj, key, pos, nxt):
        self.data[obj.__name__].move(key, pos, nxt)

    def get(self, ent):
        stim = {}

        stim['Entity'], ents = self.data['Entity'].get(ent, entity=True)
        stim['Entity']['N'] = np.array([len(ents)], dtype=np.int32)

        ent.targets = ents
        stim['Tile'] = self.__apply_visibility_to_tiles(
            ent, self.data['Tile'].get(ent))

        return stim

    def __apply_visibility_to_tiles(self, ent, tiles):
        accessibility_col = self.data["Tile"].discrete.cols["AccessibilityColor"]
        accessibility_offset = self.data["Tile"].discrete.discrete["AccessibilityColor"]
        visibility_col = self.data["Tile"].discrete.cols["VisibilityColor"]
        visibility_offset = self.data["Tile"].discrete.discrete["VisibilityColor"]
        visibility_c_col = self.data["Tile"].continuous.cols["VisibilityColor"]
        accessibility_c_col = self.data["Tile"].continuous.cols["AccessibilityColor"]
        index_col = self.data["Tile"].discrete.cols["Index"]
        index_c_col = self.data["Tile"].continuous.cols["Index"]

        new_tiles = {"Continuous": [], "Discrete": []}
        for (row_d, row_c) in zip(tiles["Discrete"], tiles["Continuous"]):
            new_row_d = copy.deepcopy(row_d)
            new_row_c = copy.deepcopy(row_c)
            if (row_d[visibility_col] - visibility_offset) not in ent.visible_colors:
                new_row_d[index_col] = 5
                new_row_c[index_c_col] = 5.
            new_row_d[visibility_col] = visibility_offset
            new_row_d[accessibility_col] = accessibility_offset
            new_row_c[visibility_c_col] = 0.
            new_row_c[accessibility_c_col] = 0.
            new_tiles["Continuous"].append(new_row_c)
            new_tiles["Discrete"].append(new_row_d)
        new_tiles["Continuous"] = np.array(new_tiles["Continuous"])
        new_tiles["Discrete"] = np.array(new_tiles["Discrete"])
        return new_tiles
