/*
 * Decompiled with CFR 0.152.
 */
package com.amazonaws.ml.mms.http;

import com.amazonaws.ml.mms.archive.Manifest;
import com.amazonaws.ml.mms.archive.ModelArchive;
import com.amazonaws.ml.mms.archive.ModelException;
import com.amazonaws.ml.mms.archive.ModelNotFoundException;
import com.amazonaws.ml.mms.http.BadRequestException;
import com.amazonaws.ml.mms.http.DescribeModelResponse;
import com.amazonaws.ml.mms.http.HttpRequestHandlerChain;
import com.amazonaws.ml.mms.http.InternalServerException;
import com.amazonaws.ml.mms.http.ListModelsResponse;
import com.amazonaws.ml.mms.http.MethodNotAllowedException;
import com.amazonaws.ml.mms.http.ResourceNotFoundException;
import com.amazonaws.ml.mms.http.StatusResponse;
import com.amazonaws.ml.mms.http.messages.RegisterModelRequest;
import com.amazonaws.ml.mms.util.ConfigManager;
import com.amazonaws.ml.mms.util.JsonUtils;
import com.amazonaws.ml.mms.util.NettyUtils;
import com.amazonaws.ml.mms.wlm.Model;
import com.amazonaws.ml.mms.wlm.ModelManager;
import com.amazonaws.ml.mms.wlm.WorkerThread;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.QueryStringDecoder;
import io.netty.util.CharsetUtil;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import software.amazon.ai.mms.servingsdk.ModelServerEndpoint;

