/*
 * Decompiled with CFR 0.152.
 */
package org.pytorch.serve.util;

import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.HttpResponseStatus;
import java.io.IOException;
import java.nio.file.FileAlreadyExistsException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutionException;
import java.util.function.Function;
import org.pytorch.serve.archive.DownloadArchiveException;
import org.pytorch.serve.archive.model.Manifest;
import org.pytorch.serve.archive.model.ModelArchive;
import org.pytorch.serve.archive.model.ModelException;
import org.pytorch.serve.archive.model.ModelNotFoundException;
import org.pytorch.serve.archive.model.ModelVersionNotFoundException;
import org.pytorch.serve.archive.utils.ArchiveUtils;
import org.pytorch.serve.http.BadRequestException;
import org.pytorch.serve.http.InternalServerException;
import org.pytorch.serve.http.InvalidModelVersionException;
import org.pytorch.serve.http.RequestTimeoutException;
import org.pytorch.serve.http.ServiceUnavailableException;
import org.pytorch.serve.http.StatusResponse;
import org.pytorch.serve.http.messages.DescribeModelResponse;
import org.pytorch.serve.http.messages.ListModelsResponse;
import org.pytorch.serve.http.messages.RegisterModelRequest;
import org.pytorch.serve.job.RestJob;
import org.pytorch.serve.snapshot.SnapshotManager;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.messages.RequestInput;
import org.pytorch.serve.util.messages.WorkerCommands;
import org.pytorch.serve.wlm.Model;
import org.pytorch.serve.wlm.ModelManager;
import org.pytorch.serve.wlm.ModelVersionedRefs;
import org.pytorch.serve.wlm.WorkerInitializationException;
import org.pytorch.serve.wlm.WorkerState;
import org.pytorch.serve.wlm.WorkerThread;

public final class ApiUtils {
    private ApiUtils() {
    }

    public static ListModelsResponse getModelList(int limit, int pageToken) {
        if (limit > 100 || limit < 0) {
            limit = 100;
        }
        if (pageToken < 0) {
            pageToken = 0;
        }
        Map<String, Model> models = ModelManager.getInstance().getDefaultModels(true);
        ArrayList<String> keys = new ArrayList<String>(models.keySet());
        Collections.sort(keys);
        ListModelsResponse list = new ListModelsResponse();
        int last = pageToken + limit;
        if (last > keys.size()) {
            last = keys.size();
        } else {
            list.setNextPageToken(String.valueOf(last));
        }
        for (int i = pageToken; i < last; ++i) {
            String modelName = (String)keys.get(i);
            Model model = models.get(modelName);
            list.addModel(modelName, model.getModelUrl());
        }
        return list;
    }

    public static ArrayList<DescribeModelResponse> getModelDescription(String modelName, String modelVersion) throws ModelNotFoundException, ModelVersionNotFoundException {
        ModelManager modelManager = ModelManager.getInstance();
        ArrayList<DescribeModelResponse> resp = new ArrayList<DescribeModelResponse>();
        if ("all".equals(modelVersion)) {
            for (Map.Entry<String, Model> m3 : modelManager.getAllModelVersions(modelName)) {
                resp.add(ApiUtils.createModelResponse(modelManager, modelName, m3.getValue()));
            }
        } else {
            Model model = modelManager.getModel(modelName, modelVersion);
            if (model == null) {
                throw new ModelNotFoundException("Model not found: " + modelName);
            }
            resp.add(ApiUtils.createModelResponse(modelManager, modelName, model));
        }
        return resp;
    }

    public static String setDefault(String modelName, String newModelVersion) throws ModelNotFoundException, ModelVersionNotFoundException {
        ModelManager modelManager = ModelManager.getInstance();
        modelManager.setDefaultVersion(modelName, newModelVersion);
        String msg = "Default vesion succsesfully updated for model \"" + modelName + "\" to \"" + newModelVersion + "\"";
        SnapshotManager.getInstance().saveSnapshot();
        return msg;
    }

