"""This module contains the Estimator abstract class"""

import logging
from abc import ABC, abstractmethod
from typing import Any

import pandas as pd

from causal_testing.testing.base_test_case import BaseTestCase

logger = logging.getLogger(__name__)


class Estimator(ABC):
    # pylint: disable=too-many-instance-attributes
    """An estimator contains all of the information necessary to compute a causal estimate for the effect of changing
    a set of treatment variables to a set of values.

    All estimators must implement the following two methods:

    1) add_modelling_assumptions: The validity of a model-assisted causal inference result depends on whether
    the modelling assumptions imposed by a model actually hold. Therefore, for each model, is important to state
    the modelling assumption upon which the validity of the results depend. To achieve this, the estimator object
    maintains a list of modelling assumptions (as strings). If a user wishes to implement their own estimator, they
    must implement this method and add all assumptions to the list of modelling assumptions.

    2) estimate_ate: All estimators must be capable of returning the average treatment effect as a minimum. That is, the
    average effect of the intervention (changing treatment from control to treated value) on the outcome of interest
    adjusted for all confounders.
    """

    def __init__(
        # pylint: disable=too-many-arguments
        # pylint: disable=R0801
        self,
        base_test_case: BaseTestCase,
        treatment_value: float,
        control_value: float,
        adjustment_set: set,
        df: pd.DataFrame = None,
        effect_modifiers: dict[str, Any] = None,
        alpha: float = 0.05,
        query: str = "",
    ):
        self.base_test_case = base_test_case
        self.treatment_value = treatment_value
        self.control_value = control_value
        self.adjustment_set = adjustment_set
        self.alpha = alpha
        self.df = df.query(query) if query else df

        if effect_modifiers is None:
            self.effect_modifiers = {}
        else:
            self.effect_modifiers = effect_modifiers
        self.modelling_assumptions = []
        if query:
            self.modelling_assumptions.append(query)
        self.add_modelling_assumptions()
        logger.debug("Effect Modifiers: %s", self.effect_modifiers)

    @abstractmethod
    def add_modelling_assumptions(self):
        """
        Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
        must hold if the resulting causal inference is to be considered valid.
        """

    def compute_confidence_intervals(self) -> list[float, float]:
        """
        Estimate the 95% Wald confidence intervals for the effect of changing the treatment from control values to
        treatment values on the outcome.
        :return: 95% Wald confidence intervals.
        """
