import argparse
import collections
import logging
import zipfile
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any, Counter, Dict, Iterator, List, Optional, Set, Tuple

import pandas
from annofabapi.dataclass.annotation import SimpleAnnotation
from annofabapi.models import AdditionalDataDefinitionType, TaskPhase, TaskStatus
from annofabapi.parser import (
    SimpleAnnotationParser,
    SimpleAnnotationParserByTask,
    lazy_parse_simple_annotation_dir,
    lazy_parse_simple_annotation_dir_by_task,
    lazy_parse_simple_annotation_zip,
    lazy_parse_simple_annotation_zip_by_task,
)
from dataclasses_json import dataclass_json

import annofabcli
import annofabcli.common.cli
from annofabcli import AnnofabApiFacade
from annofabcli.common.cli import AbstractCommandLineInterface, ArgumentParser, build_annofabapi_resource_and_login
from annofabcli.common.visualize import AddProps, MessageLocale

logger = logging.getLogger(__name__)

AttributesColumn = Tuple[str, str, str]
"""
属性値ごとの個数を表したCSVの列の型
Tuple[Label, Attribute, Choice]
"""

LabelColumnList = List[str]

AttributesColumnList = List[AttributesColumn]


class GroupBy(Enum):
    TASK_ID = "task_id"
    INPUT_DATA_ID = "input_data_id"


@dataclass_json
@dataclass(frozen=True)
class AnnotationCounterByTask:
    """

    """

    task_id: str
    task_status: TaskStatus
    task_phase: TaskPhase
    task_phase_stage: int
    labels_count: Counter[str]
    attirbutes_count: Counter[AttributesColumn]


@dataclass_json
@dataclass(frozen=True)
class AnnotationCounterByInputData:
    """

    """

    task_id: str
    task_status: TaskStatus
    task_phase: TaskPhase
    task_phase_stage: int

    input_data_id: str
    input_data_name: str
    labels_count: Counter[str]
    attirbutes_count: Counter[AttributesColumn]


