#!python
import os
import typer
from typing import Optional
from together_cli.src.core.instances import pprint_instances, shutdown_instance, fetch_logs
from together_cli.src.model import serve_model
from together_cli.src.system import check_binary_exists, check_folders, check_lockable_drive
from together_cli.src.core.models import pprint_models
from together_cli.src.core.config import read_config

app = typer.Typer()
home_dir = os.path.expanduser("~")
default_together_home = os.path.join(home_dir, "together")
config_path = os.path.join(default_together_home, "config.json")
config = read_config(config_path)

@app.command()
def check():
    is_slurm = check_binary_exists("sinfo")
    print("Slurm: ", is_slurm)
    is_singularity = check_binary_exists("singularity")
    print("Singularity: ", is_singularity)
    is_docker = check_binary_exists("docker")
    print("Docker: ", is_docker)


@app.command()
def serve(
        # required arguments
        model: str = typer.Option(...,
            prompt="What's the model name you want to serve?"
        ),
        home_dir: str = typer.Option(
            config['home_dir'], help="The home directory for Together? It cannot be on an NFS drive."
        ),
        data_dir: str = typer.Option(
            config['data_dir'], prompt="The directory you want to store model weights? It could be on an NFS drive."
        ),
        # optional, but suggested arguments
        gpus: str = typer.Option(
            None, help="GPU Specifiers (e.g., titanrtx:1), required if you are not using baremetal ndoes"
        ),
        queue: str = typer.Option(None, help="Queue name - default is None"),
        singularity: bool = typer.Option(
            False, help="Use singularity to serve the model"),
        docker: Optional[bool] = typer.Option(
            False, help="Use docker to serve the model"),
        tags: Optional[str] = typer.Option("", help="tags"),
        account: str = typer.Option(
            None, help="Account name - default is None"),
        modules: str = typer.Option(
            None, help="Modules to load - default is None"),
        duration: str = typer.Option(
                "1:00:00", help="Duration of the job - default is '1:00:00'"),
        matchmaker_addr: str = typer.Option(
                "wss://api.together.xyz/websocket", help="Global Matchmaker address - leave it as it is in most cases"),
        port: int = typer.Option(
            config['last_used_port']+2, help="Port number - default is 8092-8093. In case of conflict, change it to a different number, increase by 2"),
        node_list: str = typer.Option(
                None, help="Node list - default is None"),
        cluster: str = typer.Option(
                "baremetal", help="Cluster Management System - default is 'baremetal'"),
        dry_run: bool = typer.Option(
                False, help="Only Generate submission scripts for review - default is False"),
        owner: str = typer.Option(
                "", help="Owner of the instance - default is None"),
    ):
    if docker and singularity:
        print("[ERROR] You can only choose one of docker or singularity")
        return
    if docker:
        print("[INFO] Containerization: Docker")
    elif singularity:
        print("[INFO] Containerization: Singularity")
    else:
        print("[ERROR] You must choose one of docker or singularity")

    if cluster != 'baremetal' and gpus is None:
        print("[ERROR] You must specify gpus if you are not using baremetal nodes")
        return
    # expand home_dir and data_dir if they start with ~
    if home_dir.startswith("~"):
        home_dir = os.path.expanduser(home_dir)
    if data_dir.startswith("~"):
        data_dir = os.path.expanduser(data_dir)

    check_folders(home_dir=home_dir, data_dir=data_dir)
    is_homedir_lockable = check_lockable_drive(home_dir)

    if not is_homedir_lockable:
        print(
            "[ERROR] Your home directory is not lockable. Please choose another directory.")
        return

    serve_model(
        model_name=model,
        queue_name=queue,
        home_dir=home_dir,
        data_dir=data_dir,
        matchmaker_addr=matchmaker_addr,
        tags=tags,
        use_docker=docker,
        use_singularity=singularity,
        gpus=gpus,
        account=account,
        modules=modules,
        node_list=node_list,
        port=port,
        duration=duration,
        cluster=cluster,
        dry_run=dry_run,
        owner = owner,
    )


@app.command()
def list(entity: str):
    if entity == "jobs":
        pprint_instances()
    elif entity == "models":
        pprint_models()
    else:
        print("[ERROR] Unknown: ", entity)

@app.command()
def status():
    pprint_instances()

@app.command()
def main():
    print("Together CLI")

@app.command()
def logs(node_name: str):
    fetch_logs(node_name)

@app.command()
def stop(node_name: str):
    shutdown_instance(node_name)

if __name__ == "__main__":
    app()
