# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/05_dataclasses.ipynb.

# %% auto 0
__all__ = ['T', 'Accelerator', 'Note', 'Header', 'EnvDep', 'EnvDeps', 'BaseCategory', 'SubCategory', 'BaseCategories',
           'SubCategories', 'Category', 'Categories', 'EnvSpec']

# %% ../nbs/05_dataclasses.ipynb 3
import os, re, yaml, math, textwrap
from typing import List, Optional, Set, Union, Dict, Tuple, TypeVar, Type, Iterable
from collections import OrderedDict

from dataclasses import dataclass, field
from rich.repr import auto as rich_auto, Result as RichReprResult

from rich.prompt import Prompt
from rich.pretty import pprint, Pretty
from rich.panel import Panel
from rich import print
from rich.tree import Tree
from rich.table import Table

from questionary import Choice, checkbox, Style, select

# %% ../nbs/05_dataclasses.ipynb 4
from litenv.constants import (
    DEP_SEP, HEADER_LINE, HEADER_SEP, HEADER_LEN,
    NAME, NOTE, LEVEL, ALWAYS, PYTHON, CHECKED,
    VERSION, CHANNEL, CHANNELS, PIP_ONLY, 
    CATEGORY, PYPI_NAME, CUDA_ONLY, CATEGORIES,
    DEPENDENCIES, SUBCATEGORIES, NO_ARM_SUPPORT,
    DEFAULT_STYLE, DEFAULT_SPEC_FILE, ACCELERATORS,
    FUCHSIA, BLUE_NIGHT, GOLDEN_SHINE, GREEN_LEAF, RED_VELVET,
    ACCELERATORS, DEFAULT_ACCELERATORS
)

from litenv.themes import (
    THEMES, 
)

from .rich import console, cprint, swrap, clist, c_strs

# %% ../nbs/05_dataclasses.ipynb 5
@dataclass
@rich_auto(angular=True)
class Accelerator:
    accelerator: str

    @property
    def _accel(self) -> str:
        return self.accelerator.lower()
    
    @property
    def is_cpu(self) -> bool:
        return self._accel == 'cpu'

    @property
    def is_mps(self) -> bool:
        return self._accel == 'mps'

    @property
    def is_gpu(self) -> bool:
        return self._accel == 'cuda'

    def get_bools(self) -> List[bool]:
        return [self.is_cpu, self.is_gpu, self.is_mps]


# %% ../nbs/05_dataclasses.ipynb 6
@dataclass
@rich_auto(angular=True)
class Note:
    note: str
    _width: int = HEADER_LEN
    _tab_width: int = 8
    _prefix: str = '# NOTE: '
    _second: str = '#       '
    
    @property
    def width(self) -> int:
        diff = self._width - self._tab_width
        return min(1e4, max(diff, 0))

    @property
    def repr_width(self) -> int:
        return self.width // 4
    
    def wrap(self, **kwargs) -> List[str]:
        if not kwargs.get('width', None):
            kwargs.setdefault('width', self.width)

        if not kwargs.get('break_long_words', None):
            kwargs.setdefault('break_long_words', True)
        
        wrapper = textwrap.TextWrapper(**kwargs)
        lines = wrapper.wrap(self.note)
        res = []
        for i, line in enumerate(lines):
            if i == 0:
                res.append(f'{self._prefix}{line}')
            else:
                res.append(f'{self._second}{line}')
        return res
    
    def truncate(self, **kwargs):
        if not kwargs.get('width', None):
            kwargs.setdefault('width', self.repr_width)

        if not kwargs.get('break_long_words', None):
            kwargs.setdefault('break_long_words', False)

        width = kwargs['width']
        if len(self.note) <= width:
            return self

        lines = self.wrap(**kwargs)            
        return lines[0][:width].replace(self._prefix, '')

    def __rich_repr__(self) -> RichReprResult:
        yield self.truncate()

