from typing import Callable, Iterable, Protocol

from OTAnalytics.application.analysis.intersect import (
    RunIntersect,
    group_sections_by_offset,
)
from OTAnalytics.domain.event import (
    EventBuilder,
    EventDataset,
    PythonEventDataset,
    SectionEventBuilder,
)
from OTAnalytics.domain.geometry import (
    DirectionVector2D,
    RelativeOffsetCoordinate,
    calculate_direction_vector,
)
from OTAnalytics.domain.intersect import Intersector, IntersectParallelizationStrategy
from OTAnalytics.domain.section import Area, LineSection, Section
from OTAnalytics.domain.track_dataset.track_dataset import (
    IntersectionError,
    TrackDataset,
)
from OTAnalytics.domain.types import EventType


class IntersectByIntersectionPoints(Intersector):
    """Use intersection points of tracks and sections to create events.

    This strategy is intended to be used with LineSections.
    """

    def __init__(
        self,
        calculate_direction_vector_: Callable[
            [float, float, float, float], DirectionVector2D
        ] = calculate_direction_vector,
    ) -> None:
        self._calculate_direction_vector = calculate_direction_vector_

    def intersect(
        self,
        track_dataset: TrackDataset,
        sections: Iterable[Section],
        event_builder: EventBuilder,
    ) -> EventDataset:
        sections_grouped_by_offset = group_sections_by_offset(
            sections, EventType.SECTION_ENTER
        )
        event_dataset = PythonEventDataset()
        for offset, section_group in sections_grouped_by_offset.items():
            result = self.__do_intersect(
                track_dataset, section_group, offset, event_builder
            )
            event_dataset.extend(result)
        return event_dataset

    def __do_intersect(
        self,
        track_dataset: TrackDataset,
        sections: list[Section],
        offset: RelativeOffsetCoordinate,
        event_builder: EventBuilder,
    ) -> EventDataset:
        intersection_result = track_dataset.intersection_points(sections, offset)
        return intersection_result.create_events(offset, event_builder)


class IntersectAreaByTrackPoints(Intersector):
    def __init__(
        self,
        calculate_direction_vector_: Callable[
            [float, float, float, float], DirectionVector2D
        ] = calculate_direction_vector,
    ) -> None:
        self._calculate_direction_vector = calculate_direction_vector_

    def intersect(
        self,
        track_dataset: TrackDataset,
        sections: Iterable[Section],
        event_builder: EventBuilder,
    ) -> EventDataset:
        sections_grouped_by_offset = group_sections_by_offset(
            sections, EventType.SECTION_ENTER
        )
        event_dataset = PythonEventDataset()
        for offset, section_group in sections_grouped_by_offset.items():
            result = self.__do_intersect(
                track_dataset, section_group, offset, event_builder
            )
            event_dataset.extend(result)
        return event_dataset

    def __do_intersect(
        self,
        track_dataset: TrackDataset,
        sections: list[Section],
        offset: RelativeOffsetCoordinate,
        event_builder: EventBuilder,
    ) -> EventDataset:
        contained_by_sections_result = track_dataset.contained_by_sections(
            sections, offset
        )

        events = []
        for (
            track_id,
            contained_by_sections_masks,
        ) in contained_by_sections_result.items():
            if not (track := track_dataset.get_for(track_id)):
                raise IntersectionError(
                    "Track not found. Unable to create intersection event "
                    f"for track {track_id}."
                )
            track_detections = track.detections
            for section_id, section_entered_mask in contained_by_sections_masks:
                event_builder.add_section_id(section_id)
                event_builder.add_road_user_type(track.classification)

                track_starts_inside_area = section_entered_mask[0]
                if track_starts_inside_area:
                    first_detection = track_detections[0]
                    first_coord = first_detection.get_coordinate(offset)
                    second_coord = track_detections[1].get_coordinate(offset)

                    event_builder.add_event_type(EventType.SECTION_ENTER)
                    event_builder.add_direction_vector(
                        self._calculate_direction_vector(
                            first_coord.x,
                            first_coord.y,
                            second_coord.x,
                            second_coord.y,
                        )
                    )
                    event_builder.add_event_coordinate(
                        first_detection.x, first_detection.y
                    )
                    event = event_builder.create_event(first_detection)
                    events.append(event)

                section_currently_entered = track_starts_inside_area
                for current_index, current_detection in enumerate(
                    track_detections[1:], start=1
                ):
                    entered = section_entered_mask[current_index]
                    if section_currently_entered == entered:
                        continue

                    prev_coord = track_detections[current_index - 1].get_coordinate(
                        offset
                    )
                    current_coord = current_detection.get_coordinate(offset)

                    event_builder.add_direction_vector(
                        self._calculate_direction_vector(
                            prev_coord.x,
                            prev_coord.y,
                            current_coord.x,
                            current_coord.y,
                        )
                    )
                    event_builder.add_event_coordinate(current_coord.x, current_coord.y)
                    event_builder.add_interpolated_event_coordinate(
                        current_coord.x, current_coord.y
                    )
                    event_builder.add_interpolated_occurrence(
                        current_detection.occurrence
                    )

                    if entered:
                        event_builder.add_event_type(EventType.SECTION_ENTER)
                    else:
                        event_builder.add_event_type(EventType.SECTION_LEAVE)

                    event = event_builder.create_event(current_detection)
                    events.append(event)
                    section_currently_entered = entered

        return PythonEventDataset(events)


