/*
 * Decompiled with CFR 0.152.
 */
package com.amazonaws.ml.mms.wlm;

import com.amazonaws.ml.mms.archive.Manifest;
import com.amazonaws.ml.mms.archive.ModelArchive;
import com.amazonaws.ml.mms.archive.ModelException;
import com.amazonaws.ml.mms.archive.ModelNotFoundException;
import com.amazonaws.ml.mms.http.BadRequestException;
import com.amazonaws.ml.mms.http.StatusResponse;
import com.amazonaws.ml.mms.util.ConfigManager;
import com.amazonaws.ml.mms.util.NettyUtils;
import com.amazonaws.ml.mms.wlm.Job;
import com.amazonaws.ml.mms.wlm.Model;
import com.amazonaws.ml.mms.wlm.WorkLoadManager;
import com.amazonaws.ml.mms.wlm.WorkerThread;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.HttpResponseStatus;
import java.io.IOException;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class ModelManager {
    private static final Logger logger = LoggerFactory.getLogger(ModelManager.class);
    private static ModelManager modelManager;
    private ConfigManager configManager;
    private WorkLoadManager wlm;
    private ConcurrentHashMap<String, Model> models;
    private HashSet<String> startupModels;
    private ScheduledExecutorService scheduler;

    private ModelManager(ConfigManager configManager, WorkLoadManager wlm) {
        this.configManager = configManager;
        this.wlm = wlm;
        this.models = new ConcurrentHashMap();
        this.scheduler = Executors.newScheduledThreadPool(2);
        this.startupModels = new HashSet();
    }

    public ScheduledExecutorService getScheduler() {
        return this.scheduler;
    }

    public static void init(ConfigManager configManager, WorkLoadManager wlm) {
        modelManager = new ModelManager(configManager, wlm);
    }

    public static ModelManager getInstance() {
        return modelManager;
    }

    public ModelArchive registerModel(String url, String defaultModelName) throws ModelException, IOException {
        return this.registerModel(url, null, null, null, 1, 100, this.configManager.getDefaultResponseTimeout(), defaultModelName);
    }

    public ModelArchive registerModel(String url, String modelName, Manifest.RuntimeType runtime, String handler, int batchSize, int maxBatchDelay, int responseTimeout, String defaultModelName) throws ModelException, IOException {
        ModelArchive archive = ModelArchive.downloadModel(this.configManager.getModelStore(), url);
        if (modelName == null || modelName.isEmpty()) {
            if (archive.getModelName() == null || archive.getModelName().isEmpty()) {
                archive.getManifest().getModel().setModelName(defaultModelName);
            }
            modelName = archive.getModelName();
        } else {
            archive.getManifest().getModel().setModelName(modelName);
        }
        if (runtime != null) {
            archive.getManifest().setRuntime(runtime);
        }
        if (handler != null) {
            archive.getManifest().getModel().setHandler(handler);
        } else if (archive.getHandler() == null || archive.getHandler().isEmpty()) {
            archive.getManifest().getModel().setHandler(this.configManager.getMmsDefaultServiceHandler());
        }
        archive.validate();
        Model model = new Model(archive, this.configManager.getJobQueueSize());
        model.setBatchSize(batchSize);
        model.setMaxBatchDelay(maxBatchDelay);
        model.setResponseTimeout(responseTimeout);
        Model existingModel = this.models.putIfAbsent(modelName, model);
        if (existingModel != null) {
            throw new BadRequestException("Model " + modelName + " is already registered.");
        }
        logger.info("Model {} loaded.", (Object)model.getModelName());
        return archive;
    }

    public boolean unregisterModel(String modelName) {
        Model model = this.models.remove(modelName);
        if (model == null) {
            logger.warn("Model not found: " + modelName);
            return false;
        }
        model.setMinWorkers(0);
        model.setMaxWorkers(0);
        this.wlm.modelChanged(model);
        model.getModelArchive().clean();
        this.startupModels.remove(modelName);
        logger.info("Model {} unregistered.", (Object)modelName);
        return true;
    }

    public CompletableFuture<HttpResponseStatus> updateModel(String modelName, int minWorkers, int maxWorkers) {
        Model model = this.models.get(modelName);
        if (model == null) {
            throw new AssertionError((Object)("Model not found: " + modelName));
        }
        model.setMinWorkers(minWorkers);
        model.setMaxWorkers(maxWorkers);
        logger.debug("updateModel: {}, count: {}", (Object)modelName, (Object)minWorkers);
        return this.wlm.modelChanged(model);
    }

    public Map<String, Model> getModels() {
        return this.models;
    }

    public List<WorkerThread> getWorkers(String modelName) {
        return this.wlm.getWorkers(modelName);
    }

    public Map<Integer, WorkerThread> getWorkers() {
        return this.wlm.getWorkers();
    }

    public boolean addJob(Job job) throws ModelNotFoundException {
        String modelName = job.getModelName();
        Model model = this.models.get(modelName);
        if (model == null) {
            throw new ModelNotFoundException("Model not found: " + modelName);
        }
        if (this.wlm.hasNoWorker(modelName)) {
            return false;
        }
        return model.addJob(job);
    }

    public void workerStatus(ChannelHandlerContext ctx) {
        Runnable r = () -> {
            String response = "Healthy";
            int numWorking = 0;
            int numScaled = 0;
            for (Map.Entry<String, Model> m : this.models.entrySet()) {
                numScaled += m.getValue().getMinWorkers();
                numWorking += this.wlm.getNumRunningWorkers(m.getValue().getModelName());
            }
            if (numWorking > 0 && numWorking < numScaled) {
                response = "Partial Healthy";
            } else if (numWorking == 0 && numScaled > 0) {
                response = "Unhealthy";
            }
            NettyUtils.sendJsonResponse(ctx, new StatusResponse(response), HttpResponseStatus.OK);
        };
        this.wlm.scheduleAsync(r);
    }

    public boolean scaleRequestStatus(String modelName) {
        Model model = ModelManager.getInstance().getModels().get(modelName);
        int numWorkers = this.wlm.getNumRunningWorkers(modelName);
        return model == null || model.getMinWorkers() <= numWorkers;
    }

    public void submitTask(Runnable runnable) {
        this.wlm.scheduleAsync(runnable);
    }

    public Set<String> getStartupModels() {
        return this.startupModels;
    }
}

