from typing import List, Literal, Optional, Union

from pydantic import BaseModel, Field, field_validator, model_validator


class IndicesMode(BaseModel):
    model_config = {"extra": "forbid"}

    mode: Literal['full', 'sub', 'center'] = Field(default="full", description="Indices mode for selecting probe positions")
    subscan_slow: Optional[int] = Field(default=None, ge=1, description="Number of scan positions along slow direction")
    subscan_fast: Optional[int] = Field(default=None, ge=1, description="Number of scan positions along fast direction")

class BatchSize(BaseModel):
    model_config = {"extra": "forbid"}

    size: int = Field(default=32, ge=1, description="Number of diffraction patterns in a mini-batch")
    grad_accumulation: int = Field(default=1, ge=1, description="Number of mini-batches for gradient accumulation")

class ResultModes(BaseModel):
    model_config = {"extra": "forbid"}
    
    obj_dim: List[int] = Field(default=[2, 3, 4], description="Object dimensions to save", min_items=1)
    FOV: List[str] = Field(default=["crop"], description="Field of view options", min_items=1)
    bit: List[str] = Field(default=["8"], description="Bit depth options", min_items=1)

    @field_validator("obj_dim")
    @classmethod
    def validate_obj_dim(cls, v: List[int]) -> List[int]:
        """Ensure obj_dim values are between 2 and 4."""
        if not all(2 <= x <= 4 for x in v):
            raise ValueError("obj_dim must contain integers between 2 and 4")
        return v

    @field_validator("FOV")
    @classmethod
    def validate_fov(cls, v: List[str]) -> List[str]:
        """Ensure FOV values are 'full' or 'crop'."""
        valid = {"full", "crop"}
        if not all(x in valid for x in v):
            raise ValueError("FOV must contain 'full' or 'crop'")
        return v

    @field_validator("bit")
    @classmethod
    def validate_bit(cls, v: List[str]) -> List[str]:
        """Ensure bit values are 'raw', '32', '16', or '8'."""
        valid = {"raw", "32", "16", "8"}
        if not all(x in valid for x in v):
            raise ValueError("bit must contain 'raw', '32', '16', or '8'")
        return v

class CompilerConfigs(BaseModel):
    model_config = {"extra": "forbid"}
    
    enable: bool = Field(default=False, description="Boolean flag to turn on/off torch.compile") # Note that the torch.compile function signature and default are actually disable=False
    fullgraph: bool = Field(default=False)
    dynamic: Optional[bool] = Field(default=None)
    backend: Literal['inductor', 'cudagraphs', 'ipex', 'onnxrt'] = Field(default='inductor')
    mode: Literal['default', 'reduce-overhead', 'max-autotune', 'max-autotune-no-cudagraphs'] = Field(default='default')
    options: Optional[dict[str, Union[str, int, bool]]] = Field(default=None)

