import tensorflow as tf
import numpy as np
from unittest import TestCase

from kolibri.synthetic_data.ctgan.utils import get_test_variables
from kolibri.synthetic_data.ctgan.layers import ResidualLayer


class TestResidualLayer(TestCase):
    def setUp(self):
        self._vars = get_test_variables()

    def tearDown(self):
        del self._vars

    def test_residual_layer(self):
        tf.random.set_seed(0)
        inputs = tf.random.uniform(
            [self._vars['batch_size'], self._vars['input_dim']])
        residual_layer = ResidualLayer(
            self._vars['input_dim'],
            self._vars['output_dim'])

        outputs = residual_layer(inputs)
        expected_outputs = tf.constant(
            [[6.39936924e-02, 0.00000000e+00, 0.00000000e+00, 5.49477674e-02,
              0.00000000e+00, 0.00000000e+00, 1.22009017e-01, 3.94027561e-01,
              3.26341540e-01, 0.00000000e+00, 2.91975141e-01, 2.06566453e-01,
              5.35390735e-01, 5.61257482e-01, 4.16674495e-01, 8.07827950e-01,
              4.93225098e-01, 9.98129249e-01, 6.96735144e-01, 1.25373602e-01],
             [2.50875145e-01, 0.00000000e+00, 0.00000000e+00, 1.50090620e-01,
              0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 6.93749905e-01,
              2.40225554e-01, 0.00000000e+00, 7.09816694e-01, 6.62415624e-01,
              5.72256565e-01, 3.64753485e-01, 4.20518279e-01, 6.30056977e-01,
              9.13812995e-01, 6.61647201e-01, 8.33473563e-01, 8.39580297e-02],
             [6.22058928e-01, 1.46962956e-01, 0.00000000e+00, 5.54854095e-01,
              0.00000000e+00, 1.36148538e-02, 6.34664744e-02, 3.13383728e-01,
              6.41522408e-02, 0.00000000e+00, 2.79759407e-01, 1.55231953e-02,
              7.26373553e-01, 7.65538692e-01, 6.79866672e-01, 5.32727957e-01,
              7.56514072e-01, 4.74219322e-02, 5.03714085e-02, 7.51743436e-01],
             [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
              0.00000000e+00, 0.00000000e+00, 1.07349448e-01, 4.95546997e-01,
              3.49783063e-01, 0.00000000e+00, 1.72712803e-01, 3.11935186e-01,
              2.91373849e-01, 1.00512385e-01, 1.65670753e-01, 7.69665122e-01,
              5.85679770e-01, 9.82009649e-01, 9.14832711e-01, 1.41665339e-01],
             [2.07440868e-01, 1.08263694e-01, 0.00000000e+00, 3.57647479e-01,
              0.00000000e+00, 2.28045493e-01, 0.00000000e+00, 2.98051685e-01,
              0.00000000e+00, 0.00000000e+00, 9.75655317e-02, 6.06278419e-01,
              1.77921772e-01, 5.18051982e-01, 9.82121110e-01, 1.75779462e-01,
              4.56316471e-02, 5.97541451e-01, 5.62954307e-01, 8.05074334e-01],
             [2.96128422e-01, 2.76152223e-01, 0.00000000e+00, 3.96360070e-01,
              0.00000000e+00, 1.31017998e-01, 1.96711719e-01, 3.14709246e-01,
              1.18389919e-01, 0.00000000e+00, 5.83926678e-01, 1.03126526e-01,
              9.44904447e-01, 2.82598138e-01, 5.52851439e-01, 2.79494762e-01,
              7.95048475e-02, 3.89494061e-01, 6.98117018e-02, 4.19343710e-02],
             [7.11943209e-01, 0.00000000e+00, 0.00000000e+00, 6.43363774e-01,
              0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 6.96319044e-01,
              0.00000000e+00, 0.00000000e+00, 9.23898578e-01, 8.75213504e-01,
              4.67960358e-01, 4.69916582e-01, 6.27748251e-01, 7.30225563e-01,
              9.35058832e-01, 3.57794881e-01, 2.06411958e-01, 6.03211164e-01],
             [7.20947608e-02, 6.80062547e-02, 0.00000000e+00, 3.68038088e-01,
              0.00000000e+00, 8.38010535e-02, 0.00000000e+00, 1.78120926e-01,
              1.05953358e-01, 0.00000000e+00, 4.01760459e-01, 5.88148952e-01,
              8.69272232e-01, 8.18386197e-01, 7.30836391e-03, 2.41001368e-01,
              2.00944304e-01, 4.06260490e-02, 9.78007674e-01, 1.70186639e-01],
             [4.42298532e-01, 2.30380744e-01, 0.00000000e+00, 3.65675747e-01,
              0.00000000e+00, 1.80062130e-01, 0.00000000e+00, 5.09209633e-01,
              8.52867030e-03, 0.00000000e+00, 6.97713017e-01, 7.36728072e-01,
              8.43186617e-01, 7.35882163e-01, 7.31128097e-01, 1.87684178e-01,
              7.35530376e-01, 1.88523889e-01, 5.82885981e-01, 3.86672020e-02],
             [2.80899912e-01, 0.00000000e+00, 8.26702341e-02, 2.42651299e-01,
              0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 4.60958362e-01,
              0.00000000e+00, 0.00000000e+00, 6.71606898e-01, 5.86572170e-01,
              8.91327858e-04, 4.74029779e-01, 9.53284144e-01, 7.33208656e-01,
              5.91633201e-01, 3.41272116e-01, 7.69475222e-01, 5.93137741e-02]],
            dtype=tf.float32)
        np.testing.assert_almost_equal(
            outputs.numpy(), expected_outputs.numpy(),
            decimal=self._vars['decimal'])
