# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.messages import PytorchEngineConfig

from ..check_env.adapter import AdapterChecker
from ..check_env.base import BaseChecker
from ..check_env.model import ModelChecker
from ..check_env.torch import TorchChecker
from ..check_env.transformers import TransformersChecker


class EngineChecker(BaseChecker):
    """check transformers is available."""

    def __init__(self,
                 model_path: str,
                 engine_config: PytorchEngineConfig,
                 trust_remote_code: bool = True,
                 logger=None):
        super().__init__(logger)
        logger = self.get_logger()

        self.engine_config = engine_config

        dtype = engine_config.dtype
        device_type = engine_config.device_type

        # pytorch
        torch_checker = TorchChecker(logger=logger)

        if device_type == 'cuda':
            # triton
            from ..check_env.triton import TritonChecker
            triton_checker = TritonChecker(logger=logger)
            triton_checker.register_required_checker(torch_checker)
            self.register_required_checker(triton_checker)
        else:
            # deeplink
            from ..check_env.deeplink import DeeplinkChecker
            dl_checker = DeeplinkChecker(device_type, logger=logger)
            self.register_required_checker(dl_checker)
            self.register_required_checker(torch_checker)

        # transformers

        # model
        trans_checker = TransformersChecker()
        model_checker = ModelChecker(model_path=model_path,
                                     trust_remote_code=trust_remote_code,
                                     dtype=dtype,
                                     device_type=device_type,
                                     logger=logger)
        model_checker.register_required_checker(torch_checker)
        model_checker.register_required_checker(trans_checker)
        self.register_required_checker(model_checker)

        # adapters
        adapters = engine_config.adapters
        if adapters is not None:
            adapter_paths = list(adapters.values())
            for adapter in adapter_paths:
                adapter_checker = AdapterChecker(adapter, logger=logger)
                self.register_required_checker(adapter_checker)

    def check(self):
        """check."""
        engine_config = self.engine_config
        logger = self.get_logger()

        if engine_config.thread_safe:
            logger.warning('thread safe mode has been deprecated and'
                           ' it would be removed in the future.')

        if engine_config.max_batch_size <= 0:
            self.log_and_exit(
                mod_name='Engine',
                message='max_batch_size should be'
                f' greater than 0, but got {engine_config.max_batch_size}')
