/*
 * Decompiled with CFR 0.152.
 */
package org.pytorch.serve.wlm;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Scanner;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.pytorch.serve.archive.model.ModelConfig;
import org.pytorch.serve.metrics.Metric;
import org.pytorch.serve.metrics.MetricCache;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.Connector;
import org.pytorch.serve.util.messages.EnvironmentUtils;
import org.pytorch.serve.wlm.Model;
import org.pytorch.serve.wlm.ModelManager;
import org.pytorch.serve.wlm.WorkerInitializationException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class WorkerLifeCycle {
    private static final Logger logger = LoggerFactory.getLogger(WorkerLifeCycle.class);
    private static final Pattern PID_LOG_PATTERN = Pattern.compile(".*\\[PID\\](\\d+)$");
    private static final String METRIC_LOG_START_SUBSTRING = "[METRICS]";
    private ConfigManager configManager;
    private ModelManager modelManager = ModelManager.getInstance();
    private Model model;
    private int pid = -1;
    private Process process;
    private CountDownLatch latch;
    private boolean success;
    private Connector connector;
    private ReaderThread errReader;
    private ReaderThread outReader;
    private int numWorker;
    private int currNumRunningWorkers;

    public WorkerLifeCycle(ConfigManager configManager, Model model) {
        this.configManager = configManager;
        this.model = model;
        this.numWorker = model.getMinWorkers();
        this.currNumRunningWorkers = this.modelManager.getNumRunningWorkers(model.getModelVersionName());
    }

    public Process getProcess() {
        return this.process;
    }

    public ArrayList<String> launcherArgsToList(String launcherArgs) {
        ArrayList<String> arrlist = new ArrayList<String>();
        arrlist.add("-m");
        arrlist.add("torch.backends.xeon.run_cpu");
        if (launcherArgs != null && launcherArgs.length() > 1) {
            String[] argarray = launcherArgs.split(" ");
            for (int i = 0; i < argarray.length; ++i) {
                arrlist.add(argarray[i]);
            }
        }
        return arrlist;
    }

    public boolean isLauncherAvailable(String launcherArgs) throws WorkerInitializationException, InterruptedException {
        boolean launcherAvailable = false;
        ArrayList<String> cmd = new ArrayList<String>();
        cmd.add("python");
        ArrayList<String> args = this.launcherArgsToList(launcherArgs);
        cmd.addAll(args);
        cmd.add("--no_python");
        String dummyCmd = "hostname";
        cmd.add(dummyCmd);
        String[] cmdList = new String[cmd.size()];
        cmdList = cmd.toArray(cmdList);
        logger.debug("launcherAvailable cmdline: {}", (Object)cmd.toString());
        try {
            Process processLauncher = Runtime.getRuntime().exec(cmdList);
            int ret = processLauncher.waitFor();
            launcherAvailable = ret == 0;
        }
        catch (IOException | InterruptedException e) {
            throw new WorkerInitializationException("Failed to start launcher", e);
        }
        return launcherAvailable;
    }

    public void startWorker(int port, String deviceIds) throws WorkerInitializationException, InterruptedException {
        switch (this.model.getRuntimeType()) {
            case LSP: {
                logger.info("LSP startWorker");
                this.startWorkerCPP(port, "LSP", deviceIds);
                break;
            }
            default: {
                this.startWorkerPython(port, deviceIds);
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void startWorkerPython(int port, String deviceIds) throws WorkerInitializationException, InterruptedException {
        File modelPath;
        File workingDir = new File(this.configManager.getModelServerHome());
        this.setPort(port);
        try {
            modelPath = this.model.getModelDir();
            modelPath.getCanonicalFile();
        }
        catch (IOException e) {
            throw new WorkerInitializationException("Failed get TS home directory", e);
        }
        ArrayList<String> argl = new ArrayList<String>();
        ArrayList<String> envp = new ArrayList<String>();
        envp.addAll(Arrays.asList(EnvironmentUtils.getEnvString(workingDir.getAbsolutePath(), modelPath.getAbsolutePath(), this.model.getModelArchive().getManifest().getModel().getHandler())));
        if (this.model.getParallelLevel() > 0) {
            this.attachRunner(argl, envp, port, deviceIds);
        } else if (this.model.getParallelLevel() == 0) {
            argl.add(EnvironmentUtils.getPythonRunTime(this.model));
        }
        if (this.configManager.isCPULauncherEnabled()) {
            String launcherArgs = this.configManager.getCPULauncherArgs();
            boolean launcherAvailable = this.isLauncherAvailable(launcherArgs);
            if (launcherAvailable) {
                ArrayList<String> args = this.launcherArgsToList(launcherArgs);
                argl.addAll(args);
                if (this.numWorker > 1) {
                    argl.add("--ninstances");
                    argl.add(String.valueOf(this.numWorker));
                    argl.add("--rank");
                    argl.add(String.valueOf(this.currNumRunningWorkers));
                }
            } else {
                logger.warn("torch.backends.xeon.run_cpu is not available. Proceeding without worker core pinning. For better performance, please make sure torch.backends.xeon.run_cpu is available.");
            }
        }
        argl.add(new File(workingDir, "ts/model_service_worker.py").getAbsolutePath());
        argl.add("--sock-type");
        argl.add(this.connector.getSocketType());
        argl.add(this.connector.isUds() ? "--sock-name" : "--port");
        argl.add(this.connector.getSocketPath());
        argl.add("--metrics-config");
        argl.add(this.configManager.getMetricsConfigPath());
        try {
            this.latch = new CountDownLatch(this.model.getParallelLevel() > 0 ? this.model.getParallelLevel() : 1);
            String[] args = argl.toArray(new String[argl.size()]);
            String[] envs = envp.toArray(new String[envp.size()]);
            logger.debug("Worker cmdline: {}", (Object)argl.toString());
            WorkerLifeCycle workerLifeCycle = this;
            synchronized (workerLifeCycle) {
                this.process = Runtime.getRuntime().exec(args, envs, modelPath);
                String threadName = "W-" + port + '-' + this.model.getModelVersionName().getVersionedModelName();
                this.errReader = new ReaderThread(threadName, this.process.getErrorStream(), true, this);
                this.outReader = new ReaderThread(threadName, this.process.getInputStream(), false, this);
                this.errReader.start();
                this.outReader.start();
            }
            if (this.latch.await(2L, TimeUnit.MINUTES)) {
                if (!this.success) {
                    throw new WorkerInitializationException("Backend stream closed.");
                }
                return;
            }
            try {
                throw new WorkerInitializationException("Backend worker startup time out.");
            }
            catch (IOException e) {
                throw new WorkerInitializationException("Failed start worker process", e);
            }
        }
        finally {
            if (!this.success) {
                this.exit();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void startWorkerCPP(int port, String runtimeType, String deviceIds) throws WorkerInitializationException, InterruptedException {
        File modelPath;
        File workingDir = new File(this.configManager.getModelServerHome());
        this.setPort(port);
        try {
            modelPath = this.model.getModelDir().getCanonicalFile();
        }
        catch (IOException e) {
            throw new WorkerInitializationException("Failed get TS home directory", e);
        }
        ArrayList<String> argl = new ArrayList<String>();
        File cppBackendBin = new File(workingDir, "ts/cpp/bin/model_worker_socket");
        File cppBackendLib = new File(workingDir, "ts/cpp/lib");
        if (!cppBackendBin.exists()) {
            throw new WorkerInitializationException("model_worker_socket not found");
        }
        if (!cppBackendLib.exists()) {
            throw new WorkerInitializationException("model_worker cpp library not found");
        }
        argl.add(cppBackendBin.getAbsolutePath());
        argl.add("--sock_type");
        argl.add(this.connector.getSocketType());
        argl.add(this.connector.isUds() ? "--sock_name" : "--port");
        argl.add(this.connector.getSocketPath());
        argl.add("--runtime_type");
        argl.add(runtimeType);
        argl.add("--model_dir");
        argl.add(modelPath.getAbsolutePath());
        if (ConfigManager.getInstance().getTsCppLogConfig() != null) {
            argl.add("--logger_config_path");
            argl.add(ConfigManager.getInstance().getTsCppLogConfig());
        }
        argl.add("--metrics_config_path");
        argl.add(this.configManager.getMetricsConfigPath());
        String[] envp = EnvironmentUtils.getCppEnvString(cppBackendLib.getAbsolutePath());
        try {
            this.latch = new CountDownLatch(1);
            String[] args = argl.toArray(new String[argl.size()]);
            logger.debug("Worker cmdline: {}", (Object)argl.toString());
            WorkerLifeCycle workerLifeCycle = this;
            synchronized (workerLifeCycle) {
                this.process = Runtime.getRuntime().exec(args, envp, modelPath);
                String threadName = "W-" + port + '-' + this.model.getModelVersionName().getVersionedModelName();
                this.errReader = new ReaderThread(threadName, this.process.getErrorStream(), true, this);
                this.outReader = new ReaderThread(threadName, this.process.getInputStream(), false, this);
                this.errReader.start();
                this.outReader.start();
            }
            if (this.latch.await(2L, TimeUnit.MINUTES)) {
                if (!this.success) {
                    throw new WorkerInitializationException("Backend stream closed.");
                }
                return;
            }
            try {
                throw new WorkerInitializationException("Backend worker startup time out.");
            }
            catch (IOException e) {
                throw new WorkerInitializationException("Failed start worker process", e);
            }
        }
        finally {
            if (!this.success) {
                this.exit();
            }
        }
    }

    private void attachRunner(ArrayList<String> argl, List<String> envp, int port, String deviceIds) {
        envp.add("LOGLEVEL=INFO");
        if (deviceIds != null) {
            envp.add("CUDA_VISIBLE_DEVICES=" + deviceIds);
        }
        ModelConfig.TorchRun torchRun = this.model.getModelArchive().getModelConfig().getTorchRun();
        envp.add(String.format("OMP_NUM_THREADS=%d", torchRun.getOmpNumberThreads()));
        argl.add("torchrun");
        argl.add("--nnodes");
        argl.add(String.valueOf(torchRun.getNnodes()));
        argl.add("--nproc-per-node");
        argl.add(String.valueOf(torchRun.getNprocPerNode()));
        argl.add("--log-dir");
        argl.add(ConfigManager.getInstance().getTorchRunLogDir());
        argl.add("--rdzv-backend");
        argl.add(torchRun.getRdzvBackend());
        if (torchRun.getRdzvEndpoint() != null) {
            argl.add("--rdzv-endpoint");
            argl.add(torchRun.getRdzvEndpoint());
        }
        argl.add("--rdzv-id");
        argl.add(String.format("%s_%d", this.model.getModelName(), port));
        if (torchRun.getMasterAddr() != null) {
            argl.add("--master-addr");
            argl.add(torchRun.getMasterAddr());
            argl.add("--master-port");
            argl.add(String.valueOf(torchRun.getMasterPort()));
        }
        argl.add("--max-restarts");
        argl.add(String.valueOf(1));
    }

    public synchronized void terminateIOStreams() {
        if (this.errReader != null) {
            logger.warn("terminateIOStreams() threadName={}", (Object)this.errReader.getName());
            this.errReader.terminate();
        }
        if (this.outReader != null) {
            logger.warn("terminateIOStreams() threadName={}", (Object)this.outReader.getName());
            this.outReader.terminate();
        }
    }

    public synchronized void exit() {
        if (this.process != null) {
            this.process.destroyForcibly();
            this.connector.clean();
            this.terminateIOStreams();
        }
    }

    public synchronized Integer getExitValue() {
        if (this.process != null && !this.process.isAlive()) {
            return this.process.exitValue();
        }
        return null;
    }

    public void setSuccess(boolean success) {
        this.success = success;
        this.latch.countDown();
    }

    public synchronized int getPid() {
        return this.pid;
    }

    public synchronized void setPid(int pid) {
        this.pid = pid;
    }

    private synchronized void setPort(int port) {
        this.connector = new Connector(port);
    }

    private static final class ReaderThread
    extends Thread {
        private static final Pattern METRIC_PATTERN = Pattern.compile("^(INFO > )?(\\[METRICS])(.*)");
        private static final Pattern WORKER_START_PATTERN = Pattern.compile("(.*)(INFO > )?(Torch worker started.)$");
        private static final Pattern WORKER_PID_PATTERN = Pattern.compile("^(INFO > )?(\\[PID])(\\d+)$");
        private static final Logger loggerModelOutput = LoggerFactory.getLogger("MODEL_LOG");
        private final MetricCache metricCache;
        private InputStream is;
        private boolean error;
        private WorkerLifeCycle lifeCycle;
        private AtomicBoolean isRunning = new AtomicBoolean(true);

        public ReaderThread(String name, InputStream is, boolean error, WorkerLifeCycle lifeCycle) {
            super(name + (error ? "-stderr" : "-stdout"));
            this.is = is;
            this.error = error;
            this.lifeCycle = lifeCycle;
            this.metricCache = MetricCache.getInstance();
        }

        public void terminate() {
            this.isRunning.set(false);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public void run() {
            try (Scanner scanner = new Scanner(this.is, StandardCharsets.UTF_8.name());){
                while (this.isRunning.get() && scanner.hasNext()) {
                    String result = scanner.nextLine();
                    if (result == null) {
                        break;
                    }
                    Matcher matcher = METRIC_PATTERN.matcher(result);
                    if (matcher.matches()) {
                        logger.info("result={}, pattern={}", (Object)result, (Object)matcher.group(2));
                        Metric parsedMetric = Metric.parse(matcher.group(3));
                        if (parsedMetric == null) {
                            logger.error("Failed to parse metrics line: \"{}\".", (Object)result);
                            continue;
                        }
                        try {
                            if (this.metricCache.getMetricBackend(parsedMetric.getMetricName()) == null) {
                                if (!this.lifeCycle.configManager.isModelMetricsAutoDetectEnabled()) continue;
                                logger.info("Registering auto detected backend metric: {}", (Object)parsedMetric);
                                this.metricCache.addAutoDetectMetricBackend(parsedMetric);
                            }
                            List<String> dimensionValues = parsedMetric.getDimensionValues();
                            dimensionValues.add(parsedMetric.getHostName());
                            this.metricCache.getMetricBackend(parsedMetric.getMetricName()).addOrUpdate(dimensionValues, parsedMetric.getRequestId(), Double.parseDouble(parsedMetric.getValue()));
                        }
                        catch (Exception e) {
                            logger.error("Failed to update backend metric ", parsedMetric.getMetricName(), ": ", e);
                        }
                        continue;
                    }
                    matcher = WORKER_START_PATTERN.matcher(result);
                    if (matcher.matches()) {
                        this.lifeCycle.setSuccess(true);
                    } else {
                        matcher = WORKER_PID_PATTERN.matcher(result);
                        if (matcher.matches()) {
                            this.lifeCycle.setPid(Integer.parseInt(matcher.group(3)));
                        }
                    }
                    if (this.error) {
                        loggerModelOutput.warn(result);
                        continue;
                    }
                    loggerModelOutput.info(result);
                }
            }
            catch (Exception e) {
                logger.error("Couldn't create scanner - {}", (Object)this.getName(), (Object)e);
            }
            finally {
                logger.info("Stopped Scanner - {}", (Object)this.getName());
                this.lifeCycle.setSuccess(false);
                try {
                    this.is.close();
                }
                catch (IOException e) {
                    logger.error("Failed to close stream for thread {}", (Object)this.getName(), (Object)e);
                }
            }
        }
    }
}