# %% ../nbs/05_dataclasses.ipynb 7
@dataclass
@rich_auto(angular=True)
class Header:
    name: str

    max_len: int = HEADER_LEN
    pip_len: int = HEADER_LEN - 2
    tab_len: int = 4
    _prefix: str = '# '
    sepchar: str = HEADER_SEP

    @property
    def _4th_len(self) -> int:
        return math.ceil(self.max_len / 4)

    @property
    def _max_line(self) -> str:
        line = f'{self.sepchar * self.max_len}'
        return f'{self._prefix}{line}'

    @property
    def _pip_line(self) -> str:
        line = f'{self.sepchar * self.pip_len}'
        return f'{self._prefix}{line}'

    @property
    def _str(self) -> str:
        return f'{self._prefix}{self.name}'

    def add_to_list(self, lst:List[str], level:Optional[int]=0, is_pip:Optional[bool]=False) -> List[str]:
        match level:
            case 0:
                lst.append(self._str)
                lst.append(self._pip_line if is_pip else self._max_line)

            case 1:                    
                _tab = " " * self._4th_len
                _len = (self.pip_len if is_pip else self.max_len) - len(_tab)
                line = self.sepchar * _len
                lst.append(f'{_tab}{self._str}')
                lst.append(f'{_tab}{self._prefix}{line}')

        return lst

    def place_in_dependency_list(
        self, conda_deps: List[str], pip_deps: List[str], 
        level: Optional[int] = 0, n_conda: Optional[int] = 0, n_pip: Optional[int] = 0
    ) -> Tuple[List[str], List[str]]:
        if n_conda > 0:
            conda_deps = self.add_to_list(conda_deps, level, is_pip=False)
        if n_pip > 0:
            pip_deps = self.add_to_list(pip_deps, level, is_pip=True)
        return conda_deps, pip_deps
    

# %% ../nbs/05_dataclasses.ipynb 8
@dataclass
@rich_auto(angular=True)
class EnvDep:
    name: str
    note: Optional[Union[str, Note]] = None
    always: Optional[bool] = False
    checked: Optional[bool] = False
    version: Optional[str] = None
    channel: Optional[str] = None
    pip_only: Optional[bool] = False
    pypi_name: Optional[str] = None
    cuda_only: Optional[bool] = False
    no_arm_support: Optional[bool] = False
    
    category: Optional[str] = None
    subcategory: Optional[str] = None

    @property
    def pkg_name(self) -> str:
        return self.pypi_name or self.name

    @property
    def channel_str(self) -> str:
        return f'{self.channel}{DEP_SEP}' if self.channel else ''

    @property
    def version_str(self) -> str:
        return f'{self.version}' if self.version else ''

    @property
    def pip_str(self) -> str:
        pkg_str = f'{self.pkg_name}{self.version_str}'
        return pkg_str

    @property
    def conda_str(self) -> str:
        pkg_str = f'{self.channel_str}{self.name}{self.version_str}'        
        return pkg_str

    def add_note_to_list(self, lst: List[str]) -> List[str]:        
        if self.note:
            if isinstance(self.note, str):
                note = Note(self.note)
            lst.extend(note.wrap())
        return lst

    def add_to_list(
        self, lst: List[str], accel:Accelerator, pip:Optional[bool]=False, conda:Optional[bool]=False) -> List[str]:
        pkg_str = self.pip_str if pip else self.conda_str
        self.add_note_to_list(lst)
        lst.append(pkg_str)
        return lst

    def add_to_pip_list(self, lst: List[str], accel:Accelerator) -> List[str]:
        return self.add_to_list(lst, accel, pip=True)

    def add_to_conda_list(self, lst: List[str], accel:Accelerator) -> List[str]:
        return self.add_to_list(lst, accel, conda=True)

    def is_pip(self, accel:Accelerator) -> bool:
        cpu, gpu, mps = accel.get_bools()
        pip_only = self.pip_only
        pypi_name = self.pypi_name
        no_arm_support = self.no_arm_support
        if mps and no_arm_support:
            pip_only = True
        return pip_only

    def adjust_for_accelerator(self, accel:Accelerator) -> 'EnvDep':
        if self.is_pip(accel):
            self.channel = None
            self.pip_only = True
        return self

    def place_in_dependency_list(
        self, conda_deps: List[str], pip_deps: List[str], accel:Accelerator
    ) -> Tuple[List[str], List[str]]:        
        cpu, gpu, mps = accel.get_bools()        
        if not gpu and self.cuda_only:
            return conda_deps, pip_deps

        if self.is_pip(accel):
            self.add_to_pip_list(pip_deps, accel)
        else:
            self.add_to_conda_list(conda_deps, accel)
        return conda_deps, pip_deps
        

