# Proprietary Changes made for Trainy under the Trainy Software License
# Original source: skypilot: https://github.com/skypilot-org/skypilot
# which is Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
On module import, we attempt to parse the config located at KONDUKTOR_CONFIG
(default: ~/.konduktor/config.yaml). Caller can then use

  >> konduktor_config.loaded()

to check if the config is successfully loaded.

To read a nested-key config:

  >> konduktor_config.get_nested(('auth', 'some_auth_config'), default_value)

The config can be overridden by the configs in task YAMLs. Callers are
responsible to provide the override_configs. If the nested key is part of
OVERRIDEABLE_CONFIG_KEYS, override_configs must be provided (can be empty):

  >> konduktor_config.get_nested(('docker', 'run_options'), default_value
                        override_configs={'docker': {'run_options': 'value'}})

To set a value in the nested-key config:

  >> config_dict = konduktor_config.set_nested(('auth', 'some_key'), value)

This operation returns a deep-copy dict, and is safe in that any key not found
will not raise an error.

Example usage:

Consider the following config contents:

    a:
        nested: 1
    b: 2

then:

    # Assuming ~/.konduktor/config.yaml exists and can be loaded:
    konduktor_config.loaded()  # ==> True

    konduktor_config.get_nested(('a', 'nested'), None)    # ==> 1
    konduktor_config.get_nested(('a', 'nonexist'), None)  # ==> None
    konduktor_config.get_nested(('a',), None)             # ==> {'nested': 1}

    # If ~/.konduktor/config.yaml doesn't exist or failed to be loaded:
    konduktor_config.loaded()  # ==> False
    konduktor_config.get_nested(('a', 'nested'), None)    # ==> None
    konduktor_config.get_nested(('a', 'nonexist'), None)  # ==> None
    konduktor_config.get_nested(('a',), None)             # ==> None
