#!/usr/bin/env python3

"""Script that automatically updates OpenCPI Projects."""

from __future__ import annotations

import argparse
import pathlib
import sys
import typing
import xml.etree.ElementTree as ET


MODELS = ["hdl", "rcc"]


class Version:
    """Class representation of major.minor.patch version."""

    class ComparisonWithNotAVersionError(TypeError):
        """Error when Version is compared to an invalid type."""

        def __init__(self, operator: str, other: object) -> None:
            """Construct."""
            super().__init__(f"`{operator}` operator not defined between "
                             f"`{self.__class__.__name__}` and "
                             f"`{other.__class__.__name__}`")

    @classmethod
    def raise_if_compared_to_different_type(cls, operator: str, other: object) -> None:
        """Raise error if a comparison is performed with an invalid type."""
        if not isinstance(other, cls):
            raise cls.ComparisonWithNotAVersionError(operator, other)

    def __init__(self, version_str: str) -> None:
        """Construct."""
        self._version_parts = [int(part) for part in version_str.split(".")]

    def __repr__(self) -> str:
        """Dunder method: Convert to string evaluating to class constructor."""
        return f"Version(\"{'.'.join(map(str, self._version_parts))}\")"

    def __lt__(self, other: object) -> bool:
        """Dunder method: Less than."""
        self.raise_if_compared_to_different_type("<", other)
        return self._version_parts < other._version_parts

    def __le__(self, other: object) -> bool:
        """Dunder method: Less than or equal to."""
        self.raise_if_compared_to_different_type("<", other)
        return self._version_parts <= other._version_parts

    def __eq__(self, other: object) -> bool:
        """Dunder method: Equal to."""
        self.raise_if_compared_to_different_type("<", other)
        return self._version_parts == other._version_parts

    def __gt__(self, other: object) -> bool:
        """Dunder method: Greater than."""
        self.raise_if_compared_to_different_type("<", other)
        return self._version_parts > other._version_parts

    def __ge__(self, other: object) -> bool:
        """Dunder method: Greater than or equal to."""
        self.raise_if_compared_to_different_type("<", other)
        return self._version_parts >= other._version_parts

    def __ne__(self, other: object) -> bool:
        """Dunder method: Not equal to."""
        self.raise_if_compared_to_different_type("<", other)
        return self._version_parts != other._version_parts

    def __hash__(self) -> int:
        """Dunder method: Hash."""
        return hash(self.name)

    def __str__(self) -> str:
        """Dunder method: Convert to string."""
        return ".".join(map(str, self._version_parts))


V2_4_7 = Version("2.4.7")


def yield_workers_from_library(
    library_directory: pathlib.Path,
) -> typing.Iterable[pathlib.Path]:
    """Yield a generator of worker directory paths from a library path."""
    for path in library_directory.iterdir():
        if not path.is_dir():
            continue
        if len(path.suffixes) == 0:
            continue
        model = path.suffix[1:]
        if model not in MODELS:
            continue
        yield path


def yield_specs_from_library(
    library_directory: pathlib.Path,
) -> typing.Iterable[pathlib.Path]:
    """Yield a generator of spec file paths from a library path."""
    if not (library_directory / "specs").exists():
        return
    for path in (library_directory / "specs").iterdir():
        if path.suffix != ".xml":
            continue
        if not path.stem.endswith("spec"):
            continue
        yield path


def recursive_findall(element: ET.Element, tag: str) -> list[ET.Element]:
    """Find all occurrences of a given XML tag at any depth in an XML tree."""
    matches = []
    if element.tag == tag:
        matches.append(element)
    for child in element:
        matches.extend(recursive_findall(child, tag))
    return matches


def v2_4_7_owd_rename(worker_directory: pathlib.Path, version_to: Version) -> bool:
    """
    Rename all OWD files to their v2.4.7 names.

    - Move all *.hdl/*.xml to *.hdl/*-hdl.xml
    - Move all *.rcc/*.xml to *.rcc/*-rcc.xml
        - This isn't done for RCC Workers that proxy one or more HDL Workers
          when moving to v2.4.7 or earlier.
        - See https://opencpi.dev/t/broken-hdl-worker-search-path-on-slave-attributes/105

    This function ignores OWDs that have already been migrated.
    """
    if version_to < V2_4_7:
        return False
    name = worker_directory.stem
    model = worker_directory.suffix[1:]
    old_owd_file = worker_directory / f"{name}.xml"
    # Ignore already converted workers
    if not old_owd_file.exists():
        return False
    # Ignore RCC Workers that proxy HDL Workers in v2.4.7 and earlier
    if version_to <= V2_4_7 and model == "rcc":
        slaves = ET.parse(old_owd_file).getroot().find("slaves")
        if slaves is not None:
            for instance in recursive_findall(slaves, "instance"):
                if instance.attrib.get("worker").endswith("hdl"):
                    return False
    # Rename the file
    new_owd_file = worker_directory / f"{name}-{model}.xml"
    old_owd_file.rename(new_owd_file)
    print(f"Moved '{old_owd_file}' to '{new_owd_file}'")
    return True


