/*
 * Decompiled with CFR 0.152.
 */
package org.pytorch.serve.wlm;

import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.pytorch.serve.job.Job;
import org.pytorch.serve.job.JobGroup;
import org.pytorch.serve.util.messages.BaseModelRequest;
import org.pytorch.serve.util.messages.ModelWorkerResponse;
import org.pytorch.serve.wlm.BatchAggregator;
import org.pytorch.serve.wlm.Model;
import org.pytorch.serve.wlm.WorkerState;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SequenceBatchAggregator
extends BatchAggregator {
    private static final Logger logger = LoggerFactory.getLogger(SequenceBatchAggregator.class);
    private ExecutorService pollExecutors;
    private LinkedBlockingDeque<String> eventJobGroupIds;
    private LinkedBlockingDeque<Job> jobsQueue;
    private Thread eventDispatcher;
    private AtomicBoolean isPollJobGroup;
    private LinkedList<String> currentJobGroupIds = new LinkedList();
    private int localCapacity;

    public SequenceBatchAggregator(Model model) {
        super(model);
        this.pollExecutors = Executors.newFixedThreadPool(model.getBatchSize() + 1);
        this.jobsQueue = new LinkedBlockingDeque();
        this.isPollJobGroup = new AtomicBoolean(false);
        this.localCapacity = model.getMaxNumSequence() / model.getMinWorkers();
        this.eventJobGroupIds = new LinkedBlockingDeque();
        this.eventJobGroupIds.add("");
        this.eventDispatcher = new Thread(new EventDispatcher());
        this.eventDispatcher.start();
    }

    public void startEventDispatcher() {
        this.eventDispatcher.start();
    }

    public void stopEventDispatcher() {
        this.eventDispatcher.interrupt();
    }

    private void pollJobGroup() throws InterruptedException {
        if (this.isPollJobGroup.getAndSet(true)) {
            return;
        }
        LinkedHashSet tmpJobGroups = new LinkedHashSet();
        String jobGroupId = this.model.getPendingJobGroups().poll(Long.MAX_VALUE, TimeUnit.MILLISECONDS);
        if (jobGroupId != null) {
            this.addJobGroup(jobGroupId);
            int quota = Math.min(this.localCapacity - this.jobsQueue.size(), this.model.getPendingJobGroups().size() / this.model.getMaxWorkers());
            if (quota > 0 && this.model.getPendingJobGroups().size() > 0) {
                this.model.getPendingJobGroups().drainTo(tmpJobGroups, quota);
            }
            for (String jGroupId : tmpJobGroups) {
                this.addJobGroup(jGroupId);
            }
        }
        this.isPollJobGroup.set(false);
    }

    private void pollInferJob() throws InterruptedException {
        this.model.pollInferJob(this.jobs, this.model.getBatchSize(), this.jobsQueue);
        for (Job job : this.jobs.values()) {
            if (job.getGroupId() == null) continue;
            this.currentJobGroupIds.add(job.getGroupId());
        }
    }

    @Override
    public void pollBatch(String threadName, WorkerState state) throws InterruptedException, ExecutionException {
        boolean pollMgmtJobStatus = false;
        if (this.jobs.isEmpty()) {
            pollMgmtJobStatus = this.model.pollMgmtJob(threadName, state == WorkerState.WORKER_MODEL_LOADED ? 0L : Long.MAX_VALUE, this.jobs);
        }
        if (!pollMgmtJobStatus && state == WorkerState.WORKER_MODEL_LOADED) {
            this.pollInferJob();
        }
    }

    private void cleanJobGroup(String jobGroupId) {
        logger.debug("Clean jobGroup: {}", (Object)jobGroupId);
        if (jobGroupId != null) {
            this.model.removeJobGroup(jobGroupId);
        }
    }

    @Override
    public void handleErrorJob(Job job) {
        if (job.getGroupId() == null) {
            this.model.addFirst(job);
        } else {
            logger.error("Failed to process requestId: {}, sequenceId: {}", (Object)job.getPayload().getRequestId(), (Object)job.getGroupId());
        }
    }

    @Override
    public boolean sendResponse(ModelWorkerResponse message) {
        boolean jobDone = super.sendResponse(message);
        if (jobDone && !this.currentJobGroupIds.isEmpty()) {
            this.eventJobGroupIds.addAll(this.currentJobGroupIds);
            this.currentJobGroupIds.clear();
        }
        return jobDone;
    }

    @Override
    public void sendError(BaseModelRequest message, String error, int status) {
        super.sendError(message, error, status);
        if (!this.currentJobGroupIds.isEmpty()) {
            this.eventJobGroupIds.addAll(this.currentJobGroupIds);
            this.currentJobGroupIds.clear();
        }
    }

    @Override
    public void cleanJobs() {
        super.cleanJobs();
        if (!this.currentJobGroupIds.isEmpty()) {
            this.eventJobGroupIds.addAll(this.currentJobGroupIds);
            this.currentJobGroupIds.clear();
        }
    }

    public void shutdownExecutors() {
        this.pollExecutors.shutdown();
    }

    private void addJobGroup(String jobGroupId) {
        if (jobGroupId != null) {
            this.eventJobGroupIds.add(jobGroupId);
        }
    }

    class EventDispatcher
    implements Runnable {
        EventDispatcher() {
        }

        @Override
        public void run() {
            while (true) {
                try {
                    while (true) {
                        String jobGroupId;
                        if ((jobGroupId = (String)SequenceBatchAggregator.this.eventJobGroupIds.poll(SequenceBatchAggregator.this.model.getMaxBatchDelay(), TimeUnit.MILLISECONDS)) == null || jobGroupId.isEmpty()) {
                            CompletableFuture.runAsync(() -> {
                                try {
                                    SequenceBatchAggregator.this.pollJobGroup();
                                }
                                catch (InterruptedException e) {
                                    logger.error("Failed to poll a job group", e);
                                }
                            }, SequenceBatchAggregator.this.pollExecutors);
                            continue;
                        }
                        CompletableFuture.runAsync(() -> this.pollJobFromJobGroup(jobGroupId), SequenceBatchAggregator.this.pollExecutors);
                    }
                }
                catch (InterruptedException e) {
                    logger.error("EventDispatcher failed to get jobGroup", e);
                    continue;
                }
                break;
            }
        }

        private void pollJobFromJobGroup(String jobGroupId) {
            JobGroup jobGroup = SequenceBatchAggregator.this.model.getJobGroup(jobGroupId);
            Job job = jobGroup.pollJob(SequenceBatchAggregator.this.model.getSequenceMaxIdleMSec());
            if (job == null) {
                SequenceBatchAggregator.this.cleanJobGroup(jobGroupId);
                SequenceBatchAggregator.this.eventJobGroupIds.add("");
            } else {
                SequenceBatchAggregator.this.jobsQueue.add(job);
            }
        }
    }
}

