#  Copyright 2021 Collate
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  http://www.apache.org/licenses/LICENSE-2.0
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

"""
Validator for column value to be in set test case
"""

import traceback
from abc import abstractmethod
from ast import literal_eval
from typing import Union

from sqlalchemy import Column

from metadata.data_quality.validations import utils
from metadata.data_quality.validations.base_test_handler import BaseTestValidator
from metadata.generated.schema.tests.basic import (
    TestCaseResult,
    TestCaseStatus,
    TestResultValue,
)
from metadata.profiler.metrics.registry import Metrics
from metadata.utils.logger import test_suite_logger
from metadata.utils.sqa_like_column import SQALikeColumn

logger = test_suite_logger()

ALLOWED_VALUE_COUNT = "allowedValueCount"


class BaseColumnValuesToBeInSetValidator(BaseTestValidator):
    """Validator for column value to be in set test case"""

    def run_validation(self) -> TestCaseResult:
        """Run validation for the given test case

        Returns:
            TestCaseResult:
        """
        allowed_values = self.get_test_case_param_value(
            self.test_case.parameterValues,  # type: ignore
            "allowedValues",
            literal_eval,
        )

        match_enum = utils.get_bool_test_case_param(
            self.test_case.parameterValues, "matchEnum"
        )

        try:
            column: Union[SQALikeColumn, Column] = self._get_column_name()
            res = self._run_results(Metrics.COUNT_IN_SET, column, values=allowed_values)
            if match_enum:
                count = self._run_results(
                    Metrics.ROW_COUNT, column, values=allowed_values
                )
                res = count - res
        except (ValueError, RuntimeError) as exc:
            msg = f"Error computing {self.test_case.fullyQualifiedName}: {exc}"  # type: ignore
            logger.debug(traceback.format_exc())
            logger.warning(msg)
            return self.get_test_case_result_object(
                self.execution_date,
                TestCaseStatus.Aborted,
                msg,
                [TestResultValue(name=ALLOWED_VALUE_COUNT, value=None)],
            )

        if self.test_case.computePassedFailedRowCount:
            row_count = self.get_row_count()
        else:
            row_count = None

        return self.get_test_case_result_object(
            self.execution_date,
            self.get_test_case_status(res == 0 if match_enum else res >= 1),
            f"Found countInSet={res}.",
            [TestResultValue(name=ALLOWED_VALUE_COUNT, value=str(res))],
            row_count=row_count,
            passed_rows=res,
        )

    @abstractmethod
    def _get_column_name(self):
        raise NotImplementedError

    @abstractmethod
    def _run_results(
        self, metric: Metrics, column: Union[SQALikeColumn, Column], **kwargs
    ):
        raise NotImplementedError

    @abstractmethod
    def compute_row_count(self, column: Union[SQALikeColumn, Column]):
        """Compute row count for the given column

        Args:
            column (Union[SQALikeColumn, Column]): column to compute row count for

        Raises:
            NotImplementedError:
        """
        raise NotImplementedError

    def get_row_count(self) -> int:
        """Get row count

        Returns:
            Tuple[int, int]:
        """
        return self.compute_row_count(self._get_column_name())
