/*
 * Decompiled with CFR 0.152.
 */
package org.pytorch.serve.wlm;

import io.netty.channel.EventLoopGroup;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.pytorch.serve.snapshot.SnapshotManager;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.OSUtils;
import org.pytorch.serve.wlm.BatchAggregator;
import org.pytorch.serve.wlm.ContinuousBatching;
import org.pytorch.serve.wlm.Model;
import org.pytorch.serve.wlm.ModelVersionName;
import org.pytorch.serve.wlm.SequenceBatchAggregator;
import org.pytorch.serve.wlm.WorkerLifeCycle;
import org.pytorch.serve.wlm.WorkerState;
import org.pytorch.serve.wlm.WorkerStateListener;
import org.pytorch.serve.wlm.WorkerThread;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class WorkLoadManager {
    private static final Logger logger = LoggerFactory.getLogger(WorkLoadManager.class);
    private ExecutorService threadPool;
    private ConcurrentHashMap<ModelVersionName, List<WorkerThread>> workers;
    private ConfigManager configManager;
    private EventLoopGroup backendGroup;
    private AtomicInteger port;
    private AtomicInteger distributionPort;
    private AtomicInteger gpuCounter;

    public WorkLoadManager(ConfigManager configManager, EventLoopGroup backendGroup) {
        this.configManager = configManager;
        this.backendGroup = backendGroup;
        this.port = new AtomicInteger(configManager.getInitialWorkerPort());
        this.distributionPort = new AtomicInteger(configManager.getInitialDistributionPort());
        this.gpuCounter = new AtomicInteger(0);
        this.threadPool = Executors.newCachedThreadPool();
        this.workers = new ConcurrentHashMap();
    }

    public List<WorkerThread> getWorkers(ModelVersionName modelVersionName) {
        List<WorkerThread> list = this.workers.get(modelVersionName);
        if (list == null) {
            return Collections.emptyList();
        }
        return new ArrayList<WorkerThread>(list);
    }

    public Map<Integer, WorkerThread> getWorkers() {
        HashMap<Integer, WorkerThread> map = new HashMap<Integer, WorkerThread>();
        for (List<WorkerThread> workerThreads : this.workers.values()) {
            for (WorkerThread worker : workerThreads) {
                map.put(worker.getPid(), worker);
            }
        }
        return map;
    }

    public boolean hasNoWorker(ModelVersionName modelVersionName) {
        List<WorkerThread> worker = this.workers.get(modelVersionName);
        if (worker == null) {
            return true;
        }
        return worker.isEmpty();
    }

    public int getNumRunningWorkers(ModelVersionName modelVersionName) {
        int numWorking = 0;
        List threads = this.workers.getOrDefault(modelVersionName, null);
        if (threads != null) {
            for (WorkerThread thread : threads) {
                if (thread.getState() == WorkerState.WORKER_STOPPED || thread.getState() == WorkerState.WORKER_ERROR || thread.getState() == WorkerState.WORKER_SCALED_DOWN) continue;
                ++numWorking;
            }
        }
        return numWorking;
    }

    public int getNumHealthyWorkers(ModelVersionName modelVersionName) {
        int numHealthy = 0;
        List threads = this.workers.getOrDefault(modelVersionName, null);
        if (threads != null) {
            for (WorkerThread thread : threads) {
                if (!thread.isHealthy()) continue;
                ++numHealthy;
            }
        }
        return numHealthy;
    }

    public boolean isLauncherRestartWorkers(int currentWorkers) {
        return this.configManager.isCPULauncherEnabled() && currentWorkers > 0;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public CompletableFuture<Integer> modelChanged(Model model, boolean isStartup, boolean isCleanUp) {
        ModelVersionName modelVersionName = model.getModelVersionName();
        synchronized (modelVersionName) {
            int currentWorkers;
            boolean isRestartWorkers;
            List threads;
            boolean isSnapshotSaved = false;
            CompletableFuture<Integer> future = new CompletableFuture<Integer>();
            int minWorker = model.getMinWorkers();
            int maxWorker = model.getMaxWorkers();
            int restartNumWorkers = minWorker;
            if (minWorker == 0) {
                threads = this.workers.remove(model.getModelVersionName());
                if (threads == null) {
                    future.complete(200);
                    if (!(isStartup || isCleanUp || model.isWorkflowModel())) {
                        SnapshotManager.getInstance().saveSnapshot();
                    }
                    return future;
                }
            } else {
                threads = this.workers.computeIfAbsent(model.getModelVersionName(), k -> new ArrayList());
            }
            if (isRestartWorkers = this.isLauncherRestartWorkers(currentWorkers = threads.size())) {
                logger.warn("removing {} current thread(s) prior to restarting {} thread(s)", (Object)currentWorkers, (Object)minWorker);
                maxWorker = 0;
                minWorker = 0;
            }
            if (currentWorkers < minWorker) {
                this.addThreads(threads, model, minWorker - currentWorkers, future);
            } else {
                for (int i = currentWorkers - 1; i >= maxWorker; --i) {
                    WorkerThread thread = (WorkerThread)threads.remove(i);
                    WorkerLifeCycle lifecycle = thread.getLifeCycle();
                    thread.shutdown();
                    Process workerProcess = lifecycle.getProcess();
                    if (workerProcess == null || !workerProcess.isAlive()) continue;
                    boolean workerDestroyed = false;
                    try {
                        String cmd = String.format(OSUtils.getKillCmd(), workerProcess.pid());
                        Process workerKillProcess = Runtime.getRuntime().exec(cmd, null, null);
                        workerDestroyed = workerKillProcess.waitFor(this.configManager.getUnregisterModelTimeout(), TimeUnit.SECONDS);
                    }
                    catch (IOException | InterruptedException e) {
                        logger.warn("WorkerThread interrupted during waitFor, possible async resource cleanup.");
                        future.complete(500);
                        return future;
                    }
                    if (workerDestroyed) continue;
                    logger.warn("WorkerThread timed out while cleaning, please resend request.");
                    future.complete(408);
                    return future;
                }
                if (!(isStartup || isCleanUp || model.isWorkflowModel())) {
                    SnapshotManager.getInstance().saveSnapshot();
                    isSnapshotSaved = true;
                }
                future.complete(200);
            }
            if (isRestartWorkers) {
                logger.warn("restarting {} thread(s)", (Object)restartNumWorkers);
                this.addThreads(threads, model, restartNumWorkers, future);
            }
            if (!(isStartup || isSnapshotSaved || isCleanUp || model.isWorkflowModel())) {
                SnapshotManager.getInstance().saveSnapshot();
            }
            return future;
        }
    }

    private void addThreads(List<WorkerThread> threads, Model model, int count, CompletableFuture<Integer> future) {
        WorkerStateListener listener = new WorkerStateListener(future, count);
        int maxGpu = model.getNumCores();
        int stride = model.getParallelLevel() > 0 ? model.getParallelLevel() : 1;
        for (int i = 0; i < count; ++i) {
            int gpuId = -1;
            if (maxGpu > 0) {
                if (model.isHasCfgDeviceIds() || model.getParallelLevel() > 0) {
                    gpuId = model.getGpuCounter().getAndAccumulate(stride, (prev, myStride) -> (prev + myStride) % maxGpu);
                    if (model.getParallelLevel() == 0) {
                        gpuId = model.getDeviceIds().get(gpuId);
                    }
                } else {
                    gpuId = this.gpuCounter.accumulateAndGet(maxGpu, (prev, maxGpuId) -> ++prev % maxGpuId);
                }
            }
            BatchAggregator aggregator = model.isStateful() ? new SequenceBatchAggregator(model) : (model.isContinuousBatching() ? new ContinuousBatching(model) : new BatchAggregator(model));
            int currentPort = model.getParallelLevel() > 0 ? (this.configManager.isDebug() ? this.distributionPort.get() : this.distributionPort.getAndAdd(model.getParallelLevel())) : (this.configManager.isDebug() ? this.port.get() : this.port.getAndIncrement());
            WorkerThread thread = new WorkerThread(this.configManager, this.backendGroup, currentPort, gpuId, model, aggregator, listener);
            threads.add(thread);
            this.threadPool.submit(thread);
        }
    }

    public void scheduleAsync(Runnable r) {
        this.threadPool.execute(r);
    }
}

