# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/models_components__autocorrelation.ipynb (unless otherwise specified).

__all__ = ['AutoCorrelation', 'AutoCorrelationLayer']

# Cell
import math

import torch
import torch.nn as nn

# Cell
class AutoCorrelation(nn.Module):
    """
    AutoCorrelation Mechanism with the following two phases:
    (1) period-based dependencies discovery
    (2) time delay aggregation
    This block can replace the self-attention family mechanism seamlessly.
    """
    def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False):
        super(AutoCorrelation, self).__init__()
        self.factor = factor
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)

    def time_delay_agg_training(self, values, corr):
        """
        SpeedUp version of Autocorrelation (a batch-normalization style design)
        This is for the training phase.
        """
        head = values.shape[1]
        channel = values.shape[2]
        length = values.shape[3]
        # find top k
        top_k = int(self.factor * math.log(length))
        mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
        index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]
        weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
        # update corr
        tmp_corr = torch.softmax(weights, dim=-1)
        # aggregation
        tmp_values = values
        delays_agg = torch.zeros_like(values, dtype=torch.float, device=values.device)
        for i in range(top_k):
            pattern = torch.roll(tmp_values, -int(index[i]), -1)
            delays_agg = delays_agg + pattern * \
                         (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
        return delays_agg

    def time_delay_agg_inference(self, values, corr):
        """
        SpeedUp version of Autocorrelation (a batch-normalization style design)
        This is for the inference phase.
        """
        batch = values.shape[0]
        head = values.shape[1]
        channel = values.shape[2]
        length = values.shape[3]
        # index init
        init_index = torch.arange(length, device=values.device).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1)
        # find top k
        top_k = int(self.factor * math.log(length))
        mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
        weights = torch.topk(mean_value, top_k, dim=-1)[0]
        delay = torch.topk(mean_value, top_k, dim=-1)[1]
        # update corr
        tmp_corr = torch.softmax(weights, dim=-1)
        # aggregation
        tmp_values = values.repeat(1, 1, 1, 2)
        delays_agg = torch.zeros_like(values, dtype=torch.float, device=values.device)
        for i in range(top_k):
            tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)
            pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
            delays_agg = delays_agg + pattern * \
                         (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
        return delays_agg

    def time_delay_agg_full(self, values, corr):
        """
        Standard version of Autocorrelation
        """
        batch = values.shape[0]
        head = values.shape[1]
        channel = values.shape[2]
        length = values.shape[3]
        # index init
        init_index = torch.arange(length, device=values.device).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1)
        # find top k
        top_k = int(self.factor * math.log(length))
        weights = torch.topk(corr, top_k, dim=-1)[0]
        delay = torch.topk(corr, top_k, dim=-1)[1]
        # update corr
        tmp_corr = torch.softmax(weights, dim=-1)
        # aggregation
        tmp_values = values.repeat(1, 1, 1, 2)
        delays_agg = torch.zeros_like(values, dtype=torch.float, device=values.device)
        for i in range(top_k):
            tmp_delay = init_index + delay[..., i].unsqueeze(-1)
            pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
            delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))
        return delays_agg

    def forward(self, queries, keys, values, attn_mask):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        if L > S:
            zeros = torch.zeros_like(queries[:, :(L - S), :], dtype=torch.float, device=queries.device)
            values = torch.cat([values, zeros], dim=1)
            keys = torch.cat([keys, zeros], dim=1)
        else:
            values = values[:, :L, :, :]
            keys = keys[:, :L, :, :]

        # period-based dependencies
        q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)
        k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)
        res = q_fft * torch.conj(k_fft)
        corr = torch.fft.irfft(res, dim=-1)

        # time delay agg
        if self.training:
            V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
        else:
            V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)

        if self.output_attention:
            return (V.contiguous(), corr.permute(0, 3, 1, 2))
        else:
            return (V.contiguous(), None)


class AutoCorrelationLayer(nn.Module):
    def __init__(self, correlation, d_model, n_heads, d_keys=None,
                 d_values=None):
        super(AutoCorrelationLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)

        self.inner_correlation = correlation
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads

    def forward(self, queries, keys, values, attn_mask):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        out, attn = self.inner_correlation(
            queries,
            keys,
            values,
            attn_mask
        )
        out = out.view(B, L, -1)

        return self.out_projection(out), attn