"""

import copy
import os
import pprint
from typing import Any, Dict, List, Optional, Tuple

import yaml  # type: ignore

from konduktor import logging
from konduktor.utils import common_utils, schemas, ux_utils

logger = logging.get_logger(__name__)

# overrides are specified in task YAMLs.
OVERRIDEABLE_CONFIG_KEYS: List[Tuple[str, ...]] = [
    ('kubernetes', 'pod_config'),
]

# The config path is discovered in this order:
#
# (1) (Used internally) If env var {ENV_VAR_SKYPILOT_CONFIG} exists, use its
#     path;
# (2) If file {CONFIG_PATH} exists, use this file.
#
# If the path discovered by (1) fails to load, we do not attempt to go to step
# 2 in the list.

# (Used internally) An env var holding the path to the local config file. This
# is only used by jobs controller tasks to ensure recoveries of the same job
# use the same config file.
ENV_VAR_CONFIG = 'KONDUKTOR_CONFIG'

# Path to the local config file.
CONFIG_PATH = '~/.konduktor/config.yaml'


class Config(Dict[str, Any]):
    """Konduktor config that supports setting/getting values with nested keys."""

    def get_nested(
        self,
        keys: Tuple[str, ...],
        default_value: Any,
        override_configs: Optional[Dict[str, Any]] = None,
    ) -> Any:
        """Gets a nested key.

        If any key is not found, or any intermediate key does not point to a
        dict value, returns 'default_value'.

        Args:
            keys: A tuple of strings representing the nested keys.
            default_value: The default value to return if the key is not found.
            override_configs: A dict of override configs with the same schema as
                the config file, but only containing the keys to override.

        Returns:
            The value of the nested key, or 'default_value' if not found.
        """
        config = copy.deepcopy(self)
        if override_configs is not None:
            config = _recursive_update(config, override_configs)
        return _get_nested(config, keys, default_value)

    def set_nested(self, keys: Tuple[str, ...], value: Any) -> None:
        """In-place sets a nested key to value.

        Like get_nested(), if any key is not found, this will not raise an
        error.
        """
        override = {}
        for i, key in enumerate(reversed(keys)):
            if i == 0:
                override = {key: value}
            else:
                override = {key: override}
        _recursive_update(self, override)

    @classmethod
    def from_dict(cls, config: Optional[Dict[str, Any]]) -> 'Config':
        if config is None:
            return cls()
        return cls(**config)


# The loaded config.
_dict = Config()
_loaded_config_path: Optional[str] = None


def get_nested(
    keys: Tuple[str, ...],
    default_value: Any,
    override_configs: Optional[Dict[str, Any]] = None,
) -> Any:
    """Gets a nested key.

    If any key is not found, or any intermediate key does not point to a dict
    value, returns 'default_value'.

    When 'keys' is within OVERRIDEABLE_CONFIG_KEYS, 'override_configs' must be
    provided (can be empty). Otherwise, 'override_configs' must not be provided.

    Args:
        keys: A tuple of strings representing the nested keys.
        default_value: The default value to return if the key is not found.
        override_configs: A dict of override configs with the same schema as
            the config file, but only containing the keys to override.

    Returns:
        The value of the nested key, or 'default_value' if not found.
    """
    assert not (keys in OVERRIDEABLE_CONFIG_KEYS and override_configs is None), (
        f'Override configs must be provided when keys {keys} is within '
        'OVERRIDEABLE_CONFIG_KEYS: '
        f'{OVERRIDEABLE_CONFIG_KEYS}'
    )
    assert not (
        keys not in OVERRIDEABLE_CONFIG_KEYS and override_configs is not None
    ), (
        f'Override configs must not be provided when keys {keys} is not within '
        'OVERRIDEABLE_CONFIG_KEYS: '
        f'{OVERRIDEABLE_CONFIG_KEYS}'
    )
    return _dict.get_nested(keys, default_value, override_configs)


def set_nested(keys: Tuple[str, ...], value: Any) -> Dict[str, Any]:
    """Returns a deep-copied config with the nested key set to value.

    Like get_nested(), if any key is not found, this will not raise an error.
    """
    copied_dict = copy.deepcopy(_dict)
    copied_dict.set_nested(keys, value)
    return dict(**copied_dict)


def to_dict() -> Config:
    """Returns a deep-copied version of the current config."""
    return copy.deepcopy(_dict)


def _try_load_config() -> None:
    global _dict, _loaded_config_path
    config_path_via_env_var = os.environ.get(ENV_VAR_CONFIG)
    if config_path_via_env_var is not None:
        config_path = os.path.expanduser(config_path_via_env_var)
        if not os.path.exists(config_path):
            with ux_utils.print_exception_no_traceback():
                raise FileNotFoundError(
                    'Config file specified by env var '
                    f'{ENV_VAR_CONFIG} ({config_path!r}) does not '
                    'exist. Please double check the path or unset the env var: '
                    f'unset {ENV_VAR_CONFIG}'
                )
    else:
        config_path = CONFIG_PATH
    config_path = os.path.expanduser(config_path)
    if os.path.exists(config_path):
        logger.debug(f'Using config path: {config_path}')
        try:
            config = common_utils.read_yaml(config_path)
            _dict = Config.from_dict(config)
            _loaded_config_path = config_path
            logger.debug(f'Config loaded:\n{pprint.pformat(_dict)}')
        except yaml.YAMLError as e:
            logger.error(f'Error in loading config file ({config_path}):', e)
        if _dict:
            common_utils.validate_schema(
                _dict,
                schemas.get_config_schema(),
                f'Invalid config YAML ({config_path}). See: '
                'https://konduktor.readthedocs.io/en/latest/reference/config.html. '  # pylint: disable=line-too-long
                'Error: ',
                skip_none=False,
            )

        logger.debug('Config syntax check passed.')


def _check_allowed_and_disallowed_override_keys(
    key: str,
    allowed_override_keys: Optional[List[Tuple[str, ...]]] = None,
    disallowed_override_keys: Optional[List[Tuple[str, ...]]] = None,
) -> Tuple[Optional[List[Tuple[str, ...]]], Optional[List[Tuple[str, ...]]]]:
    allowed_keys_with_matched_prefix: Optional[List[Tuple[str, ...]]] = []
    disallowed_keys_with_matched_prefix: Optional[List[Tuple[str, ...]]] = []
    if allowed_override_keys is not None:
        for nested_key in allowed_override_keys:
            if key == nested_key[0]:
                if len(nested_key) == 1:
                    # Allowed key is fully matched, no need to check further.
                    allowed_keys_with_matched_prefix = None
                    break
                assert allowed_keys_with_matched_prefix is not None
                allowed_keys_with_matched_prefix.append(nested_key[1:])
        if (
            allowed_keys_with_matched_prefix is not None
            and not allowed_keys_with_matched_prefix
        ):
            raise ValueError(
                f'Key {key} is not in allowed override keys: '
                f'{allowed_override_keys}'
            )
    else:
        allowed_keys_with_matched_prefix = None

    if disallowed_override_keys is not None:
        for nested_key in disallowed_override_keys:
            if key == nested_key[0]:
                if len(nested_key) == 1:
                    raise ValueError(
                        f'Key {key} is in disallowed override keys: '
                        f'{disallowed_override_keys}'
                    )
                assert disallowed_keys_with_matched_prefix is not None
                disallowed_keys_with_matched_prefix.append(nested_key[1:])
    else:
        disallowed_keys_with_matched_prefix = None
    return allowed_keys_with_matched_prefix, disallowed_keys_with_matched_prefix


def _recursive_update(
    base_config: Config,
    override_config: Dict[str, Any],
    allowed_override_keys: Optional[List[Tuple[str, ...]]] = None,
    disallowed_override_keys: Optional[List[Tuple[str, ...]]] = None,
) -> Config:
    """Recursively updates base configuration with override configuration"""
    for key, value in override_config.items():
        (next_allowed_override_keys, next_disallowed_override_keys) = (
            _check_allowed_and_disallowed_override_keys(
                key, allowed_override_keys, disallowed_override_keys
            )
        )
        if key == 'kubernetes' and key in base_config:
            merge_k8s_configs(
                base_config[key],
                value,
                next_allowed_override_keys,
                next_disallowed_override_keys,
            )
        elif (
            isinstance(value, dict)
            and key in base_config
            and isinstance(base_config[key], dict)
        ):
            _recursive_update(
                base_config[key],
                value,
                next_allowed_override_keys,
                next_disallowed_override_keys,
            )
        else:
            base_config[key] = value
    return base_config


def _get_nested(
    configs: Optional[Dict[str, Any]],
    keys: Tuple[str, ...],
    default_value: Any,
    pop: bool = False,
) -> Any:
    if configs is None:
        return default_value
    curr = configs
    for i, key in enumerate(keys):
        if isinstance(curr, dict) and key in curr:
            value = curr[key]
            if i == len(keys) - 1:
                if pop:
                    curr.pop(key, default_value)
            curr = value
        else:
            return default_value
    logger.debug(f'User config: {".".join(keys)} -> {curr}')
    return curr


def merge_k8s_configs(
    base_config: Dict[Any, Any],
    override_config: Dict[Any, Any],
    allowed_override_keys: Optional[List[Tuple[str, ...]]] = None,
    disallowed_override_keys: Optional[List[Tuple[str, ...]]] = None,
) -> None:
    """Merge two configs into the base_config.

    Updates nested dictionaries instead of replacing them.
    If a list is encountered, it will be appended to the base_config list.

    An exception is when the key is 'containers', in which case the
    first container in the list will be fetched and merge_dict will be
    called on it with the first container in the base_config list.
    """
    for key, value in override_config.items():
        (next_allowed_override_keys, next_disallowed_override_keys) = (
            _check_allowed_and_disallowed_override_keys(
                key, allowed_override_keys, disallowed_override_keys
            )
        )
        if isinstance(value, dict) and key in base_config:
            merge_k8s_configs(
                base_config[key],
                value,
                next_allowed_override_keys,
                next_disallowed_override_keys,
            )
        elif isinstance(value, list) and key in base_config:
            assert isinstance(
                base_config[key], list
            ), f'Expected {key} to be a list, found {base_config[key]}'
            if key in ['containers', 'imagePullSecrets']:
                # If the key is 'containers' or 'imagePullSecrets, we take the
                # first and only container/secret in the list and merge it, as
                # we only support one container per pod.
                assert len(value) == 1, f'Expected only one container, found {value}'
                merge_k8s_configs(
                    base_config[key][0],
                    value[0],
                    next_allowed_override_keys,
                    next_disallowed_override_keys,
                )
            elif key in ['volumes', 'volumeMounts']:
                # If the key is 'volumes' or 'volumeMounts', we search for
                # item with the same name and merge it.
                for new_volume in value:
                    new_volume_name = new_volume.get('name')
                    if new_volume_name is not None:
                        destination_volume = next(
                            (
                                v
                                for v in base_config[key]
                                if v.get('name') == new_volume_name
                            ),
                            None,
                        )
                        if destination_volume is not None:
                            merge_k8s_configs(destination_volume, new_volume)
                        else:
                            base_config[key].append(new_volume)
            else:
                base_config[key].extend(value)
        else:
            base_config[key] = value


def loaded_config_path() -> Optional[str]:
    """Returns the path to the loaded config file."""
    return _loaded_config_path


# Load on import.
_try_load_config()


def loaded() -> bool:
    """Returns if the user configurations are loaded."""
    return bool(_dict)