    public static StatusResponse registerModel(RegisterModelRequest registerModelRequest) throws ModelException, InternalServerException, ExecutionException, InterruptedException, DownloadArchiveException, WorkerInitializationException {
        String modelUrl = registerModelRequest.getModelUrl();
        if (modelUrl == null) {
            throw new BadRequestException("Parameter url is required.");
        }
        String modelName = registerModelRequest.getModelName();
        String runtime = registerModelRequest.getRuntime();
        String handler = registerModelRequest.getHandler();
        int batchSize = registerModelRequest.getBatchSize();
        int maxBatchDelay = registerModelRequest.getMaxBatchDelay();
        int initialWorkers = registerModelRequest.getInitialWorkers();
        int responseTimeout = registerModelRequest.getResponseTimeout();
        boolean s3SseKms = registerModelRequest.getS3SseKms();
        if (responseTimeout == -1) {
            responseTimeout = ConfigManager.getInstance().getDefaultResponseTimeout();
        }
        Manifest.RuntimeType runtimeType = null;
        if (runtime != null) {
            try {
                runtimeType = Manifest.RuntimeType.fromValue(runtime);
            }
            catch (IllegalArgumentException e) {
                throw new BadRequestException(e);
            }
        }
        return ApiUtils.handleRegister(modelUrl, modelName, runtimeType, handler, batchSize, maxBatchDelay, responseTimeout, initialWorkers, registerModelRequest.getSynchronous(), false, s3SseKms);
    }

    public static StatusResponse handleRegister(String modelUrl, String modelName, Manifest.RuntimeType runtimeType, String handler, int batchSize, int maxBatchDelay, int responseTimeout, int initialWorkers, boolean isSync, boolean isWorkflowModel, boolean s3SseKms) throws ModelException, ExecutionException, InterruptedException, DownloadArchiveException, WorkerInitializationException {
        ModelArchive archive;
        ModelManager modelManager = ModelManager.getInstance();
        try {
            archive = modelManager.registerModel(modelUrl, modelName, runtimeType, handler, batchSize, maxBatchDelay, responseTimeout, null, false, isWorkflowModel, s3SseKms);
        }
        catch (FileAlreadyExistsException e) {
            throw new InternalServerException("Model file already exists " + ArchiveUtils.getFilenameFromUrl(modelUrl), e);
        }
        catch (IOException | InterruptedException e) {
            throw new InternalServerException("Failed to save model: " + modelUrl, e);
        }
        modelName = archive.getModelName();
        int minWorkers = 0;
        int maxWorkers = 0;
        if (archive.getModelConfig() != null) {
            int marMinWorkers = archive.getModelConfig().getMinWorkers();
            int marMaxWorkers = archive.getModelConfig().getMaxWorkers();
            if (marMinWorkers > 0 && marMaxWorkers >= marMinWorkers) {
                minWorkers = marMinWorkers;
                maxWorkers = marMaxWorkers;
            }
        }
        if (initialWorkers <= 0 && minWorkers == 0) {
            String msg = "Model \"" + modelName + "\" Version: " + archive.getModelVersion() + " registered with 0 initial workers. Use scale workers API to add workers for the model.";
            if (!isWorkflowModel) {
                SnapshotManager.getInstance().saveSnapshot();
            }
            return new StatusResponse(msg, 200);
        }
        minWorkers = minWorkers > 0 ? minWorkers : initialWorkers;
        maxWorkers = maxWorkers > 0 ? maxWorkers : initialWorkers;
        return ApiUtils.updateModelWorkers(modelName, archive.getModelVersion(), minWorkers, maxWorkers, isSync, true, f -> {
            modelManager.unregisterModel(archive.getModelName(), archive.getModelVersion());
            return null;
        });
    }

