from collections.abc import Sequence
from dataclasses import dataclass
from typing import Optional

from atoti_core import MeasureIdentifier, is_array_type, keyword_only_dataclass

from .._java_api import JavaApi
from .._measure_convertible import MeasureConvertible
from .._measure_description import MeasureDescription, convert_to_measure_description
from .._measure_metadata import MeasureMetadata
from ..column import Column
from ..type import DOUBLE_ARRAY
from .utils import convert_measure_args


@keyword_only_dataclass
@dataclass(eq=False, frozen=True)
class SumProductFieldsMeasure(MeasureDescription):
    """Sum of the product of factors for table fields."""

    _factors: Sequence[Column]

    def _do_distil(
        self,
        identifier: Optional[MeasureIdentifier] = None,
        /,
        *,
        cube_name: str,
        java_api: JavaApi,
        measure_metadata: Optional[MeasureMetadata] = None,
    ) -> MeasureIdentifier:
        # Checks fields are in the selection, otherwise use the other sum product implementation because UDAF needs fields in the selection.
        selection_fields = java_api.get_selection_fields(cube_name)
        if not all(factor._identifier in selection_fields for factor in self._factors):
            raise ValueError(
                f"The columns {[factor.name for factor in self._factors if factor._identifier not in selection_fields]}"
                f" cannot be used in a sum product aggregation without first being converted into measures."
            )
        factors_and_type = {}
        for factor in self._factors:
            if is_array_type(factor.data_type) and factor.data_type != DOUBLE_ARRAY:
                raise TypeError(
                    f"Only array columns of type `{DOUBLE_ARRAY}` are supported and `{factor._identifier!r}` is not."
                )
            factors_and_type[factor._identifier] = factor.data_type
        return java_api.create_measure(
            identifier,
            "SUM_PRODUCT_UDAF",
            [factor._identifier for factor in self._factors],
            factors_and_type,
            cube_name=cube_name,
            measure_metadata=measure_metadata,
        )


@keyword_only_dataclass
@dataclass(eq=False, frozen=True)
class SumProductEncapsulationMeasure(MeasureDescription):
    """Create an intermediate measure needing to be aggregated with the key "ATOTI_SUM_PRODUCT"."""

    _factors: Sequence[MeasureConvertible]

    def _do_distil(
        self,
        identifier: Optional[MeasureIdentifier] = None,
        /,
        *,
        cube_name: str,
        java_api: JavaApi,
        measure_metadata: Optional[MeasureMetadata] = None,
    ) -> MeasureIdentifier:
        return java_api.create_measure(
            identifier,
            "SUM_PRODUCT_ENCAPSULATION",
            convert_measure_args(
                java_api=java_api,
                cube_name=cube_name,
                args=tuple(
                    convert_to_measure_description(factor) for factor in self._factors
                ),
            ),
            cube_name=cube_name,
            measure_metadata=measure_metadata,
        )
