/*
 * Decompiled with CFR 0.152.
 */
package com.amazonaws.ml.mms.wlm;

import com.amazonaws.ml.mms.util.ConfigManager;
import com.amazonaws.ml.mms.util.NettyUtils;
import com.amazonaws.ml.mms.util.codec.ModelRequestEncoder;
import com.amazonaws.ml.mms.util.codec.ModelResponseDecoder;
import com.amazonaws.ml.mms.util.messages.InputParameter;
import com.amazonaws.ml.mms.util.messages.ModelWorkerResponse;
import com.amazonaws.ml.mms.util.messages.RequestInput;
import com.amazonaws.ml.mms.util.messages.WorkerCommands;
import com.amazonaws.ml.mms.wlm.BatchAggregator;
import com.amazonaws.ml.mms.wlm.Job;
import com.amazonaws.ml.mms.wlm.Model;
import com.amazonaws.ml.mms.wlm.ModelManager;
import com.amazonaws.ml.mms.wlm.WorkerInitializationException;
import com.amazonaws.ml.mms.wlm.WorkerLifeCycle;
import com.amazonaws.ml.mms.wlm.WorkerState;
import com.amazonaws.ml.mms.wlm.WorkerStateListener;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import java.io.IOException;
import java.net.SocketAddress;
import java.util.UUID;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class WorkerThread
implements Runnable {
    static final Logger logger = LoggerFactory.getLogger(WorkerThread.class);
    private static final int[] BACK_OFF = new int[]{0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597};
    static final long WORKER_TIMEOUT = 2L;
    static final ModelRequestEncoder ENCODER = new ModelRequestEncoder();
    private ConfigManager configManager;
    private EventLoopGroup backendEventGroup;
    private int port;
    private Model model;
    private Channel backendChannel;
    private AtomicBoolean running = new AtomicBoolean(true);
    private int backoffIdx;
    private BatchAggregator aggregator;
    private WorkerStateListener listener;
    ArrayBlockingQueue<ModelWorkerResponse> replies;
    private int gpuId;
    private long memory;
    private long startTime;
    private Thread currentThread;
    private String workerId;
    private WorkerState state;
    private WorkerLifeCycle lifeCycle;

    public WorkerState getState() {
        return this.state;
    }

    public WorkerThread(ConfigManager configManager, EventLoopGroup backendEventGroup, int port, int gpuId, Model model, BatchAggregator aggregator, WorkerStateListener listener) {
        this.workerId = String.valueOf(port);
        this.configManager = configManager;
        this.backendEventGroup = backendEventGroup;
        this.port = port;
        this.model = model;
        this.aggregator = aggregator;
        this.gpuId = gpuId;
        this.listener = listener;
        this.startTime = System.currentTimeMillis();
        this.lifeCycle = new WorkerLifeCycle(configManager, model);
        this.replies = new ArrayBlockingQueue(1);
    }

    /*
     * Exception decompiling
     */
    @Override
    public void run() {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 2 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    public String getWorkerId() {
        return this.workerId;
    }

    public long getMemory() {
        return this.memory;
    }

    public void setMemory(long memory) {
        this.memory = memory;
    }

    private void connect() throws WorkerInitializationException, InterruptedException {
        if (!this.configManager.isDebug()) {
            this.lifeCycle.startWorker(this.port);
        }
        String modelName = this.model.getModelName();
        this.setState(WorkerState.WORKER_STARTED);
        CountDownLatch latch = new CountDownLatch(1);
        try {
            Bootstrap b = new Bootstrap();
            ((Bootstrap)((Bootstrap)b.group(this.backendEventGroup)).channel(NettyUtils.getClientChannel())).handler(new ChannelInitializer<Channel>(){

                @Override
                public void initChannel(Channel ch) {
                    ChannelPipeline p = ch.pipeline();
                    p.addLast(ENCODER);
                    p.addLast(new ModelResponseDecoder());
                    p.addLast(new WorkerHandler());
                }
            });
            SocketAddress address = NettyUtils.getSocketAddress(this.port);
            logger.info("Connecting to: {}", (Object)address);
            this.backendChannel = b.connect(address).sync().channel();
            this.backendChannel.closeFuture().addListener(future -> {
                latch.countDown();
                logger.info("{} Worker disconnected.", (Object)this.getWorkerId());
                this.running.set(false);
                this.currentThread.interrupt();
            });
            this.backendChannel.newSucceededFuture().addListener(future -> {
                RequestInput input = new RequestInput(UUID.randomUUID().toString());
                if (this.gpuId >= 0) {
                    input.addParameter(new InputParameter("gpu", String.valueOf(this.gpuId)));
                }
                Job job = new Job(null, modelName, WorkerCommands.LOAD, input);
                this.model.addJob(this.workerId, job);
                latch.countDown();
            });
            if (!latch.await(2L, TimeUnit.MINUTES)) {
                throw new WorkerInitializationException("Worker failed to initialize within 2 mins");
            }
            this.running.set(true);
        }
        catch (Throwable t) {
            if (t instanceof IOException) {
                throw new WorkerInitializationException("Failed to connect to worker.", t);
            }
            throw t;
        }
    }

    public boolean isRunning() {
        return this.running.get();
    }

    public int getGpuId() {
        return this.gpuId;
    }

    public long getStartTime() {
        return this.startTime;
    }

    public int getPid() {
        return this.lifeCycle.getPid();
    }

    public void shutdown() {
        this.running.set(false);
        this.setState(WorkerState.WORKER_TERMINATED);
        if (this.backendChannel != null) {
            this.backendChannel.close();
        }
        if (this.currentThread != null) {
            this.currentThread.interrupt();
            this.aggregator.sendError(null, "Worker scaled down.");
            this.model.removeJobQueue(this.workerId);
        }
    }

    void setState(WorkerState newState) {
        this.listener.notifyChangeState(this.model.getModelName(), newState);
        if (this.state != WorkerState.WORKER_TERMINATED) {
            this.state = newState;
        }
    }

    void retry() {
        if (this.state == WorkerState.WORKER_TERMINATED) {
            return;
        }
        ModelManager manager = ModelManager.getInstance();
        if (this.backoffIdx < BACK_OFF.length - 1) {
            ++this.backoffIdx;
        }
        manager.getScheduler().schedule(() -> manager.submitTask(this), (long)BACK_OFF[this.backoffIdx], TimeUnit.SECONDS);
        logger.info("Retry worker: {} in {} seconds.", (Object)this.workerId, (Object)BACK_OFF[this.backoffIdx]);
    }

    @ChannelHandler.Sharable
    private class WorkerHandler
    extends SimpleChannelInboundHandler<ModelWorkerResponse> {
        private WorkerHandler() {
        }

        @Override
        public void channelRead0(ChannelHandlerContext ctx, ModelWorkerResponse msg) {
            if (!WorkerThread.this.replies.offer(msg)) {
                throw new IllegalStateException("Reply queue is full.");
            }
        }

        @Override
        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
            logger.error("Unknown exception", cause);
            ctx.close();
        }
    }
}

