/*
 * Decompiled with CFR 0.152.
 */
package org.pytorch.serve.wlm;

import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.io.File;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;
import org.pytorch.serve.archive.model.Manifest;
import org.pytorch.serve.archive.model.ModelArchive;
import org.pytorch.serve.archive.model.ModelConfig;
import org.pytorch.serve.archive.utils.ArchiveUtils;
import org.pytorch.serve.job.Job;
import org.pytorch.serve.job.JobGroup;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.messages.WorkerCommands;
import org.pytorch.serve.wlm.ModelVersionName;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Model {
    public static final String DEFAULT_DATA_QUEUE = "DATA_QUEUE";
    public static final String MIN_WORKERS = "minWorkers";
    public static final String MAX_WORKERS = "maxWorkers";
    public static final String BATCH_SIZE = "batchSize";
    public static final String MAX_BATCH_DELAY = "maxBatchDelay";
    public static final String RESPONSE_TIMEOUT = "responseTimeout";
    public static final String PARALLEL_LEVEL = "parallelLevel";
    public static final String DEFAULT_VERSION = "defaultVersion";
    public static final String MAR_NAME = "marName";
    public static final String RUNTIME_TYPE = "runtimeType";
    private static final Logger logger = LoggerFactory.getLogger(Model.class);
    private ModelArchive modelArchive;
    private int minWorkers;
    private int maxWorkers;
    private int batchSize;
    private int maxBatchDelay;
    private int parallelLevel;
    private long maxRetryTimeoutInMill = 300000L;
    private long clientTimeoutInMills;
    private ModelConfig.ParallelType parallelType = ModelConfig.ParallelType.NONE;
    private ModelConfig.DeviceType deviceType = ConfigManager.getInstance().getNumberOfGpu() > 0 ? ModelConfig.DeviceType.GPU : ModelConfig.DeviceType.CPU;
    private List<Integer> deviceIds;
    private int numCores;
    private ReentrantLock lock;
    private ReentrantLock jobGroupLock;
    private int responseTimeout;
    private long sequenceMaxIdleMSec;
    private int maxNumSequence;
    private int maxSequenceJobQueueSize;
    private boolean stateful;
    private ConcurrentMap<String, JobGroup> jobGroups;
    private LinkedBlockingDeque<String> pendingJobGroups;
    private ModelVersionName modelVersionName;
    private AtomicInteger gpuCounter = new AtomicInteger(0);
    private boolean hasCfgDeviceIds;
    private boolean isWorkflowModel;
    private Manifest.RuntimeType runtimeType;
    private AtomicInteger failedInfReqs;
    private ConcurrentMap<String, LinkedBlockingDeque<Job>> jobsDb;
    private boolean useJobTicket;
    private AtomicInteger numJobTickets;
    private boolean continuousBatching;
    private boolean useVenv;

    public Model(ModelArchive modelArchive, int queueSize) {
        this.modelArchive = modelArchive;
        if (modelArchive != null && modelArchive.getModelConfig() != null) {
            this.continuousBatching = modelArchive.getModelConfig().isContinuousBatching();
            this.useVenv = modelArchive.getModelConfig().getUseVenv();
            if (modelArchive.getModelConfig().getParallelLevel() > 0 && modelArchive.getModelConfig().getParallelType() != ModelConfig.ParallelType.NONE) {
                this.parallelLevel = modelArchive.getModelConfig().getParallelLevel();
                this.parallelType = modelArchive.getModelConfig().getParallelType();
            }
            if (modelArchive.getModelConfig().getDeviceType() != ModelConfig.DeviceType.NONE) {
                this.deviceType = modelArchive.getModelConfig().getDeviceType() == ModelConfig.DeviceType.GPU && ConfigManager.getInstance().getNumberOfGpu() > 0 ? ModelConfig.DeviceType.GPU : ModelConfig.DeviceType.CPU;
            }
            this.deviceIds = modelArchive.getModelConfig().getDeviceIds();
            if (this.deviceIds != null && this.deviceIds.size() > 0) {
                this.hasCfgDeviceIds = true;
                for (Integer deviceId : this.deviceIds) {
                    if (deviceId >= 0 && deviceId < ConfigManager.getInstance().getNumberOfGpu()) continue;
                    logger.warn("Invalid deviceId:{}, ignore deviceIds list", (Object)deviceId);
                    this.deviceIds = null;
                    this.hasCfgDeviceIds = false;
                    break;
                }
            }
            this.maxRetryTimeoutInMill = modelArchive.getModelConfig().getMaxRetryTimeoutInSec() * 1000;
            this.clientTimeoutInMills = modelArchive.getModelConfig().getClientTimeoutInMills();
            if (modelArchive.getModelConfig().getJobQueueSize() > 0) {
                queueSize = modelArchive.getModelConfig().getJobQueueSize();
            }
            this.useJobTicket = modelArchive.getModelConfig().isUseJobTicket();
            if (modelArchive.getModelConfig().getSequenceMaxIdleMSec() > 0L) {
                this.sequenceMaxIdleMSec = modelArchive.getModelConfig().getSequenceMaxIdleMSec();
                this.maxSequenceJobQueueSize = modelArchive.getModelConfig().getMaxSequenceJobQueueSize();
                this.maxNumSequence = Math.max(modelArchive.getModelConfig().getMaxNumSequence(), this.batchSize * this.maxWorkers);
                this.jobGroups = new ConcurrentHashMap<String, JobGroup>(this.maxNumSequence);
                this.pendingJobGroups = new LinkedBlockingDeque(this.maxNumSequence);
                this.jobGroupLock = new ReentrantLock();
                this.stateful = true;
            }
        } else {
            this.batchSize = 1;
            this.maxBatchDelay = 100;
        }
        if (ConfigManager.getInstance().getNumberOfGpu() > 0 && this.deviceType != ModelConfig.DeviceType.CPU) {
            this.numCores = this.hasCfgDeviceIds ? this.deviceIds.size() : ConfigManager.getInstance().getNumberOfGpu();
        }
        this.jobsDb = new ConcurrentHashMap<String, LinkedBlockingDeque<Job>>();
        this.jobsDb.putIfAbsent(DEFAULT_DATA_QUEUE, new LinkedBlockingDeque(queueSize));
        this.failedInfReqs = new AtomicInteger(0);
        this.numJobTickets = new AtomicInteger(0);
        this.lock = new ReentrantLock();
        this.modelVersionName = new ModelVersionName(this.modelArchive.getModelName(), this.modelArchive.getModelVersion());
        this.runtimeType = modelArchive.getManifest().getRuntime();
    }

    public JsonObject getModelState(boolean isDefaultVersion) {
        JsonObject modelInfo = new JsonObject();
        modelInfo.addProperty(DEFAULT_VERSION, isDefaultVersion);
        modelInfo.addProperty(MAR_NAME, ArchiveUtils.getFilenameFromUrl(this.getModelUrl()));
        modelInfo.addProperty(MIN_WORKERS, this.getMinWorkers());
        modelInfo.addProperty(MAX_WORKERS, this.getMaxWorkers());
        modelInfo.addProperty(BATCH_SIZE, this.getBatchSize());
        modelInfo.addProperty(MAX_BATCH_DELAY, this.getMaxBatchDelay());
        modelInfo.addProperty(RESPONSE_TIMEOUT, this.getResponseTimeout());
        modelInfo.addProperty(RUNTIME_TYPE, this.getRuntimeType().getValue());
        if (this.parallelLevel > 0) {
            modelInfo.addProperty(PARALLEL_LEVEL, this.parallelLevel);
        }
        return modelInfo;
    }

    public void setModelState(JsonObject modelInfo) {
        this.minWorkers = modelInfo.get(MIN_WORKERS).getAsInt();
        this.maxWorkers = modelInfo.get(MAX_WORKERS).getAsInt();
        this.maxBatchDelay = modelInfo.get(MAX_BATCH_DELAY).getAsInt();
        this.responseTimeout = modelInfo.get(RESPONSE_TIMEOUT).getAsInt();
        this.batchSize = modelInfo.get(BATCH_SIZE).getAsInt();
        JsonElement runtime = modelInfo.get(RUNTIME_TYPE);
        String runtime_str = Manifest.RuntimeType.PYTHON.getValue();
        if (runtime != null) {
            runtime_str = runtime.getAsString();
        }
        this.runtimeType = Manifest.RuntimeType.fromValue(runtime_str);
        if (modelInfo.get(PARALLEL_LEVEL) != null) {
            this.parallelLevel = modelInfo.get(PARALLEL_LEVEL).getAsInt();
        }
    }

    public String getModelName() {
        return this.modelArchive.getModelName();
    }

    public ModelVersionName getModelVersionName() {
        return this.modelVersionName;
    }

    public String getVersion() {
        return this.modelArchive.getModelVersion();
    }

    public File getModelDir() {
        return this.modelArchive.getModelDir();
    }

    public String getModelUrl() {
        return this.modelArchive.getUrl();
    }

    public ModelArchive getModelArchive() {
        return this.modelArchive;
    }

    public int getMinWorkers() {
        return this.minWorkers;
    }

    public void setMinWorkers(int minWorkers) {
        this.minWorkers = minWorkers;
    }

    public int getMaxWorkers() {
        return this.maxWorkers;
    }

    public void setMaxWorkers(int maxWorkers) {
        this.maxWorkers = maxWorkers;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public void setBatchSize(int batchSize) {
        this.batchSize = batchSize;
    }

    public int getMaxBatchDelay() {
        return this.maxBatchDelay;
    }

    public void setMaxBatchDelay(int maxBatchDelay) {
        this.maxBatchDelay = maxBatchDelay;
    }

    public boolean isWorkflowModel() {
        return this.isWorkflowModel;
    }

    public void setWorkflowModel(boolean workflowModel) {
        this.isWorkflowModel = workflowModel;
    }

    public Manifest.RuntimeType getRuntimeType() {
        return this.runtimeType;
    }

    public void setRuntimeType(Manifest.RuntimeType runtimeType) {
        this.runtimeType = runtimeType;
    }

    public void addJob(String threadId, Job job) {
        LinkedBlockingDeque<Job> blockingDeque = (LinkedBlockingDeque<Job>)this.jobsDb.get(threadId);
        if (blockingDeque == null) {
            blockingDeque = new LinkedBlockingDeque<Job>();
            this.jobsDb.put(threadId, blockingDeque);
        }
        blockingDeque.offer(job);
    }

    public void removeJobQueue(String threadId) {
        if (!threadId.equals(DEFAULT_DATA_QUEUE)) {
            this.jobsDb.remove(threadId);
        }
    }

    public boolean addJob(Job job) {
        if (this.isUseJobTicket() && !this.getJobTickets()) {
            logger.info("There are no job tickets available");
            return false;
        }
        if (job.getGroupId() != null) {
            return this.addJobInGroup(job);
        }
        return ((LinkedBlockingDeque)this.jobsDb.get(DEFAULT_DATA_QUEUE)).offer(job);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private boolean addJobInGroup(Job job) {
        try {
            this.jobGroupLock.lockInterruptibly();
            JobGroup jobGroup = (JobGroup)this.jobGroups.get(job.getGroupId());
            if (jobGroup == null) {
                if (this.jobGroups.size() < this.maxNumSequence) {
                    jobGroup = new JobGroup(job.getGroupId(), this.maxSequenceJobQueueSize);
                    this.jobGroups.put(job.getGroupId(), jobGroup);
                    this.pendingJobGroups.offer(job.getGroupId());
                    logger.info("added jobGroup for sequenceId:{}", (Object)job.getGroupId());
                } else {
                    logger.warn("Skip the requestId: {} for sequence: {} due to exceeding maxNumSequence: {}", job.getJobId(), job.getGroupId(), this.maxNumSequence);
                    boolean bl = false;
                    return bl;
                }
            }
            boolean bl = jobGroup.appendJob(job);
            return bl;
        }
        catch (InterruptedException | NullPointerException e) {
            logger.error("Skip the requestId: {} for sequence: {} due to exception", job.getJobId(), job.getGroupId(), e);
            boolean bl = false;
            return bl;
        }
        finally {
            if (this.jobGroupLock.isHeldByCurrentThread()) {
                this.jobGroupLock.unlock();
            }
        }
    }

    public void addFirst(Job job) {
        ((LinkedBlockingDeque)this.jobsDb.get(DEFAULT_DATA_QUEUE)).addFirst(job);
    }

    public boolean pollMgmtJob(String threadId, long waitTime, Map<String, Job> jobsRepo) throws InterruptedException {
        Job j;
        if (jobsRepo == null || threadId == null || threadId.isEmpty()) {
            throw new IllegalArgumentException("Invalid input given provided");
        }
        if (!jobsRepo.isEmpty()) {
            throw new IllegalArgumentException("The jobs repo provided contains stale jobs. Clear them!!");
        }
        LinkedBlockingDeque jobsQueue = (LinkedBlockingDeque)this.jobsDb.get(threadId);
        if (jobsQueue != null && !jobsQueue.isEmpty() && (j = (Job)jobsQueue.poll(waitTime, TimeUnit.MILLISECONDS)) != null) {
            jobsRepo.put(j.getJobId(), j);
            return true;
        }
        return false;
    }

    public void pollInferJob(Map<String, Job> jobsRepo, int batchSize, LinkedBlockingDeque<Job> jobsQueue) throws InterruptedException {
        boolean pollNoWait = !jobsRepo.isEmpty();
        long maxDelay = this.maxBatchDelay;
        Job j = null;
        if (jobsRepo.isEmpty()) {
            j = jobsQueue.poll(Long.MAX_VALUE, TimeUnit.MILLISECONDS);
            logger.trace("get first job: {}", (Object)Objects.requireNonNull(j).getJobId());
            jobsRepo.put(j.getJobId(), j);
            if (j.getCmd() == WorkerCommands.DESCRIBE) {
                if (jobsRepo.isEmpty()) {
                    jobsRepo.put(j.getJobId(), j);
                    return;
                }
                jobsQueue.addFirst(j);
                return;
            }
        }
        long begin = System.currentTimeMillis();
        for (int i = 0; i < batchSize - 1 && (j = pollNoWait ? jobsQueue.poll() : jobsQueue.poll(maxDelay, TimeUnit.MILLISECONDS)) != null; ++i) {
            long end = System.currentTimeMillis();
            if (j.getCmd() == WorkerCommands.DESCRIBE) {
                jobsQueue.addFirst(j);
                break;
            }
            maxDelay -= end - begin;
            begin = end;
            if (j.getPayload().getClientExpireTS() > System.currentTimeMillis()) {
                jobsRepo.put(j.getJobId(), j);
            } else {
                logger.warn("Drop inference request {} due to client timeout", (Object)j.getPayload().getRequestId());
            }
            if (maxDelay <= 0L) break;
        }
        logger.trace("sending jobs, size: {}", (Object)jobsRepo.size());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void pollInferJob(Map<String, Job> jobsRepo, int batchSize) throws InterruptedException {
        try {
            if (this.isUseJobTicket()) {
                this.incNumJobTickets();
            }
            this.lock.lockInterruptibly();
            LinkedBlockingDeque jobsQueue = (LinkedBlockingDeque)this.jobsDb.get(DEFAULT_DATA_QUEUE);
            this.pollInferJob(jobsRepo, batchSize, jobsQueue);
        }
        finally {
            if (this.lock.isHeldByCurrentThread()) {
                this.lock.unlock();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void pollBatch(String threadId, long waitTime, Map<String, Job> jobsRepo) throws InterruptedException {
        Job j;
        if (jobsRepo == null || threadId == null || threadId.isEmpty()) {
            throw new IllegalArgumentException("Invalid input given provided");
        }
        if (!jobsRepo.isEmpty()) {
            throw new IllegalArgumentException("The jobs repo provided contains stale jobs. Clear them!!");
        }
        LinkedBlockingDeque jobsQueue = (LinkedBlockingDeque)this.jobsDb.get(threadId);
        if (jobsQueue != null && !jobsQueue.isEmpty() && (j = (Job)jobsQueue.poll(waitTime, TimeUnit.MILLISECONDS)) != null) {
            jobsRepo.put(j.getJobId(), j);
            return;
        }
        try {
            if (this.isUseJobTicket()) {
                this.incNumJobTickets();
            }
            this.lock.lockInterruptibly();
            long maxDelay = this.maxBatchDelay;
            jobsQueue = (LinkedBlockingDeque)this.jobsDb.get(DEFAULT_DATA_QUEUE);
            Job j2 = (Job)jobsQueue.poll(Long.MAX_VALUE, TimeUnit.MILLISECONDS);
            logger.trace("get first job: {}", (Object)Objects.requireNonNull(j2).getJobId());
            jobsRepo.put(j2.getJobId(), j2);
            if (j2.getCmd() == WorkerCommands.DESCRIBE || j2.getCmd() == WorkerCommands.STREAMPREDICT) {
                return;
            }
            long begin = System.currentTimeMillis();
            for (int i = 0; i < this.batchSize - 1 && (j2 = (Job)jobsQueue.poll(maxDelay, TimeUnit.MILLISECONDS)) != null; ++i) {
                long end = System.currentTimeMillis();
                if (j2.getCmd() == WorkerCommands.DESCRIBE || j2.getCmd() == WorkerCommands.STREAMPREDICT) {
                    jobsQueue.addFirst(j2);
                    break;
                }
                maxDelay -= end - begin;
                begin = end;
                if (j2.getPayload().getClientExpireTS() > System.currentTimeMillis()) {
                    jobsRepo.put(j2.getJobId(), j2);
                } else {
                    logger.warn("Drop inference request {} due to client timeout", (Object)j2.getPayload().getRequestId());
                }
                if (maxDelay <= 0L) break;
            }
            logger.trace("sending jobs, size: {}", (Object)jobsRepo.size());
        }
        finally {
            if (this.lock.isHeldByCurrentThread()) {
                this.lock.unlock();
            }
        }
    }

    public int getJobQueueRemainingCapacity() {
        LinkedBlockingDeque jobsQueue = (LinkedBlockingDeque)this.jobsDb.get(DEFAULT_DATA_QUEUE);
        if (jobsQueue != null) {
            return jobsQueue.remainingCapacity();
        }
        return 0;
    }

    public int getPendingRequestsInJobQueue() {
        LinkedBlockingDeque jobsQueue = (LinkedBlockingDeque)this.jobsDb.get(DEFAULT_DATA_QUEUE);
        if (jobsQueue != null) {
            return jobsQueue.size();
        }
        return 0;
    }

    public int incrFailedInfReqs() {
        return this.failedInfReqs.incrementAndGet();
    }

    public void resetFailedInfReqs() {
        this.failedInfReqs.set(0);
    }

    public int getResponseTimeout() {
        return ConfigManager.getInstance().isDebug() ? Integer.MAX_VALUE : this.responseTimeout;
    }

    public void setResponseTimeout(int responseTimeout) {
        this.responseTimeout = responseTimeout;
    }

    public List<Integer> getDeviceIds() {
        return this.deviceIds;
    }

    public void setDeviceIds(List<Integer> deviceIds) {
        Collections.copy(this.deviceIds, deviceIds);
    }

    public int getParallelLevel() {
        return this.parallelLevel;
    }

    public ModelConfig.ParallelType getParallelType() {
        return this.parallelType;
    }

    public ModelConfig.DeviceType getDeviceType() {
        return this.deviceType;
    }

    public int getNumCores() {
        return this.numCores;
    }

    public AtomicInteger getGpuCounter() {
        return this.gpuCounter;
    }

    public boolean isHasCfgDeviceIds() {
        return this.hasCfgDeviceIds;
    }

    public long getMaxRetryTimeoutInMill() {
        return this.maxRetryTimeoutInMill;
    }

    public void setMaxRetryTimeoutInMill(long maxRetryTimeoutInMill) {
        this.maxRetryTimeoutInMill = maxRetryTimeoutInMill;
    }

    public long getClientTimeoutInMills() {
        return this.clientTimeoutInMills;
    }

    public void setClientTimeoutInMills(long clientTimeoutInMills) {
        this.clientTimeoutInMills = clientTimeoutInMills;
    }

    public boolean isUseJobTicket() {
        return this.useJobTicket;
    }

    public int incNumJobTickets() {
        return this.numJobTickets.incrementAndGet();
    }

    public int decNumJobTickets() {
        return this.numJobTickets.decrementAndGet();
    }

    public synchronized boolean getJobTickets() {
        if (this.numJobTickets.get() == 0) {
            return false;
        }
        this.numJobTickets.decrementAndGet();
        return true;
    }

    public long getSequenceMaxIdleMSec() {
        return this.sequenceMaxIdleMSec;
    }

    public void setSequenceMaxIdleMSec(long sequenceMaxIdleMSec) {
        this.sequenceMaxIdleMSec = sequenceMaxIdleMSec;
    }

    public boolean isStateful() {
        return this.stateful;
    }

    public int getMaxSequenceJobQueueSize() {
        return this.maxSequenceJobQueueSize;
    }

    public int getMaxNumSequence() {
        return this.maxNumSequence;
    }

    public LinkedBlockingDeque<String> getPendingJobGroups() {
        return this.pendingJobGroups;
    }

    public JobGroup getJobGroup(String groupId) {
        return (JobGroup)this.jobGroups.get(groupId);
    }

    public void removeJobGroup(String groupId) {
        this.jobGroups.remove(groupId);
    }

    public boolean isContinuousBatching() {
        return this.continuousBatching;
    }

    public boolean isUseVenv() {
        if (this.getRuntimeType() == Manifest.RuntimeType.PYTHON) {
            return this.useVenv;
        }
        return false;
    }

    public boolean hasTensorParallel() {
        switch (this.parallelType) {
            case PP: 
            case NONE: {
                return false;
            }
        }
        return true;
    }
}