public class ManagementRequestHandler
extends HttpRequestHandlerChain {
    public ManagementRequestHandler(Map<String, ModelServerEndpoint> ep) {
        this.endpointMap = ep;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    protected void handleRequest(ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder, String[] segments) throws ModelException {
        if (this.isManagementReq(segments)) {
            if (this.endpointMap.getOrDefault(segments[1], null) != null) {
                this.handleCustomEndpoint(ctx, req, segments, decoder);
                return;
            } else {
                if (!"models".equals(segments[1])) {
                    throw new ResourceNotFoundException();
                }
                HttpMethod method = req.method();
                if (segments.length < 3) {
                    if (HttpMethod.GET.equals(method)) {
                        this.handleListModels(ctx, decoder);
                        return;
                    }
                    if (!HttpMethod.POST.equals(method)) throw new MethodNotAllowedException();
                    this.handleRegisterModel(ctx, decoder, req);
                    return;
                }
                if (HttpMethod.GET.equals(method)) {
                    this.handleDescribeModel(ctx, segments[2]);
                    return;
                } else if (HttpMethod.PUT.equals(method)) {
                    this.handleScaleModel(ctx, decoder, segments[2]);
                    return;
                } else {
                    if (!HttpMethod.DELETE.equals(method)) throw new MethodNotAllowedException();
                    this.handleUnregisterModel(ctx, segments[2]);
                }
            }
            return;
        } else {
            this.chain.handleRequest(ctx, req, decoder, segments);
        }
    }

    private boolean isManagementReq(String[] segments) {
        return segments.length == 0 || (segments.length == 2 || segments.length == 3) && segments[1].equals("models") || this.endpointMap.containsKey(segments[1]);
    }

    private void handleListModels(ChannelHandlerContext ctx, QueryStringDecoder decoder) {
        int limit = NettyUtils.getIntParameter(decoder, "limit", 100);
        int pageToken = NettyUtils.getIntParameter(decoder, "next_page_token", 0);
        if (limit > 100 || limit < 0) {
            limit = 100;
        }
        if (pageToken < 0) {
            pageToken = 0;
        }
        ModelManager modelManager = ModelManager.getInstance();
        Map<String, Model> models = modelManager.getModels();
        ArrayList<String> keys = new ArrayList<String>(models.keySet());
        Collections.sort(keys);
        ListModelsResponse list = new ListModelsResponse();
        int last = pageToken + limit;
        if (last > keys.size()) {
            last = keys.size();
        } else {
            list.setNextPageToken(String.valueOf(last));
        }
        for (int i = pageToken; i < last; ++i) {
            String modelName = (String)keys.get(i);
            Model model = models.get(modelName);
            list.addModel(modelName, model.getModelUrl());
        }
        NettyUtils.sendJsonResponse(ctx, list);
    }

    private void handleDescribeModel(ChannelHandlerContext ctx, String modelName) throws ModelNotFoundException {
        ModelManager modelManager = ModelManager.getInstance();
        Model model = modelManager.getModels().get(modelName);
        if (model == null) {
            throw new ModelNotFoundException("Model not found: " + modelName);
        }
        DescribeModelResponse resp = new DescribeModelResponse();
        resp.setModelName(modelName);
        resp.setModelUrl(model.getModelUrl());
        resp.setBatchSize(model.getBatchSize());
        resp.setMaxBatchDelay(model.getMaxBatchDelay());
        resp.setMaxWorkers(model.getMaxWorkers());
        resp.setMinWorkers(model.getMinWorkers());
        resp.setLoadedAtStartup(modelManager.getStartupModels().contains(modelName));
        Manifest manifest = model.getModelArchive().getManifest();
        Manifest.Engine engine = manifest.getEngine();
        if (engine != null) {
            resp.setEngine(engine.getEngineName());
        }
        resp.setModelVersion(manifest.getModel().getModelVersion());
        resp.setRuntime(manifest.getRuntime().getValue());
        List<WorkerThread> workers = modelManager.getWorkers(modelName);
        for (WorkerThread worker : workers) {
            String workerId = worker.getWorkerId();
            long startTime = worker.getStartTime();
            boolean isRunning = worker.isRunning();
            int gpuId = worker.getGpuId();
            long memory = worker.getMemory();
            resp.addWorker(workerId, startTime, isRunning, gpuId, memory);
        }
        NettyUtils.sendJsonResponse(ctx, resp);
    }

    private void handleRegisterModel(ChannelHandlerContext ctx, QueryStringDecoder decoder, FullHttpRequest req) throws ModelException {
        ModelArchive archive;
        RegisterModelRequest registerModelRequest = this.parseRequest(req, decoder);
        String modelUrl = registerModelRequest.getModelUrl();
        if (modelUrl == null) {
            throw new BadRequestException("Parameter url is required.");
        }
        String modelName = registerModelRequest.getModelName();
        String runtime = registerModelRequest.getRuntime();
        String handler = registerModelRequest.getHandler();
        int batchSize = registerModelRequest.getBatchSize();
        int maxBatchDelay = registerModelRequest.getMaxBatchDelay();
        int initialWorkers = registerModelRequest.getInitialWorkers();
        boolean synchronous = registerModelRequest.getSynchronous();
        int responseTimeout = registerModelRequest.getResponseTimeout();
        if (responseTimeout == -1) {
            responseTimeout = ConfigManager.getInstance().getDefaultResponseTimeout();
        }
        Manifest.RuntimeType runtimeType = null;
        if (runtime != null) {
            try {
                runtimeType = Manifest.RuntimeType.fromValue(runtime);
            }
            catch (IllegalArgumentException e) {
                throw new BadRequestException(e);
            }
        }
        ModelManager modelManager = ModelManager.getInstance();
        try {
            archive = modelManager.registerModel(modelUrl, modelName, runtimeType, handler, batchSize, maxBatchDelay, responseTimeout, null);
        }
        catch (IOException e) {
            throw new InternalServerException("Failed to save model: " + modelUrl, e);
        }
        modelName = archive.getModelName();
        String msg = "Model \"" + modelName + "\" registered";
        if (initialWorkers <= 0) {
            NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg));
            return;
        }
        this.updateModelWorkers(ctx, modelName, initialWorkers, initialWorkers, synchronous, f -> {
            modelManager.unregisterModel(archive.getModelName());
            return null;
        });
    }

    private void handleUnregisterModel(ChannelHandlerContext ctx, String modelName) throws ModelNotFoundException {
        ModelManager modelManager = ModelManager.getInstance();
        if (!modelManager.unregisterModel(modelName)) {
            throw new ModelNotFoundException("Model not found: " + modelName);
        }
        String msg = "Model \"" + modelName + "\" unregistered";
        NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg));
    }

    private void handleScaleModel(ChannelHandlerContext ctx, QueryStringDecoder decoder, String modelName) throws ModelNotFoundException {
        int minWorkers = NettyUtils.getIntParameter(decoder, "min_worker", 1);
        int maxWorkers = NettyUtils.getIntParameter(decoder, "max_worker", minWorkers);
        if (maxWorkers < minWorkers) {
            throw new BadRequestException("max_worker cannot be less than min_worker.");
        }
        boolean synchronous = Boolean.parseBoolean(NettyUtils.getParameter(decoder, "synchronous", null));
        ModelManager modelManager = ModelManager.getInstance();
        if (!modelManager.getModels().containsKey(modelName)) {
            throw new ModelNotFoundException("Model not found: " + modelName);
        }
        this.updateModelWorkers(ctx, modelName, minWorkers, maxWorkers, synchronous, null);
    }

    private void updateModelWorkers(ChannelHandlerContext ctx, String modelName, int minWorkers, int maxWorkers, boolean synchronous, Function<Void, Void> onError) {
        ModelManager modelManager = ModelManager.getInstance();
        CompletableFuture<HttpResponseStatus> future = modelManager.updateModel(modelName, minWorkers, maxWorkers);
        if (!synchronous) {
            NettyUtils.sendJsonResponse(ctx, new StatusResponse("Processing worker updates..."), HttpResponseStatus.ACCEPTED);
            return;
        }
        ((CompletableFuture)future.thenApply(v -> {
            boolean status = modelManager.scaleRequestStatus(modelName);
            if (HttpResponseStatus.OK.equals(v)) {
                if (status) {
                    NettyUtils.sendJsonResponse(ctx, new StatusResponse("Workers scaled"), v);
                } else {
                    NettyUtils.sendJsonResponse(ctx, new StatusResponse("Workers scaling in progress..."), new HttpResponseStatus(210, "Partial Success"));
                }
            } else {
                NettyUtils.sendError(ctx, v, new InternalServerException("Failed to start workers"));
                if (onError != null) {
                    onError.apply(null);
                }
            }
            return v;
        })).exceptionally(e -> {
            if (onError != null) {
                onError.apply(null);
            }
            NettyUtils.sendError(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, e);
            return null;
        });
    }

    private RegisterModelRequest parseRequest(FullHttpRequest req, QueryStringDecoder decoder) {
        CharSequence mime = HttpUtil.getMimeType(req);
        RegisterModelRequest in = HttpHeaderValues.APPLICATION_JSON.contentEqualsIgnoreCase(mime) ? JsonUtils.GSON.fromJson(req.content().toString(CharsetUtil.UTF_8), RegisterModelRequest.class) : new RegisterModelRequest(decoder);
        return in;
    }
}

