/*
 * Decompiled with CFR 0.152.
 */
package com.amazonaws.ml.mms.wlm;

import com.amazonaws.ml.mms.util.ConfigManager;
import com.amazonaws.ml.mms.wlm.BatchAggregator;
import com.amazonaws.ml.mms.wlm.Model;
import com.amazonaws.ml.mms.wlm.WorkerState;
import com.amazonaws.ml.mms.wlm.WorkerStateListener;
import com.amazonaws.ml.mms.wlm.WorkerThread;
import io.netty.channel.EventLoopGroup;
import io.netty.handler.codec.http.HttpResponseStatus;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;

public class WorkLoadManager {
    private ExecutorService threadPool;
    private ConcurrentHashMap<String, List<WorkerThread>> workers;
    private ConfigManager configManager;
    private EventLoopGroup backendGroup;
    private AtomicInteger port;
    private AtomicInteger gpuCounter;

    public WorkLoadManager(ConfigManager configManager, EventLoopGroup backendGroup) {
        this.configManager = configManager;
        this.backendGroup = backendGroup;
        this.port = new AtomicInteger(9000);
        this.gpuCounter = new AtomicInteger(0);
        this.threadPool = Executors.newCachedThreadPool();
        this.workers = new ConcurrentHashMap();
    }

    public List<WorkerThread> getWorkers(String modelName) {
        List<WorkerThread> list = this.workers.get(modelName);
        if (list == null) {
            return Collections.emptyList();
        }
        return new ArrayList<WorkerThread>(list);
    }

    public Map<Integer, WorkerThread> getWorkers() {
        HashMap<Integer, WorkerThread> map = new HashMap<Integer, WorkerThread>();
        for (List<WorkerThread> workerThreads : this.workers.values()) {
            for (WorkerThread worker : workerThreads) {
                map.put(worker.getPid(), worker);
            }
        }
        return map;
    }

    public boolean hasNoWorker(String modelName) {
        List<WorkerThread> worker = this.workers.get(modelName);
        if (worker == null) {
            return true;
        }
        return worker.isEmpty();
    }

    public int getNumRunningWorkers(String modelName) {
        int numWorking = 0;
        List threads = this.workers.getOrDefault(modelName, null);
        if (threads != null) {
            for (WorkerThread thread : threads) {
                if (thread.getState() == WorkerState.WORKER_STOPPED || thread.getState() == WorkerState.WORKER_ERROR || thread.getState() == WorkerState.WORKER_SCALED_DOWN) continue;
                ++numWorking;
            }
        }
        return numWorking;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public CompletableFuture<HttpResponseStatus> modelChanged(Model model) {
        String string = model.getModelName();
        synchronized (string) {
            int currentWorkers;
            List threads;
            CompletableFuture<HttpResponseStatus> future = new CompletableFuture<HttpResponseStatus>();
            int minWorker = model.getMinWorkers();
            int maxWorker = model.getMaxWorkers();
            if (minWorker == 0) {
                threads = this.workers.remove(model.getModelName());
                if (threads == null) {
                    future.complete(HttpResponseStatus.OK);
                    return future;
                }
            } else {
                threads = this.workers.computeIfAbsent(model.getModelName(), k -> new ArrayList());
            }
            if ((currentWorkers = threads.size()) < minWorker) {
                this.addThreads(threads, model, minWorker - currentWorkers, future);
            } else {
                for (int i = currentWorkers - 1; i >= maxWorker; --i) {
                    WorkerThread thread = (WorkerThread)threads.remove(i);
                    thread.shutdown();
                }
                future.complete(HttpResponseStatus.OK);
            }
            return future;
        }
    }

    private void addThreads(List<WorkerThread> threads, Model model, int count, CompletableFuture<HttpResponseStatus> future) {
        WorkerStateListener listener = new WorkerStateListener(future, count);
        int maxGpu = this.configManager.getNumberOfGpu();
        for (int i = 0; i < count; ++i) {
            int gpuId = -1;
            if (maxGpu > 0) {
                gpuId = this.gpuCounter.accumulateAndGet(maxGpu, (prev, maxGpuId) -> ++prev % maxGpuId);
            }
            BatchAggregator aggregator = new BatchAggregator(model);
            WorkerThread thread = new WorkerThread(this.configManager, this.backendGroup, this.configManager.isDebug() ? this.port.get() : this.port.getAndIncrement(), gpuId, model, aggregator, listener);
            threads.add(thread);
            this.threadPool.submit(thread);
        }
    }

    public void scheduleAsync(Runnable r) {
        this.threadPool.execute(r);
    }
}