# %% ../nbs/05_dataclasses.ipynb 9
@dataclass
@rich_auto(angular=True)
class EnvDeps:
    dependencies: Dict[str, EnvDep] = field(default_factory=dict)
    def __post_init__(self):
        flag = all(list(map(lambda d: isinstance(d, EnvDep), self.dependencies.values())))
        if flag:            
            return
        self.dependencies = {k: EnvDep(**{'name': k, **(v or {})}) for k, v in self.dependencies.items()}

    def __getitem__(self, key:str) -> EnvDep:
        return self.dependencies[key]
    
    def __setitem__(self, key:str, value:EnvDep) -> None:
        self.dependencies[key] = value
    
    def __delitem__(self, key:str) -> None:
        del self.dependencies[key]
    
    def __iter__(self):
        return iter(self.dependencies)

    def __len__(self) -> int:
        return len(self.dependencies)

    def keys(self) -> Iterable[str]:
        return self.dependencies.keys()

    def values(self) -> Iterable[EnvDep]:
        return self.dependencies.values()

    def items(self) -> Iterable[Tuple[str, EnvDep]]:
        return self.dependencies.items()

    def n_conda(self, accel:Accelerator) -> int:
        return sum([1 for dep in self.dependencies.values() if not dep.is_pip(accel)])

    def n_pip(self, accel:Accelerator) -> int:
        return sum([1 for dep in self.dependencies.values() if dep.is_pip(accel)])

    def calc_n_conda_n_pip(self, accel:Accelerator) -> Tuple[int, int]:
        n_conda = self.n_conda(accel)
        n_pip = self.n_pip(accel)
        return n_conda, n_pip

    def determine_needed_channels(self, accel:Accelerator, possible_channels:List[str]=[]) -> Set[str]:
        prefered_order = {chan: i for i, chan in enumerate(possible_channels)}
        channels = []
        cpu, gpu, mps = accel.get_bools()
        if gpu and 'nvidia' in possible_channels: 
            channels.append('nvidia')

        for dep in self.dependencies.values():
            # Channel only needed for GPUs and this environment is not for GPU
            if dep.cuda_only and not gpu:
                continue
            
            # Channel does not support ARM (e.g. Apple Silicon) and this environment is for MPS
            if dep.no_arm_support and mps:
                continue
            
            # Channel is only for conda installation, but this environment will be installed from pip
            if dep.is_pip(accel):
                continue
            
            # Guards passed, add channel to list
            if dep.channel:
                channels.append(dep.channel)
            
            # Compare against additional possible channels
            for chan in possible_channels:
                if chan == dep.name:
                    channels.append(chan)

                gpu_name_cond = any(chan in dep.name for chan in 'cuda cudnn nccl gpu'.split())
                if gpu and gpu_name_cond:                
                    channels.append(chan)

        channels = list(set(channels))
        channels.sort(key=lambda c: prefered_order.get(c, len(possible_channels)))
        return channels

    def get_names(self) -> List[str]:
        return [dep.pkg_name for dep in self.dependencies.values()]
        

