#!python

import argparse
from argparse import Namespace

import requests
import yaml


def copy_args(args: Namespace, args_to_copy: list):
    """
    Copies specific arguments from a namespace

    :param args: the provided arguments
    :param args_to_copy: list of arguments to copy
    :return: a new dictionary containing the arguments
    """
    payload = {}
    for a in args_to_copy:
        if a in args:
            payload[a] = vars(args)[a]
    return payload


class HttpClient:
    def __init__(self):
        self.management_url = None
        self.prediction_url = None
        self.proxies = {}
        self.timeout = 30
        self.initialized = False

    def load_config(self, file: str):
        try:
            with open(file, 'r') as stream:
                config = yaml.safe_load(stream)
        except IOError:
            return

        self.management_url = config.get('management_url')
        self.prediction_url = config.get('prediction_url')
        self.proxies = config.get('proxies', self.proxies)
        self.timeout = config.get('timeout', self.timeout)

        if not self.management_url or not self.prediction_url:
            return

        self.initialized = True

    def deploy_model(self, args: Namespace):
        payload = copy_args(args,
                            ['model_name', 'handler', 'runtime', 'batch_size', 'max_batch_delay', 'initial_workers',
                             'synchronous', 'response_timeout'])
        payload['url'] = args.url
        return requests.post(f'{self.management_url}/models', params=payload, timeout=self.timeout,
                             proxies=self.proxies)

    def remove_model(self, args: Namespace):
        return requests.delete(f'{self.management_url}/models/{args.model_name}', timeout=self.timeout,
                               proxies=self.proxies)

    def list_models(self, args: Namespace):
        payload = copy_args(args, ['limit', 'next_page_token'])
        return requests.get(f'{self.management_url}/models', params=payload, timeout=self.timeout, proxies=self.proxies)

    def get_status(self, args: Namespace):
        return requests.get(f'{self.management_url}/models/{args.model_name}', timeout=self.timeout,
                            proxies=self.proxies)

    def scale_workers(self, args: Namespace):
        payload = copy_args(args, ['min_worker', 'max_worker', 'number_gpu', 'synchronous', 'timeout'])
        return requests.put(f'{self.management_url}/models/{args.model_name}', params=payload, timeout=self.timeout,
                            proxies=self.proxies)

    def translate(self, args: Namespace):
        json = {'text': args.text}
        if args.constraint:
            json['constraints'] = args.constraint
        if args.avoid:
            json['avoid'] = args.avoid
        return requests.post(f'{self.prediction_url}/predictions/{args.model_name}', json=json,
                             timeout=self.timeout, proxies=self.proxies)

    def upload(self, args: Namespace):
        return requests.post(f'{self.prediction_url}/predictions/{args.model_name}',
                             files={'file': open(args.file, 'rb')}, timeout=self.timeout, proxies=self.proxies)


def main():
    cli = HttpClient()

    parser = argparse.ArgumentParser()
    parser.add_argument('-f', '--config', help='path to the config file')
    parser.add_argument('-v', '--verbose', help='add more output', action='store_true')
    parser.set_defaults(request=None)
    subparsers = parser.add_subparsers(help='available sub-commands')

    deploy_parser = subparsers.add_parser('deploy', help='load a model onto the server')
    deploy_parser.add_argument('url', help='model URL; can be a relative path to a MAR file or directory')
    deploy_parser.add_argument('-m', '--model-name', help='model name')
    deploy_parser.add_argument('-x', '--handler', help='inference handler entry point')
    deploy_parser.add_argument('-r', '--runtime', help='runtime for custom service code')
    deploy_parser.add_argument('-b', '--batch-size', help='inference batch size', type=int)
    deploy_parser.add_argument('-d', '--max-batch-delay', help='max delay for batch aggregation', type=int)
    deploy_parser.add_argument('-i', '--initial-workers', help='number of initial workers to create', type=int,
                               default=1)
    deploy_parser.add_argument('-s', '--synchronous', action='store_true',
                               help='whether or not the request is synchronous')
    deploy_parser.add_argument('-t', '--timeout', dest='response_timeout', help='timeout for inference', type=int)
    deploy_parser.set_defaults(request=cli.deploy_model)

    remove_parser = subparsers.add_parser('remove', help='remove a model from the server')
    remove_parser.add_argument('model_name', help='model name')
    remove_parser.set_defaults(request=cli.remove_model)

    list_parser = subparsers.add_parser('list', help='list the registered models')
    list_parser.add_argument('-l', '--limit', help='maximum number of items to return', type=int)
    list_parser.add_argument('-t', '--next-page-token', help='token to query for next page', type=int)
    list_parser.set_defaults(request=cli.list_models)

    status_parser = subparsers.add_parser('status', help='check the status of a model')
    status_parser.add_argument('model_name', help='model name')
    status_parser.set_defaults(request=cli.get_status)

    scale_parser = subparsers.add_parser('scale', help='scale the number of workers for a model')
    scale_parser.add_argument('model_name', help='model name')
    scale_parser.add_argument('-a', '--min-worker', help='minimum number of workers', type=int)
    scale_parser.add_argument('-b', '--max-worker', help='maximum number of workers', type=int)
    scale_parser.add_argument('-n', '--gpus', dest='number_gpu', help='number of GPU workers to create', type=int)
    scale_parser.add_argument('-s', '--synchronous', action='store_true', help='whether or not the call is synchronous')
    scale_parser.add_argument('-t', '--timeout', help='wait time for a worker to complete all pending requests',
                              type=int)
    scale_parser.set_defaults(request=cli.scale_workers)

    translate_parser = subparsers.add_parser('translate', help='translate text')
    translate_parser.add_argument('model_name', help='model name')
    translate_parser.add_argument('text', help='text to translate')
    translate_parser.add_argument('-c', '--constraint', action='append',
        help='a single lexical constraint - multiple constraints can be provided')
    translate_parser.add_argument('-a', '--avoid', action='append',
        help='a single negative lexical constraint - multiple constraints can be provided')
    translate_parser.set_defaults(request=cli.translate)

    upload_parser = subparsers.add_parser('upload', help='upload a file for translation')
    upload_parser.add_argument('model_name', help='model name')
    upload_parser.add_argument('file', help='file to upload')
    upload_parser.set_defaults(request=cli.upload)

    args = parser.parse_args()

    if args.config:
        cli.load_config(args.config)
    else:
        file = 'sockeye-client.yml'
        cli.load_config(file)

    if not cli.initialized:
        print('Config file missing or incomplete. Use --config to specify a location.')
        exit(1)

    if args.request:
        r = args.request(args)
        if args.verbose:
            print(r.url)
            print(r.headers)
        print(r.text)
    else:
        parser.print_help()


if __name__ == '__main__':
    main()
