/*
 * Decompiled with CFR 0.152.
 */
package org.pytorch.serve.wlm;

import java.util.Map;
import java.util.concurrent.ExecutionException;
import org.pytorch.serve.job.Job;
import org.pytorch.serve.util.messages.BaseModelRequest;
import org.pytorch.serve.util.messages.ModelInferenceRequest;
import org.pytorch.serve.util.messages.ModelLoadModelRequest;
import org.pytorch.serve.util.messages.ModelWorkerResponse;
import org.pytorch.serve.util.messages.Predictions;
import org.pytorch.serve.util.messages.RequestInput;
import org.pytorch.serve.util.messages.WorkerCommands;
import org.pytorch.serve.wlm.BatchAggregator;
import org.pytorch.serve.wlm.Model;
import org.pytorch.serve.wlm.WorkerState;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ContinuousBatching
extends BatchAggregator {
    private static final Logger logger = LoggerFactory.getLogger(ContinuousBatching.class);

    public ContinuousBatching(Model model) {
        super(model);
    }

    @Override
    public BaseModelRequest getRequest(String threadName, WorkerState state) throws InterruptedException, ExecutionException {
        int batchQuota = this.model.getBatchSize() - this.jobs.size();
        ModelInferenceRequest req = new ModelInferenceRequest(this.model.getModelName());
        this.pollBatch(threadName, state, batchQuota);
        if (this.model.isUseJobTicket() && this.jobs.isEmpty()) {
            this.model.decNumJobTickets();
            return req;
        }
        for (Job j : this.jobs.values()) {
            if (j.isControlCmd()) {
                if (this.jobs.size() > 1) {
                    throw new IllegalStateException("Received more than 1 control command. Control messages should be processed/retrieved one at a time.");
                }
                RequestInput input = j.getPayload();
                int gpuId = -1;
                String gpu = input.getStringParameter("gpu");
                if (gpu != null) {
                    gpuId = Integer.parseInt(gpu);
                }
                return new ModelLoadModelRequest(this.model, gpuId);
            }
            if (j.getCmd() == WorkerCommands.STREAMPREDICT) {
                req.setCommand(WorkerCommands.STREAMPREDICT);
            }
            j.setScheduled();
            req.addRequest(j.getPayload());
        }
        return req;
    }

    @Override
    public boolean sendResponse(ModelWorkerResponse message) {
        if (message.getCode() == 200) {
            if (message.getPredictions().isEmpty()) {
                for (Map.Entry entry : this.jobs.entrySet()) {
                    Job job = (Job)entry.getValue();
                    if (!job.isControlCmd()) continue;
                    this.cleanJobs();
                    return true;
                }
            }
            for (Predictions predictions : message.getPredictions()) {
                String jobId = predictions.getRequestId();
                Job job = (Job)this.jobs.get(jobId);
                if (job == null) {
                    throw new IllegalStateException("Unexpected job in sendResponse() with 200 status code: " + jobId);
                }
                if (job.getPayload().getClientExpireTS() > System.currentTimeMillis()) {
                    job.response(predictions.getResp(), predictions.getContentType(), predictions.getStatusCode(), predictions.getReasonPhrase(), predictions.getHeaders());
                } else {
                    logger.warn("Drop response for inference request {} due to client timeout", (Object)job.getPayload().getRequestId());
                }
                String streamNext = predictions.getHeaders().get("ts_stream_next");
                if (streamNext != null && streamNext.equals("false")) {
                    this.jobs.remove(jobId);
                    continue;
                }
                if (!job.isOpen()) {
                    this.jobs.remove(job.getJobId());
                    logger.info("Connection to client got closed; Removing job: {}", (Object)job.getPayload().getRequestId());
                    continue;
                }
                job.getPayload().setCachedInBackend(true);
            }
        } else {
            for (Map.Entry entry : this.jobs.entrySet()) {
                if (entry.getValue() == null) {
                    throw new IllegalStateException("Unexpected job in sendResponse() with non 200 status code: " + (String)entry.getKey());
                }
                Job job = (Job)entry.getValue();
                if (job.getPayload().getClientExpireTS() > System.currentTimeMillis()) {
                    job.sendError(message.getCode(), message.getMessage());
                    continue;
                }
                logger.warn("Drop error response for inference request {} due to client timeout", (Object)job.getPayload().getRequestId());
            }
            this.cleanJobs();
        }
        return true;
    }

    private void pollBatch(String threadName, WorkerState state, int batchSize) throws InterruptedException, ExecutionException {
        boolean pollMgmtJobStatus = false;
        if (this.jobs.isEmpty()) {
            pollMgmtJobStatus = this.model.pollMgmtJob(threadName, state == WorkerState.WORKER_MODEL_LOADED ? 0L : Long.MAX_VALUE, this.jobs);
        }
        if (!pollMgmtJobStatus && state == WorkerState.WORKER_MODEL_LOADED) {
            this.model.pollInferJob(this.jobs, batchSize);
        }
    }
}