# %% ../nbs/05_dataclasses.ipynb 10
@dataclass
@rich_auto(angular=True)
class BaseCategory:
    key: str
    name: Optional[str] = None
    checked: Optional[bool] = False
    dependencies: Optional[EnvDeps] = field(default_factory=EnvDeps)
    level: Optional[int] = 0

    def __init__(self, key, **kwargs):
        # Here we ignore any kwargs that we don't know what to do with
        self.key = key
        self.name = kwargs.get(NAME, None)
        self.checked = kwargs.get(CHECKED, False)
        self.dependencies = EnvDeps(dependencies=kwargs.get(DEPENDENCIES, {}))
        self.level = kwargs.get(LEVEL, 0)

    @property
    def header(self) -> Header:
        return Header(self.name)
    
    @property
    def title(self) -> str:
        return self.name or self.key

    def place_in_dependency_list(
        self, conda_deps: List[str], pip_deps: List[str], accel:Accelerator, only_always:bool=False
    ) -> Tuple[List[str], List[str]]:

        n_conda, n_pip = self.dependencies.calc_n_conda_n_pip(accel)
        conda_deps, pip_deps = self.header.place_in_dependency_list(
            conda_deps, pip_deps, self.level, n_conda, n_pip
        )
        for dep in self.dependencies.values():
            if only_always:
                if dep.always:
                    conda_deps, pip_deps = dep.place_in_dependency_list(conda_deps, pip_deps, accel)
            else:
                conda_deps, pip_deps = dep.place_in_dependency_list(conda_deps, pip_deps, accel)        
        return conda_deps, pip_deps

    def to_choice(self) -> Choice:    
        deps = self.dependencies.get_names()
        title = self.title
        deps = str(tuple(deps)).replace("'", "").replace(",)", ")").replace(", )", ")")
        title = f'{self.title} {deps}'
        return Choice(title, checked=self.checked, value=self.key)

# %% ../nbs/05_dataclasses.ipynb 11
@dataclass
@rich_auto(angular=True)
class SubCategory(BaseCategory):    
    level: Optional[int] = 1
    # TypeError: SubCategory.__init__() got an unexpected keyword argument 'packaging'
    def __init__(self, key, **kwargs):
        super().__init__(key, **kwargs)
        

# %% ../nbs/05_dataclasses.ipynb 12
T = TypeVar("T", bound=BaseCategory)

@dataclass
@rich_auto(angular=True)
class BaseCategories:        
    categories: Dict[str, T] = field(default_factory=dict)
    category_type: Type[T] = BaseCategory

    def __init__(self, category_type: Type[T], **kwargs):
        self.category_type = category_type
        self.categories = {k: category_type(k, **v) for k, v in kwargs.items()}

    def __getitem__(self, key:str) -> T:
        return self.categories[key]

    def __setitem__(self, key:str, value:T) -> None:
        self.categories[key] = value

    def __delitem__(self, key:str) -> None:
        del self.categories[key]

    def __iter__(self):
        return iter(self.categories)

    def __len__(self) -> int:
        return len(self.categories)

    def keys(self) -> Iterable[str]:
        return self.categories.keys()

    def values(self) -> Iterable[T]:
        return self.categories.values()

    def items(self) -> Iterable[Tuple[str, T]]:
        return self.categories.items()

    def place_in_dependency_list(
        self, conda_deps: List[str], pip_deps: List[str], accel:Accelerator, 
        include: List[str] = [], include_always: Optional[bool] = False
    ) -> Tuple[List[str], List[str]]:
        keys = list(filter(lambda k: k in include, self.categories.keys()))
        for key in keys:
            conda_deps, pip_deps = self.categories[key].place_in_dependency_list(conda_deps, pip_deps, accel)
        if not include_always:
            return conda_deps, pip_deps
        
        for key in self.categories.keys():
            if key not in keys:
                conda_deps, pip_deps = self.categories[key].place_in_dependency_list(conda_deps, pip_deps, accel, only_always=True)                         
        return conda_deps, pip_deps


    def make_choices(self) -> List[Choice]:
        return [cat.to_choice() for cat in self.categories.values()]

    def get_all_dependencies(self) -> Iterable[Tuple[str, EnvDep]]:
        for cat in self.categories.values():
            for key, dep in cat.dependencies.items():
                yield key, dep
        
    def get_selected_dependencies(self, selected: List[str] = []) -> Dict[str, EnvDep]:
        deps = {}
        selected = selected or []
        subset = {k: v for k, v in self.categories.items() if k in selected}
        for category in subset.values():
            deps.update(category.dependencies)        
        return deps

    def get_all_dependencies(self) -> Dict[str, EnvDep]:
        all_deps = {}
        for category in self.categories.values():
            all_deps.update(category.dependencies)
        return all_deps

