import tensorflow as tf
import numpy as np
from unittest import TestCase

from kolibri.synthetic_data.ctgan.utils import get_test_variables
from kolibri.synthetic_data.ctgan.layers import GenActivation


class TestGenActLayer(TestCase):
    def setUp(self):
        self._vars = get_test_variables()
        self._output_tensor = [
            tf.constant([0, 1, 0], dtype=tf.int32),
            tf.constant([1, self._vars['output_dim'], 1], dtype=tf.int32)
        ]

    def tearDown(self):
        del self._vars
        del self._output_tensor

    def test_gumbel_softmax(self):
        tf.random.set_seed(0)
        inputs = tf.random.uniform(
            [self._vars['batch_size'], self._vars['input_dim']])
        gen_act_layer = GenActivation(
            self._vars['input_dim'],
            self._vars['output_dim'],
            self._output_tensor,
            self._vars['tau'])
        outputs = gen_act_layer._gumbel_softmax(inputs, tau=self._vars['tau'])
        expected_outputs = tf.constant(
            [[2.5203612e-08, 5.1191260e-09, 6.2580289e-06, 8.8759303e-02,
              8.6855674e-05, 1.3614854e-01, 1.5237300e-05, 2.3315281e-03,
              7.7256054e-01, 9.1768146e-05],
             [2.5267678e-07, 1.2717593e-09, 3.5713683e-08, 2.3282080e-10,
              6.3239231e-10, 1.6829775e-08, 9.9999976e-01, 1.3369175e-08,
              1.0020080e-08, 4.1065191e-09],
             [9.9970883e-01, 7.0147536e-08, 1.4887677e-04, 3.9260326e-06,
              1.3550435e-04, 6.5569994e-09, 3.2505940e-07, 2.2628283e-06,
              1.5529440e-10, 1.2366169e-07],
             [7.1061681e-05, 7.2914936e-09, 8.0928497e-04, 9.9911875e-01,
              1.2028343e-07, 7.6647319e-07, 1.2323967e-08, 4.7224105e-09,
              6.2253469e-09, 1.0437786e-09],
             [3.0685652e-07, 1.9758712e-01, 7.8216940e-01, 1.2924423e-03,
              1.2296512e-03, 1.6724426e-02, 1.3887612e-05, 7.7307220e-07,
              9.7094814e-04, 1.1089712e-05],
             [1.3105995e-06, 1.0000190e-02, 9.8894638e-01, 1.0700265e-07,
              9.9610642e-04, 6.7987878e-09, 3.2664647e-09, 5.5520373e-05,
              3.6499455e-07, 1.2366691e-07],
             [5.7631564e-06, 2.3302120e-04, 9.9975842e-01, 3.4796191e-10,
              8.5701252e-10, 1.9418301e-06, 1.2495369e-08, 1.8805335e-09,
              1.4094038e-08, 8.5269869e-07],
             [8.7884471e-02, 4.9140328e-01, 1.9537224e-01, 1.0472466e-04,
              2.2511473e-01, 5.8175503e-10, 9.1088725e-05, 1.3922374e-05,
              2.3651278e-06, 1.3103404e-05],
             [6.6259997e-07, 2.5967475e-10, 4.7318011e-10, 3.2092371e-09,
              3.9947822e-13, 9.9999928e-01, 1.4720424e-12, 4.2286508e-10,
              7.9255608e-10, 1.4511153e-12],
             [3.1757179e-05, 1.8895147e-07, 5.4502758e-09, 1.7631198e-03,
              9.2034124e-06, 8.0049807e-01, 9.7241864e-05, 1.9759937e-01,
              1.1126443e-06, 6.7034933e-09]],
            dtype=tf.float32)
        print(outputs)

        outputs = gen_act_layer._gumbel_softmax(
            inputs, tau=self._vars['tau'], hard=True)
        expected_outputs = tf.constant(
            [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
             [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
             [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
             [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
             [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
             [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
             [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
             [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]], dtype=tf.float32)
        print(outputs)

    def test_gen_act_layer(self):
        tf.random.set_seed(0)
        inputs = tf.random.uniform(
            [self._vars['batch_size'], self._vars['input_dim']])
        gen_act_layer = GenActivation(
            self._vars['input_dim'],
            self._vars['output_dim'],
            self._output_tensor,
            self._vars['tau'])

        outputs, outputs_act = gen_act_layer(inputs)
        expected_outputs = tf.constant(
            [[0.06399402, -0.15808228, -0.11558336, 0.05494805, -0.33126852,
              -0.17637017, 0.12200963, 0.39402956, 0.32634318, -0.33070335],
             [0.25087643, -0.19576189, -0.23913635, 0.15009138, -0.3564161,
              -0.21825123, -0.14072102, 0.6937534, 0.24022676, -0.45597944],
             [0.6220621, 0.1469637, -0.18215235, 0.5548569, -0.0987303,
              0.01361492, 0.06346679, 0.3133853, 0.06415256, -0.6227628],
             [-0.10288254, -0.34491983, -0.1632171, -0.08494198, -0.4720165,
              -0.45313528, 0.10734999, 0.4955495, 0.34978482, -0.15236723],
             [0.20744193, 0.10826425, -0.32170045, 0.3576493, -0.44370794,
              0.22804666, -0.10140005, 0.2980532, -0.36202985, -0.46871606],
             [0.2961299, 0.27615362, -0.02876888, 0.39636207, -0.09771496,
              0.13101867, 0.19671272, 0.31471086, 0.11839052, -0.18321344],
             [0.71194685, -0.10758564, -0.14955036, 0.64336705, -0.34915873,
              -0.00202614, -0.26932833, 0.69632256, -0.03363119, -0.7745644],
             [0.07209513, 0.0680066, -0.678686, 0.36803997, -0.25703406,
              0.08380148, -0.12455904, 0.17812183, 0.10595389, -0.2080901],
             [0.44230077, 0.2303819, -0.3798572, 0.3656776, -0.15100443,
              0.18006304, -0.18976846, 0.5092122, 0.00852871, -0.5495614],
             [0.28090134, -0.11573267, 0.08267065, 0.24265254, -0.28548294,
              -0.32438374, -0.20357645, 0.4609607, -0.08798116, -0.3462641]],
            dtype=tf.float32)
        expected_outputs_act = tf.constant(
            [[6.39068037e-02, 2.63707761e-07, 1.01531981e-07, 5.62491841e-05,
              1.01640113e-01, 4.44595469e-04, 4.38252151e-01, 9.21413826e-04,
              8.05049948e-03, 4.50634688e-01],
             [2.45742306e-01, 4.96961761e-08, 6.37695408e-08, 2.84826851e-09,
              9.97520555e-09, 3.66198682e-10, 1.10903176e-09, 6.71532121e-07,
              9.99999166e-01, 1.45172085e-09],
             [5.52562118e-01, 3.02950931e-09, 1.01592095e-08, 9.99993443e-01,
              1.00154507e-08, 1.06622167e-06, 2.96606260e-08, 5.48157823e-06,
              1.59207494e-10, 8.31109279e-11],
             [-1.02521062e-01, 1.94926830e-08, 3.26989998e-12, 1.15498361e-10,
              3.92803088e-07, 2.20820046e-11, 4.47764760e-05, 9.99954820e-01,
              4.19329673e-08, 1.05896170e-09],
             [2.04516679e-01, 2.63936181e-06, 1.62419127e-08, 8.94743209e-07,
              1.30295604e-07, 4.11487093e-07, 4.00995556e-03, 9.95974243e-01,
              1.10770225e-05, 6.07314917e-07],
             [2.87766963e-01, 2.93322868e-04, 1.01649192e-07, 3.00192449e-09,
              3.78971066e-07, 4.04820044e-09, 4.23168694e-06, 6.44652843e-01,
              3.55048865e-01, 2.33219822e-07],
             [6.11896217e-01, 2.67488271e-04, 5.80619552e-09, 3.99584849e-07,
              1.00834231e-05, 1.85956912e-06, 1.90315816e-07, 9.19687911e-04,
              1.23315968e-03, 9.97567177e-01],
             [7.19704702e-02, 5.38084987e-06, 1.43931715e-07, 3.66261043e-02,
              3.71713122e-06, 5.51273843e-05, 3.10739444e-04, 1.17430696e-02,
              6.52445734e-01, 2.98810005e-01],
             [4.15549695e-01, 5.89263905e-03, 1.92697556e-07, 9.93969083e-01,
              6.02961639e-11, 6.03810549e-05, 3.23741119e-06, 1.66977500e-07,
              4.29663078e-06, 7.00083183e-05],
             [2.73739070e-01, 4.73451070e-11, 1.36624198e-10, 3.52620644e-09,
              3.20546270e-14, 1.00000000e+00, 1.74026958e-13, 2.13664126e-08,
              3.58262253e-10, 2.74014925e-12]],
            dtype=tf.float32)

        np.testing.assert_almost_equal(
            outputs.numpy(), expected_outputs.numpy(),
            decimal=self._vars['decimal'])
