from __future__ import annotations
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from collections import Counter
import math
from typing import TYPE_CHECKING, Iterable

from mido import MidiFile, Message

if TYPE_CHECKING:
    from argparse import _ArgumentGroup, Namespace
    from ._music_box import MusicBox


@dataclass
class Sound:
    note: int
    time: int
    track: int


@dataclass(frozen=True)
class Transposition:
    shift: int
    ratio: float


@dataclass
class Distance:
    time: int
    diff: float = math.inf


@dataclass
class Melody:
    path: Path
    music_box: MusicBox
    transpose_lower: int = -100
    transpose_upper: int = 100
    tracks: frozenset = frozenset(range(16))

    @classmethod
    def init_parser(cls, parser: _ArgumentGroup) -> None:
        parser.add_argument(
            '--input', type=Path, required=True,
            help='path to the input MIDI (*.mid) file',
        )
        parser.add_argument(
            '--transpose-lower', type=int, default=cls.transpose_lower,
            help='the lowest transposition to try',
        )
        parser.add_argument(
            '--transpose-upper', type=int, default=cls.transpose_upper,
            help='the highest transposition to try',
        )
        parser.add_argument(
            '--tracks', nargs='*', type=int, default=[],
            help='numbers of sound tracks to include',
        )

    @classmethod
    def from_args(cls, args: Namespace, *, music_box: MusicBox) -> Melody:
        return cls(
            path=args.input,
            transpose_lower=args.transpose_lower,
            transpose_upper=args.transpose_upper,
            tracks=frozenset(args.tracks) or cls.tracks,
            music_box=music_box,
        )

    @cached_property
    def sounds(self) -> list[Sound]:
        sounds: list[Sound] = []
        with MidiFile(str(self.path)) as midi_file:
            message: Message
            for i, track in enumerate(midi_file.tracks):
                if i not in self.tracks:
                    continue
                print(f'reading track #{i} "{track.name}"...')
                time = 0
                for message in track:
                    time += message.time
                    if message.is_meta:
                        continue
                    if message.type != "note_on":
                        continue
                    if message.velocity == 0:
                        continue
                    sound = Sound(note=message.note, time=time, track=i)
                    sounds.append(sound)

        # skip silence at the beginning
        min_time = min(sound.time for sound in sounds)
        for sound in sounds:
            sound.time -= min_time

        sounds.sort(key=lambda sound: sound.time)
        return sounds

    @cached_property
    def notes_use(self) -> dict[int, int]:
        """How many times each note appears in the melody.
        """
        return dict(Counter(sounds.note for sounds in self.sounds))

    @cached_property
    def sounds_count(self) -> int:
        """How many sounds in total there are in the melody.
        """
        return len(self.sounds)

    def count_available_sounds(self, trans: int) -> int:
        """How many notes from the melody fit in the music box.
        """
        count = 0
        for note, freq in self.notes_use.items():
            if self.music_box.contains_note(note + trans):
                count += freq
        return count

    @cached_property
    def max_time(self) -> int:
        """The tick when the last note plays.
        """
        return max(sounds.time for sounds in self.sounds)

    @cached_property
    def best_transpose(self) -> Transposition:
        """Transposition that fits most of the notes.
        """
        lower_octave = int(self.transpose_lower / 12) * 12
        best_transpose = self._get_best_transpose(
            range(lower_octave, self.transpose_upper, 12),
        )
        # Better to transpose with preserving most of the notes.
        # If full octave transposition doesn't fit just a bit, roll with it.
        if best_transpose.ratio >= .90:
            return best_transpose
        return self._get_best_transpose(
            range(self.transpose_lower, self.transpose_upper),
        )

    def _get_best_transpose(self, seq: Iterable[int]) -> Transposition:
        """Try all transpositions from the sequence and pick the best one.
        """
        best_transpose: Transposition = Transposition(0, 0)
        for shift in seq:
            avail = self.count_available_sounds(shift)
            ratio = avail / float(self.sounds_count)
            if ratio == 1:
                return Transposition(shift, 1)
            if ratio > best_transpose.ratio:
                best_transpose = Transposition(shift, ratio)
        return best_transpose

    @cached_property
    def min_distance(self) -> float:
        """The shortest time between 2 consequentive sounds (in ticks).
        """
        min_distances: dict[int, Distance] = {}
        for sound in self.sounds:
            distance = min_distances.setdefault(sound.note, Distance(sound.time))
            diff = sound.time - distance.time
            if math.isclose(diff, 0):
                continue
            distance.diff = min(distance.diff, diff)
            distance.time = sound.time
        if not min_distances:
            return math.inf
        return min(d.diff for d in min_distances.values())
