/*
 * Decompiled with CFR 0.152.
 */
package org.pytorch.serve.workflow;

import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
import org.pytorch.serve.archive.DownloadArchiveException;
import org.pytorch.serve.archive.model.ModelNotFoundException;
import org.pytorch.serve.archive.model.ModelVersionNotFoundException;
import org.pytorch.serve.archive.workflow.InvalidWorkflowException;
import org.pytorch.serve.archive.workflow.WorkflowArchive;
import org.pytorch.serve.archive.workflow.WorkflowException;
import org.pytorch.serve.archive.workflow.WorkflowNotFoundException;
import org.pytorch.serve.ensemble.DagExecutor;
import org.pytorch.serve.ensemble.InvalidDAGException;
import org.pytorch.serve.ensemble.Node;
import org.pytorch.serve.ensemble.NodeOutput;
import org.pytorch.serve.ensemble.WorkFlow;
import org.pytorch.serve.ensemble.WorkflowModel;
import org.pytorch.serve.http.BadRequestException;
import org.pytorch.serve.http.ConflictStatusException;
import org.pytorch.serve.http.InternalServerException;
import org.pytorch.serve.http.StatusResponse;
import org.pytorch.serve.util.ApiUtils;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.NettyUtils;
import org.pytorch.serve.util.messages.RequestInput;
import org.pytorch.serve.workflow.messages.ModelRegistrationResult;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class WorkflowManager {
    private static final Logger logger = LoggerFactory.getLogger(WorkflowManager.class);
    private final ThreadFactory namedThreadFactory = new ThreadFactoryBuilder().setNameFormat("wf-manager-thread-%d").build();
    private final ExecutorService inferenceExecutorService = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors(), this.namedThreadFactory);
    private static WorkflowManager workflowManager;
    private final ConfigManager configManager;
    private final ConcurrentHashMap<String, WorkFlow> workflowMap;

    private WorkflowManager(ConfigManager configManager) {
        this.configManager = configManager;
        this.workflowMap = new ConcurrentHashMap();
    }

    public static void init(ConfigManager configManager) {
        workflowManager = new WorkflowManager(configManager);
    }

    public static synchronized WorkflowManager getInstance() {
        return workflowManager;
    }

    private WorkflowArchive createWorkflowArchive(String workflowName, String url) throws DownloadArchiveException, IOException, WorkflowException {
        return this.createWorkflowArchive(workflowName, url, false);
    }

    private WorkflowArchive createWorkflowArchive(String workflowName, String url, boolean s3SseKmsEnabled) throws DownloadArchiveException, IOException, WorkflowException {
        WorkflowArchive archive = WorkflowArchive.downloadWorkflow(this.configManager.getAllowedUrls(), this.configManager.getWorkflowStore(), url, s3SseKmsEnabled);
        if (workflowName != null && !workflowName.isEmpty()) {
            archive.getManifest().getWorkflow().setWorkflowName(workflowName);
        }
        archive.validate();
        return archive;
    }

    private WorkFlow createWorkflow(WorkflowArchive archive) throws IOException, InvalidDAGException, InvalidWorkflowException {
        return new WorkFlow(archive);
    }

    public StatusResponse registerWorkflow(String workflowName, String url, int responseTimeout, boolean synchronous) throws WorkflowException {
        return this.registerWorkflow(workflowName, url, responseTimeout, synchronous, false);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public StatusResponse registerWorkflow(String workflowName, String url, int responseTimeout, boolean synchronous, boolean s3SseKms) throws WorkflowException {
        StatusResponse status;
        block18: {
            if (url == null) {
                throw new BadRequestException("Parameter url is required.");
            }
            status = new StatusResponse();
            ExecutorService executorService = Executors.newFixedThreadPool(4);
            ExecutorCompletionService<ModelRegistrationResult> executorCompletionService = new ExecutorCompletionService<ModelRegistrationResult>(executorService);
            boolean failed = false;
            ArrayList<String> failedMessages = new ArrayList<String>();
            ArrayList<String> successNodes = new ArrayList<String>();
            try {
                WorkflowArchive archive = this.createWorkflowArchive(workflowName, url);
                WorkFlow workflow = this.createWorkflow(archive);
                if (this.workflowMap.get(workflow.getWorkflowArchive().getWorkflowName()) != null) {
                    throw new ConflictStatusException("Workflow " + workflow.getWorkflowArchive().getWorkflowName() + " is already registered.");
                }
                Map<String, Node> nodes = workflow.getDag().getNodes();
                ArrayList<Future<ModelRegistrationResult>> futures = new ArrayList<Future<ModelRegistrationResult>>();
                for (Map.Entry<String, Node> entry : nodes.entrySet()) {
                    Node node = entry.getValue();
                    WorkflowModel wfm = node.getWorkflowModel();
                    futures.add(executorCompletionService.submit(() -> this.registerModelWrapper(wfm, responseTimeout, synchronous)));
                }
                for (int i = 0; i < futures.size(); ++i) {
                    Future future = executorCompletionService.take();
                    ModelRegistrationResult result = (ModelRegistrationResult)future.get();
                    if (result.getResponse().getHttpResponseCode() != 200) {
                        failed = true;
                        String msg = result.getResponse().getStatus() == null ? "Failed to register the model " + result.getModelName() + ". Check error logs." : result.getResponse().getStatus();
                        failedMessages.add(msg);
                        continue;
                    }
                    successNodes.add(result.getModelName());
                }
                if (failed) {
                    String rollbackFailure = null;
                    try {
                        this.removeArtifacts(workflowName, workflow, successNodes);
                    }
                    catch (Exception e) {
                        rollbackFailure = "Error while doing rollback of failed workflow. Details" + e.getMessage();
                    }
                    if (rollbackFailure != null) {
                        failedMessages.add(rollbackFailure);
                    }
                    status.setHttpResponseCode(500);
                    String message = String.format("Workflow %s has failed to register. Failures: %s", workflow.getWorkflowArchive().getWorkflowName(), failedMessages.toString());
                    status.setStatus(message);
                    status.setE(new WorkflowException(message));
                    break block18;
                }
                status.setHttpResponseCode(200);
                status.setStatus(String.format("Workflow %s has been registered and scaled successfully.", workflow.getWorkflowArchive().getWorkflowName()));
                this.workflowMap.putIfAbsent(workflow.getWorkflowArchive().getWorkflowName(), workflow);
            }
            catch (DownloadArchiveException e) {
                status.setHttpResponseCode(400);
                status.setStatus("Failed to download workflow archive file");
                status.setE(e);
            }
            catch (InvalidDAGException e) {
                status.setHttpResponseCode(400);
                status.setStatus("Invalid workflow specification");
                status.setE(e);
            }
            catch (IOException | InterruptedException | ExecutionException e) {
                status.setHttpResponseCode(500);
                status.setStatus("Failed to register workflow.");
                status.setE(e);
            }
            finally {
                executorService.shutdown();
            }
        }
        return status;
    }

    public ModelRegistrationResult registerModelWrapper(WorkflowModel wfm, int responseTimeout, boolean synchronous) {
        StatusResponse status = new StatusResponse();
        try {
            status = ApiUtils.handleRegister(wfm.getUrl(), wfm.getName(), null, wfm.getHandler(), wfm.getBatchSize(), wfm.getMaxBatchDelay(), responseTimeout, wfm.getMaxWorkers(), synchronous, true, false);
        }
        catch (Exception e) {
            status.setHttpResponseCode(500);
            String msg = e.getMessage() == null ? "Check error logs." : e.getMessage();
            status.setStatus(String.format("Workflow Node %s failed to register. Details: %s", wfm.getName(), msg));
            status.setE(e);
            logger.error("Model '" + wfm.getName() + "' failed to register.", e);
        }
        return new ModelRegistrationResult(wfm.getName(), status);
    }

    public ConcurrentHashMap<String, WorkFlow> getWorkflows() {
        return this.workflowMap;
    }

    public void unregisterWorkflow(String workflowName, ArrayList<String> successNodes) throws WorkflowNotFoundException, InterruptedException, ExecutionException {
        WorkFlow workflow = this.workflowMap.get(workflowName);
        if (workflow == null) {
            throw new WorkflowNotFoundException("Workflow not found: " + workflowName);
        }
        this.workflowMap.remove(workflowName);
        this.removeArtifacts(workflowName, workflow, successNodes);
    }

    public void removeArtifacts(String workflowName, WorkFlow workflow, ArrayList<String> successNodes) throws ExecutionException, InterruptedException {
        WorkflowArchive.removeWorkflow(this.configManager.getWorkflowStore(), workflow.getWorkflowArchive().getUrl());
        Map<String, Node> nodes = workflow.getDag().getNodes();
        this.unregisterModels(workflowName, nodes, successNodes);
    }

    public void unregisterModels(String workflowName, Map<String, Node> nodes, ArrayList<String> successNodes) throws InterruptedException, ExecutionException {
        ExecutorService executorService = Executors.newFixedThreadPool(4);
        ExecutorCompletionService<ModelRegistrationResult> executorCompletionService = new ExecutorCompletionService<ModelRegistrationResult>(executorService);
        ArrayList<Future<ModelRegistrationResult>> futures = new ArrayList<Future<ModelRegistrationResult>>();
        for (Map.Entry<String, Node> entry : nodes.entrySet()) {
            Node node = entry.getValue();
            WorkflowModel wfm = node.getWorkflowModel();
            futures.add(executorCompletionService.submit(() -> {
                StatusResponse status = new StatusResponse();
                try {
                    ApiUtils.unregisterModel(wfm.getName(), null);
                    status.setHttpResponseCode(200);
                    status.setStatus(String.format("Unregisterd workflow node %s", wfm.getName()));
                }
                catch (ModelNotFoundException | ModelVersionNotFoundException e) {
                    if (successNodes == null || successNodes.contains(wfm.getName())) {
                        status.setHttpResponseCode(500);
                        status.setStatus(String.format("Error while unregistering workflow node %s", wfm.getName()));
                        status.setE(e);
                        logger.error("Model '" + wfm.getName() + "' failed to unregister.", e);
                    } else {
                        status.setHttpResponseCode(200);
                        status.setStatus(String.format("Error while unregistering workflow node %s but can be ignored.", wfm.getName()));
                        status.setE(e);
                    }
                }
                catch (Exception e) {
                    status.setHttpResponseCode(500);
                    status.setStatus(String.format("Error while unregistering workflow node %s", wfm.getName()));
                    status.setE(e);
                }
                return new ModelRegistrationResult(wfm.getName(), status);
            }));
        }
        boolean failed = false;
        ArrayList<String> failedMessages = new ArrayList<String>();
        for (int i = 0; i < futures.size(); ++i) {
            Future future = executorCompletionService.take();
            ModelRegistrationResult result = (ModelRegistrationResult)future.get();
            if (result.getResponse().getHttpResponseCode() == 200) continue;
            failed = true;
            failedMessages.add(result.getResponse().getStatus());
        }
        if (failed) {
            throw new InternalServerException("Error while unregistering the workflow " + workflowName + ". Details: " + failedMessages.toArray().toString());
        }
        executorService.shutdown();
    }

    public WorkFlow getWorkflow(String workflowName) {
        return this.workflowMap.get(workflowName);
    }

    public void predict(ChannelHandlerContext ctx, String wfName, RequestInput input) throws WorkflowNotFoundException {
        WorkFlow wf = this.workflowMap.get(wfName);
        if (wf == null) {
            throw new WorkflowNotFoundException("Workflow not found: " + wfName);
        }
        DagExecutor dagExecutor = new DagExecutor(wf.getDag());
        CompletableFuture<ArrayList> predictionFuture = CompletableFuture.supplyAsync(() -> dagExecutor.execute(input, null));
        ((CompletableFuture)predictionFuture.thenApplyAsync(predictions -> {
            if (!predictions.isEmpty()) {
                if (predictions.size() == 1) {
                    DefaultFullHttpResponse resp = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, true);
                    resp.headers().set((CharSequence)HttpHeaderNames.CONTENT_TYPE, (Object)HttpHeaderValues.APPLICATION_JSON);
                    resp.content().writeBytes((byte[])((NodeOutput)predictions.get(0)).getData());
                    NettyUtils.sendHttpResponse(ctx, resp, true);
                } else {
                    JsonObject result = new JsonObject();
                    for (NodeOutput prediction : predictions) {
                        String val = new String((byte[])prediction.getData(), StandardCharsets.UTF_8);
                        result.add(prediction.getNodeName(), JsonParser.parseString(val).getAsJsonObject());
                    }
                    NettyUtils.sendJsonResponse(ctx, result);
                }
            } else {
                throw new InternalServerException("Workflow inference request failed!");
            }
            return null;
        }, (Executor)this.inferenceExecutorService)).exceptionally(ex -> {
            String[] error = ex.getMessage().split(":");
            NettyUtils.sendError(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, new InternalServerException(error[error.length - 1].strip()));
            return null;
        });
    }
}

