"""Organizations"""
{% raw -%}
from pydantic import BaseModel, Field, ConfigDict
from typing import Literal, Union, Optional
from typing_extensions import Annotated
from aind_data_schema_models.registries import Registry
{% endraw %}

class OrganizationModel(BaseModel):
    """Base model for organizations"""
    model_config = ConfigDict(frozen=True)
    name: str
    abbreviation: str
    registry: Optional[Registry] = Field(default=None)
    registry_identifier: Optional[str] = Field(default=None)

{% for _, row in data.iterrows() %}
class {{ row['name'] | to_class_name_underscored }}(OrganizationModel):
    """Model {{row['name']}}"""
    name: Literal["{{ row['name'] }}"] = "{{ row['name'] }}"
    {% set abbreviation = row['abbreviation'] -%}
    {% set registry_abbreviation = row['registry_abbreviation'] -%}
    {% set registry_identifier = row['registry_identifier'] -%}
    abbreviation: Literal[{% if abbreviation == abbreviation %}"{{ abbreviation }}"{% else %}None{% endif %}] = {% if abbreviation == abbreviation %}"{{ abbreviation }}"{% else %}None{% endif %}
    registry: Optional[Registry] = {% if registry_abbreviation == registry_abbreviation %}Field(default=Registry.{{ registry_abbreviation }}){% else %}Field(default=None){% endif %}
    registry_identifier: Optional[str] = {% if registry_identifier == registry_identifier %}Field(default="{{ registry_identifier }}"){% else %}Field(default=None){% endif %}

{% endfor %}
class Organization:
    """Organization"""
{% for _, row in data.iterrows() -%}
    {% set abbreviation = row['abbreviation'] %}
    {{ (abbreviation if abbreviation == abbreviation else row['name']) | to_class_name | upper }} = {{ row['name'] | to_class_name_underscored }}()
{%- endfor %}

    ALL = tuple(OrganizationModel.__subclasses__())

    ONE_OF = Annotated[Union[{% for _, row in data.iterrows() %}{{ row['name'] | to_class_name_underscored }}{{ ", " if not loop.last else "" }}{% endfor %}], Field(discriminator="name")]

    abbreviation_map = {m().abbreviation: m() for m in ALL if m().abbreviation is not None}

    @classmethod
    def from_abbreviation(cls, abbreviation: str):
        """Get platform from abbreviation"""
        return cls.abbreviation_map.get(abbreviation, None)

    name_map = {m().name: m() for m in ALL}

    @classmethod
    def from_name(cls, name: str):
        """Get platform from name"""
        return cls.name_map.get(name, None)
