import tensorflow as tf
import numpy as np
from unittest import TestCase

from kolibri.synthetic_data.ctgan.utils import get_test_variables
from kolibri.synthetic_data.ctgan.models import Generator


class TestGenerator(TestCase):
    def setUp(self):
        self._vars = get_test_variables()
        self._output_tensor = [
            tf.constant([0, 1, 0], dtype=tf.int32),
            tf.constant([1, self._vars['output_dim'], 0], dtype=tf.int32)
        ]

    def tearDown(self):
        del self._vars
        del self._output_tensor

    def test_build_model(self):
        generator = Generator(
            self._vars['input_dim'],
            self._vars['layer_dims'],
            self._vars['output_dim'],
            self._output_tensor,
            self._vars['tau'])
        generator.build((self._vars['batch_size'], self._vars['input_dim']))
        self.assertIsNotNone(generator)
        self.assertEqual(
            len(generator.layers), len(self._vars['layer_dims']) + 1)

    def test_call_model(self):
        tf.random.set_seed(0)
        inputs = tf.random.uniform(
            [self._vars['batch_size'], self._vars['input_dim']])

        generator = Generator(
            self._vars['input_dim'],
            self._vars['layer_dims'],
            self._vars['output_dim'],
            self._output_tensor,
            self._vars['tau'])
        generator.build((self._vars['batch_size'], self._vars['input_dim']))

        outputs, outputs_act = generator(inputs)
        expected_outputs = tf.constant(
            [[-0.09601411, -0.05729139, 0.19020572, -0.1797896, 0.02286813,
              0.07439799, 0.12995596, 0.0488571, -0.0749413, 0.10059232],
             [-0.18619692, -0.06627501, 0.1862518, -0.2113035, 0.01581551,
              -0.01421695, 0.09985986, 0.03720004, -0.01467223, 0.06440569],
             [-0.13894776, -0.12020721, 0.14429338, -0.09382746, -0.0797215,
              0.04900312, 0.06119634, 0.02027673, -0.18325481, 0.03599714],
             [-0.05591032, -0.07698131, 0.20061807, -0.20924883, 0.03713575,
              0.04930973, 0.10697968, 0.04430075, -0.06608526, 0.13931122],
             [-0.16535681, -0.03096866, 0.2077156, -0.05521879, -0.01992304,
              0.0924558, 0.01926909, 0.12916182, -0.06599644, 0.04315265],
             [-0.2171669, 0.03691592, 0.10808319, -0.09851374, -0.05920613,
              -0.05774612, 0.12822376, 0.02041179, 0.02793016, -0.0140909],
             [-0.23017089, -0.04269604, 0.1598379, -0.08609627, 0.01311586,
              -0.00599252, -0.01298555, 0.03262053, -0.07506532, 0.05415451],
             [-0.19203907, -0.05534588, 0.15888636, -0.16571346, -0.0952314,
              -0.12779173, 0.1111488, 0.02334097, 0.07275349, 0.03363734],
             [-0.21607769, -0.05887552, 0.1601755, -0.17244826, -0.05100354,
              -0.07022924, 0.09414856, 0.03434329, 0.02077717, -0.0324632],
             [-0.16389519, -0.01955811, 0.1468435, -0.22203475, 0.03108735,
              -0.06563827, 0.03660776, 0.05814479, 0.00825485, 0.00584181]],
            dtype=tf.float32)
        expected_outputs_act = tf.constant(
            [[-0.09572015, -0.05722879, 0.18794465, -0.17787711, 0.02286414,
              0.07426102, 0.12922926, 0.04881826, -0.07480131, 0.10025439],
             [-0.18407457, -0.06617814, 0.1841276, -0.2082138, 0.01581419,
              -0.01421599, 0.09952924, 0.03718289, -0.01467118, 0.06431677],
             [-0.1380604, -0.11963154, 0.14330022, -0.09355307, -0.07955302,
              0.04896393, 0.06112006, 0.02027395, -0.1812306, 0.0359816],
             [-0.05585213, -0.07682958, 0.19796923, -0.20624737, 0.03711868,
              0.0492698, 0.10657341, 0.04427179, -0.06598922, 0.13841692],
             [-0.163866, -0.03095876, 0.2047789, -0.05516273, -0.0199204,
              0.09219324, 0.0192667, 0.12844831, -0.06590078, 0.04312588],
             [-0.2138161, 0.03689916, 0.10766426, -0.09819627, -0.05913705,
              -0.05768201, 0.12752563, 0.02040895, 0.0279229, -0.01408997],
             [-0.22619049, -0.04267012, 0.15849046, -0.08588416, 0.0131151,
              -0.00599245, -0.01298482, 0.03260897, -0.07492464, 0.05410162],
             [-0.18971263, -0.05528943, 0.15756269, -0.16421305, -0.09494455,
              -0.12710059, 0.11069333, 0.02333673, 0.0726254, 0.03362466],
             [-0.21277645, -0.05880757, 0.15881957, -0.1707589, -0.05095936,
              -0.07011399, 0.09387136, 0.03432979, 0.02077418, -0.03245179],
             [-0.16244327, -0.01955561, 0.14579704, -0.21845654, 0.03107734,
              -0.06554417, 0.03659142, 0.05807935, 0.00825466, 0.00584175]],
            dtype=tf.float32)

        np.testing.assert_almost_equal(
            outputs.numpy(), expected_outputs.numpy(),
            decimal=self._vars['decimal'])
        np.testing.assert_almost_equal(
            outputs_act.numpy(), expected_outputs_act.numpy(),
            decimal=self._vars['decimal'])