# %% ../nbs/05_dataclasses.ipynb 13
@dataclass
@rich_auto(angular=True)
class SubCategories(BaseCategories):
    def __init__(self, *args, **kwargs):                    
        super().__init__(SubCategory, **kwargs)
        

    def get_subcategory(self, subkey:str) -> dict:
        return self.categories.get(subkey, {})

# %% ../nbs/05_dataclasses.ipynb 14
@dataclass
@rich_auto(angular=True)
class Category(BaseCategory):    
    subcategories: Optional[SubCategories] = field(default_factory=SubCategories)

    def __init__(self, key, **kwargs):
        super().__init__(key, **kwargs)
        self.subcategories = SubCategories(**kwargs.pop(SUBCATEGORIES, {}))
        
    def add_to_dependency_lists(
        self, conda_deps: List[str], pip_deps: List[str], accel:Accelerator,
        subcategories_to_include: Optional[List[str]] = [], include_always: Optional[bool] = False
    ) -> Tuple[List[str], List[str]]:
        self.place_in_dependency_list(conda_deps, pip_deps, accel, include_always)
        self.subcategories.place_in_dependency_list(
            conda_deps, pip_deps, accel, subcategories_to_include, include_always
        )
        return conda_deps, pip_deps
    
    def get_subcategory(self, subkey:str) -> dict:
        return self.subcategories.get_subcategory(subkey, {})  

    def get_selected_dependencies(self, selected: List[str] = []) -> Dict[str, EnvDep]:
        deps = {**self.dependencies.dependencies}
        selected = selected or []
        subdeps = self.subcategories.get_selected_dependencies(selected)
        deps.update(subdeps)        
        return deps
    

# %% ../nbs/05_dataclasses.ipynb 15
@dataclass
@rich_auto(angular=True)
class Categories(BaseCategories):
    def __init__(self, *args, **kwargs):
        super().__init__(Category, **kwargs)

    def add_to_dependency_lists(
        self, conda_deps: List[str], pip_deps: List[str], accel:Accelerator,
        include_specification: Dict[str, List[str]] = {}, include_always: Optional[bool] = False
    ) -> Tuple[List[str], List[str]]:        
        for key, category in self.categories.items():
            if key in include_specification:        
                subcategories_to_include = include_specification[key] or []      
                conda_deps, pip_deps = category.add_to_dependency_lists(
                    conda_deps, pip_deps, accel, subcategories_to_include, include_always
                )
        return conda_deps, pip_deps

    def get_category(self, catkey:str) -> dict:
        return self.categories[catkey]

    def get_subcategories(self, catkey:str) -> dict:
        return self.get_category(catkey).subcategories
    
    def get_subcategory(self, catkey:str, subkey:str) -> dict:
        return self.get_subcategories(catkey).get_subcategory(subkey)

    
    def get_all_dependencies(self) -> Dict[str, EnvDep]:
        all_deps = super().get_all_dependencies()
        for category in self.categories.values():
            all_deps.update(category.subcategories.get_all_dependencies())
        return all_deps

    def get_selected_dependencies(self, selected: Dict[str, List[str]] = {}, ) -> Dict[str, EnvDep]:
        deps = {}
        for key, category in self.categories.items():
            if key in selected:
                subselect = selected[key]       
                cat_deps = category.get_selected_dependencies(subselect)
                deps.update(cat_deps)                
        return deps