class RunCreateIntersectionEvents:
    def __init__(
        self,
        intersect_line_section: Intersector,
        intersect_area_section: Intersector,
        track_dataset: TrackDataset,
        sections: Iterable[Section],
        event_builder: SectionEventBuilder,
    ):
        self._intersect_line_section = intersect_line_section
        self._intersect_area_section = intersect_area_section
        self._track_dataset = track_dataset
        self._sections = sections
        self._event_builder = event_builder

    def create(self) -> EventDataset:
        event_dataset = PythonEventDataset()
        line_sections, area_sections = separate_sections(self._sections)
        event_dataset.extend(
            self._intersect_line_section.intersect(
                self._track_dataset, line_sections, self._event_builder
            )
        )
        event_dataset.extend(
            self._intersect_area_section.intersect(
                self._track_dataset, area_sections, self._event_builder
            )
        )
        return event_dataset


class GetTracks(Protocol):
    def as_dataset(self) -> TrackDataset: ...


class BatchedTracksRunIntersect(RunIntersect):
    def __init__(
        self,
        intersect_parallelizer: IntersectParallelizationStrategy,
        get_tracks: GetTracks,
    ) -> None:
        self._intersect_parallelizer = intersect_parallelizer
        self._get_tracks = get_tracks

    def __call__(self, sections: Iterable[Section]) -> EventDataset:
        filtered_tracks = self._get_tracks.as_dataset()
        filtered_tracks.calculate_geometries_for(
            {_section.get_offset(EventType.SECTION_ENTER) for _section in sections}
        )

        batches = filtered_tracks.split(self._intersect_parallelizer.num_processes)

        tasks = [(batch, sections) for batch in batches]
        return self._intersect_parallelizer.execute(_create_events, tasks)


def _create_events(tracks: TrackDataset, sections: Iterable[Section]) -> EventDataset:
    event_dataset = PythonEventDataset()
    event_builder = SectionEventBuilder()

    create_intersection_events = RunCreateIntersectionEvents(
        intersect_line_section=IntersectByIntersectionPoints(),
        intersect_area_section=IntersectAreaByTrackPoints(),
        track_dataset=tracks,
        sections=sections,
        event_builder=event_builder,
    )
    event_dataset.extend(create_intersection_events.create())
    return event_dataset


def separate_sections(
    sections: Iterable[Section],
) -> tuple[Iterable[LineSection], Iterable[Area]]:
    line_sections = []
    area_sections = []
    for section in sections:
        if isinstance(section, LineSection):
            line_sections.append(section)
        elif isinstance(section, Area):
            area_sections.append(section)
        else:
            raise TypeError(
                "Unable to separate section. "
                f"Unknown section type for section {section.name} "
                f"with type {type(section)}"
            )

    return line_sections, area_sections