class ReconParams(BaseModel):
    """
    "recon_params" determines the overall reconstruction behavior including iterations, grouping/batching, 
    and saving configurations for both reconstruction and hypertune modes
    
    The PtyRAD results are organized into folder structures with 2 (reconstruction) or 3 (hypertune) levels not including the 'output/'. 
    The main (1st) output directory is specified by 'output_dir', this is usually separated by material systems or projects, 
    and presumably you'll have multiple reconstructions for this material system / project.
    Each PtyRAD reconstruction would be saved into a "reconstruction folder" that will be automatically generated by PtyRAD if 'SAVE_ITERS' is not null.
    For reconstruction mode, the folder structure might look like 'output/<MATERIALS>/<RECONSTRUCTION>'
    Note that 'recon_dir_affixes', 'prefix_time', 'prefix', and 'postfix' all operates on the reconstruction folder,
    and have no effect if 'SAVE_ITERS' is null because reconstuction folder would not be generated from the first place
    The date, pre- and postfix are automatically connected by '_', and the / behind 'output_dir' will also be automatically generated.
    For hypertune mode, a hypertune folder is automatically inserted as the 2nd level between 'output_dir' and the (optional) resonstruction folders. 
    For example, 'output_dir/<MATERIALS>/<HYPERTUNE>/<RECONSTRUCTION>'. 
    This way the hypertune folders are organized under the <MATERIALS> folder just like other reconstruction folders. 
    Note that in hypertune mode, 'prefix_time', 'prefix', and 'postfix' would be applied (i.e., hijacked) on this hypertune folder and have no effect to the reconstruction folder name. 
    In other words, 'prefix_time', 'prefix', and 'postfix' would always be applied to the folders under 'output_dir' for both reconstruction and hypertune modes.
    """

    model_config = {"extra": "forbid"}

    
    NITER: int = Field(default=200, ge=1, description="Total number of reconstruction iterations")
    """
    Total number of reconstruction iterations. 
    1 iteration means a full pass of all selected diffraction patterns. 
    Usually 20-50 iterations can get 90% of the work done with a proper learning rate between 1e-3 to 1e-4. 
    For faster trials in hypertune mode, set 'NITER' to a smaller number than your typical reconstruction to save time. 
    Usually 10-20 iterations are enough for the hypertune parameters to show their relative performance. 
    """
    
    INDICES_MODE: IndicesMode = Field(default_factory=IndicesMode, description="Indices mode configuration")
    """
    Indices mode determines multiple ways to use diffraction patterns at each probe positions for reconstructions. 
    Each probe position (or each diffraction pattern) has a unique index, 
    by selecting a subset of the indices, we can conveniently change the effective reconstruction area ('center'),
    or effective scan step size or real space overlap ('sub'). 
    You may choose 'full' for reconstructing with all probe positions, 
    'sub' for subsampling by selecting only every few probe positions with the full FOV 
    (i.e., this will increase the effective scan step size and is a good way to test whether we can reduce the real space overlap), 
    or 'center' for only using the center rectangular region in real space with a effectively reduced FOV but full sized object canvas. 
    'subscan_slow' and 'subscan_fast' determine the number of scan positions chosen for 'sub' and 'center', and have no effect to 'full'. 
    If 'subscan_slow' and 'subscan_fast' are not provided (or null), they'll be set to half of the 'N_scan_slow' and half of the 'N_scan_fast' by default. 
    Typically we can start from 'INDICES_MODE': {'mode': 'sub', 'subscan_slow: null, 'subscan_fast': null} to get an quick idea of the entire object 
    by reconstructing the entire object FOV but using only every other diffraction pattern along fast and slow scan directions 
    (so only 1/4 diffraction patterns are used, hence 4x speedup in the iteration time). 
    Similarly we can use 'center' to keep the effective scan step size, but reconstruct a smaller FOV. 
    Once the 'sub' or 'center' show reasonable results, 
    we can then switch to 'full' to further refine it without starting from scratch because there's no object dimension mismatch between the 'INDICES_MODEs'.
    """
    
    BATCH_SIZE: BatchSize = Field(default_factory=BatchSize, description="Batch size configuration")
    """
    Batch size is the number of diffraction patterns processed simultaneously to get the gradient update. 
    'size' is the number of diffraction pattern in a sub-batch, and 'grad_accumulation' is how many sub-batches' gradients are accumulated before applying the update. 
    Effective batch size (for 1 update) is batch_size * grad_accumulation. 
    Gradient accumulation is a ML technique that allows people to use large effective batch size by trading the iteration time with memory requirement, 
    so if you can fit the entire batch inside your memory, you should always set 'grad_accumulation': 1 for performance. 
    "Batch size" is commonly used in machine learning community, while it's called "grouping" in PtychoShelves. 
    Batch size has an effect on both convergence speed and final quality, usually smaller batch size leads to better final quality for iterative gradient descent, 
    but smaller batch size would also lead to longer computation time per iteration because the GPU isn't as utilized as large batch sizes (due to less GPU parallelism). 
    On the other hand, large batch size is known to be more robust (noise-resilient) but converges slower. 
    Generally batch size of 32 to 128 is used, although certain algorithms (like DM) would prefer a large batch size that is equal to the dataset size for robustness. 
    For extremely large object (or with a lot of object modes), you'll need to reduce batch size to save GPU memory, 
    or use `grad_accumulation` to split a batch into multiple sub-batches for 1 update.
    """
    
    GROUP_MODE: Literal["random", "sparse", "compact"] = Field(
        default="random", description="Spatial distribution of probe positions in a batch"
    )
    """
    Group mode determines the spatial distribution of the selected probe positions within a batch (group), 
    this is similar to the 'MLs' for 'sparse' and 'MLc' for 'compact' in PtychoShelves. 
    Available options are 'random', 'sparse', and 'compact'. 
    Usually 'random' is good enough with small batch sizes and is the suggested option for most cases. 
    'compact' is believed to provide best final quality, although it's converging much slower. 
    'sparse' gives the most uniform coverage on the object so converges the fastest, 
    and is also preferred for reconstructions with few scan positions to prevent any locally biased update. 
    However, 'sparse' for 256x256 scan could take more than 10 mins on CPU just to compute the grouping, 
    hence PtychoShelves automatically switches to 'random' for Nscans > 1e3. 
    The grouping in PtyRAD is fixed during optimization, but the order between each group is shuffled for every iteration.
    """
    
    SAVE_ITERS: Optional[int] = Field(default=10, ge=1, description="Iterations before saving results")
    """
    Number of completed iterations before saving the current reconstruction results (model, probe, object) and summary figures. 
    If 'SAVE_ITERS' is 50, it'll create an output reconstruction folder and save the results and figures into it every 50 iterations. 
    If null, the output reconstruction folder would not be created and no reconstruction results or summary figures would be saved. 
    If 'SAVE_ITERS' > 'NITER', it'll create the output reconstruction folder but no results / figs would be saved. 
    Typically we set 'SAVE_ITERS' to 50 for reconstruction mode with 'NITER' around 200 to 500. 
    For hypertune mode, it's suggested to set 'SAVE_ITERS' to null and set 'collate_results' to true to save the disk space, 
    while also provide an convenient way to check the hypertune performance by the collated results.
    """
    
    output_dir: str = Field(default="output/", description="Main output directory path")
    """
    Path and name of the main output directory. 
    Ideally the 'output_dir' keeps a series of reconstruction of the same materials system or project. 
    The PtyRAD results and figs will be saved into a reconstruction-specific folder under 'output_dir'. 
    The 'output_dir' folder will be automatically created if it doesn't exist.
    """
    
    recon_dir_affixes: List[
        Literal[
            "minimal",
            "default",
            "all",
            "indices",
            "meas",
            "batch",
            "pmode",
            "omode",
            "nlayer",
            "lr",
            "optimizer",
            "start_iter",
            "model",
            "constraint",
            "loss",
            "illumination",
            "dx",
            "tilt",
            "affine",
        ]
    ] = Field(default=["default"], description="Affixes for reconstruction folder name")
    """
    This list specifies the optional affixes to the reconstruction folder name for file management. 
    The order of strings has NO effect to the output folder name. 
    PtyRAD provides high-level presets including 'minimal', 'default', and 'all', while each of them corresponds to a subset of all available options. 
    There are currently 16 available options, including 'indices', 'meas', 'batch', 'pmode', 'omode', 'nlayer', 'lr', 'optimizer', 'start_iter', 'model', 
    'constraint', 'loss', 'illumination', 'dx', 'tilt', and 'affine'. 
    Each option corresponds to specific fields in the params file. 
    These individual tags can be combined with the presets, e.g. ['minimal', 'tilt']. 
    A typical output folder name of 'default' looks like: 
    'ptyrad/demo/output/tBL_WSe2/20250607_full_N16384_dp128_flipT100_random32_p6_1obj_6slice_dz2_plr1e-4_oalr5e-4_oplr5e-4_slr1e-4_orblur0.5_ozblur1_mamp0.03_4_oathr0.98_oposc_sng1.0_spr0.1'. 
    Note that certain trivial values might not be shown even it's specified, e.g. tilt of [0,0] mrad, slice thickness for single-slice ptychography, or start_iter = 1, etc. 
    It's recommended to use 'default' for 'reconstruction' mode and adjust if needed. 
    For 'hypertune' mode, you can set to [] (empty list) or ['minimal'] if you're considering saving intermediate results, 
    because a unique identifier (trial number) would be appended and detail information are fully stored in the sqlite, 
    and the hypertuned params are appended to the collated result anyway if you have 'collate_results': true.
    """
    
    prefix_time: Union[bool, str] = Field(default="date", description="Prefix time format for folder name")
    """
    Set to true to prepend a date str like '20240903_' in front of the reconstruction folder name, 
    so that reconstruction with the same 'recon_dir_affixes' setting won't get incorrectly saved to the same output folder. 
    Available options are None, True, False, 'date', 'time', 'datetime', and time format string like '%Y%m%d_%H%M%S'. 
    Suggested value is 'date' for both 'reconstruction' and 'hypertune' modes. 
    In hypertune mode, the date string would be applied on the hypertune folder instead of the reconstruction folder. 
    Also note that if you're using hypertune mode on multiple GPUs, you should set prefix_time to 'date' or False, and handle any additional identifier using 'prefix', 
    otherwise different workers launched at different times would each generate their own output folder with different time strings despite using the same sqlite database file.
    """
    
    @field_validator("prefix_time")
    @classmethod
    def validate_prefix_time(cls, v: Union[bool, str]) -> Union[bool, str]:
        """Ensure prefix_time is a boolean or valid time format."""
        valid_formats = {"date", "time", "datetime"}
        if isinstance(v, str) and v not in valid_formats and not v.startswith("%"):
            raise ValueError(f"prefix_time must be a boolean, one of {valid_formats}, or a valid time format string")
        return v
    
    prefix: str = Field(default="", description="Prefix for reconstruction folder name")
    """
    Prefix this string to the reconstruction folder name. 
    Note that "_" will be automatically generated, and the attached str would be after the time str if 'prefix_time' is true. 
    In hypertune mode, the prefix string would be applied on the hypertune folder instead of the reconsstruction folder. 
    """
    
    postfix: str = Field(default="", description="Postfix for reconstruction folder name")
    """
    Postfix this string to the reconstruction folder name. 
    Note that "_" will be automatically generated. 
    In hypertune mode, the postfix string would be applied on the hypertune folder instead of the reconsstruction folder. 
    """

    save_result: List[Literal["model", "obja", "objp", "probe", "probe_prop", "optim_state"]] = (
        Field(default=["model", "objp"], description="Results to save")
    )
    """
    This list specifies the available results to save every SAVE_ITERS, so it keeps the intermediate progress. 
    Available options are 'model', 'obja', 'objp', 'probe', 'probe_prop', and 'optim_state'. 
    'model' is a nested dict that later got stored as an hdf5 file. 
    'model' contains optimizable tensors and metadata so that you can always refine from it and load whatever optimizable tensors (object, probe, positions, tilts) 
    if you want to continue the reconstruction. 
    It's similar to the NiterXXX.mat from PtychoShelves. 'object' and 'probe' output the reconstructed object and probe as '.tif'. 
    If you don't want to save anything, set 'SAVE_ITERS' to null. 
    Suggested setting is to save everything (i.e., ['model', 'obja', 'objp', 'probe']). 
    For hypertune mode, you can set 'collate_results' to true and set 'SAVE_ITERS' to null to disable result saving.
    """
    
    result_modes: ResultModes = Field(default_factory=ResultModes, description="Object output configurations")
    """
    This dict specifies which object output is saved by their final dimension ('obj_dim'), 
    whether to save the full or cropped FOV ('FOV') of object, and whether to save the raw or normalized bit depth version of object and probe. 
    A comprehensive (but probably redundant) saving option looks like {'obj_dim': [2,3,4], 'FOV': ['full', 'crop'], 'bit': ['raw', '32', '16', '8']}. 
    'obj_dim' takes a list of int, the int ranges between 2 to 4, corresponding to 2D to 4D object output. 
    Set 'obj_dim': [2] if you only want the zsum from multislice ptychography. 
    Suggested value is [2,3,4] to save all possible output. 
    'FOV' takes a list of strings, the available strings are either 'full' or 'crop'. 
    Suggested value is 'crop' so the lateral padded region of object is not saved. 
    'bit' takes a list of strings, the available strings are 'raw', '32', '16', and '8'. 
    'raw' is the original value range, while '32' normalizes the value from 0 to 1. '16' and '8' will normalize the value from 0 to 65535 and 255 correspondingly. 
    Defualt is '8' to save only the normalized 8bit result for quick visualization. 
    You can set it to ['raw', '8'] if you want to keep the original float32 bit results with normalized 8bit results. 
    These postprocessing would postfix corresponding labels to the result files.
    """

    selected_figs: List[
        Literal[
            "loss", "forward", "probe_r_amp", "probe_k_amp", "probe_k_phase", "pos", "tilt", "tilt_avg", "slice_thickness", "all"
        ]
    ] = Field(default=["loss", "forward", "probe_r_amp", "pos"], description="Figures to plot/save")
    """
    This list specified the selected figures that will be plotted/saved. 
    The available strings are 'loss', 'forward', 'probe_r_amp', 'probe_k_amp', 'probe_k_phase', 'pos', 'tilt', 'tilt_avg', 'slice_thickness', and 'all'. 
    The suggested value is ['loss', 'forward', 'probe_r_amp', 'pos'].
    """

    copy_params: bool = Field(default=True, description="Copy params file to output folder")
    """
    Set to true if you want to copy the .yml params file to the hypertune folder (hypertune mode) or individual reconstruction folders (reconsturction mode). 
    Suggested value is true for better record keeping, although most information is saved in model.pt and can be loaded by ckpt = torch.load('model.pt'), params = ckpt['params'].
    """
    
    if_quiet: bool = Field(default=False, description="Reduce printed information during reconstruction")
    """
    Set to true if you want to reduce the amount of printed information during PtyRAD reconstruction. 
    Suggested value is false for more information, but if you're running hypertune mode you should consider setting it to true.
    """

    compiler_configs: Optional[CompilerConfigs] = Field(default_factory=CompilerConfigs, description="PyTorch compiler configurations")
    """
    This dict specifies the PyTorch JIT compiler configurations.
    Set to {'enable': true} to enable PyTorch JIT compilation for a 1.3-1.9x speedup on supported hardware.
    See https://docs.pytorch.org/docs/stable/generated/torch.compile.html for more details.
    """
    
    @model_validator(mode="after")
    def fill_compiler_configs_defaults(self):
        if self.compiler_configs is None:
            self.compiler_configs = CompilerConfigs()
        return self