# %% ../nbs/05_dataclasses.ipynb 16
@dataclass
@rich_auto(angular=True)
class EnvSpec:
    filename: str = DEFAULT_SPEC_FILE
    theme: str = BLUE_NIGHT
    accelerators: List[str] = field(default_factory=list)
    
    def __post_init__(self):
        with open(self.filename, 'r') as file:
            self.data = yaml.safe_load(file)      
          
        self.python_version = self.data.get(PYTHON, None)
        self.categories = Categories(**self.data.get(CATEGORIES, {}))
        self.channels = self.data.get(CHANNELS, [])
        self.accelerators = self.data.get(ACCELERATORS, DEFAULT_ACCELERATORS)

    def get_style(self):
        style = None
        if self.theme in THEMES:
            theme = THEMES[self.theme]
            style = theme.checkbox
        return style


    def select_accelerator(self) -> str:        
        style = self.get_style()
        accel = select(
            'Select accelerator:',
            choices=self.accelerators, style=style, default=self.accelerators[0],
            use_jk_keys=False, use_shortcuts=False, 
            use_arrow_keys=True, show_selected=True
        ).ask()
        
        return accel


    def select_categories(self) -> List[str]:
        style = self.get_style()
        
        choices = self.categories.make_choices()
        selected_catkeys = checkbox(
            'Select categories:',
            choices=choices, style=style,
            use_jk_keys=False, use_arrow_keys=True,
        ).ask()
        if selected_catkeys is None:
            return 

        return selected_catkeys

    def select_subcategories(self, catkey:str) -> List[str]:
        style = self.get_style()

        subchoices = self.categories.get_subcategories(catkey).make_choices()
        if len(subchoices) == 0:
            return []
        
        selected_subkeys = checkbox(
            f'Select additional subcategories for {catkey} (use space to select multiple):',
            choices=subchoices, style=style,
            use_jk_keys=False, use_arrow_keys=True,
        ).ask()
        return selected_subkeys


    def to_dict(
        self, 
        name:str,
        selected:Dict[str, List[str]],
        accelerator:Optional[str]='cpu',
    ) -> dict:
        accel = Accelerator(accelerator)        
        cnd_deps = list([f'{PYTHON}{self.python_version}'])
        pip_deps = list()

        self.categories.add_to_dependency_lists(cnd_deps, pip_deps, accel, selected)
        
        if len(pip_deps) > 0:
            cnd_deps.append('pip')
            cnd_deps.append(dict(pip=pip_deps))

    
        all_deps = self.categories.get_selected_dependencies(selected)
        all_deps = EnvDeps(dependencies=all_deps)
        channels = all_deps.determine_needed_channels(accel, possible_channels=self.channels)

        res = dict(name=name, channels=channels, dependencies=cnd_deps)
        return res


    def dict_to_str(self, env_dict:dict) -> str:        
        env_str = yaml.dump(env_dict, width=float('inf'), sort_keys=False)

        env_lines = []
        added_space = False
        for line in env_str.split('\n'):
            prev = env_lines[-1] if len(env_lines) > 0 else ''
            
            if '#' in line:                
                prev_not_dep = '::' not in prev and ':' in prev 
                if not added_space and env_lines[-1] != '' and (not prev_not_dep and '#' not in env_lines[-1]):
                    env_lines.append('')
                    added_space = True
                                    
                line = re.sub(r"'", '', line, flags=re.IGNORECASE)
                line = line.replace('- ', '')
            else:                
                added_space = False         
                                   
            # space before top level keys
            if '::' not in line and ':' in line and len(env_lines) > 0 and env_lines[-1] != '':
                if not ('- pip' in env_lines[-1]):
                    env_lines.append('')

            if line == '- pip' and ':' not in line and not added_space:
                env_lines.append('')

            if 'NOTE' in line:
                prev = env_lines[-1] if len(env_lines) > 0 else ''
                if prev == '' and len(env_lines):
                    env_lines.pop()
            env_lines.append(line)
            

        env_str = '\n'.join(env_lines)
        return env_str
    

    def print_selected(self, selected:Dict[str, List[str]],):
        tree = Tree('Selected Categories')
        for cat, subs in selected.items():
            tcat = tree.add(self.categories[cat].title)
            for dep in self.categories[cat].dependencies.values():
                tcat.add(dep.name)

            for sub in subs:
                tsub = tcat.add(self.categories[cat].subcategories[sub].title)
                for dep in self.categories[cat].subcategories[sub].dependencies.values():
                    tsub.add(dep.name)
        
        print(tree)
