from dataclasses import dataclass
from itertools import combinations, chain
import json
from ratesb_python.common import util

import os
import re
import sympy

VALID_FILE_EXTENSION = '.json'
VALID_OPTIONAL_SYMBOLS = [
    "compartment", "parameter", "reactant1", "reactant2", "reactant3", 
    "product1", "product2", "product3", "enzyme"
]
VALID_POWER_LIMITED_SPECIES = ["reactant1", "reactant2", "reactant3", "product1", "product2", "product3", "enzyme"]

@dataclass
class DefaultClassification:
    ZERO: bool
    UNDR: bool
    UNMO: bool
    BIDR: bool
    BIMO: bool
    MM: bool
    MMCAT: bool
    HILL: bool


class _DefaultClassifier:
    """Default Classifier for rate laws.

    This classifier validates, classifies and operates helper methods on rate laws.
    This class is modified from CustomClassifier, the difference is that DefaultClassifier
    is more accurate and has more runtime complexity (O(num_param! * num_rtct! * num_prod!)).
    
    Attributes:
        rate_law_classifications_path (str): The path to the rate law classifications.
        custom_classifications (list): A list of custom classifications.
        warning_message (str): A message warning about potential issues.
    """
    def __init__(self, rate_law_classifications_path):
        """Constructs all the necessary attributes for the custom classifier object.

        Args:
            rate_law_classifications_path (str): The path to the rate law classifications.
        """
        self.rate_law_classifications_path = rate_law_classifications_path
        self.custom_classifications = []
        self.warning_message = ""
        self.validate()
        
    def validate(self):
        """Validates the rate law classifications file.
        
        It checks if the file is a JSON file, loads the file, and validates its contents.
        """
        _, ext = os.path.splitext(self.rate_law_classifications_path)
        if ext != VALID_FILE_EXTENSION:
            raise ValueError(f"Invalid file format, accepting {VALID_FILE_EXTENSION}")
        else:
            json_str = util.get_json_str(self.rate_law_classifications_path)
            unchecked_custom_classifications = json.loads(json_str)

            # This list will be used to collect all warnings
            warnings = []

            # Iterate through each item in the loaded data
            for index, item in enumerate(unchecked_custom_classifications):
                # Perform similar checks as in the JavaScript version
                if (not isinstance(item, dict) or
                    'name' not in item or not isinstance(item['name'], str) or
                    'expression' not in item or not isinstance(item['expression'], str) or
                    'optional_symbols' not in item or not isinstance(item['optional_symbols'], list) or
                    'power_limited_species' not in item or not isinstance(item['power_limited_species'], list)):
                    
                    if 'name' in item and isinstance(item['name'], str):
                        warnings.append(f"Rate law {item['name']} does not follow the correct structure.")
                    else:
                        warnings.append(f"Item at index {index} does not follow the correct structure.")
                    continue

                # Check the mathematical expression validity
                replaced_expression = re.sub(r'(compartment|parameter|reactant1|reactant2|reactant3|product1|product2|product3|enzyme|\*\*|\^)', '1', item['expression'])
                replaced_expression = replaced_expression.replace("**", "^")
                try:
                    # We need a similar library as 'math' in JS for Python, let's use 'eval' here
                    eval(replaced_expression)
                except:
                    warnings.append(f"Rate law {item['name']} has an invalid expression.")
                    continue
                
                valid_optional_symbols = ["compartment", "parameter", "reactant1", "reactant2", "reactant3", "product1", "product2", "product3", "enzyme"]
                if not isinstance(item["optional_symbols"], list):
                    warnings.append(f"optional_symbols in rate law {item['name']} should be a list of strings.")
                    continue
                
                is_error = False
                for symbol in item["optional_symbols"]:
                    if not isinstance(symbol, str):
                        warnings.append(f"optional_symbols in rate law {item['name']} should be a list of strings.")
                        is_error = True
                    if symbol not in valid_optional_symbols:
                        warnings.append(f"Invalid item in optional_symbols in rate law {item['name']}, should only contain {', '.join(valid_optional_symbols)}.")
                        is_error = True
                if is_error:
                    continue
                    
                if not isinstance(item["power_limited_species"], list):
                    warnings.append(f"power_limited_species in rate law {item['name']} should be a list of strings.")
                    continue
                
                for symbol in item["power_limited_species"]:
                    if not isinstance(symbol, str):
                        warnings.append(f"power_limited_species in rate law {item['name']} should be a list of strings.")
                        is_error = True
                    if symbol not in VALID_POWER_LIMITED_SPECIES:
                        warnings.append(f"Invalid item in power_limited_species in rate law {item['name']}, should only contain {', '.join(valid_power_limited_species)}.")
                        is_error = True
                if is_error:
                    continue
                self.custom_classifications.append(item)

            # If there are warnings, print them
            if warnings:
                self.warning_message = 'Some items in your JSON file were invalid and have been removed.\nDetails:\n'
                self.warning_message += '\n'.join(warnings)
                    
    def permute(self, arr):
        """Generates all permutations of a list.

        Args:
            arr (list): The list for which permutations are to be generated.

        Returns:
            list: A list of all permutations.
        """
        if len(arr) == 1:
            return [arr]
        permutations = []
        for i in range(len(arr)):
            remaining = arr[:i] + arr[i+1:]
            subPermutations = self.permute(remaining)
            mappedPermutations = [[arr[i]] + subPermutation for subPermutation in subPermutations]
            permutations.extend(mappedPermutations)
        return permutations

    def replace_occurrences(self, reactants_in_kinetic_law, products_in_kinetic_law, enzyme_list, compartment_in_kinetic_law, parameters_in_kinetic_law_only, kinetics_sim):
        """Replaces the occurrences of different elements in the kinetics_sim with standard terms.

        Args:
            reactants_in_kinetic_law (list): List of reactants in the kinetic law.
            products_in_kinetic_law (list): List of products in the kinetic law.
            enzyme_list (list): List of enzymes in the kinetic law.
            compartment_in_kinetic_law (list): List of compartments in the kinetic law.
            parameters_in_kinetic_law_only (list): List of parameters only present in the kinetic law.
            kinetics_sim (str): The kinetics simulation string.

        Returns:
            list: List of replaced kinetics.
        """
        permuted_reactants = self.permute(reactants_in_kinetic_law) or [[]]
        permuted_products = self.permute(products_in_kinetic_law) or [[]]
        
        ret = []
        for reactant_perm in permuted_reactants:
            for product_perm in permuted_products:
                replaced_string = kinetics_sim
                for i, symbol in enumerate(reactant_perm):
                    replaced_string = replaced_string.replace(symbol, 'reactant' + str(i + 1))
                for i, symbol in enumerate(product_perm):
                    replaced_string = replaced_string.replace(symbol, 'product' + str(i + 1))
                for symbol in enzyme_list:
                    replaced_string = replaced_string.replace(symbol, 'enzyme')
                for symbol in compartment_in_kinetic_law:
                    replaced_string = replaced_string.replace(symbol, 'compartment')
                for symbol in parameters_in_kinetic_law_only:
                    replaced_string = replaced_string.replace(symbol, 'parameter')
                    
                ret.append(replaced_string)

        return ret

    def custom_classify(self, **kwargs):
        """Classify the provided data according to the rate laws defined in the file.

        Args:
            reactant_list (list): List of all reactants involved in the reaction.
            product_list (list): List of all products generated by the reaction.
            kinetics_sim (str): A string representing the kinetics of the reaction.
            species_in_kinetic_law (list): List of species involved in the kinetics.
            parameters_in_kinetic_law_only (list): List of parameters present only in the kinetics law.
            compartment_in_kinetic_law (list): List of compartments present in the kinetics law.

        Returns:
            A dictionary containing the name of the rate law and the result of the comparison.
        """
        reactant_list = kwargs["reactant_list"]
        product_list = kwargs["product_list"]
        kinetics_sim = kwargs["kinetics_sim"]
        species_in_kinetic_law = kwargs["species_in_kinetic_law"]
        parameters_in_kinetic_law_only = kwargs["parameters_in_kinetic_law_only"]
        compartment_in_kinetic_law = kwargs["compartment_in_kinetic_law"]

        reactants_in_kinetic_law = [species for species in species_in_kinetic_law if species in reactant_list]
        products_in_kinetic_law = [species for species in species_in_kinetic_law if species in product_list]

        ret = {}
        enzyme_list = [species for species in species_in_kinetic_law if species not in reactant_list and species not in product_list]
        replaced_kinetics_list = self.replace_occurrences(reactants_in_kinetic_law, products_in_kinetic_law, enzyme_list, compartment_in_kinetic_law, parameters_in_kinetic_law_only, kinetics_sim)
        count = 0
        for item in self.custom_classifications:
            kinetics_expression = item['expression'].replace("^", "**")
            optional_symbols = item['optional_symbols']
            power_limited_species = item['power_limited_species']
            classified_true = False
            for replaced_kinetics in replaced_kinetics_list:
                all_expr = self.get_all_expr(kinetics_expression, optional_symbols)
                replaced_kinetics_simplify = sympy.sympify(replaced_kinetics)
                comparison_result = any(util.check_equal(expr, sympy.sympify(self.lower_powers(replaced_kinetics_simplify, power_limited_species))) for expr in all_expr)
                count += 1
                if comparison_result:
                    ret[item['name']] = True
                    classified_true = True
                    break
            if not classified_true:
                ret[item['name']] = False
        return ret

    def lower_powers(self, expr, keep=[]):
        """Lowers the power of certain elements in the expression.

        Args:
            expr (str): The expression from which to lower powers.
            keep (list): A list of elements whose power is to be kept as is.

        Returns:
            sympy.Expr: The expression with lowered powers.
        """
        def replace_if_applicable(base, exp):
            """Checks the exponent and base conditions, and returns the base if the conditions meet.
            
            Args:
                base (sympy.Expr): The base part of the expression.
                exp (sympy.Expr): The exponent part of the expression.

            Returns:
                sympy.Expr: The base if conditions are met, otherwise returns the expression as is.
            """
            if exp.is_integer and exp > 1 and base.is_symbol and base not in keep:
                return base
            else:
                return sympy.Pow(base, exp)
        return expr.replace(sympy.Pow, replace_if_applicable)

    def get_all_expr(self, expr, optional_symbols):
        """Generates all possible expressions by replacing the optional symbols with 1.

        Args:
            expr (str): The initial expression.
            optional_symbols (list): A list of optional symbols in the expression.

        Returns:
            list: A list of all possible expressions.
        """
        all_combinations = list(chain(*map(lambda x: combinations(optional_symbols, x), range(0, len(optional_symbols)+1))))
        all_expr = []
        for combo in all_combinations:
            temp_expr = expr
            for sym in optional_symbols:
                if sym not in combo:
                    temp_expr = temp_expr.replace(sym, "1")
            print(temp_expr)
            all_expr.append(sympy.simplify(temp_expr))
        return all_expr
