/*
 * Decompiled with CFR 0.152.
 */
package org.pytorch.serve.archive.model;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ModelConfig {
    private static final Logger logger = LoggerFactory.getLogger(ModelConfig.class);
    private int minWorkers;
    private int maxWorkers;
    private int batchSize;
    private int maxBatchDelay;
    private int responseTimeout = 120;
    private DeviceType deviceType = DeviceType.NONE;
    private List<Integer> deviceIds;
    private int parallelLevel;
    private ParallelType parallelType = ParallelType.NONE;
    private TorchRun torchRun;
    private int maxRetryTimeoutInSec = 300;
    private long clientTimeoutInMills;
    private int jobQueueSize;
    private boolean useJobTicket;
    private long sequenceMaxIdleMSec;
    private int maxSequenceJobQueueSize = 1;
    private int maxNumSequence = 1;
    private boolean continuousBatching;
    private boolean useVenv;

    public static ModelConfig build(Map<String, Object> yamlMap) {
        ModelConfig modelConfig = new ModelConfig();
        yamlMap.forEach((k, v) -> {
            switch (k) {
                case "minWorkers": {
                    if (v instanceof Integer) {
                        modelConfig.setMinWorkers((Integer)v);
                        break;
                    }
                    logger.warn("Invalid minWorkers: {}, should be integer", v);
                    break;
                }
                case "maxWorkers": {
                    if (v instanceof Integer) {
                        modelConfig.setMaxWorkers((Integer)v);
                        break;
                    }
                    logger.warn("Invalid maxWorkers: {}, should be integer", v);
                    break;
                }
                case "batchSize": {
                    if (v instanceof Integer) {
                        modelConfig.setBatchSize((Integer)v);
                        break;
                    }
                    logger.warn("Invalid batchSize: {}, should be integer", v);
                    break;
                }
                case "maxBatchDelay": {
                    if (v instanceof Integer) {
                        modelConfig.setMaxBatchDelay((Integer)v);
                        break;
                    }
                    logger.warn("Invalid maxBatchDelay: {}, should be integer", v);
                    break;
                }
                case "responseTimeout": {
                    if (v instanceof Integer) {
                        modelConfig.setResponseTimeout((Integer)v);
                        break;
                    }
                    logger.warn("Invalid responseTimeout: {}, should be integer", v);
                    break;
                }
                case "deviceType": {
                    if (v instanceof String) {
                        modelConfig.setDeviceType((String)v);
                        break;
                    }
                    logger.warn("Invalid deviceType: {}, should be cpu, or gpu", v);
                    break;
                }
                case "parallelType": {
                    if (v instanceof String) {
                        modelConfig.setParallelMode((String)v);
                        break;
                    }
                    logger.warn("Invalid parallelType: {}, should be pp, tp,or pptp", v);
                    break;
                }
                case "deviceIds": {
                    if (v instanceof List) {
                        modelConfig.setDeviceIds((List)v);
                        break;
                    }
                    logger.warn("Invalid deviceIds: {}, should be list of integer", v);
                    break;
                }
                case "torchrun": {
                    if (v instanceof Map) {
                        modelConfig.torchRun = TorchRun.build((Map)v);
                        modelConfig.setParallelLevel(modelConfig.torchRun.getNprocPerNode());
                        break;
                    }
                    logger.warn("Invalid torchrun: {}, should be Torchrun parameters", v);
                    break;
                }
                case "maxRetryTimeoutInSec": {
                    if (v instanceof Integer) {
                        modelConfig.setMaxRetryTimeoutInSec((Integer)v);
                        break;
                    }
                    logger.warn("Invalid maxRetryTimeoutInMin: {}, should be integer", v);
                    break;
                }
                case "clientTimeoutInMills": {
                    if (v instanceof Integer) {
                        modelConfig.setClientTimeoutInMills(((Integer)v).longValue());
                        break;
                    }
                    logger.warn("Invalid clientTimeoutInMills: {}, should be positive long", v);
                    break;
                }
                case "jobQueueSize": {
                    if (v instanceof Integer) {
                        modelConfig.setJobQueueSize((Integer)v);
                        break;
                    }
                    logger.warn("Invalid jobQueueSize: {}, should be positive int", v);
                    break;
                }
                case "useJobTicket": {
                    if (v instanceof Boolean) {
                        modelConfig.setUseJobTicket((Boolean)v);
                        break;
                    }
                    logger.warn("Invalid useJobTicket: {}, should be true or false", v);
                    break;
                }
                case "sequenceMaxIdleMSec": {
                    if (v instanceof Integer) {
                        modelConfig.setSequenceMaxIdleMSec(((Integer)v).longValue());
                        break;
                    }
                    logger.warn("Invalid sequenceMaxIdleMSec: {}, should be positive int", v);
                    break;
                }
                case "maxSequenceJobQueueSize": {
                    if (v instanceof Integer) {
                        modelConfig.setMaxSequenceJobQueueSize((Integer)v);
                        break;
                    }
                    logger.warn("Invalid maxSequenceJobQueueSize: {}, should be positive int", v);
                    break;
                }
                case "maxNumSequence": {
                    if (v instanceof Integer) {
                        modelConfig.setMaxNumSequence((Integer)v);
                        break;
                    }
                    logger.warn("Invalid maxNumSequence: {}, should be positive int", v);
                    break;
                }
                case "continuousBatching": {
                    if (v instanceof Boolean) {
                        modelConfig.setContinuousBatching((Boolean)v);
                        break;
                    }
                    logger.warn("Invalid continuousBatching: {}, should be true or false", v);
                    break;
                }
                case "useVenv": {
                    if (v instanceof Boolean) {
                        modelConfig.setUseVenv((Boolean)v);
                        break;
                    }
                    logger.warn("Invalid useVenv: {}, should be true or false", v);
                    break;
                }
            }
        });
        return modelConfig;
    }

    public int getMinWorkers() {
        return this.minWorkers;
    }

    public void setMinWorkers(int minWorkers) {
        this.minWorkers = Math.max(1, minWorkers);
    }

    public int getMaxWorkers() {
        return this.maxWorkers;
    }

    public void setMaxWorkers(int maxWorkers) {
        if (maxWorkers < 0) {
            logger.warn("Invalid maxWorkers:{}", (Object)maxWorkers);
            return;
        }
        this.maxWorkers = maxWorkers;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public void setBatchSize(int batchSize) {
        this.batchSize = Math.max(1, batchSize);
    }

    public int getMaxBatchDelay() {
        return this.maxBatchDelay;
    }

    public void setMaxBatchDelay(int maxBatchDelay) {
        if (maxBatchDelay < 0) {
            logger.warn("Invalid maxBatchDelay:{}", (Object)maxBatchDelay);
            return;
        }
        this.maxBatchDelay = maxBatchDelay;
    }

    public int getResponseTimeout() {
        return this.responseTimeout;
    }

    public void setResponseTimeout(int responseTimeout) {
        if (responseTimeout <= 0) {
            logger.warn("Invalid responseTimeout:{}", (Object)responseTimeout);
            return;
        }
        this.responseTimeout = responseTimeout;
    }

    public List<Integer> getDeviceIds() {
        return this.deviceIds;
    }

    public void setDeviceIds(List<?> deviceIds) {
        this.deviceIds = new ArrayList<Integer>();
        for (int i = 0; i < deviceIds.size(); ++i) {
            if (!(deviceIds.get(i) instanceof Integer)) {
                logger.warn("Invalid deviceIds:{},", deviceIds.get(i));
                this.deviceIds = null;
                break;
            }
            this.deviceIds.add((int)((Integer)deviceIds.get(i)));
        }
    }

    public int getParallelLevel() {
        return this.parallelLevel;
    }

    public void setParallelLevel(int parallelLevel) {
        if (parallelLevel < 0) {
            logger.warn("Invalid parallelLevel:{}, set as 0", (Object)parallelLevel);
            return;
        }
        this.parallelLevel = parallelLevel;
    }

    public void setParallelMode(String parallelMode) {
        this.parallelType = ParallelType.get(parallelMode);
    }

    public ParallelType getParallelType() {
        return this.parallelType;
    }

    public void setDeviceType(String deviceType) {
        this.deviceType = DeviceType.get(deviceType);
    }

    public DeviceType getDeviceType() {
        return this.deviceType;
    }

    public TorchRun getTorchRun() {
        return this.torchRun;
    }

    public int getMaxRetryTimeoutInSec() {
        return this.maxRetryTimeoutInSec;
    }

    public void setMaxRetryTimeoutInSec(int maxRetryTimeoutInSec) {
        this.maxRetryTimeoutInSec = Math.max(0, maxRetryTimeoutInSec);
    }

    public long getClientTimeoutInMills() {
        return this.clientTimeoutInMills;
    }

    public void setClientTimeoutInMills(long clientTimeoutInMills) {
        this.clientTimeoutInMills = Math.max(0L, clientTimeoutInMills);
    }

    public int getJobQueueSize() {
        return this.jobQueueSize;
    }

    public void setJobQueueSize(int jobQueueSize) {
        this.jobQueueSize = Math.max(0, jobQueueSize);
    }

    public boolean isUseJobTicket() {
        return this.useJobTicket;
    }

    public void setUseJobTicket(boolean useJobTicket) {
        this.useJobTicket = useJobTicket;
    }

    public long getSequenceMaxIdleMSec() {
        return this.sequenceMaxIdleMSec;
    }

    public void setSequenceMaxIdleMSec(long sequenceMaxIdleMSec) {
        this.sequenceMaxIdleMSec = Math.max(0L, sequenceMaxIdleMSec);
    }

    public int getMaxSequenceJobQueueSize() {
        return this.maxSequenceJobQueueSize;
    }

    public void setMaxSequenceJobQueueSize(int maxsequenceJobQueueSize) {
        this.maxSequenceJobQueueSize = Math.max(1, maxsequenceJobQueueSize);
    }

    public boolean isContinuousBatching() {
        return this.continuousBatching;
    }

    public void setContinuousBatching(boolean continuousBatching) {
        this.continuousBatching = continuousBatching;
    }

    public int getMaxNumSequence() {
        return this.maxNumSequence;
    }

    public void setMaxNumSequence(int maxNumSequence) {
        this.maxNumSequence = Math.max(1, maxNumSequence);
    }

    public boolean getUseVenv() {
        return this.useVenv;
    }

    public void setUseVenv(boolean useVenv) {
        this.useVenv = useVenv;
    }

    public static enum DeviceType {
        NONE(""),
        CPU("cpu"),
        GPU("gpu");

        private String type;

        private DeviceType(String type) {
            this.type = type.toLowerCase();
        }

        public String getDeviceType() {
            return this.type;
        }

        public static DeviceType get(String deviceType) {
            DeviceType dType = NONE;
            try {
                dType = Arrays.stream(DeviceType.values()).filter(t -> t.type.equals(deviceType.toLowerCase())).findFirst().get();
            }
            catch (NoSuchElementException e) {
                logger.warn("Invalid DeviceType:{}", (Object)deviceType, (Object)e);
            }
            return dType;
        }
    }

    public static enum ParallelType {
        NONE(""),
        PP("pp"),
        TP("tp"),
        PPTP("pptp");

        private String type;

        private ParallelType(String type) {
            this.type = type.toLowerCase();
        }

        public String getParallelType() {
            return this.type;
        }

        public static ParallelType get(String parallelType) {
            ParallelType pType = NONE;
            try {
                pType = Arrays.stream(ParallelType.values()).filter(t -> t.type.equals(parallelType.toLowerCase())).findFirst().get();
            }
            catch (NoSuchElementException e) {
                logger.warn("Invalid ParallelType:{}", (Object)parallelType, (Object)e);
            }
            return pType;
        }
    }

    public static class TorchRun {
        private int nnodes = 1;
        private int nprocPerNode = 1;
        private String rdzvId;
        private String rdzvEndpoint;
        private String rdzvBackend = "c10d";
        private String rdzvConf;
        private int monitorInterval = 5;
        private int nodeRank;
        private String masterAddr;
        private int masterPort;
        private int ompNumberThreads = 1;

        public static TorchRun build(Map<?, ?> torchRunMap) {
            TorchRun torchRun = new TorchRun();
            torchRunMap.forEach((k, v) -> {
                switch ((String)k) {
                    case "nnodes": {
                        if (v instanceof Integer) {
                            torchRun.setNnodes((Integer)v);
                            break;
                        }
                        logger.warn("Invalid torchrun.nnodes:{}, reset to 1", v);
                        break;
                    }
                    case "nproc-per-node": {
                        if (v instanceof Integer) {
                            torchRun.setNprocPerNode((Integer)v);
                            break;
                        }
                        logger.warn("Invalid torchrun.nproc-per-node:{}, reset to 1", v);
                        break;
                    }
                    case "rdzv-backend": {
                        if (v instanceof String) {
                            torchRun.setRdzvBackend((String)v);
                            break;
                        }
                        logger.warn("Invalid torchrun.rdzv-backend:{}, reset to c10d", v);
                        break;
                    }
                    case "rdzv-endpoint": {
                        if (v instanceof String) {
                            torchRun.setRdzvEndpoint((String)v);
                            break;
                        }
                        logger.warn("Invalid torchrun.rdzv-endpoint:{}", v);
                        break;
                    }
                    case "rdzv-conf": {
                        if (v instanceof String) {
                            torchRun.setRdzvConf((String)v);
                            break;
                        }
                        logger.warn("Invalid torchrun.rdzv-conf:{}", v);
                        break;
                    }
                    case "monitor-interval": {
                        if (v instanceof Integer) {
                            torchRun.setMonitorInterval((Integer)v);
                            break;
                        }
                        logger.warn("Invalid torchrun.max-restarts:{}, reset to 5", v);
                        break;
                    }
                    case "node-rank": {
                        if (v instanceof Integer) {
                            torchRun.setNodeRank((Integer)v);
                            break;
                        }
                        logger.warn("Invalid torchrun.node-rank:{}, reset to 0", v);
                        break;
                    }
                    case "OMP_NUMBER_THREADS": {
                        if (v instanceof Integer) {
                            torchRun.setOmpNumberThreads((Integer)v);
                            break;
                        }
                        logger.warn("Invalid OMP_NUMBER_THREADS:{}, reset to 1", v);
                        break;
                    }
                    default: {
                        logger.warn("unsupported parameter {}", k);
                    }
                }
            });
            return torchRun;
        }

        public int getNnodes() {
            return this.nnodes;
        }

        public void setNnodes(int nnodes) {
            if (nnodes <= 0) {
                logger.warn("Invalid torchrun.nnodes:{}, reset to 1", (Object)nnodes);
                return;
            }
            this.nnodes = nnodes;
        }

        public int getNprocPerNode() {
            return this.nprocPerNode;
        }

        public void setNprocPerNode(int nprocPerNode) {
            if (nprocPerNode <= 0) {
                logger.warn("Invalid torchrun.nproc-per-node:{}, reset to 1", (Object)nprocPerNode);
                return;
            }
            this.nprocPerNode = nprocPerNode;
        }

        public String getRdzvId() {
            return this.rdzvId;
        }

        public void setRdzvId(String rdzvId) {
            this.rdzvId = rdzvId;
        }

        public String getRdzvEndpoint() {
            return this.rdzvEndpoint;
        }

        public void setRdzvEndpoint(String rdzvEndpoint) {
            this.rdzvEndpoint = rdzvEndpoint;
        }

        public String getRdzvBackend() {
            return this.rdzvBackend;
        }

        public void setRdzvBackend(String rdzvBackend) {
            this.rdzvBackend = rdzvBackend;
        }

        public String getRdzvConf() {
            return this.rdzvConf;
        }

        public void setRdzvConf(String rdzvConf) {
            this.rdzvConf = rdzvConf;
        }

        public int getMonitorInterval() {
            return this.monitorInterval;
        }

        public void setMonitorInterval(int monitorInterval) {
            if (monitorInterval <= 0) {
                logger.warn("Invalid torchrun.monitor-interval:{}, reset to 5", (Object)monitorInterval);
                return;
            }
            this.monitorInterval = monitorInterval;
        }

        public int getNodeRank() {
            return this.nodeRank;
        }

        public void setNodeRank(int nodeRank) {
            if (nodeRank < 0) {
                logger.warn("Invalid torchrun.node-rank:{}, reset to 0", (Object)nodeRank);
                return;
            }
            this.nodeRank = nodeRank;
        }

        public String getMasterAddr() {
            return this.masterAddr;
        }

        public void setMasterAddr(String masterAddr) {
            this.masterAddr = masterAddr;
        }

        public int getMasterPort() {
            return this.masterPort;
        }

        public void setMasterPort(int masterPort) {
            this.masterPort = masterPort;
        }

        public int getOmpNumberThreads() {
            return this.ompNumberThreads;
        }

        public void setOmpNumberThreads(int ompNumberThreads) {
            if (ompNumberThreads < 1) {
                logger.warn("Invalid OMP_NUMBER_THREADS:{}, reset to 1", (Object)ompNumberThreads);
                return;
            }
            this.ompNumberThreads = ompNumberThreads;
        }
    }
}