class ListAnnotationCount(AbstractCommandLineInterface):
    """
    アノテーション数情報を出力する。
    """

    CSV_FORMAT = {"encoding": "utf_8_sig", "index": False}

    @staticmethod
    def lazy_parse_simple_annotation_by_input_data(annotation_path: Path) -> Iterator[SimpleAnnotationParser]:
        if not annotation_path.exists():
            raise RuntimeError("'--annotation'で指定したディレクトリまたはファイルが存在しません。")

        if annotation_path.is_dir():
            return lazy_parse_simple_annotation_dir(annotation_path)
        elif zipfile.is_zipfile(str(annotation_path)):
            return lazy_parse_simple_annotation_zip(annotation_path)
        else:
            raise RuntimeError(f"'--annotation'で指定した'{annotation_path}'は、zipファイルまたはディレクトリではありませんでした。")

    @staticmethod
    def lazy_parse_simple_annotation_by_task(annotation_path: Path) -> Iterator[SimpleAnnotationParserByTask]:
        if not annotation_path.exists():
            raise RuntimeError("'--annotation'で指定したディレクトリまたはファイルが存在しません。")

        if annotation_path.is_dir():
            return lazy_parse_simple_annotation_dir_by_task(annotation_path)
        elif zipfile.is_zipfile(str(annotation_path)):
            return lazy_parse_simple_annotation_zip_by_task(annotation_path)
        else:
            raise RuntimeError(f"'--annotation'で指定した'{annotation_path}'は、zipファイルまたはディレクトリではありませんでした。")

    @staticmethod
    def count_for_input_data(
        simple_annotation: SimpleAnnotation, target_attributes: Optional[Set[Tuple[str, str]]] = None
    ) -> AnnotationCounterByInputData:
        """
        1個の入力データに対してアノテーション数をカウントする

        Args:
            simple_annotation: JSONファイルの内容
            target_attributes:

        Returns:

        """

        labels_count = collections.Counter([e.label for e in simple_annotation.details])

        attributes_list: List[Tuple[str, str, str]] = []
        for detail in simple_annotation.details:
            label = detail.label
            for attribute, value in detail.attributes.items():
                if target_attributes is not None and (label, attribute) in target_attributes:
                    attributes_list.append((label, attribute, str(value)))

        attirbutes_count = collections.Counter(attributes_list)

        return AnnotationCounterByInputData(
            task_id=simple_annotation.task_id,
            task_phase=simple_annotation.task_phase,
            task_phase_stage=simple_annotation.task_phase_stage,
            task_status=simple_annotation.task_status,
            input_data_id=simple_annotation.input_data_id,
            input_data_name=simple_annotation.input_data_name,
            labels_count=labels_count,
            attirbutes_count=attirbutes_count,
        )

    @staticmethod
    def count_for_task(
        task_parser: SimpleAnnotationParserByTask, target_attributes: Optional[Set[Tuple[str, str]]] = None
    ) -> AnnotationCounterByTask:
        """
        1個のタスクに対してアノテーション数をカウントする

        """

        labels_count: Counter[str] = collections.Counter()
        attirbutes_count: Counter[Tuple[str, str, str]] = collections.Counter()

        last_simple_annotation = None
        for parser in task_parser.lazy_parse():
            simple_annotation = parser.parse()
            input_data = ListAnnotationCount.count_for_input_data(simple_annotation, target_attributes)
            labels_count += input_data.labels_count
            attirbutes_count += input_data.attirbutes_count
            last_simple_annotation = simple_annotation

        if last_simple_annotation is None:
            raise RuntimeError(f"{task_parser.task_id} ディレクトリにはjsonファイルが１つも含まれていません。")

        return AnnotationCounterByTask(
            task_id=last_simple_annotation.task_id,
            task_status=last_simple_annotation.task_status,
            task_phase=last_simple_annotation.task_phase,
            task_phase_stage=last_simple_annotation.task_phase_stage,
            labels_count=labels_count,
            attirbutes_count=attirbutes_count,
        )

    def print_labels_count_for_task(
        self, task_counter_list: List[AnnotationCounterByTask], label_columns: List[str], output_dir: Path
    ):
        def to_dict(c: AnnotationCounterByTask) -> Dict[str, Any]:
            d = {
                "task_id": c.task_id,
                "task_status": c.task_status.value,
                "task_phase": c.task_phase.value,
                "task_phase_stage": c.task_phase_stage,
            }
            d.update({f"label_{label}": c.labels_count[label] for label in label_columns})
            return d

        columns = ["task_id", "task_status", "task_phase", "task_phase_stage"]
        columns.extend([f"label_{e}" for e in label_columns])

        df = pandas.DataFrame([to_dict(e) for e in task_counter_list], columns=columns)
        output_file = str(output_dir / "labels_count.csv")
        annofabcli.utils.print_csv(df, output=output_file, to_csv_kwargs=self.CSV_FORMAT)

    def print_attirbutes_count_for_task(
        self,
        task_counter_list: List[AnnotationCounterByTask],
        attribute_columns: List[Tuple[str, str, str]],
        output_dir: Path,
    ):
        def to_cell(c: AnnotationCounterByTask) -> Dict[AttributesColumn, Any]:
            cell = {
                ("", "", "task_id"): c.task_id,
                ("", "", "task_status"): c.task_status.value,
                ("", "", "task_phase"): c.task_phase.value,
                ("", "", "task_phase_stage"): c.task_phase_stage,
            }
            for col in attribute_columns:
                cell.update({col: c.attirbutes_count[col]})

            return cell

        columns = [("", "", "task_id"), ("", "", "task_status"), ("", "", "task_phase"), ("", "", "task_phase_stage")]
        columns.extend(attribute_columns)
        df = pandas.DataFrame([to_cell(e) for e in task_counter_list], columns=pandas.MultiIndex.from_tuples(columns))

        output_file = str(output_dir / "attirbutes_count.csv")
        annofabcli.utils.print_csv(df, output=output_file, to_csv_kwargs=self.CSV_FORMAT)

    def print_labels_count_for_input_data(
        self, input_data_counter_list: List[AnnotationCounterByInputData], label_columns: List[str], output_dir: Path
    ):
        def to_dict(c: AnnotationCounterByInputData) -> Dict[str, Any]:
            d = {
                "input_data_id": c.input_data_id,
                "input_data_name": c.input_data_name,
                "task_id": c.task_id,
                "task_status": c.task_status.value,
                "task_phase": c.task_phase.value,
                "task_phase_stage": c.task_phase_stage,
            }
            d.update({f"label_{label}": c.labels_count[label] for label in label_columns})
            return d

        columns = ["input_data_id", "input_data_name", "task_id", "task_status", "task_phase", "task_phase_stage"]
        columns.extend([f"label_{e}" for e in label_columns])

        df = pandas.DataFrame([to_dict(e) for e in input_data_counter_list])
        output_file = str(output_dir / "labels_count.csv")
        annofabcli.utils.print_csv(df, output=output_file, to_csv_kwargs=self.CSV_FORMAT)

    def print_attirbutes_count_for_input_data(
        self,
        input_data_counter_list: List[AnnotationCounterByInputData],
        attribute_columns: List[Tuple[str, str, str]],
        output_dir: Path,
    ):
        def to_cell(c: AnnotationCounterByInputData) -> Dict[AttributesColumn, Any]:
            cell = {
                ("", "", "input_data_id"): c.input_data_id,
                ("", "", "input_data_name"): c.input_data_name,
                ("", "", "task_id"): c.task_id,
                ("", "", "task_status"): c.task_status.value,
                ("", "", "task_phase"): c.task_phase.value,
                ("", "", "task_phase_stage"): c.task_phase_stage,
            }
            for col in attribute_columns:
                cell.update({col: c.attirbutes_count[col]})

            return cell

        columns = [
            ("", "", "input_data_id"),
            ("", "", "input_data_name"),
            ("", "", "task_id"),
            ("", "", "task_status"),
            ("", "", "task_phase"),
            ("", "", "task_phase_stage"),
        ]
        columns.extend(attribute_columns)
        df = pandas.DataFrame(
            [to_cell(e) for e in input_data_counter_list], columns=pandas.MultiIndex.from_tuples(columns)
        )

        output_file = str(output_dir / "attirbutes_count.csv")
        annofabcli.utils.print_csv(df, output=output_file, to_csv_kwargs=self.CSV_FORMAT)

    @staticmethod
    def get_target_attributes_columns(annotation_specs_labels: List[Dict[str, Any]]) -> List[AttributesColumn]:
        """
        出力対象の属性情報を取得する（label, attribute, choice)
        """

        target_attributes_columns: List[AttributesColumn] = []
        for label in annotation_specs_labels:
            label_name_en = AddProps.get_message(label["label_name"], MessageLocale.EN)
            label_name_en = label_name_en if label_name_en is not None else ""

            for attribute in label["additional_data_definitions"]:
                attribute_name_en = AddProps.get_message(attribute["name"], MessageLocale.EN)
                attribute_name_en = attribute_name_en if attribute_name_en is not None else ""

                if AdditionalDataDefinitionType(attribute["type"]) in [
                    AdditionalDataDefinitionType.CHOICE,
                    AdditionalDataDefinitionType.SELECT,
                ]:
                    for choice in attribute["choices"]:
                        choice_name_en = AddProps.get_message(choice["name"], MessageLocale.EN)
                        choice_name_en = choice_name_en if choice_name_en is not None else ""
                        target_attributes_columns.append((label_name_en, attribute_name_en, choice_name_en))

                elif AdditionalDataDefinitionType(attribute["type"]) == AdditionalDataDefinitionType.FLAG:
                    target_attributes_columns.append((label_name_en, attribute_name_en, "True"))
                    target_attributes_columns.append((label_name_en, attribute_name_en, "False"))

                else:
                    continue

        return target_attributes_columns

    @staticmethod
    def get_target_label_columns(annotation_specs_labels: List[Dict[str, Any]]) -> List[str]:
        """
        出力対象の属性情報を取得する（label, attribute, choice)
        """

        def to_label_name(label: Dict[str, Any]) -> str:
            label_name_en = AddProps.get_message(label["label_name"], MessageLocale.EN)
            label_name_en = label_name_en if label_name_en is not None else ""
            return label_name_en

        return [to_label_name(label) for label in annotation_specs_labels]

    def get_target_columns(self, project_id: str) -> Tuple[LabelColumnList, AttributesColumnList]:
        annotation_specs, _ = self.service.api.get_annotation_specs(project_id)
        annotation_specs_labels = annotation_specs["labels"]
        label_columns = self.get_target_label_columns(annotation_specs_labels)
        attributes_columns = self.get_target_attributes_columns(annotation_specs_labels)
        return (label_columns, attributes_columns)

    def list_annotation_count_by_task(self, project_id: str, annotation_path: Path, output_dir: Path) -> None:
        task_counter_list = []
        iter_task_parser = self.lazy_parse_simple_annotation_by_task(annotation_path)
        target_label_columns, target_attributes_columns = self.get_target_columns(project_id)

        target_attributes = {(e[0], e[1]) for e in target_attributes_columns}
        logger.info(f"アノテーションzip/ディレクトリを読み込み中")
        for task_index, task_parser in enumerate(iter_task_parser):
            task_index += 1
            if task_index % 1000 == 0:
                logger.debug(f"{task_index}  件目を読み込み中")

            task_counter = self.count_for_task(task_parser, target_attributes=target_attributes)
            task_counter_list.append(task_counter)

        self.print_labels_count_for_task(task_counter_list, label_columns=target_label_columns, output_dir=output_dir)

        self.print_attirbutes_count_for_task(
            task_counter_list, output_dir=output_dir, attribute_columns=target_attributes_columns,
        )

    def list_annotation_count_by_input_data(self, project_id: str, annotation_path: Path, output_dir: Path) -> None:
        input_data_counter_list = []
        iter_parser = self.lazy_parse_simple_annotation_by_input_data(annotation_path)
        target_label_columns, target_attributes_columns = self.get_target_columns(project_id)

        target_attributes = {(e[0], e[1]) for e in target_attributes_columns}
        logger.info(f"アノテーションzip/ディレクトリを読み込み中")
        for index, parser in enumerate(iter_parser):
            if index % 1000 == 0:
                logger.debug(f"{index}  件目を読み込み中")

            simple_annotation = parser.parse()
            input_data_counter = self.count_for_input_data(simple_annotation, target_attributes=target_attributes)
            input_data_counter_list.append(input_data_counter)

        self.print_labels_count_for_input_data(
            input_data_counter_list, label_columns=target_label_columns, output_dir=output_dir
        )

        self.print_attirbutes_count_for_input_data(
            input_data_counter_list, output_dir=output_dir, attribute_columns=target_attributes_columns,
        )

    def main(self):
        args = self.args

        project_id = args.project_id
        super().validate_project(project_id, project_member_roles=None)

        if args.annotation is not None:
            annotation_path = Path(args.annotation_dir)
        else:
            cache_dir = annofabcli.utils.get_cache_dir()
            annotation_path = cache_dir / "annotation.zip"
            logger.info(f"Simpleアノテーションzipをダウンロード中: {annotation_path}")
            self.service.wrapper.download_annotation_archive(project_id, str(annotation_path), v2=True)

        group_by = GroupBy(args.group_by)
        if group_by == GroupBy.TASK_ID:
            self.list_annotation_count_by_task(
                project_id, annotation_path=annotation_path, output_dir=Path(args.output_dir)
            )
        elif group_by == GroupBy.INPUT_DATA_ID:
            self.list_annotation_count_by_input_data(
                project_id, annotation_path=annotation_path, output_dir=Path(args.output_dir)
            )


