"""Modalities"""
{% raw -%}
from pydantic import Field, ConfigDict
from typing import Literal, Union
from typing_extensions import Annotated
from aind_data_schema_models.pid_names import BaseName
{% endraw %}

class ModalityModel(BaseName):
    """Base model for modality"""
    model_config = ConfigDict(frozen=True)
    name: str
    abbreviation: str

{% for _, row in data.iterrows() %}
class {{ row['abbreviation'] | to_class_name_underscored }}(ModalityModel):
    """Model {{row['abbreviation']}}"""
    name: Literal["{{ row['name'] }}"] = "{{ row['name'] }}"
    abbreviation: Literal["{{ row['abbreviation'] }}"] = "{{ row['abbreviation'] }}"

{% endfor %}
class Modality:
    """Modalities"""
{% for _, row in data.iterrows() %}
    {{ row['abbreviation'] | to_class_name | upper }} = {{ row['abbreviation'] | to_class_name_underscored }}()
{%- endfor %}

    ALL = tuple(ModalityModel.__subclasses__())

    ONE_OF = Annotated[Union[tuple(ModalityModel.__subclasses__())], Field(discriminator="abbreviation")]

    abbreviation_map = {m().abbreviation: m() for m in ALL}

    @classmethod
    def from_abbreviation(cls, abbreviation: str):
        """Get modality from abbreviation"""
        return cls.abbreviation_map.get(abbreviation, None)
