/*
 * Decompiled with CFR 0.152.
 */
package com.amazonaws.ml.mms;

import com.amazonaws.ml.mms.ServerInitializer;
import com.amazonaws.ml.mms.archive.ModelArchive;
import com.amazonaws.ml.mms.archive.ModelException;
import com.amazonaws.ml.mms.metrics.MetricManager;
import com.amazonaws.ml.mms.servingsdk.impl.PluginsManager;
import com.amazonaws.ml.mms.util.ConfigManager;
import com.amazonaws.ml.mms.util.Connector;
import com.amazonaws.ml.mms.util.ConnectorType;
import com.amazonaws.ml.mms.util.ServerGroups;
import com.amazonaws.ml.mms.wlm.ModelManager;
import com.amazonaws.ml.mms.wlm.WorkLoadManager;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.ServerChannel;
import io.netty.handler.ssl.SslContext;
import io.netty.util.internal.logging.InternalLoggerFactory;
import io.netty.util.internal.logging.Slf4JLoggerFactory;
import java.io.File;
import java.io.IOException;
import java.lang.annotation.Annotation;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.InvalidPropertiesFormatException;
import java.util.List;
import java.util.ServiceLoader;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.ai.mms.servingsdk.ModelServerEndpoint;
import software.amazon.ai.mms.servingsdk.annotations.Endpoint;
import software.amazon.ai.mms.servingsdk.annotations.helpers.EndpointTypes;

public class ModelServer {
    private Logger logger = LoggerFactory.getLogger(ModelServer.class);
    private ServerGroups serverGroups;
    private List<ChannelFuture> futures = new ArrayList<ChannelFuture>(2);
    private AtomicBoolean stopped = new AtomicBoolean(false);
    private ConfigManager configManager;

    public ModelServer(ConfigManager configManager) {
        this.configManager = configManager;
        this.serverGroups = new ServerGroups(configManager);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static void main(String[] args) {
        Options options = ConfigManager.Arguments.getOptions();
        try {
            DefaultParser parser = new DefaultParser();
            CommandLine cmd = parser.parse(options, args, null, false);
            ConfigManager.Arguments arguments = new ConfigManager.Arguments(cmd);
            ConfigManager.init(arguments);
            ConfigManager configManager = ConfigManager.getInstance();
            PluginsManager.getInstance().initialize();
            InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE);
            new ModelServer(configManager).startAndWait();
        }
        catch (IllegalArgumentException e) {
            System.out.println("Invalid configuration: " + e.getMessage());
        }
        catch (ParseException e) {
            HelpFormatter formatter = new HelpFormatter();
            formatter.setLeftPadding(1);
            formatter.setWidth(120);
            formatter.printHelp(e.getMessage(), options);
        }
        catch (Throwable t) {
            t.printStackTrace();
        }
        finally {
            System.exit(1);
        }
    }

    public void startAndWait() throws InterruptedException, IOException, GeneralSecurityException {
        try {
            List<ChannelFuture> channelFutures = this.start();
            MetricManager.scheduleMetrics(this.configManager);
            System.out.println("Model server started.");
            channelFutures.get(0).sync();
        }
        catch (InvalidPropertiesFormatException e) {
            this.logger.error("Invalid configuration", e);
        }
        finally {
            this.serverGroups.shutdown(true);
            this.logger.info("Model server stopped.");
        }
    }

    private String getDefaultModelName(String name) {
        if (name.contains(".model") || name.contains(".mar")) {
            return name.substring(name.lastIndexOf(47) + 1, name.lastIndexOf(46)).replaceAll("(\\W|^_)", "_");
        }
        return name.substring(name.lastIndexOf(47) + 1).replaceAll("(\\W|^_)", "_");
    }

    private void initModelStore() {
        String[] models;
        WorkLoadManager wlm = new WorkLoadManager(this.configManager, this.serverGroups.getBackendGroup());
        ModelManager.init(this.configManager, wlm);
        Set<String> startupModels = ModelManager.getInstance().getStartupModels();
        String loadModels = this.configManager.getLoadModels();
        if (loadModels == null || loadModels.isEmpty()) {
            return;
        }
        ModelManager modelManager = ModelManager.getInstance();
        int workers = this.configManager.getDefaultWorkers();
        if ("ALL".equalsIgnoreCase(loadModels)) {
            String modelStore = this.configManager.getModelStore();
            if (modelStore == null) {
                this.logger.warn("Model store is not configured.");
                return;
            }
            File modelStoreDir = new File(modelStore);
            if (!modelStoreDir.exists()) {
                this.logger.warn("Model store path is not found: {}", (Object)modelStore);
                return;
            }
            File[] files = modelStoreDir.listFiles();
            if (files != null) {
                for (File file : files) {
                    if (file.isHidden()) continue;
                    String fileName = file.getName();
                    if (file.isFile() && !fileName.endsWith(".mar") && !fileName.endsWith(".model")) continue;
                    try {
                        this.logger.debug("Loading models from model store: {}", (Object)file.getName());
                        String defaultModelName = this.getDefaultModelName(fileName);
                        ModelArchive archive = modelManager.registerModel(file.getName(), defaultModelName);
                        modelManager.updateModel(archive.getModelName(), workers, workers);
                        startupModels.add(archive.getModelName());
                    }
                    catch (ModelException | IOException e) {
                        this.logger.warn("Failed to load model: " + file.getAbsolutePath(), e);
                    }
                }
            }
            return;
        }
        for (String model : models = loadModels.split(",")) {
            String url;
            String[] pair = model.split("=", 2);
            String modelName = null;
            if (pair.length == 1) {
                url = pair[0];
            } else {
                modelName = pair[0];
                url = pair[1];
            }
            if (url.isEmpty()) continue;
            try {
                this.logger.info("Loading initial models: {}", (Object)url);
                String defaultModelName = this.getDefaultModelName(url);
                ModelArchive archive = modelManager.registerModel(url, modelName, null, null, 1, 100, this.configManager.getDefaultResponseTimeout(), defaultModelName);
                modelManager.updateModel(archive.getModelName(), workers, workers);
                startupModels.add(archive.getModelName());
            }
            catch (ModelException | IOException e) {
                this.logger.warn("Failed to load model: " + url, e);
            }
        }
    }

