/*
 * Decompiled with CFR 0.152.
 */
package org.pytorch.serve.ensemble;

import com.google.common.util.concurrent.ThreadFactoryBuilder;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.pytorch.serve.archive.model.ModelNotFoundException;
import org.pytorch.serve.archive.model.ModelVersionNotFoundException;
import org.pytorch.serve.ensemble.Dag;
import org.pytorch.serve.ensemble.NodeOutput;
import org.pytorch.serve.ensemble.WorkflowModel;
import org.pytorch.serve.http.InternalServerException;
import org.pytorch.serve.job.RestJob;
import org.pytorch.serve.util.ApiUtils;
import org.pytorch.serve.util.messages.InputParameter;
import org.pytorch.serve.util.messages.RequestInput;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DagExecutor {
    private static final Logger logger = LoggerFactory.getLogger(DagExecutor.class);
    private Dag dag;
    private Map<String, RequestInput> inputRequestMap;

    public DagExecutor(Dag dag) {
        this.dag = dag;
        this.inputRequestMap = new ConcurrentHashMap<String, RequestInput>();
    }

    public ArrayList<NodeOutput> execute(RequestInput input, ArrayList<String> topoSortedList) {
        ExecutorCompletionService<NodeOutput> executorCompletionService = null;
        ExecutorService executorService = null;
        if (topoSortedList == null) {
            ThreadFactory namedThreadFactory = new ThreadFactoryBuilder().setNameFormat("wf-execute-thread-%d").build();
            executorService = Executors.newFixedThreadPool(4, namedThreadFactory);
            executorCompletionService = new ExecutorCompletionService<NodeOutput>(executorService);
        }
        Map<String, Integer> inDegreeMap = this.dag.getInDegreeMap();
        Set<String> zeroInDegree = this.dag.getStartNodeNames();
        HashSet<String> executing = new HashSet<String>();
        if (topoSortedList == null) {
            for (String s2 : zeroInDegree) {
                RequestInput newInput = new RequestInput(UUID.randomUUID().toString());
                newInput.setHeaders(input.getHeaders());
                newInput.setParameters(input.getParameters());
                this.inputRequestMap.put(s2, newInput);
            }
        }
        ArrayList<NodeOutput> leafOutputs = new ArrayList<NodeOutput>();
        while (!zeroInDegree.isEmpty()) {
            HashSet<String> readyToExecute = new HashSet<String>(zeroInDegree);
            readyToExecute.removeAll(executing);
            executing.addAll(readyToExecute);
            ArrayList<NodeOutput> outputs = new ArrayList<NodeOutput>();
            if (topoSortedList == null) {
                for (String name : readyToExecute) {
                    executorCompletionService.submit(() -> this.invokeModel(name, this.dag.getNodes().get(name).getWorkflowModel(), this.inputRequestMap.get(name), 0));
                }
                try {
                    Future op = executorCompletionService.take();
                    if (op == null) {
                        throw new ExecutionException(new RuntimeException("WorkflowNode result empty"));
                    }
                    outputs.add((NodeOutput)op.get());
                }
                catch (InterruptedException | ExecutionException e) {
                    logger.error(e.getMessage());
                    String[] error = e.getMessage().split(":");
                    throw new InternalServerException(error[error.length - 1]);
                }
            } else {
                for (String name : readyToExecute) {
                    outputs.add(new NodeOutput(name, null));
                }
            }
            for (NodeOutput output : outputs) {
                Set<String> childNodes;
                String nodeName = output.getNodeName();
                executing.remove(nodeName);
                zeroInDegree.remove(nodeName);
                if (topoSortedList != null) {
                    topoSortedList.add(nodeName);
                }
                if ((childNodes = this.dag.getDagMap().get(nodeName).get("outDegree")).isEmpty()) {
                    leafOutputs.add(output);
                    continue;
                }
                for (String newNodeName : childNodes) {
                    if (topoSortedList == null) {
                        byte[] response = (byte[])output.getData();
                        RequestInput newInput = this.inputRequestMap.get(newNodeName);
                        if (newInput == null) {
                            ArrayList<InputParameter> params = new ArrayList<InputParameter>();
                            newInput = new RequestInput(UUID.randomUUID().toString());
                            if (inDegreeMap.get(newNodeName) == 1) {
                                params.add(new InputParameter("body", response));
                            } else {
                                params.add(new InputParameter(nodeName, response));
                            }
                            newInput.setParameters(params);
                            newInput.setHeaders(input.getHeaders());
                        } else {
                            newInput.addParameter(new InputParameter(nodeName, response));
                        }
                        this.inputRequestMap.put(newNodeName, newInput);
                    }
                    inDegreeMap.replace(newNodeName, inDegreeMap.get(newNodeName) - 1);
                    if (inDegreeMap.get(newNodeName) != 0) continue;
                    zeroInDegree.add(newNodeName);
                }
            }
        }
        if (executorService != null) {
            executorService.shutdown();
        }
        return leafOutputs;
    }

    private NodeOutput invokeModel(String nodeName, WorkflowModel workflowModel, RequestInput input, int retryAttempt) throws ModelNotFoundException, ModelVersionNotFoundException, ExecutionException, InterruptedException {
        try {
            logger.info(String.format("Invoking -  %s for attempt %d", nodeName, retryAttempt));
            CompletableFuture<byte[]> respFuture = new CompletableFuture<byte[]>();
            RestJob job = ApiUtils.addRESTInferenceJob(null, workflowModel.getName(), null, input);
            job.setResponsePromise(respFuture);
            byte[] resp = respFuture.get(workflowModel.getTimeOutMs(), TimeUnit.MILLISECONDS);
            return new NodeOutput(nodeName, resp);
        }
        catch (InterruptedException | ExecutionException | TimeoutException e) {
            logger.error(e.getMessage());
            if (retryAttempt < workflowModel.getRetryAttempts()) {
                logger.error(String.format("Timed out while executing %s for attempt %d", nodeName, retryAttempt));
                return this.invokeModel(nodeName, workflowModel, input, ++retryAttempt);
            }
            logger.error(nodeName + " : " + e.getMessage());
            throw new InternalServerException(String.format("Failed to execute workflow Node after %d attempts : Error executing %s", retryAttempt, nodeName));
        }
        catch (ModelNotFoundException e) {
            logger.error("Model not found.");
            logger.error(e.getMessage());
            throw e;
        }
        catch (ModelVersionNotFoundException e) {
            logger.error("Model version not found.");
            logger.error(e.getMessage());
            throw e;
        }
    }
}

