/*
 * Decompiled with CFR 0.152.
 */
package org.pytorch.serve.job;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.DefaultHttpContent;
import io.netty.handler.codec.http.DefaultHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.util.CharsetUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import org.pytorch.serve.archive.model.ModelNotFoundException;
import org.pytorch.serve.archive.model.ModelVersionNotFoundException;
import org.pytorch.serve.http.InternalServerException;
import org.pytorch.serve.http.messages.DescribeModelResponse;
import org.pytorch.serve.job.Job;
import org.pytorch.serve.metrics.IMetric;
import org.pytorch.serve.metrics.MetricCache;
import org.pytorch.serve.util.ApiUtils;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.JsonUtils;
import org.pytorch.serve.util.NettyUtils;
import org.pytorch.serve.util.messages.RequestInput;
import org.pytorch.serve.util.messages.WorkerCommands;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RestJob
extends Job {
    private static final Logger logger = LoggerFactory.getLogger(RestJob.class);
    private final IMetric inferenceLatencyMetric;
    private final IMetric queueLatencyMetric;
    private final List<String> latencyMetricDimensionValues;
    private final IMetric queueTimeMetric;
    private final List<String> queueTimeMetricDimensionValues;
    private ChannelHandlerContext ctx;
    private CompletableFuture<byte[]> responsePromise;
    private int numStreams;

    public RestJob(ChannelHandlerContext ctx, String modelName, String version, WorkerCommands cmd, RequestInput input) {
        super(modelName, version, cmd, input);
        this.ctx = ctx;
        this.inferenceLatencyMetric = MetricCache.getInstance().getMetricFrontend("ts_inference_latency_microseconds");
        this.queueLatencyMetric = MetricCache.getInstance().getMetricFrontend("ts_queue_latency_microseconds");
        this.latencyMetricDimensionValues = Arrays.asList(this.getModelName(), this.getModelVersion() == null ? "default" : this.getModelVersion(), ConfigManager.getInstance().getHostName());
        this.queueTimeMetric = MetricCache.getInstance().getMetricFrontend("QueueTime");
        this.queueTimeMetricDimensionValues = Arrays.asList("Host", ConfigManager.getInstance().getHostName());
        this.numStreams = 0;
    }

    @Override
    public void response(byte[] body, CharSequence contentType, int statusCode, String statusPhrase, Map<String, String> responseHeaders) {
        if (this.getCmd() == WorkerCommands.PREDICT) {
            this.responseInference(body, contentType, statusCode, statusPhrase, responseHeaders);
        } else if (this.getCmd() == WorkerCommands.DESCRIBE) {
            this.responseDescribe(body, contentType, statusCode, statusPhrase, responseHeaders);
        }
    }

    private void responseDescribe(byte[] body, CharSequence contentType, int statusCode, String statusPhrase, Map<String, String> responseHeaders) {
        try {
            ArrayList<DescribeModelResponse> respList = ApiUtils.getModelDescription(this.getModelName(), this.getModelVersion());
            if (body != null && body.length != 0 && respList != null && respList.size() == 1) {
                respList.get(0).setCustomizedMetadata(body);
            }
            HttpResponseStatus status = statusPhrase == null ? HttpResponseStatus.valueOf(statusCode) : new HttpResponseStatus(statusCode, statusPhrase);
            DefaultFullHttpResponse resp = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, true);
            if (contentType != null && contentType.length() > 0) {
                resp.headers().set((CharSequence)HttpHeaderNames.CONTENT_TYPE, (Object)contentType);
            } else {
                resp.headers().set((CharSequence)HttpHeaderNames.CONTENT_TYPE, (Object)HttpHeaderValues.APPLICATION_JSON);
            }
            if (responseHeaders != null) {
                for (Map.Entry<String, String> e : responseHeaders.entrySet()) {
                    resp.headers().set(e.getKey(), (Object)e.getValue());
                }
            }
            ByteBuf content = resp.content();
            content.writeCharSequence(JsonUtils.GSON_PRETTY.toJson(respList), CharsetUtil.UTF_8);
            content.writeByte(10);
            NettyUtils.sendHttpResponse(this.ctx, resp, true);
        }
        catch (ModelNotFoundException | ModelVersionNotFoundException e) {
            logger.trace("", e);
            NettyUtils.sendError(this.ctx, HttpResponseStatus.NOT_FOUND, e);
        }
    }

    private void responseInference(byte[] body, CharSequence contentType, int statusCode, String statusPhrase, Map<String, String> responseHeaders) {
        DefaultHttpResponse resp;
        HttpResponseStatus status;
        long inferTime = System.nanoTime() - this.getBegin();
        HttpResponseStatus httpResponseStatus = status = statusPhrase == null ? HttpResponseStatus.valueOf(statusCode) : new HttpResponseStatus(statusCode, statusPhrase);
        if (responseHeaders != null && responseHeaders.containsKey("ts_stream_next")) {
            resp = new DefaultHttpResponse(HttpVersion.HTTP_1_1, status, false);
            this.numStreams = responseHeaders.get("ts_stream_next").equals("true") ? this.numStreams + 1 : -1;
        } else {
            resp = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, true);
        }
        if (contentType != null && contentType.length() > 0) {
            resp.headers().set((CharSequence)HttpHeaderNames.CONTENT_TYPE, (Object)contentType);
        }
        if (responseHeaders != null) {
            for (Map.Entry<String, String> e : responseHeaders.entrySet()) {
                resp.headers().set(e.getKey(), (Object)e.getValue());
            }
        }
        if (this.ctx != null) {
            if (this.numStreams == 0) {
                ((DefaultFullHttpResponse)resp).content().writeBytes(body);
                NettyUtils.sendHttpResponse(this.ctx, resp, true);
            } else if (this.numStreams == -1) {
                this.ctx.writeAndFlush(new DefaultHttpContent(Unpooled.wrappedBuffer(body)));
                this.ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT);
            } else if (this.numStreams == 1) {
                NettyUtils.sendHttpResponse(this.ctx, resp, true);
                this.ctx.writeAndFlush(new DefaultHttpContent(Unpooled.wrappedBuffer(body)));
            } else if (this.numStreams > 1) {
                this.ctx.writeAndFlush(new DefaultHttpContent(Unpooled.wrappedBuffer(body)));
            }
        } else if (this.responsePromise != null) {
            this.responsePromise.complete(body);
        }
        if (this.numStreams <= 0) {
            if (this.inferenceLatencyMetric != null) {
                try {
                    this.inferenceLatencyMetric.addOrUpdate(this.latencyMetricDimensionValues, (double)inferTime / 1000.0);
                }
                catch (Exception e) {
                    logger.error("Failed to update frontend metric ts_inference_latency_microseconds: ", e);
                }
            }
            if (this.queueLatencyMetric != null) {
                try {
                    this.queueLatencyMetric.addOrUpdate(this.latencyMetricDimensionValues, (double)(this.getScheduled() - this.getBegin()) / 1000.0);
                }
                catch (Exception e) {
                    logger.error("Failed to update frontend metric ts_queue_latency_microseconds: ", e);
                }
            }
            logger.debug("Waiting time ns: {}, Backend time ns: {}", (Object)(this.getScheduled() - this.getBegin()), (Object)(System.nanoTime() - this.getScheduled()));
            double queueTime = TimeUnit.MILLISECONDS.convert(this.getScheduled() - this.getBegin(), TimeUnit.NANOSECONDS);
            if (this.queueTimeMetric != null) {
                try {
                    this.queueTimeMetric.addOrUpdate(this.queueTimeMetricDimensionValues, queueTime);
                }
                catch (Exception e) {
                    logger.error("Failed to update frontend metric QueueTime: ", e);
                }
            }
        }
    }

    @Override
    public void sendError(int status, String error) {
        if (this.ctx != null) {
            status = status == 413 ? 507 : status;
            NettyUtils.sendError(this.ctx, HttpResponseStatus.valueOf(status), new InternalServerException(error));
        } else if (this.responsePromise != null) {
            this.responsePromise.completeExceptionally(new InternalServerException(error));
        }
        if (this.getCmd() == WorkerCommands.PREDICT) {
            logger.debug("Waiting time ns: {}, Inference time ns: {}", (Object)(this.getScheduled() - this.getBegin()), (Object)(System.nanoTime() - this.getBegin()));
        }
    }

    public CompletableFuture<byte[]> getResponsePromise() {
        return this.responsePromise;
    }

    public void setResponsePromise(CompletableFuture<byte[]> responsePromise) {
        this.responsePromise = responsePromise;
    }

    @Override
    public boolean isOpen() {
        Channel c = this.ctx.channel();
        return c.isOpen();
    }
}

