"""Tests for the sql query algorithm."""
import logging
from typing import TYPE_CHECKING
from unittest.mock import Mock, create_autospec

from _pytest.logging import LogCaptureFixture
import numpy as np
import pandas as pd
from pandas.core.dtypes.common import pandas_dtype
from pytest import fixture
from pytest_mock import MockerFixture

from bitfount.data.datasource import DataSource
from bitfount.data.schema import BitfountSchema
from bitfount.data.types import ContinuousRecord
from bitfount.federated.algorithms.base import (
    _BaseAlgorithm,
    _BaseModellerAlgorithm,
    _BaseWorkerAlgorithm,
)
from bitfount.federated.algorithms.private_sql_query import (
    PrivateSqlQuery,
    _ModellerSide,
    _WorkerSide,
)
from bitfount.federated.modeller import _Modeller
from bitfount.federated.privacy.differential import DPPodConfig
from bitfount.hub import BitfountHub
from tests.utils.helper import TABLE_NAME, create_datasource, create_schema, unit_test


class TestPrivateSqlQuery:
    """Test PrivateSqlQuery algorithm."""

    @fixture
    def datasource(self) -> DataSource:
        """Fixture for datasource."""
        return create_datasource(classification=True)

    @fixture
    def column_ranges(self) -> dict:
        """Fixture for the column ranges."""
        return {
            "A": {"lower": 1, "upper": 1000},  # will be int (A-D)
            "B": {"lower": 1, "upper": 1000},
            "C": {"lower": 1, "upper": 1000},
            "D": {"lower": 1, "upper": 1000},
            "E": {"lower": 0, "upper": 1},  # will be float (E-H)
            "F": {"lower": 0, "upper": 1},
            "G": {"lower": 0, "upper": 1},
            "H": {"lower": 0, "upper": 1},
            "I": {},  # will be string (I-L)
            "J": {},
            "K": {},
            "L": {},
            "TARGET": {"lower": 0, "upper": 1},  # will be int
        }

    @fixture
    def pod_schema(self) -> BitfountSchema:
        """Fixture for schema."""
        return create_schema(classification=True)

    @unit_test
    def test_modeller_types(self, column_ranges: dict) -> None:
        """Test modeller method."""
        algorithm_factory = PrivateSqlQuery(
            query="SELECT * from df.df",
            epsilon=0.1,
            delta=0.00001,
            column_ranges=column_ranges,
        )
        algorithm = algorithm_factory.modeller()
        for type_ in [
            _BaseAlgorithm,
            _BaseModellerAlgorithm,
        ]:
            assert isinstance(algorithm, type_)

    @unit_test
    def test_worker_types(self, column_ranges: dict) -> None:
        """Test worker method."""
        algorithm_factory = PrivateSqlQuery(
            query="SELECT * from df.df",
            epsilon=0.1,
            delta=0.00001,
            column_ranges=column_ranges,
        )
        algorithm = algorithm_factory.worker(
            hub=create_autospec(BitfountHub, instance=True)
        )
        for type_ in [
            _BaseAlgorithm,
            _BaseWorkerAlgorithm,
        ]:
            assert isinstance(algorithm, type_)

    @unit_test
    def test_worker_init_datasource(
        self, datasource: DataSource, column_ranges: dict
    ) -> None:
        """Test worker init with datasource."""
        kwargs = {"hub": None}
        algorithm_factory = PrivateSqlQuery(
            query="SELECT * from df.df",
            epsilon=0.1,
            delta=0.00001,
            column_ranges=column_ranges,
            **kwargs
        )
        algorithm_factory.worker(**kwargs).initialise(
            datasource=datasource,
            pod_identifier="test-pod",
            pod_dp=DPPodConfig(max_epsilon=1, max_target_delta=0.0001),
        )

    @unit_test
    def test_worker_init_missingargs(
        self, datasource: DataSource, column_ranges: dict
    ) -> None:
        """Test worker init without all arguments."""
        kwargs = {"hub": None}
        algorithm_factory = PrivateSqlQuery(
            query="SELECT * from df.df",
            epsilon=0.1,
            delta=0.00001,
            column_ranges=column_ranges,
            **kwargs
        )
        algorithm_factory.worker(**kwargs).initialise(datasource=datasource)

    @unit_test
    def test_modeller_init(self, column_ranges: dict) -> None:
        """Test modeller init method."""
        algorithm_factory = PrivateSqlQuery(
            query="SELECT * from df.df",
            epsilon=0.1,
            delta=0.00001,
            column_ranges=column_ranges,
        )
        algorithm_factory.modeller().initialise()

    @unit_test
    def test_bad_sql_no_table(
        self,
        datasource: DataSource,
        column_ranges: dict,
        mocker: MockerFixture,
        pod_schema: BitfountSchema,
        caplog: LogCaptureFixture,
    ) -> None:
        """Test that having bad SQL query raises an error."""
        # Mock out hub creation
        mock_hub = create_autospec(BitfountHub, instance=True)
        mock_hub.get_pod_schema.return_value = pod_schema
        kwargs = {"hub": mock_hub}

        algorithm_factory = PrivateSqlQuery(
            query="SELECT MAX(G) AS MAX_OF_G",
            epsilon=0.1,
            delta=0.00001,
            column_ranges=column_ranges,
        )
        worker = algorithm_factory.worker(**kwargs)
        worker.datasource = datasource
        worker.pod_identifier = "test-pod"
        worker.pod_dp = DPPodConfig(max_epsilon=1, max_target_delta=0.0001)
        worker.run()
        assert "Error executing PrivateSQL query" in caplog.text

    @unit_test
    def test_no_pod_identifier(
        self,
        datasource: DataSource,
        column_ranges: dict,
        mocker: MockerFixture,
        caplog: LogCaptureFixture,
        pod_schema: BitfountSchema,
    ) -> None:
        """Test that having no pod identifier raises error."""
        # Mock out hub creation
        mock_hub = create_autospec(BitfountHub, instance=True)
        mock_hub.get_pod_schema.return_value = pod_schema
        kwargs = {"hub": mock_hub}

        algorithm_factory = PrivateSqlQuery(
            query="SELECT MAX(G) AS MAX_OF_G",
            epsilon=0.1,
            delta=0.00001,
            column_ranges=column_ranges,
        )
        worker = algorithm_factory.worker(**kwargs)
        worker.datasource = datasource
        worker.pod_identifier = None
        worker.pod_dp = DPPodConfig(max_epsilon=1, max_target_delta=0.0001)
        worker.run()
        assert "No pod identifier - cannot get schema" in caplog.text

    @unit_test
    def test_schema(self, column_ranges: dict) -> None:
        """Tests that schema returns parent class."""
        schema_cls = PrivateSqlQuery.get_schema()
        schema = schema_cls()
        factory = schema.recreate_factory(  # type: ignore[attr-defined]  # Reason: test will fail if wrong type  # noqa: B950
            data={
                "query": "SELECT * from df.df",
                "epsilon": 0.1,
                "delta": 0.00001,
                "column_ranges": column_ranges,
            }
        )
        assert isinstance(factory, PrivateSqlQuery)

    @unit_test
    def test_bad_sql_query_statement(
        self,
        datasource: DataSource,
        column_ranges: dict,
        pod_schema: BitfountSchema,
        caplog: LogCaptureFixture,
    ) -> None:
        """Test that a bad operator in SQL query errors out."""
        # Mock out hub creation
        mock_hub = create_autospec(BitfountHub, instance=True)
        mock_hub.get_pod_schema.return_value = pod_schema
        kwargs = {"hub": mock_hub}

        algorithm_factory = PrivateSqlQuery(
            query="SELECTOR MAX(G) AS MAX_OF_G from df.df",
            epsilon=0.1,
            delta=0.00001,
            column_ranges=column_ranges,
        )
        worker = algorithm_factory.worker(**kwargs)
        worker.datasource = datasource
        worker.pod_identifier = "test-pod"
        worker.pod_dp = DPPodConfig(max_epsilon=1, max_target_delta=0.0001)
        worker.run()
        assert "Error executing PrivateSQL query" in caplog.text

    @unit_test
    def test_bad_sql_query_column(
        self,
        datasource: DataSource,
        column_ranges: dict,
        pod_schema: BitfountSchema,
        caplog: LogCaptureFixture,
    ) -> None:
        """Test that an invalid column in SQL query errors out."""
        # Mock out hub creation
        mock_hub = create_autospec(BitfountHub, instance=True)
        mock_hub.get_pod_schema.return_value = pod_schema
        kwargs = {"hub": mock_hub}

        algorithm_factory = PrivateSqlQuery(
            query="SELECT MAX(BITFOUNT_TEST) AS MAX_OF_BITFOUNT_TEST from df.df",
            epsilon=0.1,
            delta=0.00001,
            column_ranges=column_ranges,
        )
        worker = algorithm_factory.worker(**kwargs)
        worker.datasource = datasource
        worker.pod_identifier = "test-pod"
        worker.pod_dp = DPPodConfig(max_epsilon=1, max_target_delta=0.0001)
        worker.run()
        assert "Error executing PrivateSQL query:" in caplog.text

    @unit_test
    def test_worker_gets_sql_results(
        self, datasource: DataSource, column_ranges: dict, pod_schema: BitfountSchema
    ) -> None:
        """Test that a SQL query returns correct result."""
        # Mock out hub creation
        mock_hub = create_autospec(BitfountHub, instance=True)
        mock_hub.get_pod_schema.return_value = pod_schema
        kwargs = {"hub": mock_hub}

        algorithm_factory = PrivateSqlQuery(
            query="SELECT AVG(G) AS AVG_OF_G FROM df.df",
            epsilon=0.1,
            delta=0.00001,
            column_ranges=column_ranges,
        )
        worker = algorithm_factory.worker(**kwargs)
        worker.datasource = datasource
        worker.pod_identifier = "test-pod"
        worker.pod_dp = DPPodConfig(max_epsilon=1, max_target_delta=0.0001)
        results = worker.run()

        assert results is not None
        assert len(results) > 0
        # The result should not be the same as the true value
        assert float(results[1][0]) != 0.500129

    @unit_test
    def test_schema_mapping_unknown_column(
        self,
        datasource: DataSource,
        column_ranges: dict,
        pod_schema: BitfountSchema,
        caplog: LogCaptureFixture,
    ) -> None:
        """Test SQL query returns correct result."""
        caplog.set_level(logging.INFO)
        # Mock out hub creation
        mock_hub = create_autospec(BitfountHub, instance=True)
        mock_hub.get_pod_schema.return_value = pod_schema
        kwargs = {"hub": mock_hub}

        # Update column ranges to have a column not in the schema
        column_ranges_bad = column_ranges
        column_ranges_bad["Z"] = {"lower": 0, "upper": 1}

        # Run the query
        algorithm_factory = PrivateSqlQuery(
            query="SELECT AVG(G) AS AVG_OF_G FROM df.df",
            epsilon=0.1,
            delta=0.00001,
            column_ranges=column_ranges_bad,
        )
        worker = algorithm_factory.worker(**kwargs)
        worker.datasource = datasource
        worker.pod_identifier = "test-pod"
        worker.pod_dp = DPPodConfig(max_epsilon=1, max_target_delta=0.0001)

        worker.run()

        # Check that both the mapping check, and SQL reader raise errors
        assert "No field named 'Z' present in the schema" in caplog.text
        assert "got error ['Z']" in caplog.text

    @unit_test
    def test_schema_mapping_unknown_type(
        self,
        datasource: DataSource,
        column_ranges: dict,
        pod_schema: BitfountSchema,
        caplog: LogCaptureFixture,
    ) -> None:
        """Test schema with unsupported type raises error."""
        caplog.set_level(logging.INFO)

        # Unsupported schema
        pod_schema.get_table_schema(TABLE_NAME).features["continuous"][
            "G"
        ] = ContinuousRecord("G", pandas_dtype("datetime64"))

        # Mock out hub creation
        mock_hub = create_autospec(BitfountHub, instance=True)
        mock_hub.get_pod_schema.return_value = pod_schema
        kwargs = {"hub": mock_hub}

        # Run the query
        algorithm_factory = PrivateSqlQuery(
            query="SELECT AVG(G) AS AVG_OF_G FROM df.df",
            epsilon=0.1,
            delta=0.00001,
            column_ranges=column_ranges,
        )
        worker = algorithm_factory.worker(**kwargs)
        worker.datasource = datasource
        worker.pod_identifier = "test-pod"
        worker.pod_dp = DPPodConfig(max_epsilon=1, max_target_delta=0.0001)

        worker.run()

        # Check that both the mapping check, and SQL reader raise errors
        assert "Type datetime64 for column 'G' is not supported" in caplog.text
        assert "must be over numeric or boolean, got string in AVG" in caplog.text

    @unit_test
    def test_modeller_gets_sql_results(
        self, datasource: DataSource, column_ranges: dict, pod_schema: BitfountSchema
    ) -> None:
        """Test SQL query returns a result to modeller."""
        # Mock out hub creation
        mock_hub = create_autospec(BitfountHub, instance=True)
        mock_hub.get_pod_schema.return_value = pod_schema

        algorithm_factory = PrivateSqlQuery(
            query="SELECT AVG(G) AS AVG_OF_G FROM df.df",
            epsilon=0.1,
            delta=0.00001,
            column_ranges=column_ranges,
        )
        modeller = algorithm_factory.modeller()
        data = {"AVG_OF_G": [0.500129]}
        results = pd.DataFrame(data)
        returned_results = modeller.run(results=[results])
        assert np.isclose(
            returned_results[0].AVG_OF_G[0], results.AVG_OF_G[0], atol=1e-4
        )

    @unit_test
    def test_different_privacy_different_results(
        self, datasource: DataSource, column_ranges: dict, pod_schema: BitfountSchema
    ) -> None:
        """Test different DP levels provide different results."""
        # Mock out hub creation
        mock_hub = create_autospec(BitfountHub, instance=True)
        mock_hub.get_pod_schema.return_value = pod_schema
        kwargs = {"hub": mock_hub}

        # Run query with good privacy (i.e. low epsilon and delta)
        algorithm_factory = PrivateSqlQuery(
            query="SELECT AVG(G) AS AVG_OF_G FROM df.df",
            epsilon=0.1,
            delta=0.00001,
            column_ranges=column_ranges,
        )
        worker = algorithm_factory.worker(**kwargs)
        worker.datasource = datasource
        worker.pod_identifier = "test-pod"
        worker.pod_dp = DPPodConfig(max_epsilon=1, max_target_delta=0.0001)

        # Run query 10 times, aggregate deviations from truth value
        all_deviations = []
        ground_truth = 0.500129
        for _ in range(10):
            results_good_privacy = worker.run()
            all_deviations.append(
                np.abs(ground_truth - float(results_good_privacy[1][0]))
            )
        mean_deviation_good = np.mean(np.array(all_deviations))

        # Run query with bad privacy (i.e. high epsilon and delta)
        algorithm_factory_bad = PrivateSqlQuery(
            query="SELECT AVG(G) AS AVG_OF_G FROM df.df",
            epsilon=20.0,
            delta=1.0,
            column_ranges=column_ranges,
        )
        worker = algorithm_factory_bad.worker(**kwargs)
        worker.datasource = datasource
        worker.pod_identifier = "test-pod"
        worker.pod_dp = DPPodConfig(max_epsilon=20, max_target_delta=1.0)

        # Run query 10 times, aggregate deviations from truth value
        all_deviations_bad = []
        for _ in range(10):
            results_bad_privacy = worker.run()
            all_deviations_bad.append(
                np.abs(ground_truth - float(results_bad_privacy[1][0]))
            )
        mean_deviation_bad = np.mean(np.array(all_deviations_bad))

        # Check that deviation is higher with more DP noise
        assert mean_deviation_good > mean_deviation_bad

    @unit_test
    def test_different_privacy_applied(
        self, datasource: DataSource, column_ranges: dict, pod_schema: BitfountSchema
    ) -> None:
        """Test that DP yields different results over multiple queries."""
        # Mock out hub creation
        mock_hub = create_autospec(BitfountHub, instance=True)
        mock_hub.get_pod_schema.return_value = pod_schema
        kwargs = {"hub": mock_hub}

        # Run query with good privacy (i.e. low epsilon and delta)
        algorithm_factory = PrivateSqlQuery(
            query="SELECT AVG(G) AS AVG_OF_G FROM df.df",
            epsilon=0.1,
            delta=0.00001,
            column_ranges=column_ranges,
        )
        worker = algorithm_factory.worker(**kwargs)
        worker.datasource = datasource
        worker.pod_identifier = "test-pod"
        worker.pod_dp = DPPodConfig(max_epsilon=1, max_target_delta=0.0001)

        # Run query 10 times, ensure results are unique
        all_results = {}
        for _ in range(10):
            results = worker.run()
            result_value = float(results[1][0])
            assert result_value not in all_results
            all_results[result_value] = 1

    @unit_test
    def test_pod_dp_prevails(
        self,
        datasource: DataSource,
        column_ranges: dict,
        pod_schema: BitfountSchema,
        caplog: LogCaptureFixture,
    ) -> None:
        """Test Pod DP is preferred over modeller DP."""
        caplog.set_level(logging.INFO)
        # Mock out hub creation
        mock_hub = create_autospec(BitfountHub, instance=True)
        mock_hub.get_pod_schema.return_value = pod_schema
        kwargs = {"hub": mock_hub}

        # Run query with good privacy (i.e. low epsilon and delta)
        algorithm_factory = PrivateSqlQuery(
            query="SELECT AVG(G) AS AVG_OF_G FROM df.df",
            epsilon=3.0,
            delta=0.1,
            column_ranges=column_ranges,
        )
        worker = algorithm_factory.worker(**kwargs)
        worker.datasource = datasource
        worker.pod_identifier = "test-pod"
        worker.pod_dp = DPPodConfig(max_epsilon=1.0, max_target_delta=0.0001)
        worker.run()
        print("Log is {}".format(caplog.text))
        assert "Requested DP max epsilon (3.0) exceeds maximum" in caplog.text
        assert "Requested DP target delta (0.1) exceeds maximum" in caplog.text

    @unit_test
    def test_private_sql_execute(
        self, mock_bitfount_session: Mock, mocker: MockerFixture, column_ranges: dict
    ) -> None:
        """Test execute syntactic sugar."""
        query = PrivateSqlQuery(
            query="SELECT * FROM df.df",
            epsilon=0.1,
            delta=0.00001,
            column_ranges=column_ranges,
        )
        pod_identifiers = ["username/pod-id"]

        mock_modeller_run_method = mocker.patch.object(_Modeller, "run")
        query.execute(pod_identifiers=pod_identifiers)
        mock_modeller_run_method.assert_called_once_with(
            pod_identifiers=pod_identifiers
        )


# Static tests for algorithm-protocol compatibility
if TYPE_CHECKING:
    from typing import cast

    from bitfount.federated.protocols.results_only import (
        _ResultsOnlyCompatibleAlgoFactory_,
        _ResultsOnlyCompatibleModeller,
        _ResultsOnlyDataIncompatibleWorker,
    )

    # Check compatible with ResultsOnly
    _algo_factory: _ResultsOnlyCompatibleAlgoFactory_ = PrivateSqlQuery(
        query=cast(str, object()),
        epsilon=cast(float, object()),
        delta=cast(float, object()),
        column_ranges=cast(dict, object()),
    )
    _modeller_side: _ResultsOnlyCompatibleModeller = _ModellerSide()
    _worker_side: _ResultsOnlyDataIncompatibleWorker = _WorkerSide(
        query=cast(str, object()),
        epsilon=cast(float, object()),
        delta=cast(float, object()),
        column_ranges=cast(dict, object()),
    )
