import json
from dataclasses import dataclass
from enum import Enum
from itertools import combinations
from pathlib import Path
from typing import Optional, Self

import numpy as np
import polars as pl
from numpy.typing import NDArray

from orca_parse.orca_output import OrcaOutput

PERIODIC_TABLE: pl.DataFrame = pl.read_json(Path(__file__).parent / "periodic_table.json")


@dataclass
class Element:
    symbol: str
    atomic_number: int
    name: str
    hex_color: str
    electron_configuration: str
    electronegativity: float
    atomic_radius_pm: int
    group_block: str

    @classmethod
    def from_symbol(cls, symbol: str) -> Self:
        element_df = PERIODIC_TABLE.filter(pl.col("symbol").str.to_lowercase() == symbol.lower())
        if element_df.is_empty():
            raise ValueError(f"No element with symbol '{symbol}' found!")
        element: dict = element_df.to_dicts()[0]
        return cls(**element)

    @classmethod
    def from_atomic_number(cls, atomic_number: int) -> Self:
        element_df = PERIODIC_TABLE.filter(pl.col("atomic_number") == atomic_number)
        if element_df.is_empty():
            raise ValueError(f"No element with atomic number '{atomic_number}' found!")
        element: dict = element_df.to_dicts()[0]
        return cls(**element)

    @property
    def is_metal(self) -> bool:
        metal_blocks = ["Alkali metal", "Alkaline earth metal", "Transition metal", "Post-transition metal"]
        return True if self.group_block in metal_blocks else False


@dataclass
class Atom:
    element: Element
    coords: NDArray

    def __repr__(self) -> str:
        return f"{self.element.symbol} {self.coords[0]:.4f} {self.coords[1]:.4f} {self.coords[2]:.4f}"

    @classmethod
    def from_str(cls, string: str) -> Self:
        symbol_or_atomic_number, x, y, z = string.split()
        if symbol_or_atomic_number.isdigit():
            atomic_number = int(symbol_or_atomic_number)
            element = Element.from_atomic_number(atomic_number)
        else:
            symbol = str(symbol_or_atomic_number)
            element = Element.from_symbol(symbol)
        coords = np.array([x, y, z], dtype=float)
        return cls(element, coords)


@dataclass
class Geometry:
    atoms: list[Atom]
    comment: str = ""

    def __repr__(self) -> str:
        return f"{len(self.atoms)}\n{self.comment}\n{'\n'.join([str(a) for a in self.atoms])}"

    @classmethod
    def from_xyz_file(cls, xyz_file: Path | str) -> Self:
        xyz = Path(xyz_file).read_text()
        return cls.from_xyz(xyz)

    @classmethod
    def from_xyz(cls, xyz: str) -> Self:
        lines = xyz.splitlines()
        n_atoms = int(lines[0])
        comment = str(lines[1])
        atoms = [Atom.from_str(line) for line in lines[2:]]
        if n_atoms != len(atoms):
            raise ValueError(f"Invalid XYZ file: expected {n_atoms} atoms, found {len(atoms)}")
        return cls(atoms, comment)


@dataclass
class BondTypeSettings:
    range: float
    radius: float
    hex_color: str


class BondType(Enum):
    SINGLE = BondTypeSettings(range=1.0, radius=1.0, hex_color="D3D3D3")
    COORDINATION = BondTypeSettings(range=1.3, radius=0.6, hex_color="C20CBE")


@dataclass
class Bond:
    atoms: tuple[Atom, Atom]
    type: BondType = BondType.SINGLE


class Molecule:
    def __init__(self, charge: int, mult: int, geometry: Geometry) -> None:
        self.charge = int(charge)
        self.mult = int(mult)
        self.geometry = geometry

    def __repr__(self) -> str:
        json_metadata = f'{{"charge": {self.charge}, "mult": {self.mult}}}'
        return f"{len(self.geometry.atoms)}\n{json_metadata}\n{'\n'.join([str(a) for a in self.geometry.atoms])}"

    @classmethod
    def from_xyz_file(cls, xyz_file: Path | str, charge: Optional[int] = None, mult: Optional[int] = None) -> Self:
        """Construct a Molecule from an XYZ file.

        If the charge/mult are not specified by a JSON dict in the comment line of the XYZ file,
        they need to be supplied separately.

        Supplied charge/mult override JSON dict settings.
        """
        xyz = Path(xyz_file).read_text()
        return cls.from_xyz(xyz, charge, mult)

    @classmethod
    def from_xyz(cls, xyz: str, charge: Optional[int] = None, mult: Optional[int] = None) -> Self:
        geometry = Geometry.from_xyz(xyz)

        if charge is None or mult is None:
            try:
                metadata = json.loads(geometry.comment.strip())
                charge = metadata.get("charge") if charge is None else charge
                mult = metadata.get("mult") if mult is None else mult
            except json.JSONDecodeError:
                if charge is None or mult is None:
                    raise ValueError("Charge and multiplicity must be provided if not in XYZ comment")

        if charge is None or mult is None:
            raise ValueError(
                "Failed to set either charge or mult from either the XYZ file or the method parameters - this should not have happend!"
            )

        return cls(charge, mult, geometry)

    @classmethod
    def from_output(cls, output_file: Path | str) -> Self:
        output = OrcaOutput(output_file)
        xyz = output.xyz
        charge = output.charge
        mult = output.mult
        return cls.from_xyz(xyz, charge, mult)

    def get_bonds_by_radius_overlap(self, radius_scale: float = 0.5) -> list[Bond]:
        """
        Detect bonds between atoms based on their atomic radii and distances.

        Uses atomic radii to determine if atoms are close enough to be bonded,
        with special handling for coordination bonds involving metals.

        Args:
            radius_scale: Factor to radius_scale the bond threshold distance (default: 0.5)

        Returns:
            List of detected Bond objects
        """
        bonds: list[Bond] = []
        for i, j in combinations(self.geometry.atoms, 2):
            bond_type = BondType.SINGLE
            bond_threshold = (i.element.atomic_radius_pm + j.element.atomic_radius_pm) / 100
            bond_threshold *= radius_scale

            # Coordination bonds are usually longer (same for hydrogen bonds etc)
            if i.element.is_metal or j.element.is_metal:
                bond_type = BondType.COORDINATION
                bond_threshold *= BondType.COORDINATION.value.range

            distance = np.linalg.norm(j.coords - i.coords)

            if distance <= bond_threshold:
                bonds.append(Bond((i, j), bond_type))
        return bonds
