/*
 * Decompiled with CFR 0.152.
 */
package com.amazonaws.ml.mms;

import com.amazonaws.ml.mms.ServerInitializer;
import com.amazonaws.ml.mms.archive.ModelArchive;
import com.amazonaws.ml.mms.archive.ModelException;
import com.amazonaws.ml.mms.metrics.MetricManager;
import com.amazonaws.ml.mms.util.ConfigManager;
import com.amazonaws.ml.mms.util.NettyUtils;
import com.amazonaws.ml.mms.util.ServerGroups;
import com.amazonaws.ml.mms.wlm.ModelManager;
import com.amazonaws.ml.mms.wlm.WorkLoadManager;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.ServerChannel;
import io.netty.handler.ssl.SslContext;
import io.netty.util.internal.logging.InternalLoggerFactory;
import io.netty.util.internal.logging.Slf4JLoggerFactory;
import java.io.File;
import java.io.IOException;
import java.net.URI;
import java.security.GeneralSecurityException;
import java.util.Arrays;
import java.util.InvalidPropertiesFormatException;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ModelServer {
    private Logger logger = LoggerFactory.getLogger(ModelServer.class);
    private ServerGroups serverGroups;
    private List<ChannelFuture> futures;
    private AtomicBoolean stopped = new AtomicBoolean(false);
    private ConfigManager configManager;

    public ModelServer(ConfigManager configManager) {
        this.configManager = configManager;
        this.serverGroups = new ServerGroups(configManager);
    }

    public static void main(String[] args) throws InterruptedException, IOException, GeneralSecurityException {
        Options options = ConfigManager.Arguments.getOptions();
        try {
            DefaultParser parser = new DefaultParser();
            CommandLine cmd = parser.parse(options, args, null, false);
            ConfigManager.Arguments arguments = new ConfigManager.Arguments(cmd);
            ConfigManager configManager = new ConfigManager(arguments);
            InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE);
            new ModelServer(configManager).startAndWait();
        }
        catch (IllegalArgumentException e) {
            System.out.println("Invalid configuration: " + e.getMessage());
        }
        catch (ParseException e) {
            HelpFormatter formatter = new HelpFormatter();
            formatter.setLeftPadding(1);
            formatter.setWidth(120);
            formatter.printHelp(e.getMessage(), options);
            System.exit(1);
        }
    }

    public void startAndWait() throws InterruptedException, IOException, GeneralSecurityException {
        try {
            List<ChannelFuture> channelFutures = this.start();
            MetricManager.scheduleMetrics(this.configManager);
            System.out.println("Model server started.");
            channelFutures.get(0).sync();
        }
        catch (InvalidPropertiesFormatException e) {
            this.logger.error("Invalid configuration", e);
        }
        finally {
            this.serverGroups.shutdown(true);
            this.logger.info("Model server stopped.");
        }
        Runtime.getRuntime().halt(-1);
    }

    private void initModelStore() {
        String[] models;
        WorkLoadManager wlm = new WorkLoadManager(this.configManager, this.serverGroups.getBackendGroup());
        ModelManager.init(this.configManager, wlm);
        String loadModels = this.configManager.getLoadModels();
        if (loadModels == null || loadModels.isEmpty()) {
            return;
        }
        ModelManager modelManager = ModelManager.getInstance();
        int workers = this.configManager.getNumberOfGpu();
        if (workers == 0) {
            workers = Runtime.getRuntime().availableProcessors();
        }
        if ("ALL".equalsIgnoreCase(loadModels)) {
            String modelStore = this.configManager.getModelStore();
            if (modelStore == null) {
                this.logger.warn("Model store is not configured.");
                return;
            }
            File modelStoreDir = new File(modelStore);
            if (!modelStoreDir.exists()) {
                this.logger.warn("Model store path is not found: {}", (Object)modelStore);
                return;
            }
            File[] files = modelStoreDir.listFiles();
            if (files != null) {
                for (File file : files) {
                    if (file.isHidden()) continue;
                    String fileName = file.getName();
                    if (file.isFile() && !fileName.endsWith(".mar") && !fileName.endsWith(".model")) continue;
                    try {
                        this.logger.debug("Loading models from model store: {}", (Object)file.getName());
                        ModelArchive archive = modelManager.registerModel(file.getName());
                        modelManager.updateModel(archive.getModelName(), workers, workers);
                    }
                    catch (ModelException | IOException e) {
                        this.logger.warn("Failed to load model: " + file.getAbsolutePath(), e);
                    }
                }
            }
            return;
        }
        for (String model : models = loadModels.split(",")) {
            String url;
            String[] pair = model.split("=", 2);
            String modelName = null;
            if (pair.length == 1) {
                url = pair[0];
            } else {
                modelName = pair[0];
                url = pair[1];
            }
            if (url.isEmpty()) continue;
            try {
                this.logger.info("Loading initial models: {}", (Object)url);
                ModelArchive archive = modelManager.registerModel(url, modelName, null, null, 1, 100);
                modelManager.updateModel(archive.getModelName(), workers, workers);
            }
            catch (ModelException | IOException e) {
                this.logger.warn("Failed to load model: " + url, e);
            }
        }
    }

    public ChannelFuture initializeServer(URI address, boolean management, EventLoopGroup serverGroup, EventLoopGroup workerGroup, Class<? extends ServerChannel> channelClass) throws InterruptedException, IOException, GeneralSecurityException {
        String purpose = management ? "Management" : "Inference";
        ServerBootstrap b = new ServerBootstrap();
        ((ServerBootstrap)((ServerBootstrap)b.option(ChannelOption.SO_BACKLOG, 1024)).channel(channelClass)).childOption(ChannelOption.SO_LINGER, 0).childOption(ChannelOption.SO_REUSEADDR, true).childOption(ChannelOption.SO_KEEPALIVE, true);
        b.group(serverGroup, workerGroup);
        SslContext sslCtx = null;
        if ("https".equalsIgnoreCase(address.getScheme())) {
            sslCtx = this.configManager.getSslContext();
        }
        b.childHandler(new ServerInitializer(sslCtx, management));
        ChannelFuture future = b.bind(address.getHost(), address.getPort()).sync();
        future.addListener(f -> {
            if (!f.isSuccess()) {
                try {
                    f.get();
                }
                catch (InterruptedException | ExecutionException e) {
                    this.logger.error("", e);
                }
                System.exit(-1);
            }
            this.serverGroups.registerChannel(f.channel());
        });
        future.sync();
        ChannelFuture f2 = future.channel().closeFuture();
        f2.addListener(listener -> this.logger.info("{} model server stopped.", (Object)purpose));
        this.logger.info("{} API listening on port: {}", (Object)purpose, (Object)address.getPort());
        return f2;
    }

    public List<ChannelFuture> start() throws InterruptedException, IOException, GeneralSecurityException {
        this.stopped.set(false);
        this.configManager.validateConfigurations();
        this.logger.info(this.configManager.dumpConfigurations());
        this.initModelStore();
        URI inferenceAddress = this.configManager.getInferenceAddress();
        URI managementAddress = this.configManager.getManagementAddress();
        if (inferenceAddress.getPort() == managementAddress.getPort()) {
            throw new IllegalArgumentException("Inference port must differ from the management port");
        }
        EventLoopGroup serverGroup = this.serverGroups.getServerGroup();
        EventLoopGroup workerGroup = this.serverGroups.getChildGroup();
        Class<? extends ServerChannel> channelClass = NettyUtils.getServerChannel();
        this.logger.info("Initialize servers with: {}.", (Object)channelClass.getSimpleName());
        this.futures = Arrays.asList(this.initializeServer(inferenceAddress, false, serverGroup, workerGroup, channelClass), this.initializeServer(managementAddress, true, serverGroup, workerGroup, channelClass));
        return this.futures;
    }

    public boolean isRunning() {
        return !this.stopped.get();
    }

    public void stop() {
        if (this.stopped.get()) {
            return;
        }
        this.stopped.set(true);
        for (ChannelFuture future : this.futures) {
            future.channel().close();
        }
        this.serverGroups.shutdown(true);
        this.serverGroups.init();
    }
}

