from copy import deepcopy
import numpy as np
import pandas as pd
import polars as pl
from typing import overload


class Ingredients:
    """Wrapper around polars.DataFrames to store columns roles (e.g., predictor)
        Due to the workings of polars, we do not subclass pl.dataframe anymore, but instead store the dataframe as an attribute.
    Args:
        roles: roles of DataFrame columns as (list of) strings.
            Defaults to None.
        check_roles: If set to false, doesn't check whether the roles match existing columns.
            Defaults to True.

    See also: pandas.DataFrame

    Attributes:
        roles (dict): dictionary of column roles
    """

    _metadata = ["roles"]

    def __init__(
        self,
        data=None,
        copy: bool = None,
        roles: dict = None,
        check_roles: bool = True,
    ):

        if isinstance(data, pd.DataFrame):
            self.data = pl.DataFrame(data)
            # super().__init__(data,schema=None)
        elif isinstance(data, pl.DataFrame):
            # super().__init__()
            # self._df = data._df
            self.data = data
        elif not isinstance(data, Ingredients):
            raise TypeError(f"expected DataFrame, got {data.__class__}")
        self.schema = data.schema
        self.dtypes = self.schema

        if isinstance(data, Ingredients) and roles is None:
            if copy is None or copy is True:
                self.roles = deepcopy(data.roles)
            else:
                self.roles = data.roles
            self.data = data.data
            self.schema = data.schema
            self.dtypes = self.schema
        elif roles is None:
            self.roles = {}
        elif not isinstance(roles, dict):
            raise TypeError(f"Expected dict object for roles, got {roles.__class__}")
        elif check_roles and not all(set(k).issubset(set(self.data.columns)) for k,v in roles.items()):
            raise ValueError(f"Roles contains variable names that are not in the data {list(roles.values())} {self.data.columns}.")
        # Todo: do we want to allow ingredients without grouping columns?
        # elif check_roles and select_groups(self) == []:
        #     raise ValueError("Roles are given but no groups are found in the data.")
        else:
            if copy is None or copy is True:
                self.roles = deepcopy(roles)
            else:
                self.roles = roles

    @property
    def _constructor(self):
        return Ingredients

    @property
    def columns(self):
        return self.data.columns

    def to_df(self) -> pl.DataFrame:
        """Return the underlying pandas.DataFrame.

        Returns:
            Self as DataFrame.
        """
        return pl.DataFrame(self)

    def _check_column(self, column):
        if not isinstance(column, str):
            raise ValueError(f"Expected string, got {column}")
        if column not in self.columns:
            raise ValueError(f"{column} does not exist in this Data object")

    def _check_role(self, new_role):
        if not isinstance(new_role, str):
            raise TypeError(f"new_role must be string, was {new_role.__class__}")

    def add_role(self, column: str, new_role: str):
        """Adds an additional role for a column that already has roles.

        Args:
            column: The column to receive additional roles.
            new_role: The role to add to the column.

        Raises:
            RuntimeError: If the column has no role yet.
        """
        self._check_column(column)
        self._check_role(new_role)
        if column not in self.roles.keys():
            raise RuntimeError(f"{column} has no roles yet, use update_role instead.")
        self.roles[column] += [new_role]

    def update_role(self, column: str, new_role: str, old_role: str = None):
        """Adds a new role for a column without roles or changes an existing role to a different one.

        Args:
            column: The column to update the roles of.
            new_role: The role to add or change to.
            old_role: Defaults to None. The role to be changed.

        Raises:
            ValueError: If old_role is given but column has no roles.
                If old_role is given but column has no role old_role.
                If no old_role is given but column has multiple roles already.
        """
        self._check_column(column)
        self._check_role(new_role)
        if old_role is not None:
            if column not in self.roles.keys():
                raise ValueError(
                    f"Attempted to update role of {column} from {old_role} to {new_role} "
                    f"but {column} does not have a role yet."
                )
            elif old_role not in self.roles[column]:
                raise ValueError(
                    f"Attempted to set role of {column} from {old_role} to {new_role} "
                    f"but {old_role} not among current roles: {self.roles[column]}."
                )
            self.roles[column].remove(old_role)
            self.roles[column].append(new_role)
        else:
            if column not in self.roles.keys() or len(self.roles[column]) == 1:
                self.roles[column] = [new_role]
            else:
                raise ValueError(
                    f"Attempted to update role of {column} to {new_role} but "
                    f"{column} has more than one current roles: {self.roles[column]}"
                )
    def select_dtypes(self,include=None):
        # if(isinstance(include,[str])):
        dtypes = self.get_str_dtypes()
        selected = [key for key, value in dtypes.items() if value in include]
        return selected
    def get_dtypes(self):
        dtypes = list(self.schema.values())
        return dtypes
    def get_str_dtypes(self):
        """"
            Helper function for polar dataframes to return schema with dtypes as strings
        """
        dtypes = self.data.schema
        return {key:str(value) for key,value in dtypes.items()}
        # return list(map(dtypes, cast()))
    def get_df(self):
        # TODO: Check if preferred way to get df
        return self.data
    def set_df(self,df):
        self.data = df
    def groupby(self,by):
        self.data.group_by(by)

    def __setitem__(self, idx, val):
        self.data[idx] = val

    @overload
    def __getitem__(self, list: list[str]) -> pl.DataFrame:
        return self.data[list]
    def __getitem__(self, idx:int) -> pl.Series:
        return self.data[idx]