/*
 * Decompiled with CFR 0.152.
 */
package org.pytorch.serve.wlm;

import com.google.gson.JsonObject;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import org.apache.commons.io.FileUtils;
import org.pytorch.serve.archive.DownloadArchiveException;
import org.pytorch.serve.archive.model.Manifest;
import org.pytorch.serve.archive.model.ModelArchive;
import org.pytorch.serve.archive.model.ModelConfig;
import org.pytorch.serve.archive.model.ModelException;
import org.pytorch.serve.archive.model.ModelNotFoundException;
import org.pytorch.serve.archive.model.ModelVersionNotFoundException;
import org.pytorch.serve.http.ConflictStatusException;
import org.pytorch.serve.http.InvalidModelVersionException;
import org.pytorch.serve.http.messages.RegisterModelRequest;
import org.pytorch.serve.job.Job;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.messages.EnvironmentUtils;
import org.pytorch.serve.wlm.Model;
import org.pytorch.serve.wlm.ModelVersionName;
import org.pytorch.serve.wlm.ModelVersionedRefs;
import org.pytorch.serve.wlm.WorkLoadManager;
import org.pytorch.serve.wlm.WorkerInitializationException;
import org.pytorch.serve.wlm.WorkerThread;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class ModelManager {
    private static final Logger logger = LoggerFactory.getLogger(ModelManager.class);
    private static ModelManager modelManager;
    private final ConfigManager configManager;
    private final WorkLoadManager wlm;
    private final ConcurrentHashMap<String, ModelVersionedRefs> modelsNameMap;
    private final HashSet<String> startupModels;
    private final ScheduledExecutorService scheduler;

    private ModelManager(ConfigManager configManager, WorkLoadManager wlm) {
        this.configManager = configManager;
        this.wlm = wlm;
        this.modelsNameMap = new ConcurrentHashMap();
        this.scheduler = Executors.newScheduledThreadPool(2);
        this.startupModels = new HashSet();
    }

    public static void init(ConfigManager configManager, WorkLoadManager wlm) {
        modelManager = new ModelManager(configManager, wlm);
    }

    public static ModelManager getInstance() {
        return modelManager;
    }

    public ScheduledExecutorService getScheduler() {
        return this.scheduler;
    }

    public ModelArchive registerModel(String url, String defaultModelName) throws ModelException, IOException, InterruptedException, DownloadArchiveException {
        return this.registerModel(url, null, null, null, -1 * RegisterModelRequest.DEFAULT_BATCH_SIZE, -1 * RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY, this.configManager.getDefaultResponseTimeout(), defaultModelName, false, false, false);
    }

    public void registerAndUpdateModel(String modelName, JsonObject modelInfo) throws ModelException, IOException, InterruptedException, DownloadArchiveException, WorkerInitializationException {
        boolean defaultVersion = modelInfo.get("defaultVersion").getAsBoolean();
        String url = modelInfo.get("marName").getAsString();
        ModelArchive archive = this.createModelArchive(modelName, url, null, null, modelName, false);
        Model tempModel = this.createModel(archive, modelInfo);
        String versionId = archive.getModelVersion();
        this.createVersionedModel(tempModel, versionId);
        this.setupModelVenv(tempModel);
        this.setupModelDependencies(tempModel);
        if (defaultVersion) {
            modelManager.setDefaultVersion(modelName, versionId);
        }
        logger.info("Model {} loaded.", (Object)tempModel.getModelName());
        this.updateModel(modelName, versionId, true);
    }

    public ModelArchive registerModel(String url, String modelName, Manifest.RuntimeType runtime, String handler, int batchSize, int maxBatchDelay, int responseTimeout, String defaultModelName, boolean ignoreDuplicate, boolean isWorkflowModel, boolean s3SseKms) throws ModelException, IOException, InterruptedException, DownloadArchiveException {
        Model tempModel;
        ModelArchive archive;
        block4: {
            if (isWorkflowModel && url == null) {
                Manifest manifest = new Manifest();
                manifest.getModel().setVersion("1.0");
                manifest.getModel().setModelVersion("1.0");
                manifest.getModel().setModelName(modelName);
                manifest.getModel().setHandler(new File(handler).getName());
                manifest.getModel().setEnvelope(this.configManager.getTsServiceEnvelope());
                File f = new File(handler.substring(0, handler.lastIndexOf(58)));
                archive = new ModelArchive(manifest, url, f.getParentFile(), true);
            } else {
                archive = this.createModelArchive(modelName, url, handler, runtime, defaultModelName, s3SseKms);
            }
            tempModel = this.createModel(archive, batchSize, maxBatchDelay, responseTimeout, isWorkflowModel);
            String versionId = archive.getModelVersion();
            try {
                this.createVersionedModel(tempModel, versionId);
            }
            catch (ConflictStatusException e) {
                if (ignoreDuplicate) break block4;
                throw e;
            }
        }
        this.setupModelVenv(tempModel);
        this.setupModelDependencies(tempModel);
        logger.info("Model {} loaded.", (Object)tempModel.getModelName());
        return archive;
    }

    private ModelArchive createModelArchive(String modelName, String url, String handler, Manifest.RuntimeType runtime, String defaultModelName, boolean s3SseKms) throws ModelException, IOException, DownloadArchiveException {
        ModelArchive archive = ModelArchive.downloadModel(this.configManager.getAllowedUrls(), this.configManager.getModelStore(), url, s3SseKms);
        Manifest.Model model = archive.getManifest().getModel();
        if (modelName == null || modelName.isEmpty()) {
            if (archive.getModelName() == null || archive.getModelName().isEmpty()) {
                model.setModelName(defaultModelName);
            }
        } else {
            model.setModelName(modelName);
        }
        if (runtime != null) {
            archive.getManifest().setRuntime(runtime);
        }
        if (handler != null) {
            model.setHandler(handler);
        } else if (archive.getHandler() == null || archive.getHandler().isEmpty()) {
            model.setHandler(this.configManager.getTsDefaultServiceHandler());
        }
        model.setEnvelope(this.configManager.getTsServiceEnvelope());
        if (model.getModelVersion() == null) {
            model.setModelVersion("1.0");
        }
        archive.validate();
        return archive;
    }

    private void setupModelVenv(Model model) throws IOException, InterruptedException, ModelException {
        String line;
        String[] envp;
        if (!model.isUseVenv()) {
            return;
        }
        File venvPath = EnvironmentUtils.getPythonVenvPath(model);
        ArrayList<String> commandParts = new ArrayList<String>();
        commandParts.add(this.configManager.getPythonExecutable());
        commandParts.add(Paths.get(this.configManager.getModelServerHome(), "ts", "utils", "setup_model_venv.py").toAbsolutePath().toString());
        commandParts.add(venvPath.toString());
        ProcessBuilder processBuilder = new ProcessBuilder(commandParts);
        if (!this.isValidDependencyPath(venvPath)) {
            throw new ModelException("Invalid python venv path for model " + model.getModelName() + ": " + venvPath.toString());
        }
        processBuilder.directory(venvPath.getParentFile());
        Map<String, String> environment = processBuilder.environment();
        for (String envVar : envp = EnvironmentUtils.getEnvString(this.configManager.getModelServerHome(), model.getModelDir().getAbsolutePath(), null)) {
            String[] parts = envVar.split("=", 2);
            if (parts.length != 2) continue;
            environment.put(parts[0], parts[1]);
        }
        processBuilder.redirectErrorStream(true);
        Process process = processBuilder.start();
        int exitCode = process.waitFor();
        StringBuilder outputString = new StringBuilder();
        BufferedReader brdr = new BufferedReader(new InputStreamReader(process.getInputStream()));
        while ((line = brdr.readLine()) != null) {
            outputString.append(line + "\n");
        }
        if (exitCode != 0) {
            logger.error("Virtual environment creation for model {} at {} failed:\n{}", model.getModelName(), venvPath.toString(), outputString.toString());
            throw new ModelException("Virtual environment creation failed for model " + model.getModelName());
        }
        logger.info("Created virtual environment for model {}: {}", (Object)model.getModelName(), (Object)venvPath.toString());
    }

    private void setupModelDependencies(Model model) throws IOException, InterruptedException, ModelException {
        String line;
        String requirementsFile = model.getModelArchive().getManifest().getModel().getRequirementsFile();
        if (!this.configManager.getInstallPyDepPerModel() || requirementsFile == null) {
            return;
        }
        String pythonRuntime = EnvironmentUtils.getPythonRunTime(model);
        Path requirementsFilePath = Paths.get(model.getModelDir().getAbsolutePath(), requirementsFile).toAbsolutePath();
        ArrayList<String> commandParts = new ArrayList<String>();
        ProcessBuilder processBuilder = new ProcessBuilder(new String[0]);
        if (model.isUseVenv()) {
            if (!this.isValidDependencyPath(Paths.get(pythonRuntime, new String[0]).toFile())) {
                throw new ModelException("Invalid python venv runtime path for model " + model.getModelName() + ": " + pythonRuntime);
            }
            processBuilder.directory(EnvironmentUtils.getPythonVenvPath(model).getParentFile());
            commandParts.add(pythonRuntime);
            commandParts.add("-m");
            commandParts.add("pip");
            commandParts.add("install");
            commandParts.add("-U");
            commandParts.add("--upgrade-strategy");
            commandParts.add("only-if-needed");
            commandParts.add("-r");
            commandParts.add(requirementsFilePath.toString());
        } else {
            File dependencyPath = model.getModelDir();
            if (Files.isSymbolicLink(dependencyPath.toPath())) {
                dependencyPath = dependencyPath.getParentFile();
            }
            if (!this.isValidDependencyPath(dependencyPath = dependencyPath.getAbsoluteFile())) {
                throw new ModelException("Invalid 3rd party package installation path " + dependencyPath.toString());
            }
            processBuilder.directory(dependencyPath);
            commandParts.add(pythonRuntime);
            commandParts.add("-m");
            commandParts.add("pip");
            commandParts.add("install");
            commandParts.add("-U");
            commandParts.add("-t");
            commandParts.add(dependencyPath.toString());
            commandParts.add("-r");
            commandParts.add(requirementsFilePath.toString());
        }
        processBuilder.command(commandParts);
        String[] envp = EnvironmentUtils.getEnvString(this.configManager.getModelServerHome(), model.getModelDir().getAbsolutePath(), null);
        Map<String, String> environment = processBuilder.environment();
        for (String envVar : envp) {
            String[] parts = envVar.split("=", 2);
            if (parts.length != 2) continue;
            environment.put(parts[0], parts[1]);
        }
        processBuilder.redirectErrorStream(true);
        Process process = processBuilder.start();
        int exitCode = process.waitFor();
        StringBuilder outputString = new StringBuilder();
        BufferedReader brdr = new BufferedReader(new InputStreamReader(process.getInputStream()));
        while ((line = brdr.readLine()) != null) {
            outputString.append(line + "\n");
        }
        if (exitCode != 0) {
            logger.error("Custom pip package installation failed for model {}:\n{}", (Object)model.getModelName(), (Object)outputString.toString());
            throw new ModelException("Custom pip package installation failed for model " + model.getModelName());
        }
        logger.info("Installed custom pip packages for model {}", (Object)model.getModelName());
    }

    private boolean isValidDependencyPath(File dependencyPath) {
        return dependencyPath.toPath().normalize().startsWith(FileUtils.getTempDirectory().toPath().normalize());
    }

    private Model createModel(ModelArchive archive, int batchSize, int maxBatchDelay, int responseTimeout, boolean isWorkflowModel) {
        int marResponseTimeout;
        Model model = new Model(archive, this.configManager.getJobQueueSize());
        if (batchSize == -1 * RegisterModelRequest.DEFAULT_BATCH_SIZE) {
            int marBatchSize;
            batchSize = archive.getModelConfig() != null ? ((marBatchSize = archive.getModelConfig().getBatchSize()) > 0 ? marBatchSize : this.configManager.getJsonIntValue(archive.getModelName(), archive.getModelVersion(), "batchSize", RegisterModelRequest.DEFAULT_BATCH_SIZE)) : this.configManager.getJsonIntValue(archive.getModelName(), archive.getModelVersion(), "batchSize", RegisterModelRequest.DEFAULT_BATCH_SIZE);
        }
        model.setBatchSize(batchSize);
        if (maxBatchDelay == -1 * RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY) {
            int marMaxBatchDelay;
            maxBatchDelay = archive.getModelConfig() != null ? ((marMaxBatchDelay = archive.getModelConfig().getMaxBatchDelay()) > 0 ? marMaxBatchDelay : this.configManager.getJsonIntValue(archive.getModelName(), archive.getModelVersion(), "maxBatchDelay", RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY)) : this.configManager.getJsonIntValue(archive.getModelName(), archive.getModelVersion(), "maxBatchDelay", RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY);
        }
        model.setMaxBatchDelay(maxBatchDelay);
        responseTimeout = archive.getModelConfig() != null ? ((marResponseTimeout = archive.getModelConfig().getResponseTimeout()) > 0 ? marResponseTimeout : this.configManager.getJsonIntValue(archive.getModelName(), archive.getModelVersion(), "responseTimeout", responseTimeout)) : this.configManager.getJsonIntValue(archive.getModelName(), archive.getModelVersion(), "responseTimeout", responseTimeout);
        model.setResponseTimeout(responseTimeout);
        model.setWorkflowModel(isWorkflowModel);
        model.setRuntimeType(this.configManager.getJsonRuntimeTypeValue(archive.getModelName(), archive.getModelVersion(), "runtimeType", archive.getManifest().getRuntime()));
        return model;
    }

    private Model createModel(ModelArchive archive, JsonObject modelInfo) {
        Model model = new Model(archive, this.configManager.getJobQueueSize());
        model.setModelState(modelInfo);
        model.setWorkflowModel(false);
        return model;
    }

    private void createVersionedModel(Model model, String versionId) throws ModelVersionNotFoundException, ConflictStatusException {
        ModelVersionedRefs modelVersionRef = this.modelsNameMap.get(model.getModelName());
        if (modelVersionRef == null) {
            modelVersionRef = new ModelVersionedRefs();
        }
        modelVersionRef.addVersionModel(model, versionId);
        this.modelsNameMap.putIfAbsent(model.getModelName(), modelVersionRef);
    }

    public int unregisterModel(String modelName, String versionId) {
        return this.unregisterModel(modelName, versionId, false);
    }

    public int unregisterModel(String modelName, String versionId, boolean isCleanUp) {
        int httpResponseStatus;
        ModelVersionedRefs vmodel = this.modelsNameMap.get(modelName);
        if (vmodel == null) {
            logger.warn("Model not found: " + modelName);
            return 404;
        }
        if (versionId == null) {
            versionId = vmodel.getDefaultVersion();
        }
        try {
            Model model = vmodel.removeVersionModel(versionId);
            model.setMinWorkers(0);
            model.setMaxWorkers(0);
            CompletableFuture<Integer> futureStatus = this.wlm.modelChanged(model, false, isCleanUp);
            httpResponseStatus = futureStatus.get();
            if (httpResponseStatus == 200) {
                model.getModelArchive().clean();
                this.startupModels.remove(modelName);
                logger.info("Model {} unregistered.", (Object)modelName);
            } else {
                if (versionId == null) {
                    versionId = vmodel.getDefaultVersion();
                }
                vmodel.addVersionModel(model, versionId);
            }
            if (vmodel.getAllVersions().size() == 0) {
                this.modelsNameMap.remove(modelName);
            }
            if (!isCleanUp && model.getModelUrl() != null) {
                ModelArchive.removeModel(this.configManager.getModelStore(), model.getModelUrl());
            }
        }
        catch (ModelVersionNotFoundException e) {
            logger.warn("Model {} version {} not found.", (Object)modelName, (Object)versionId);
            httpResponseStatus = 400;
        }
        catch (InvalidModelVersionException e) {
            logger.warn("Cannot remove default version {} for model {}", (Object)versionId, (Object)modelName);
            httpResponseStatus = 403;
        }
        catch (InterruptedException | ExecutionException e1) {
            logger.warn("Process was interrupted while cleaning resources.");
            httpResponseStatus = 500;
        }
        return httpResponseStatus;
    }

    public void setDefaultVersion(String modelName, String newModelVersion) throws ModelNotFoundException, ModelVersionNotFoundException {
        ModelVersionedRefs vmodel = this.modelsNameMap.get(modelName);
        if (vmodel == null) {
            logger.warn("Model not found: " + modelName);
            throw new ModelNotFoundException("Model not found: " + modelName);
        }
        vmodel.setDefaultVersion(newModelVersion);
    }

    private CompletableFuture<Integer> updateModel(String modelName, String versionId, boolean isStartup) throws ModelVersionNotFoundException, WorkerInitializationException {
        Model model = this.getVersionModel(modelName, versionId);
        return this.updateModel(modelName, versionId, model.getMinWorkers(), model.getMaxWorkers(), isStartup, false);
    }

    public CompletableFuture<Integer> updateModel(String modelName, String versionId, int minWorkers, int maxWorkers, boolean isStartup, boolean isCleanUp) throws ModelVersionNotFoundException, WorkerInitializationException {
        Model model = this.getVersionModel(modelName, versionId);
        if (model == null) {
            throw new ModelVersionNotFoundException("Model version: " + versionId + " does not exist for model: " + modelName);
        }
        if (model.getParallelLevel() > 0 && model.getDeviceType() == ModelConfig.DeviceType.GPU) {
            int capacity = model.getNumCores() / model.getParallelLevel();
            if (capacity == 0) {
                logger.error("there are no enough gpu devices to support this parallelLever: {}", (Object)model.getParallelLevel());
                throw new WorkerInitializationException("No enough gpu devices for model:" + modelName + " parallelLevel:" + model.getParallelLevel());
            }
            minWorkers = minWorkers > capacity ? capacity : minWorkers;
            maxWorkers = maxWorkers > capacity ? capacity : maxWorkers;
            logger.info("model {} set minWorkers: {}, maxWorkers: {} for parallelLevel: {} ", modelName, minWorkers, maxWorkers, model.getParallelLevel());
        }
        model.setMinWorkers(minWorkers);
        model.setMaxWorkers(maxWorkers);
        logger.debug("updateModel: {}, count: {}", (Object)modelName, (Object)minWorkers);
        return this.wlm.modelChanged(model, isStartup, isCleanUp);
    }

    private Model getVersionModel(String modelName, String versionId) {
        ModelVersionedRefs vmodel = this.modelsNameMap.get(modelName);
        if (vmodel == null) {
            throw new AssertionError((Object)("Model not found: " + modelName));
        }
        return vmodel.getVersionModel(versionId);
    }

    public CompletableFuture<Integer> updateModel(String modelName, String versionId, int minWorkers, int maxWorkers) throws ModelVersionNotFoundException, WorkerInitializationException {
        return this.updateModel(modelName, versionId, minWorkers, maxWorkers, false, false);
    }

    public Map<String, Model> getDefaultModels(boolean skipFuntions) {
        ConcurrentHashMap<String, Model> defModelsMap = new ConcurrentHashMap<String, Model>();
        for (String key : this.modelsNameMap.keySet()) {
            Model defaultModel;
            ModelVersionedRefs mvr = this.modelsNameMap.get(key);
            if (mvr == null || (defaultModel = mvr.getDefaultModel()) == null || skipFuntions && defaultModel.getModelUrl() == null) continue;
            defModelsMap.put(key, defaultModel);
        }
        return defModelsMap;
    }

    public Map<String, Model> getDefaultModels() {
        return this.getDefaultModels(false);
    }

    public List<WorkerThread> getWorkers(ModelVersionName modelVersionName) {
        return this.wlm.getWorkers(modelVersionName);
    }

    public Map<Integer, WorkerThread> getWorkers() {
        return this.wlm.getWorkers();
    }

    public boolean addJob(Job job) throws ModelNotFoundException, ModelVersionNotFoundException {
        String versionId;
        String modelName = job.getModelName();
        Model model = this.getModel(modelName, versionId = job.getModelVersion());
        if (model == null) {
            throw new ModelNotFoundException("Model not found: " + modelName);
        }
        if (this.wlm.hasNoWorker(model.getModelVersionName())) {
            return false;
        }
        return model.addJob(job);
    }

    public boolean scaleRequestStatus(String modelName, String versionId) {
        Model model = this.modelsNameMap.get(modelName).getVersionModel(versionId);
        int numWorkers = 0;
        if (model != null) {
            numWorkers = this.wlm.getNumRunningWorkers(model.getModelVersionName());
        }
        return model == null || model.getMinWorkers() <= numWorkers;
    }

    public void submitTask(Runnable runnable) {
        this.wlm.scheduleAsync(runnable);
    }

    public Set<String> getStartupModels() {
        return this.startupModels;
    }

    public Model getModel(String modelName, String versionId) throws ModelVersionNotFoundException {
        ModelVersionedRefs vmodel = this.modelsNameMap.get(modelName);
        if (vmodel == null) {
            return null;
        }
        Model model = vmodel.getVersionModel(versionId);
        if (model == null) {
            throw new ModelVersionNotFoundException("Model version: " + versionId + " does not exist for model: " + modelName);
        }
        return model;
    }

    public Set<Map.Entry<String, Model>> getAllModelVersions(String modelName) throws ModelNotFoundException {
        ModelVersionedRefs vmodel = this.modelsNameMap.get(modelName);
        if (vmodel == null) {
            throw new ModelNotFoundException("Model not found: " + modelName);
        }
        return vmodel.getAllVersions();
    }

    public Set<Map.Entry<String, ModelVersionedRefs>> getAllModels() {
        return this.modelsNameMap.entrySet();
    }

    public int getNumRunningWorkers(ModelVersionName modelVersionName) {
        return this.wlm.getNumRunningWorkers(modelVersionName);
    }

    public int getNumHealthyWorkers(ModelVersionName modelVersionName) {
        return this.wlm.getNumHealthyWorkers(modelVersionName);
    }
}

