/*
 * Decompiled with CFR 0.152.
 */
package com.amazonaws.ml.mms.wlm;

import com.amazonaws.ml.mms.util.messages.BaseModelRequest;
import com.amazonaws.ml.mms.util.messages.ModelInferenceRequest;
import com.amazonaws.ml.mms.util.messages.ModelLoadModelRequest;
import com.amazonaws.ml.mms.util.messages.ModelWorkerResponse;
import com.amazonaws.ml.mms.util.messages.Predictions;
import com.amazonaws.ml.mms.util.messages.RequestInput;
import com.amazonaws.ml.mms.wlm.Job;
import com.amazonaws.ml.mms.wlm.Model;
import java.nio.charset.StandardCharsets;
import java.util.LinkedHashMap;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BatchAggregator {
    private static final Logger logger = LoggerFactory.getLogger(BatchAggregator.class);
    private Model model;
    private Map<String, Job> jobs;

    public BatchAggregator(Model model) {
        this.model = model;
        this.jobs = new LinkedHashMap<String, Job>();
    }

    public BaseModelRequest getRequest(String threadName) throws InterruptedException {
        this.jobs.clear();
        Job job = this.model.nextJob(threadName);
        if (job.isControlCmd()) {
            RequestInput input = job.getPayload();
            int gpuId = -1;
            String gpu = input.getStringParameter("gpu");
            if (gpu != null) {
                gpuId = Integer.parseInt(gpu);
            }
            return new ModelLoadModelRequest(this.model, gpuId);
        }
        this.jobs.put(job.getJobId(), job);
        logger.trace("get first job: {}", (Object)job.getJobId());
        long maxBatchDelay = this.model.getMaxBatchDelay();
        int size = this.model.getBatchSize() - 1;
        long begin = System.currentTimeMillis();
        for (int i = 0; i < size && (job = this.model.nextJob("DATA_QUEUE", maxBatchDelay)) != null; ++i) {
            this.jobs.put(job.getJobId(), job);
            long end = System.currentTimeMillis();
            begin = end;
            if ((maxBatchDelay -= end - begin) <= 0L) break;
        }
        logger.trace("sending jobs, size: {}", (Object)this.jobs.size());
        ModelInferenceRequest req = new ModelInferenceRequest(this.model.getModelName());
        for (Job j : this.jobs.values()) {
            j.setScheduled();
            req.addRequest(j.getPayload());
        }
        return req;
    }

    public void sendResponse(ModelWorkerResponse message) {
        if (message.getCode() == 200) {
            if (this.jobs.isEmpty()) {
                return;
            }
            for (Predictions prediction : message.getPredictions()) {
                String jobId = prediction.getRequestId();
                Job job = this.jobs.remove(jobId);
                if (job == null) {
                    throw new IllegalStateException("Unexpected job: " + jobId);
                }
                job.response(prediction.getResp(), prediction.getContentType());
            }
        } else {
            for (String reqId : this.jobs.keySet()) {
                Job j = this.jobs.remove(reqId);
                if (j == null) {
                    throw new IllegalStateException("Unexpected job: " + reqId);
                }
                String err = "code:" + message.getCode() + ",message:" + message.getMessage();
                j.response(err.getBytes(StandardCharsets.UTF_8), "application/json");
            }
            if (!this.jobs.isEmpty()) {
                throw new IllegalStateException("Not all jobs get response.");
            }
        }
    }

    public void sendError(BaseModelRequest message, String error) {
        if (message instanceof ModelLoadModelRequest) {
            logger.warn("Load model failed: {}, error: {}", (Object)message.getModelName(), (Object)error);
            return;
        }
        if (message != null) {
            ModelInferenceRequest msg = (ModelInferenceRequest)message;
            for (RequestInput req : msg.getRequestBatch()) {
                String requestId = req.getRequestId();
                Job job = this.jobs.remove(requestId);
                if (job == null) {
                    logger.error("Unexpected job: " + requestId);
                    continue;
                }
                job.sendError(error);
            }
            if (!this.jobs.isEmpty()) {
                this.jobs.clear();
                logger.error("Not all jobs get response.");
            }
        } else {
            for (Map.Entry<String, Job> j : this.jobs.entrySet()) {
                String jobsId = j.getValue().getJobId();
                Job job = this.jobs.remove(jobsId);
                if (job.isControlCmd()) {
                    job.sendError(error);
                    continue;
                }
                this.model.addFirst(job);
            }
        }
    }
}