class TemporalAutoBasis(nn.Module):
    def __init__(self, dim_in, dim_out, n_dim, bias=False, factor=1, hidden_dim=256, n_heads=8, drop=0):
        super(TemporalAutoBasis, self).__init__()
        self.auto = AutoCorrelationLayer(AutoCorrelation(False, factor=factor, attention_dropout=drop, output_attention=False), d_model=hidden_dim, n_heads=n_heads)
        self.conv_q = nn.Conv1d(in_channels=n_dim, out_channels=hidden_dim, kernel_size=(1,), padding=0, bias=bias)
        self.conv_k = nn.Conv1d(in_channels=n_dim, out_channels=hidden_dim, kernel_size=(1,), padding=0, bias=bias)
        self.conv_v = nn.Conv1d(in_channels=n_dim, out_channels=hidden_dim, kernel_size=(1,), padding=0, bias=bias)
        self.re = nn.Linear(hidden_dim, dim_in)
        self.decoder = nn.Linear(hidden_dim, n_dim)
        self.predict_linear = nn.Linear(dim_in, dim_in+dim_out)
        self.relu = nn.ReLU()

        self.dim_in = dim_in
        self.dim_out =dim_out
        self.n_dim =n_dim

    def forward(self, x, insample_x_t=None, outsample_x_t=None):
        B, _, _ = x.shape
        x = self.re(x)
        q = self.conv_q(x)
        k = self.conv_k(x)
        v = self.conv_v(x)
        att, _ = self.auto(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), None)
#        att = self.predict_linear(att.view(B, -1)).view(B, self.n_dim, -1)
        att = self.decoder(att).transpose(1, 2)
        att = self.relu(att)
        att = self.predict_linear(att)

        backcast = att[:, :, :self.dim_in]
        forecast = att[:, :, -self.dim_out:]
        return backcast, forecast

class TemporalCNNAutoBasis(nn.Module):
    def __init__(self, dim_in, dim_out, n_dim, bias=False, factor=1, hidden_dim=256, n_heads=8, drop=0):
        super(TemporalCNNAutoBasis, self).__init__()
        self.auto = AutoCorrelationLayer(AutoCorrelation(False, factor=factor, attention_dropout=drop, output_attention=False), d_model=hidden_dim, n_heads=n_heads)
        self.conv_q = nn.Conv1d(in_channels=n_dim, out_channels=hidden_dim, kernel_size=(1,), padding=0, bias=bias)
        self.conv_k = nn.Conv1d(in_channels=n_dim, out_channels=hidden_dim, kernel_size=(1,), padding=0, bias=bias)
        self.conv_v = nn.Conv1d(in_channels=n_dim, out_channels=hidden_dim, kernel_size=(1,), padding=0, bias=bias)
        self.re = nn.Linear(hidden_dim, dim_in)
        self.decoder = nn.Linear(hidden_dim, n_dim)
        self.predict_linear = nn.Linear(dim_in, dim_in+dim_out)
        self.relu = nn.ReLU()

        self.dim_in = dim_in
        self.dim_out =dim_out
        self.n_dim =n_dim

    def forward(self, x, insample_x_t=None, outsample_x_t=None):
        B, _, _ = x.shape
        x = self.re(x)
        conv_in = x
        q = self.conv_q(x)
        k = self.conv_k(x)
        v = self.conv_v(x)
        att, _ = self.auto(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), None)
#        att = self.predict_linear(att.view(B, -1)).view(B, self.n_dim, -1)
        att = self.decoder(att).transpose(1, 2)
        att = self.relu(att)
        att = self.predict_linear(att)

        backcast = att[:, :, :self.dim_in]
        forecast = att[:, :, -self.dim_out:]
        return backcast, forecast



if __name__ == '__main__':
    # model = AutoCorrelationLayer(AutoCorrelation(True, factor=1, attention_dropout=0, output_attention=False), d_model=256, n_heads=8)
    # input = torch.rand(64, 96, 256)
    # q, k, v = input, input, input
    #
    # out = model(q,k,v,None)
    # print(out)
    model = TemporalAutoBasis(dim_in=96, dim_out=96, n_dim=7)
    input = torch.rand(64, 7, 256)
    out1, out2 = model(input)
    print(out1.size())
    print(out2.size())