    public ChannelFuture initializeServer(Connector connector, EventLoopGroup serverGroup, EventLoopGroup workerGroup, ConnectorType type) throws InterruptedException, IOException, GeneralSecurityException {
        ChannelFuture future;
        String purpose = connector.getPurpose();
        Class<? extends ServerChannel> channelClass = connector.getServerChannel();
        this.logger.info("Initialize {} server with: {}.", (Object)purpose, (Object)channelClass.getSimpleName());
        ServerBootstrap b = new ServerBootstrap();
        ((ServerBootstrap)((ServerBootstrap)b.option(ChannelOption.SO_BACKLOG, 1024)).channel(channelClass)).childOption(ChannelOption.SO_LINGER, 0).childOption(ChannelOption.SO_REUSEADDR, true).childOption(ChannelOption.SO_KEEPALIVE, true);
        b.group(serverGroup, workerGroup);
        SslContext sslCtx = null;
        if (connector.isSsl()) {
            sslCtx = this.configManager.getSslContext();
        }
        b.childHandler(new ServerInitializer(sslCtx, type));
        try {
            future = b.bind(connector.getSocketAddress()).sync();
        }
        catch (Exception e) {
            if (e instanceof IOException) {
                throw new IOException("Failed to bind to address: " + connector, e);
            }
            throw e;
        }
        future.addListener(f -> {
            if (!f.isSuccess()) {
                try {
                    f.get();
                }
                catch (InterruptedException | ExecutionException e) {
                    this.logger.error("", e);
                }
                System.exit(-1);
            }
            this.serverGroups.registerChannel(f.channel());
        });
        future.sync();
        ChannelFuture f2 = future.channel().closeFuture();
        f2.addListener(listener -> this.logger.info("{} model server stopped.", (Object)purpose));
        this.logger.info("{} API bind to: {}", (Object)purpose, (Object)connector);
        return f2;
    }

    public List<ChannelFuture> start() throws InterruptedException, IOException, GeneralSecurityException {
        this.stopped.set(false);
        this.configManager.validateConfigurations();
        this.logger.info(this.configManager.dumpConfigurations());
        this.initModelStore();
        Connector inferenceConnector = this.configManager.getListener(false);
        Connector managementConnector = this.configManager.getListener(true);
        inferenceConnector.clean();
        managementConnector.clean();
        EventLoopGroup serverGroup = this.serverGroups.getServerGroup();
        EventLoopGroup workerGroup = this.serverGroups.getChildGroup();
        this.futures.clear();
        if (!inferenceConnector.equals(managementConnector)) {
            this.futures.add(this.initializeServer(inferenceConnector, serverGroup, workerGroup, ConnectorType.INFERENCE_CONNECTOR));
            this.futures.add(this.initializeServer(managementConnector, serverGroup, workerGroup, ConnectorType.MANAGEMENT_CONNECTOR));
        } else {
            this.futures.add(this.initializeServer(inferenceConnector, serverGroup, workerGroup, ConnectorType.BOTH));
        }
        return this.futures;
    }

    private boolean validEndpoint(Annotation a, EndpointTypes type) {
        return a instanceof Endpoint && !((Endpoint)a).urlPattern().isEmpty() && ((Endpoint)a).endpointType().equals((Object)type);
    }

    private HashMap<String, ModelServerEndpoint> registerEndpoints(EndpointTypes type) {
        ServiceLoader<ModelServerEndpoint> loader = ServiceLoader.load(ModelServerEndpoint.class);
        HashMap<String, ModelServerEndpoint> ep = new HashMap<String, ModelServerEndpoint>();
        for (ModelServerEndpoint mep : loader) {
            Annotation[] annotations;
            Class<?> modelServerEndpointClassObj = mep.getClass();
            for (Annotation a : annotations = modelServerEndpointClassObj.getAnnotations()) {
                if (!this.validEndpoint(a, type)) continue;
                ep.put(((Endpoint)a).urlPattern(), mep);
            }
        }
        return ep;
    }

    public boolean isRunning() {
        return !this.stopped.get();
    }

    public void stop() {
        if (this.stopped.get()) {
            return;
        }
        this.stopped.set(true);
        for (ChannelFuture future : this.futures) {
            future.channel().close();
        }
        this.serverGroups.shutdown(true);
        this.serverGroups.init();
    }
}