    public static StatusResponse updateModelWorkers(String modelName, String modelVersion, int minWorkers, int maxWorkers, boolean synchronous, boolean isInit, Function<Void, Void> onError) throws ModelVersionNotFoundException, ModelNotFoundException, ExecutionException, InterruptedException, WorkerInitializationException {
        ModelManager modelManager = ModelManager.getInstance();
        if (maxWorkers < minWorkers) {
            throw new BadRequestException("max_worker cannot be less than min_worker.");
        }
        if (!modelManager.getDefaultModels().containsKey(modelName)) {
            throw new ModelNotFoundException("Model not found: " + modelName);
        }
        CompletableFuture<Integer> future = modelManager.updateModel(modelName, modelVersion, minWorkers, maxWorkers);
        StatusResponse statusResponse = new StatusResponse();
        if (!synchronous) {
            return new StatusResponse("Processing worker updates...", 202);
        }
        CompletionStage statusResponseCompletableFuture = ((CompletableFuture)future.thenApply(v -> {
            boolean status = modelManager.scaleRequestStatus(modelName, modelVersion);
            if (200 == v) {
                if (status) {
                    String msg = "Workers scaled to " + minWorkers + " for model: " + modelName;
                    if (modelVersion != null) {
                        msg = msg + ", version: " + modelVersion;
                    }
                    if (isInit) {
                        msg = "Model \"" + modelName + "\" Version: " + modelVersion + " registered with " + minWorkers + " initial workers";
                    }
                    statusResponse.setStatus(msg);
                    statusResponse.setHttpResponseCode((int)v);
                } else {
                    statusResponse.setStatus("Workers scaling in progress...");
                    statusResponse.setHttpResponseCode(206);
                }
            } else {
                statusResponse.setHttpResponseCode((int)v);
                String msg = "Failed to start workers for model " + modelName + " version: " + modelVersion;
                statusResponse.setStatus(msg);
                statusResponse.setE(new InternalServerException(msg));
                if (onError != null) {
                    onError.apply(null);
                }
            }
            return statusResponse;
        })).exceptionally(e -> {
            if (onError != null) {
                onError.apply(null);
            }
            statusResponse.setStatus(e.getMessage());
            statusResponse.setHttpResponseCode(500);
            statusResponse.setE((Throwable)e);
            return statusResponse;
        });
        return (StatusResponse)((CompletableFuture)statusResponseCompletableFuture).get();
    }

    public static void unregisterModel(String modelName, String modelVersion) throws ModelNotFoundException, ModelVersionNotFoundException {
        ModelManager modelManager = ModelManager.getInstance();
        int httpResponseStatus = modelManager.unregisterModel(modelName, modelVersion);
        if (httpResponseStatus == HttpResponseStatus.NOT_FOUND.code()) {
            throw new ModelNotFoundException("Model not found: " + modelName);
        }
        if (httpResponseStatus == HttpResponseStatus.BAD_REQUEST.code()) {
            throw new ModelVersionNotFoundException(String.format("Model version: %s does not exist for model: %s", modelVersion, modelName));
        }
        if (httpResponseStatus == HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) {
            throw new InternalServerException("Interrupted while cleaning resources: " + modelName);
        }
        if (httpResponseStatus == HttpResponseStatus.REQUEST_TIMEOUT.code()) {
            throw new RequestTimeoutException("Timed out while cleaning resources: " + modelName);
        }
        if (httpResponseStatus == HttpResponseStatus.FORBIDDEN.code()) {
            throw new InvalidModelVersionException("Cannot remove default version for model " + modelName);
        }
    }

    public static void getTorchServeHealth(Runnable r) {
        ModelManager modelManager = ModelManager.getInstance();
        modelManager.submitTask(r);
    }

    public static String getWorkerStatus() {
        ModelManager modelManager = ModelManager.getInstance();
        String response = "Healthy";
        int numWorking = 0;
        int numScaled = 0;
        for (Map.Entry<String, ModelVersionedRefs> m3 : modelManager.getAllModels()) {
            numScaled += m3.getValue().getDefaultModel().getMinWorkers();
            numWorking += modelManager.getNumRunningWorkers(m3.getValue().getDefaultModel().getModelVersionName());
        }
        if (numWorking > 0 && numWorking < numScaled) {
            response = "Partial Healthy";
        } else if (numWorking == 0 && numScaled > 0) {
            response = "Unhealthy";
        }
        return response;
    }

