# -*- coding: utf-8 -*-
# @Author: Jie Yang
# @Date:   2017-10-17 16:47:32
# @Last Modified by:   Jie Yang,     Contact: jieynlp@gmail.com
# @Last Modified time: 2018-10-18 11:12:13
from __future__ import print_function
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np

class CharBiGRU(nn.Module):
    def __init__(self, alphabet_size, pretrain_char_embedding, embedding_dim, hidden_dim, dropout, device, bidirect_flag = True):
        """

        :param alphabet_size:
        :param pretrain_char_embedding:
        :param embedding_dim:
        :param hidden_dim:
        :param dropout:
        :param device:
        :param bidirect_flag:
        """
        super(CharBiGRU, self).__init__()
        self.hidden_dim = hidden_dim
        if bidirect_flag:
            self.hidden_dim = hidden_dim // 2
        self.char_drop = nn.Dropout(dropout).to(device)
        self.char_embeddings = nn.Embedding(alphabet_size, embedding_dim).to(device)
        if pretrain_char_embedding is not None:
            self.char_embeddings.weight.data.copy_(torch.from_numpy(pretrain_char_embedding))
        else:
            self.char_embeddings.weight.data.copy_(torch.from_numpy(self.random_embedding(alphabet_size, embedding_dim)))
        self.char_lstm = nn.GRU(embedding_dim, self.hidden_dim, num_layers=1, batch_first=True, bidirectional=bidirect_flag).to(device)


    def random_embedding(self, vocab_size, embedding_dim):
        """

        :param vocab_size:
        :param embedding_dim:
        :return:
        """
        pretrain_emb = np.empty([vocab_size, embedding_dim])
        scale = np.sqrt(3.0 / embedding_dim)
        for index in range(vocab_size):
            pretrain_emb[index,:] = np.random.uniform(-scale, scale, [1, embedding_dim])
        return pretrain_emb


    def get_last_hiddens(self, input, seq_lengths):
        """

        :param input:
        :param seq_lengths:
        :return:
        """
        """
            input:
                input: Variable(batch_size, word_length)
                seq_lengths: numpy array (batch_size,  1)
            output:
                Variable(batch_size, char_hidden_dim)
            Note it only accepts ordered (length) variable, length size is recorded in seq_lengths
        """
        batch_size = input.size(0)
        char_embeds = self.char_drop(self.char_embeddings(input))
        char_hidden = None
        pack_input = pack_padded_sequence(char_embeds, seq_lengths, True)
        char_rnn_out, char_hidden = self.char_lstm(pack_input, char_hidden)
        # char_rnn_out, _ = pad_packed_sequence(char_rnn_out)
        return char_hidden.transpose(1,0).contiguous().view(batch_size,-1)

    def get_all_hiddens(self, input, seq_lengths):
        """

        :param input:
        :param seq_lengths:
        :return:
        """
        """
            input:
                input: Variable(batch_size,  word_length)
                seq_lengths: numpy array (batch_size,  1)
            output:
                Variable(batch_size, word_length, char_hidden_dim)
            Note it only accepts ordered (length) variable, length size is recorded in seq_lengths
        """
        batch_size = input.size(0)
        char_embeds = self.char_drop(self.char_embeddings(input))
        char_hidden = None
        pack_input = pack_padded_sequence(char_embeds, seq_lengths, True)
        char_rnn_out, char_hidden = self.char_lstm(pack_input, char_hidden)
        char_rnn_out, _ = pad_packed_sequence(char_rnn_out)
        return char_rnn_out.transpose(1,0)


    def forward(self, input, seq_lengths):
        """

        :param input:
        :param seq_lengths:
        :return:
        """
        return self.get_all_hiddens(input, seq_lengths)
