/*
 * Decompiled with CFR 0.152.
 */
package org.pytorch.serve.http.api.rest;

import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.QueryStringDecoder;
import io.netty.handler.codec.http.multipart.DefaultHttpDataFactory;
import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.pytorch.serve.archive.DownloadArchiveException;
import org.pytorch.serve.archive.model.ModelException;
import org.pytorch.serve.archive.model.ModelNotFoundException;
import org.pytorch.serve.archive.model.ModelVersionNotFoundException;
import org.pytorch.serve.archive.workflow.WorkflowException;
import org.pytorch.serve.http.BadRequestException;
import org.pytorch.serve.http.HttpRequestHandlerChain;
import org.pytorch.serve.http.ResourceNotFoundException;
import org.pytorch.serve.http.StatusResponse;
import org.pytorch.serve.metrics.IMetric;
import org.pytorch.serve.metrics.MetricCache;
import org.pytorch.serve.openapi.OpenApiUtils;
import org.pytorch.serve.servingsdk.ModelServerEndpoint;
import org.pytorch.serve.util.ApiUtils;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.NettyUtils;
import org.pytorch.serve.util.messages.InputParameter;
import org.pytorch.serve.util.messages.RequestInput;
import org.pytorch.serve.wlm.Model;
import org.pytorch.serve.wlm.ModelManager;
import org.pytorch.serve.wlm.WorkerInitializationException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class InferenceRequestHandler
extends HttpRequestHandlerChain {
    private static final Logger logger = LoggerFactory.getLogger(InferenceRequestHandler.class);

    public InferenceRequestHandler(Map<String, ModelServerEndpoint> ep) {
        this.endpointMap = ep;
    }

    @Override
    public void handleRequest(ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder, String[] segments) throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException {
        if (this.isInferenceReq(segments)) {
            if (this.endpointMap.getOrDefault(segments[1], null) != null) {
                this.handleCustomEndpoint(ctx, req, segments, decoder);
            } else {
                switch (segments[1]) {
                    case "ping": {
                        Runnable r = () -> {
                            boolean isHealthy = ApiUtils.isModelHealthy();
                            int code = 200;
                            String response = "Healthy";
                            if (!isHealthy) {
                                response = "Unhealthy";
                                code = 500;
                            }
                            NettyUtils.sendJsonResponse(ctx, new StatusResponse(response, code));
                        };
                        ApiUtils.getTorchServeHealth(r);
                        break;
                    }
                    case "models": 
                    case "invocations": {
                        this.validatePredictionsEndpoint(segments);
                        this.handleInvocations(ctx, req, decoder, segments);
                        break;
                    }
                    case "predictions": {
                        this.handlePredictions(ctx, req, segments, false);
                        break;
                    }
                    case "explanations": {
                        this.handlePredictions(ctx, req, segments, true);
                        break;
                    }
                    default: {
                        this.handleLegacyPredict(ctx, req, decoder, segments);
                        break;
                    }
                }
            }
        } else if (this.isKFV1InferenceReq(segments)) {
            if (segments[3].contains(":predict")) {
                this.handleKFV1Predictions(ctx, req, segments, false);
            } else if (segments[3].contains(":explain")) {
                this.handleKFV1Predictions(ctx, req, segments, true);
            }
        } else if (this.isKFV2InferenceReq(segments)) {
            if (segments[4].equals("infer")) {
                this.handleKFV2Predictions(ctx, req, segments, false);
            } else if (segments[4].equals("explain")) {
                this.handleKFV2Predictions(ctx, req, segments, true);
            }
        } else {
            this.chain.handleRequest(ctx, req, decoder, segments);
        }
    }

    private boolean isInferenceReq(String[] segments) {
        return segments.length == 0 || segments.length >= 2 && (segments[1].equals("ping") || segments[1].equals("predictions") || segments[1].equals("explanations") || segments[1].equals("api-description") || segments[1].equals("invocations") || this.endpointMap.containsKey(segments[1])) || segments.length == 4 && segments[1].equals("models") || segments.length == 3 && segments[2].equals("predict") || segments.length == 4 && segments[3].equals("predict");
    }

    private boolean isKFV1InferenceReq(String[] segments) {
        return segments.length == 4 && "v1".equals(segments[1]) && "models".equals(segments[2]) && (segments[3].contains(":predict") || segments[3].contains(":explain"));
    }

    private boolean isKFV2InferenceReq(String[] segments) {
        return segments.length == 5 && "v2".equals(segments[1]) && "models".equals(segments[2]) && (segments[4].equals("infer") || segments[4].equals("explain"));
    }

    private void validatePredictionsEndpoint(String[] segments) {
        if (segments.length == 2 && "invocations".equals(segments[1])) {
            return;
        }
        if (segments.length == 4 && "models".equals(segments[1]) && "invoke".equals(segments[3])) {
            return;
        }
        throw new ResourceNotFoundException();
    }

    private void handlePredictions(ChannelHandlerContext ctx, FullHttpRequest req, String[] segments, boolean explain) throws ModelNotFoundException, ModelVersionNotFoundException {
        if (segments.length < 3) {
            throw new ResourceNotFoundException();
        }
        String modelVersion = null;
        if (segments.length == 4) {
            modelVersion = segments[3];
        }
        req.headers().add("explain", (Object)"False");
        if (explain) {
            req.headers().add("explain", (Object)"True");
        }
        this.predict(ctx, req, null, segments[2], modelVersion);
    }

    private void handleKFV1Predictions(ChannelHandlerContext ctx, FullHttpRequest req, String[] segments, boolean explain) throws ModelNotFoundException, ModelVersionNotFoundException {
        String modelVersion = null;
        String modelName = segments[3].split(":")[0];
        req.headers().add("explain", (Object)"False");
        if (explain) {
            req.headers().add("explain", (Object)"True");
        }
        this.predict(ctx, req, null, modelName, modelVersion);
    }

    private void handleKFV2Predictions(ChannelHandlerContext ctx, FullHttpRequest req, String[] segments, boolean explain) throws ModelNotFoundException, ModelVersionNotFoundException {
        String modelVersion = null;
        String modelName = segments[3].split(":")[0];
        req.headers().add("explain", (Object)"False");
        if (explain) {
            req.headers().add("explain", (Object)"True");
        }
        this.predict(ctx, req, null, modelName, modelVersion);
    }

    private void handleInvocations(ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder, String[] segments) throws ModelNotFoundException, ModelVersionNotFoundException {
        String modelName;
        String string = modelName = "invocations".equals(segments[1]) ? NettyUtils.getParameter(decoder, "model_name", null) : segments[2];
        if ((modelName == null || modelName.isEmpty()) && ModelManager.getInstance().getStartupModels().size() == 1) {
            modelName = ModelManager.getInstance().getStartupModels().iterator().next();
        }
        this.predict(ctx, req, decoder, modelName, null);
    }

    private void handleLegacyPredict(ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder, String[] segments) throws ModelNotFoundException, ModelVersionNotFoundException {
        String modelVersion = null;
        if (segments.length == 4 && "predict".equals(segments[3])) {
            modelVersion = segments[2];
        } else if (segments.length < 3 || !"predict".equals(segments[2])) {
            throw new ResourceNotFoundException();
        }
        this.predict(ctx, req, decoder, segments[1], modelVersion);
    }

    private void predict(ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder, String modelName, String modelVersion) throws ModelNotFoundException, ModelVersionNotFoundException {
        RequestInput input = InferenceRequestHandler.parseRequest(ctx, req, decoder);
        if (modelName == null && (modelName = input.getStringParameter("model_name")) == null) {
            throw new BadRequestException("Parameter model_name is required.");
        }
        ModelManager modelManager = ModelManager.getInstance();
        Model model = modelManager.getModel(modelName, modelVersion);
        if (model == null) {
            throw new ModelNotFoundException("Model not found: " + modelName);
        }
        input.setClientExpireTS(model.getClientTimeoutInMills());
        if (HttpMethod.OPTIONS.equals(req.method())) {
            String resp = OpenApiUtils.getModelApi(model);
            NettyUtils.sendJsonResponse(ctx, resp);
            return;
        }
        IMetric inferenceRequestsTotalMetric = MetricCache.getInstance().getMetricFrontend("ts_inference_requests_total");
        if (inferenceRequestsTotalMetric != null) {
            List<String> inferenceRequestsTotalMetricDimensionValues = Arrays.asList(modelName, modelVersion == null ? "default" : modelVersion, ConfigManager.getInstance().getHostName());
            try {
                inferenceRequestsTotalMetric.addOrUpdate(inferenceRequestsTotalMetricDimensionValues, 1.0);
            }
            catch (Exception e) {
                logger.error("Failed to update frontend metric ts_inference_requests_total: ", e);
            }
        }
        ApiUtils.addRESTInferenceJob(ctx, modelName, modelVersion, input);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static RequestInput parseRequest(ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder) {
        String requestId = NettyUtils.getRequestId(ctx.channel());
        RequestInput inputData = new RequestInput(requestId);
        if (decoder != null) {
            for (Map.Entry<String, List<String>> entry : decoder.parameters().entrySet()) {
                String string = entry.getKey();
                for (String value : entry.getValue()) {
                    inputData.addParameter(new InputParameter(string, value));
                }
            }
        }
        CharSequence contentType = HttpUtil.getMimeType(req);
        for (Map.Entry<String, String> entry : req.headers().entries()) {
            inputData.updateHeaders(entry.getKey(), entry.getValue());
        }
        if (HttpPostRequestDecoder.isMultipart(req) || HttpHeaderValues.APPLICATION_X_WWW_FORM_URLENCODED.contentEqualsIgnoreCase(contentType)) {
            DefaultHttpDataFactory defaultHttpDataFactory = new DefaultHttpDataFactory(ConfigManager.getInstance().getMaxRequestSize());
            HttpPostRequestDecoder httpPostRequestDecoder = new HttpPostRequestDecoder(defaultHttpDataFactory, req);
            try {
                while (httpPostRequestDecoder.hasNext()) {
                    inputData.addParameter(NettyUtils.getFormData(httpPostRequestDecoder.next()));
                }
            }
            catch (HttpPostRequestDecoder.EndOfDataDecoderException ignore) {
                logger.trace("End of multipart items.");
            }
            finally {
                httpPostRequestDecoder.cleanFiles();
                httpPostRequestDecoder.destroy();
            }
        } else {
            byte[] byArray = NettyUtils.getBytes(req.content());
            inputData.addParameter(new InputParameter("body", byArray, contentType));
        }
        return inputData;
    }
}