    public static boolean isModelHealthy() {
        ModelManager modelManager = ModelManager.getInstance();
        int numHealthy = 0;
        int numScaled = 0;
        for (Map.Entry<String, ModelVersionedRefs> m3 : modelManager.getAllModels()) {
            numScaled = m3.getValue().getDefaultModel().getMinWorkers();
            numHealthy = modelManager.getNumHealthyWorkers(m3.getValue().getDefaultModel().getModelVersionName());
            if (numHealthy >= numScaled) continue;
            return false;
        }
        return true;
    }

    private static DescribeModelResponse createModelResponse(ModelManager modelManager, String modelName, Model model) {
        DescribeModelResponse resp = new DescribeModelResponse();
        resp.setModelName(modelName);
        resp.setModelUrl(model.getModelUrl());
        resp.setBatchSize(model.getBatchSize());
        resp.setMaxBatchDelay(model.getMaxBatchDelay());
        resp.setMaxWorkers(model.getMaxWorkers());
        resp.setMinWorkers(model.getMinWorkers());
        resp.setLoadedAtStartup(modelManager.getStartupModels().contains(modelName));
        Manifest manifest = model.getModelArchive().getManifest();
        resp.setModelVersion(manifest.getModel().getModelVersion());
        resp.setRuntime(manifest.getRuntime().getValue());
        List<WorkerThread> workers = modelManager.getWorkers(model.getModelVersionName());
        for (WorkerThread worker : workers) {
            String workerId = worker.getWorkerId();
            long startTime = worker.getStartTime();
            boolean isRunning = worker.isRunning() && worker.getState() == WorkerState.WORKER_MODEL_LOADED;
            int gpuId = worker.getGpuId();
            long memory = worker.getMemory();
            int pid = worker.getPid();
            String gpuUsage = worker.getGpuUsage();
            resp.addWorker(workerId, startTime, isRunning, gpuId, memory, pid, gpuUsage);
        }
        DescribeModelResponse.JobQueueStatus jobQueueStatus = new DescribeModelResponse.JobQueueStatus();
        jobQueueStatus.setRemainingCapacity(model.getJobQueueRemainingCapacity());
        jobQueueStatus.setPendingRequests(model.getPendingRequestsInJobQueue());
        resp.setJobQueueStatus(jobQueueStatus);
        return resp;
    }

    public static RestJob addRESTInferenceJob(ChannelHandlerContext ctx, String modelName, String version, RequestInput input) throws ModelNotFoundException, ModelVersionNotFoundException {
        RestJob job = new RestJob(ctx, modelName, version, WorkerCommands.PREDICT, input);
        if (!ModelManager.getInstance().addJob(job)) {
            String responseMessage = ApiUtils.getStreamingInferenceErrorResponseMessage(modelName, version);
            throw new ServiceUnavailableException(responseMessage);
        }
        return job;
    }

    public static String getStreamingInferenceErrorResponseMessage(String modelName, String modelVersion) {
        StringBuilder responseMessage = new StringBuilder().append("Model \"").append(modelName);
        if (modelVersion != null) {
            responseMessage.append("\" Version ").append(modelVersion);
        }
        responseMessage.append("\" has no worker to serve inference request. Please use scale workers API to add workers. If this is a sequence inference, please check if it is closed, or expired; or exceeds maxSequenceJobQueueSize");
        return responseMessage.toString();
    }

    public static String getDescribeErrorResponseMessage(String modelName) {
        String responseMessage = "Model \"" + modelName;
        responseMessage = responseMessage + "\" has no worker to serve describe request. Please use scale workers API to add workers.";
        return responseMessage;
    }
}

