/*
 * Decompiled with CFR 0.152.
 */
package com.amazonaws.ml.mms.wlm;

import com.amazonaws.ml.mms.metrics.Dimension;
import com.amazonaws.ml.mms.metrics.Metric;
import com.amazonaws.ml.mms.util.ConfigManager;
import com.amazonaws.ml.mms.util.Connector;
import com.amazonaws.ml.mms.util.NettyUtils;
import com.amazonaws.ml.mms.util.codec.ModelRequestEncoder;
import com.amazonaws.ml.mms.util.codec.ModelResponseDecoder;
import com.amazonaws.ml.mms.util.messages.BaseModelRequest;
import com.amazonaws.ml.mms.util.messages.InputParameter;
import com.amazonaws.ml.mms.util.messages.ModelWorkerResponse;
import com.amazonaws.ml.mms.util.messages.RequestInput;
import com.amazonaws.ml.mms.util.messages.WorkerCommands;
import com.amazonaws.ml.mms.wlm.BatchAggregator;
import com.amazonaws.ml.mms.wlm.Job;
import com.amazonaws.ml.mms.wlm.Model;
import com.amazonaws.ml.mms.wlm.ModelManager;
import com.amazonaws.ml.mms.wlm.WorkerInitializationException;
import com.amazonaws.ml.mms.wlm.WorkerLifeCycle;
import com.amazonaws.ml.mms.wlm.WorkerState;
import com.amazonaws.ml.mms.wlm.WorkerStateListener;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.HttpResponseStatus;
import java.io.IOException;
import java.net.SocketAddress;
import java.util.UUID;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class WorkerThread
implements Runnable {
    static final Logger logger = LoggerFactory.getLogger(WorkerThread.class);
    private static final org.apache.log4j.Logger loggerMmsMetrics = org.apache.log4j.Logger.getLogger("MMS_METRICS");
    private Metric workerLoadTime;
    private static final int[] BACK_OFF = new int[]{0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597};
    static final long WORKER_TIMEOUT = 2L;
    static final ModelRequestEncoder ENCODER = new ModelRequestEncoder();
    private ConfigManager configManager;
    private EventLoopGroup backendEventGroup;
    private int port;
    private Model model;
    private Channel backendChannel;
    private AtomicBoolean running = new AtomicBoolean(true);
    private int backoffIdx;
    private BatchAggregator aggregator;
    private WorkerStateListener listener;
    ArrayBlockingQueue<ModelWorkerResponse> replies;
    private int gpuId;
    private long memory;
    private long startTime;
    private AtomicReference<Thread> currentThread = new AtomicReference();
    private String workerId;
    private WorkerState state;
    private WorkerLifeCycle lifeCycle;

    public WorkerState getState() {
        return this.state;
    }

    public WorkerThread(ConfigManager configManager, EventLoopGroup backendEventGroup, int port, int gpuId, Model model, BatchAggregator aggregator, WorkerStateListener listener) {
        this.workerId = String.valueOf(port);
        this.configManager = configManager;
        this.backendEventGroup = backendEventGroup;
        this.port = port;
        this.model = model;
        this.aggregator = aggregator;
        this.gpuId = gpuId;
        this.listener = listener;
        this.startTime = System.currentTimeMillis();
        this.lifeCycle = new WorkerLifeCycle(configManager, model);
        this.replies = new ArrayBlockingQueue(1);
        this.workerLoadTime = new Metric(this.getWorkerName(), String.valueOf(System.currentTimeMillis()), "ms", ConfigManager.getInstance().getHostName(), new Dimension("Level", "Host"));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void run() {
        int responseTimeout = this.model.getResponseTimeout();
        Thread thread = Thread.currentThread();
        thread.setName(this.getWorkerName());
        this.currentThread.set(thread);
        BaseModelRequest req = null;
        HttpResponseStatus status = HttpResponseStatus.INTERNAL_SERVER_ERROR;
        try {
            this.connect();
            while (this.isRunning()) {
                req = this.aggregator.getRequest(this.workerId, this.state);
                this.backendChannel.writeAndFlush(req).sync();
                long begin = System.currentTimeMillis();
                ModelWorkerResponse reply = this.replies.poll(responseTimeout, TimeUnit.SECONDS);
                long duration = System.currentTimeMillis() - begin;
                logger.info("Backend response time: {}", (Object)duration);
                if (reply == null) {
                    int val = this.model.incrFailedInfReqs();
                    logger.error("Number or consecutive unsuccessful inference {}", (Object)val);
                    throw new WorkerInitializationException("Backend worker did not respond in given time");
                }
                this.aggregator.sendResponse(reply);
                switch (req.getCommand()) {
                    case PREDICT: {
                        this.model.resetFailedInfReqs();
                        break;
                    }
                    case LOAD: {
                        if (reply.getCode() == 200) {
                            this.setState(WorkerState.WORKER_MODEL_LOADED, HttpResponseStatus.OK);
                            this.backoffIdx = 0;
                            break;
                        }
                        this.setState(WorkerState.WORKER_ERROR, HttpResponseStatus.valueOf(reply.getCode()));
                        status = HttpResponseStatus.valueOf(reply.getCode());
                        break;
                    }
                }
                req = null;
            }
        }
        catch (InterruptedException e) {
            if (this.state == WorkerState.WORKER_SCALED_DOWN) {
                logger.debug("Shutting down the thread .. Scaling down.");
            } else {
                logger.debug("Backend worker monitoring thread interrupted or backend worker process died.", e);
            }
        }
        catch (WorkerInitializationException e) {
            logger.error("Backend worker error", e);
        }
        catch (OutOfMemoryError oom) {
            logger.error("Out of memory error when creating workers", oom);
            status = HttpResponseStatus.INSUFFICIENT_STORAGE;
        }
        catch (Throwable t) {
            logger.warn("Backend worker thread exception.", t);
        }
        finally {
            this.currentThread.set(null);
            if (req != null) {
                this.aggregator.sendError(req, "Worker died.", status);
            }
            this.setState(WorkerState.WORKER_STOPPED, status);
            this.lifeCycle.exit();
            this.retry();
        }
    }

    public String getWorkerId() {
        return this.workerId;
    }

    public long getMemory() {
        return this.memory;
    }

    public void setMemory(long memory) {
        this.memory = memory;
    }

    private void connect() throws WorkerInitializationException, InterruptedException {
        if (!this.configManager.isDebug()) {
            this.lifeCycle.startWorker(this.port);
        }
        String modelName = this.model.getModelName();
        this.setState(WorkerState.WORKER_STARTED, HttpResponseStatus.OK);
        CountDownLatch latch = new CountDownLatch(1);
        final int responseBufferSize = this.configManager.getMaxResponseSize();
        try {
            Connector connector = new Connector(this.port);
            Bootstrap b = new Bootstrap();
            ((Bootstrap)((Bootstrap)b.group(this.backendEventGroup)).channel(connector.getClientChannel())).handler(new ChannelInitializer<Channel>(){

                @Override
                public void initChannel(Channel ch) {
                    ChannelPipeline p = ch.pipeline();
                    p.addLast(ENCODER);
                    p.addLast(new ModelResponseDecoder(responseBufferSize));
                    p.addLast(new WorkerHandler());
                }
            });
            SocketAddress address = connector.getSocketAddress();
            logger.info("Connecting to: {}", (Object)address);
            this.backendChannel = b.connect(address).sync().channel();
            this.backendChannel.closeFuture().addListener(future -> {
                latch.countDown();
                logger.info("{} Worker disconnected. {}", (Object)this.getWorkerId(), (Object)this.state);
                Thread thread = this.currentThread.getAndSet(null);
                if (thread != null) {
                    thread.interrupt();
                }
            });
            this.backendChannel.newSucceededFuture().addListener(future -> {
                RequestInput input = new RequestInput(UUID.randomUUID().toString());
                if (this.gpuId >= 0) {
                    input.addParameter(new InputParameter("gpu", String.valueOf(this.gpuId)));
                }
                Job job = new Job(null, modelName, WorkerCommands.LOAD, input);
                this.model.addJob(this.workerId, job);
                latch.countDown();
            });
            if (!latch.await(2L, TimeUnit.MINUTES)) {
                throw new WorkerInitializationException("Worker failed to initialize within 2 mins");
            }
            this.running.set(true);
        }
        catch (Throwable t) {
            if (t instanceof IOException) {
                throw new WorkerInitializationException("Failed to connect to worker.", t);
            }
            throw t;
        }
    }

    public boolean isRunning() {
        return this.running.get();
    }

    public int getGpuId() {
        return this.gpuId;
    }

    public long getStartTime() {
        return this.startTime;
    }

    public int getPid() {
        return this.lifeCycle.getPid();
    }

    public void shutdown() {
        Thread thread;
        this.running.set(false);
        this.setState(WorkerState.WORKER_SCALED_DOWN, HttpResponseStatus.OK);
        if (this.backendChannel != null) {
            this.backendChannel.close();
        }
        if ((thread = (Thread)this.currentThread.getAndSet(null)) != null) {
            thread.interrupt();
            this.aggregator.sendError(null, "Worker scaled down.", HttpResponseStatus.INTERNAL_SERVER_ERROR);
            this.model.removeJobQueue(this.workerId);
        }
    }

    private final String getWorkerName() {
        String modelName = this.model.getModelName();
        if (modelName.length() > 25) {
            modelName = modelName.substring(0, 25);
        }
        return "W-" + this.port + '-' + modelName;
    }

    void setState(WorkerState newState, HttpResponseStatus status) {
        this.listener.notifyChangeState(this.model.getModelName(), newState, status);
        logger.debug("{} State change {} -> {}", new Object[]{this.getWorkerName(), this.state, newState});
        long timeTaken = System.currentTimeMillis() - this.startTime;
        if (this.state != WorkerState.WORKER_SCALED_DOWN) {
            this.state = newState;
        }
        if (this.state == WorkerState.WORKER_MODEL_LOADED) {
            this.workerLoadTime.setValue(String.valueOf(timeTaken));
            this.workerLoadTime.setTimestamp(String.valueOf(TimeUnit.MILLISECONDS.toSeconds(System.currentTimeMillis())));
            loggerMmsMetrics.info(this.workerLoadTime);
        }
    }

    void retry() {
        if (this.state == WorkerState.WORKER_SCALED_DOWN) {
            logger.debug("Worker terminated due to scale-down call.");
            return;
        }
        ModelManager manager = ModelManager.getInstance();
        if (this.backoffIdx < BACK_OFF.length - 1) {
            ++this.backoffIdx;
        }
        manager.getScheduler().schedule(() -> manager.submitTask(this), (long)BACK_OFF[this.backoffIdx], TimeUnit.SECONDS);
        logger.info("Retry worker: {} in {} seconds.", (Object)this.workerId, (Object)BACK_OFF[this.backoffIdx]);
    }

    @ChannelHandler.Sharable
    private class WorkerHandler
    extends SimpleChannelInboundHandler<ModelWorkerResponse> {
        private WorkerHandler() {
        }

        @Override
        public void channelRead0(ChannelHandlerContext ctx, ModelWorkerResponse msg) {
            if (!WorkerThread.this.replies.offer(msg)) {
                throw new IllegalStateException("Reply queue is full.");
            }
        }

        @Override
        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
            logger.error("Unknown exception", cause);
            if (cause instanceof OutOfMemoryError) {
                NettyUtils.sendError(ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, cause);
            }
            ctx.close();
        }
    }
}

