/*
 * Decompiled with CFR 0.152.
 */
package org.pytorch.serve;

import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.ServerInterceptors;
import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.FixedRecvByteBufAllocator;
import io.netty.channel.ServerChannel;
import io.netty.handler.ssl.SslContext;
import io.netty.util.internal.logging.InternalLoggerFactory;
import io.netty.util.internal.logging.Slf4JLoggerFactory;
import java.io.File;
import java.io.IOException;
import java.lang.annotation.Annotation;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.InvalidPropertiesFormatException;
import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.pytorch.serve.ServerInitializer;
import org.pytorch.serve.archive.DownloadArchiveException;
import org.pytorch.serve.archive.model.ModelArchive;
import org.pytorch.serve.archive.model.ModelException;
import org.pytorch.serve.archive.model.ModelNotFoundException;
import org.pytorch.serve.grpcimpl.GRPCInterceptor;
import org.pytorch.serve.grpcimpl.GRPCServiceFactory;
import org.pytorch.serve.http.messages.RegisterModelRequest;
import org.pytorch.serve.metrics.MetricCache;
import org.pytorch.serve.metrics.MetricManager;
import org.pytorch.serve.servingsdk.ModelServerEndpoint;
import org.pytorch.serve.servingsdk.annotations.Endpoint;
import org.pytorch.serve.servingsdk.annotations.helpers.EndpointTypes;
import org.pytorch.serve.servingsdk.impl.PluginsManager;
import org.pytorch.serve.snapshot.InvalidSnapshotException;
import org.pytorch.serve.snapshot.SnapshotManager;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.Connector;
import org.pytorch.serve.util.ConnectorType;
import org.pytorch.serve.util.ServerGroups;
import org.pytorch.serve.wlm.Model;
import org.pytorch.serve.wlm.ModelManager;
import org.pytorch.serve.wlm.WorkLoadManager;
import org.pytorch.serve.wlm.WorkerInitializationException;
import org.pytorch.serve.workflow.WorkflowManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ModelServer {
    private Logger logger = LoggerFactory.getLogger(ModelServer.class);
    private ServerGroups serverGroups;
    private Server inferencegRPCServer;
    private Server managementgRPCServer;
    private Server OIPgRPCServer;
    private List<ChannelFuture> futures = new ArrayList<ChannelFuture>(2);
    private AtomicBoolean stopped = new AtomicBoolean(false);
    private ConfigManager configManager;
    public static final int MAX_RCVBUF_SIZE = 4096;

    public ModelServer(ConfigManager configManager) {
        this.configManager = configManager;
        this.serverGroups = new ServerGroups(configManager);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static void main(String[] args) {
        Options options = ConfigManager.Arguments.getOptions();
        try {
            DefaultParser parser = new DefaultParser();
            CommandLine cmd = parser.parse(options, args, null, false);
            ConfigManager.Arguments arguments = new ConfigManager.Arguments(cmd);
            ConfigManager.init(arguments);
            ConfigManager configManager = ConfigManager.getInstance();
            PluginsManager.getInstance().initialize();
            MetricCache.init();
            InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE);
            final ModelServer modelServer = new ModelServer(configManager);
            Runtime.getRuntime().addShutdownHook(new Thread(){

                @Override
                public void run() {
                    modelServer.stop();
                }
            });
            modelServer.startAndWait();
        }
        catch (IllegalArgumentException e) {
            System.out.println("Invalid configuration: " + e.getMessage());
        }
        catch (ParseException e) {
            HelpFormatter formatter = new HelpFormatter();
            formatter.setLeftPadding(1);
            formatter.setWidth(120);
            formatter.printHelp(e.getMessage(), options);
        }
        catch (Throwable t) {
            t.printStackTrace();
        }
        finally {
            System.exit(1);
        }
    }

    public void startAndWait() throws InterruptedException, IOException, GeneralSecurityException, InvalidSnapshotException {
        try {
            List<ChannelFuture> channelFutures = this.startRESTserver();
            this.startGRPCServers();
            if (!this.configManager.isSystemMetricsDisabled()) {
                MetricManager.scheduleMetrics(this.configManager);
            }
            System.out.println("Model server started.");
            channelFutures.get(0).sync();
        }
        catch (InvalidPropertiesFormatException e) {
            this.logger.error("Invalid configuration", e);
        }
        finally {
            this.serverGroups.shutdown(true);
            this.logger.info("Torchserve stopped.");
        }
    }

    private String getDefaultModelName(String name) {
        if (name.contains(".model") || name.contains(".mar")) {
            return name.substring(name.lastIndexOf(47) + 1, name.lastIndexOf(46)).replaceAll("(\\W|^_)", "_");
        }
        return name.substring(name.lastIndexOf(47) + 1).replaceAll("(\\W|^_)", "_");
    }

    private void initModelStore() throws InvalidSnapshotException, IOException {
        String[] models;
        WorkLoadManager wlm = new WorkLoadManager(this.configManager, this.serverGroups.getBackendGroup());
        ModelManager.init(this.configManager, wlm);
        WorkflowManager.init(this.configManager);
        SnapshotManager.init(this.configManager);
        Set<String> startupModels = ModelManager.getInstance().getStartupModels();
        String modelSnapshot = this.configManager.getModelSnapshot();
        if (modelSnapshot != null) {
            SnapshotManager.getInstance().restore(modelSnapshot);
            return;
        }
        String loadModels = this.configManager.getLoadModels();
        if (loadModels == null || loadModels.isEmpty()) {
            return;
        }
        ModelManager modelManager = ModelManager.getInstance();
        int workers = this.configManager.getDefaultWorkers();
        if ("ALL".equalsIgnoreCase(loadModels)) {
            String modelStore = this.configManager.getModelStore();
            if (modelStore == null) {
                this.logger.warn("Model store is not configured.");
                return;
            }
            File modelStoreDir = new File(modelStore);
            if (!modelStoreDir.exists()) {
                this.logger.warn("Model store path is not found: {}", (Object)modelStore);
                return;
            }
            File[] files = modelStoreDir.listFiles();
            if (files != null) {
                for (File file : files) {
                    if (file.isHidden()) continue;
                    String fileName = file.getName();
                    if (file.isFile() && !fileName.endsWith(".mar") && !fileName.endsWith(".model")) continue;
                    try {
                        this.logger.debug("Loading models from model store: {}", (Object)file.getName());
                        String defaultModelName = this.getDefaultModelName(fileName);
                        ModelArchive archive = modelManager.registerModel(file.getName(), defaultModelName);
                        int minWorkers = this.configManager.getJsonIntValue(archive.getModelName(), archive.getModelVersion(), "minWorkers", workers);
                        int maxWorkers = this.configManager.getJsonIntValue(archive.getModelName(), archive.getModelVersion(), "maxWorkers", workers);
                        if (archive.getModelConfig() != null) {
                            int marMinWorkers = archive.getModelConfig().getMinWorkers();
                            int marMaxWorkers = archive.getModelConfig().getMaxWorkers();
                            if (marMinWorkers > 0 && marMaxWorkers >= marMinWorkers) {
                                minWorkers = marMinWorkers;
                                maxWorkers = marMaxWorkers;
                            }
                        }
                        modelManager.updateModel(archive.getModelName(), archive.getModelVersion(), minWorkers, maxWorkers, true, false);
                        startupModels.add(archive.getModelName());
                    }
                    catch (IOException | InterruptedException | DownloadArchiveException | ModelException | WorkerInitializationException e) {
                        this.logger.warn("Failed to load model: " + file.getAbsolutePath(), e);
                    }
                }
            }
            return;
        }
        for (String model : models = loadModels.split(",")) {
            String url;
            String[] pair = model.split("=", 2);
            String modelName = null;
            if (pair.length == 1) {
                url = pair[0];
            } else {
                modelName = pair[0];
                url = pair[1];
            }
            if (url.isEmpty()) continue;
            try {
                this.logger.info("Loading initial models: {}", (Object)url);
                String defaultModelName = this.getDefaultModelName(url);
                ModelArchive archive = modelManager.registerModel(url, modelName, null, null, -1 * RegisterModelRequest.DEFAULT_BATCH_SIZE, -1 * RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY, this.configManager.getDefaultResponseTimeout(), defaultModelName, false, false, false);
                int minWorkers = this.configManager.getJsonIntValue(archive.getModelName(), archive.getModelVersion(), "minWorkers", workers);
                int maxWorkers = this.configManager.getJsonIntValue(archive.getModelName(), archive.getModelVersion(), "maxWorkers", workers);
                if (archive.getModelConfig() != null) {
                    int marMinWorkers = archive.getModelConfig().getMinWorkers();
                    int marMaxWorkers = archive.getModelConfig().getMaxWorkers();
                    if (marMinWorkers > 0 && marMaxWorkers >= marMinWorkers) {
                        minWorkers = marMinWorkers;
                        maxWorkers = marMaxWorkers;
                    } else {
                        this.logger.warn("Invalid model config in mar, minWorkers:{}, maxWorkers:{}", (Object)marMinWorkers, (Object)marMaxWorkers);
                    }
                }
                modelManager.updateModel(archive.getModelName(), archive.getModelVersion(), minWorkers, maxWorkers, true, false);
                startupModels.add(archive.getModelName());
            }
            catch (IOException | InterruptedException | DownloadArchiveException | ModelException | WorkerInitializationException e) {
                this.logger.warn("Failed to load model: " + url, e);
            }
        }
    }

    public ChannelFuture initializeServer(Connector connector, EventLoopGroup serverGroup, EventLoopGroup workerGroup, ConnectorType type) throws InterruptedException, IOException, GeneralSecurityException {
        ChannelFuture future;
        String purpose = connector.getPurpose();
        Class<? extends ServerChannel> channelClass = connector.getServerChannel();
        this.logger.info("Initialize {} server with: {}.", (Object)purpose, (Object)channelClass.getSimpleName());
        ServerBootstrap b = new ServerBootstrap();
        ((ServerBootstrap)((ServerBootstrap)b.option(ChannelOption.SO_BACKLOG, 1024)).channel(channelClass)).childOption(ChannelOption.SO_LINGER, 0).childOption(ChannelOption.SO_REUSEADDR, true).childOption(ChannelOption.SO_KEEPALIVE, true).childOption(ChannelOption.RCVBUF_ALLOCATOR, new FixedRecvByteBufAllocator(4096));
        b.group(serverGroup, workerGroup);
        SslContext sslCtx = null;
        if (connector.isSsl()) {
            sslCtx = this.configManager.getSslContext();
        }
        b.childHandler(new ServerInitializer(sslCtx, type));
        try {
            future = b.bind(connector.getSocketAddress()).sync();
        }
        catch (Exception e) {
            if (e instanceof IOException) {
                throw new IOException("Failed to bind to address: " + connector, e);
            }
            throw e;
        }
        future.addListener(f -> {
            if (!f.isSuccess()) {
                try {
                    f.get();
                }
                catch (InterruptedException | ExecutionException e) {
                    this.logger.error("", e);
                }
                System.exit(-1);
            }
            this.serverGroups.registerChannel(f.channel());
        });
        future.sync();
        ChannelFuture f2 = future.channel().closeFuture();
        f2.addListener(listener -> this.logger.info("{} model server stopped.", (Object)purpose));
        this.logger.info("{} API bind to: {}", (Object)purpose, (Object)connector);
        return f2;
    }

    public List<ChannelFuture> startRESTserver() throws InterruptedException, IOException, GeneralSecurityException, InvalidSnapshotException {
        this.stopped.set(false);
        this.configManager.validateConfigurations();
        this.logger.info(this.configManager.dumpConfigurations());
        this.initModelStore();
        Connector inferenceConnector = this.configManager.getListener(ConnectorType.INFERENCE_CONNECTOR);
        Connector managementConnector = this.configManager.getListener(ConnectorType.MANAGEMENT_CONNECTOR);
        inferenceConnector.clean();
        managementConnector.clean();
        EventLoopGroup serverGroup = this.serverGroups.getServerGroup();
        EventLoopGroup workerGroup = this.serverGroups.getChildGroup();
        this.futures.clear();
        if (!inferenceConnector.equals(managementConnector)) {
            this.futures.add(this.initializeServer(inferenceConnector, serverGroup, workerGroup, ConnectorType.INFERENCE_CONNECTOR));
            this.futures.add(this.initializeServer(managementConnector, serverGroup, workerGroup, ConnectorType.MANAGEMENT_CONNECTOR));
        } else {
            this.futures.add(this.initializeServer(inferenceConnector, serverGroup, workerGroup, ConnectorType.ALL));
        }
        if (this.configManager.isMetricApiEnable()) {
            EventLoopGroup metricsGroup = this.serverGroups.getMetricsGroup();
            Connector metricsConnector = this.configManager.getListener(ConnectorType.METRICS_CONNECTOR);
            metricsConnector.clean();
            this.futures.add(this.initializeServer(metricsConnector, serverGroup, metricsGroup, ConnectorType.METRICS_CONNECTOR));
        }
        SnapshotManager.getInstance().saveStartupSnapshot();
        return this.futures;
    }

    public void startGRPCServers() throws IOException {
        this.inferencegRPCServer = this.startGRPCServer(ConnectorType.INFERENCE_CONNECTOR);
        this.managementgRPCServer = this.startGRPCServer(ConnectorType.MANAGEMENT_CONNECTOR);
    }

    private Server startGRPCServer(ConnectorType connectorType) throws IOException {
        ServerBuilder s2 = NettyServerBuilder.forPort(this.configManager.getGRPCPort(connectorType)).maxInboundMessageSize(this.configManager.getMaxRequestSize()).addService(ServerInterceptors.intercept(GRPCServiceFactory.getgRPCService(connectorType), new GRPCInterceptor()));
        if (connectorType == ConnectorType.INFERENCE_CONNECTOR && ConfigManager.getInstance().isOpenInferenceProtocol()) {
            ((ServerBuilder)s2.maxInboundMessageSize(this.configManager.getMaxRequestSize())).addService(ServerInterceptors.intercept(GRPCServiceFactory.getgRPCService(ConnectorType.OPEN_INFERENCE_CONNECTOR), new GRPCInterceptor()));
        }
        if (this.configManager.isGRPCSSLEnabled()) {
            s2.useTransportSecurity(new File(this.configManager.getCertificateFile()), new File(this.configManager.getPrivateKeyFile()));
        }
        Server server = s2.build();
        server.start();
        return server;
    }

    private boolean validEndpoint(Annotation a, EndpointTypes type) {
        return a instanceof Endpoint && !((Endpoint)a).urlPattern().isEmpty() && ((Endpoint)a).endpointType().equals((Object)type);
    }

    private HashMap<String, ModelServerEndpoint> registerEndpoints(EndpointTypes type) {
        ServiceLoader<ModelServerEndpoint> loader = ServiceLoader.load(ModelServerEndpoint.class);
        HashMap<String, ModelServerEndpoint> ep = new HashMap<String, ModelServerEndpoint>();
        for (ModelServerEndpoint mep : loader) {
            Annotation[] annotations;
            Class<?> modelServerEndpointClassObj = mep.getClass();
            for (Annotation a : annotations = modelServerEndpointClassObj.getAnnotations()) {
                if (!this.validEndpoint(a, type)) continue;
                ep.put(((Endpoint)a).urlPattern(), mep);
            }
        }
        return ep;
    }

    public boolean isRunning() {
        return !this.stopped.get();
    }

    private void stopgRPCServer(Server server) {
        if (server != null) {
            try {
                server.shutdown().awaitTermination();
            }
            catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }

    private void exitModelStore() throws ModelNotFoundException {
        ModelManager modelMgr = ModelManager.getInstance();
        Map<String, Model> defModels = modelMgr.getDefaultModels();
        for (Map.Entry<String, Model> m3 : defModels.entrySet()) {
            Set<Map.Entry<String, Model>> versionModels = modelMgr.getAllModelVersions(m3.getKey());
            String defaultVersionId = m3.getValue().getVersion();
            for (Map.Entry<String, Model> versionedModel : versionModels) {
                if (defaultVersionId.equals(versionedModel.getKey())) continue;
                this.logger.info("Unregistering model {} version {}", (Object)versionedModel.getValue().getModelName(), (Object)versionedModel.getKey());
                modelMgr.unregisterModel(versionedModel.getValue().getModelName(), versionedModel.getKey(), true);
            }
            this.logger.info("Unregistering model {} version {}", (Object)m3.getValue().getModelName(), (Object)defaultVersionId);
            modelMgr.unregisterModel(m3.getValue().getModelName(), defaultVersionId, true);
        }
    }

    public void stop() {
        if (this.stopped.get()) {
            return;
        }
        this.stopped.set(true);
        this.stopgRPCServer(this.inferencegRPCServer);
        this.stopgRPCServer(this.managementgRPCServer);
        for (ChannelFuture future : this.futures) {
            try {
                future.channel().close().sync();
            }
            catch (InterruptedException ignore) {
                ignore.printStackTrace();
            }
        }
        SnapshotManager.getInstance().saveShutdownSnapshot();
        this.serverGroups.shutdown(true);
        this.serverGroups.init();
        try {
            this.exitModelStore();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }
}