def v2_4_7_move_spec_to_comp(spec_file: pathlib.Path, version_to: Version) -> bool:
    """Move all specs/*-spec.xml to *.comp/*-comp.xml."""
    if version_to < V2_4_7:
        return False
    # Make comp dir
    spec_file_name = spec_file.stem[:-5]
    comp_dir = spec_file.parent.parent / f"{spec_file_name}.comp"
    comp_dir.mkdir()
    print(f"Created '{comp_dir}'")
    # Move file to new location
    new_comp_file = comp_dir / f"{spec_file_name}-comp.xml"
    spec_file.rename(new_comp_file)
    print(f"Moved '{spec_file}' to '{new_comp_file}'")
    return True


# Replace the `spec` tag in a worker for any replaced spec file
def v2_4_7_replace_renamed_specs(
    worker_xml: ET.Element,
    spec_files: list[pathlib.Path],
    version_to: Version,
) -> bool:
    """Replace the `spec` attribute where required due to a file move."""
    if version_to < V2_4_7:
        return False
    print(f"Scanning '{worker_xml}' ... ")
    with worker_xml.open("r") as file:
        lines = file.readlines()
    changed_something = False
    for i, line in enumerate(lines):
        for spec_file in spec_files:
            # Case where spec="<spec>[-_]spec.xml"
            # Case where spec="<spec>[-_]spec"
            name = spec_file.stem[:-5]
            for case in [spec_file.name, spec_file.stem]:
                if case in line:
                    lines[i] = line.replace(case, name)
                    print(f"Replaced '{case}' with '{name}' on line {i} of "
                          f"'{worker_xml}'")
                    changed_something = True
                    break
    if changed_something:
        with worker_xml.open("w") as file:
            file.writelines(lines)
    return changed_something


class MissingArgumentError(Exception):
    """Error when script is not given a required argument."""

    def __init__(self, argument: str) -> None:
        """Construct."""
        super().__init__(f"{argument} must be provided at least once")


def main() -> None:
    """Run the script."""
    # Argument parsing
    argparser = argparse.ArgumentParser()
    argparser.add_argument(
        "--project",
        action="append",
        type=pathlib.Path,
        help="The projects to search when modifying `spec` attributes",
    )
    argparser.add_argument(
        "--library",
        action="append",
        type=pathlib.Path,
        help="The libraries to search when moving `[-_]spec` files",
    )
    argparser.add_argument(
        "--to-version",
        type=Version,
        help="The OpenCPI version to migrate to (2.4.7 [default] or newer)",
        default=V2_4_7,
    )
    argparser.add_argument(
        "--verbose",
        action="store_true",
    )
    args, unknown_args = argparser.parse_known_args()
    if len(unknown_args) != 0:
        print(f"Extra arguments not recognised: {unknown_args}")
        sys.exit(1)

    # Validate arguments
    if args.project is None:
        argument = "--project"
        raise MissingArgumentError(argument)
    if args.library is None:
        argument = "--library"
        raise MissingArgumentError(argument)

    # Start of processing
    try:
        print(f"Running over projects '{args.project}' "
              f"and libraries '{args.library} ...")
        projects = args.project
        libraries = args.library
        files_moved = []
        for library in libraries:
            for worker in yield_workers_from_library(library):
                v2_4_7_owd_rename(worker, version_to=args.to_version)
            for spec_file in yield_specs_from_library(library):
                v2_4_7_move_spec_to_comp(spec_file, version_to=args.to_version)
                files_moved.append(spec_file)
        # Edit any worker that referenced a moved spec
        for project in projects:
            for worker_directory in (f for model in MODELS for f in project.rglob(f"*.{model}")):
                model = worker_directory.suffix
                if not worker_directory.is_dir():
                    continue
                worker_xml = worker_directory / f"{worker_directory.name}-{model}.xml"
                if not worker_xml.exists():
                    worker_xml = worker_directory / f"{worker_directory.name}.xml"
                    if not worker_xml.exists():
                        continue
                v2_4_7_replace_renamed_specs(worker_xml, files_moved, version_to=args.to_version)
    except Exception as err:
        if args.verbose:
            raise err
        print(f"ERROR: {err}")


if __name__ == "__main__":
    main()