def parse_args(parser: argparse.ArgumentParser):
    argument_parser = ArgumentParser(parser)

    argument_parser.add_project_id()
    parser.add_argument(
        "--annotation", type=str, help="Simpleアノテーションzip、またはzipを展開したディレクトリを指定します。" "指定しない場合はAnnoFabからダウンロードします。"
    )
    parser.add_argument("-o", "--output_dir", type=str, required=True, help="出力ディレクトリのパス")

    parser.add_argument(
        "--group_by",
        type=str,
        choices=[GroupBy.TASK_ID.value, GroupBy.INPUT_DATA_ID.value],
        default=GroupBy.TASK_ID.value,
        help="アノテーションの個数をどの単位で集約するかを指定してます。デフォルトは'task_id'です。",
    )

    parser.set_defaults(subcommand_func=main)


def main(args):
    service = build_annofabapi_resource_and_login(args)
    facade = AnnofabApiFacade(service)
    ListAnnotationCount(service, facade, args).main()


def add_parser(subparsers: argparse._SubParsersAction):
    subcommand_name = "list_annotation_count"
    subcommand_help = "各ラベル、各属性値のアノテーション数を出力します。"
    description = "各ラベル、各属性値のアノテーション数を、タスクごと/入力データごとに出力します。"
    parser = annofabcli.common.cli.add_parser(subparsers, subcommand_name, subcommand_help, description=description)
    parse_args(parser)
