/*
 * Decompiled with CFR 0.152.
 */
package com.amazonaws.ml.mms.http;

import com.amazonaws.ml.mms.archive.ModelNotFoundException;
import com.amazonaws.ml.mms.http.BadRequestException;
import com.amazonaws.ml.mms.http.HttpRequestHandler;
import com.amazonaws.ml.mms.http.ResourceNotFoundException;
import com.amazonaws.ml.mms.http.ServiceUnavailableException;
import com.amazonaws.ml.mms.openapi.OpenApiUtils;
import com.amazonaws.ml.mms.util.NettyUtils;
import com.amazonaws.ml.mms.util.messages.InputParameter;
import com.amazonaws.ml.mms.util.messages.RequestInput;
import com.amazonaws.ml.mms.util.messages.WorkerCommands;
import com.amazonaws.ml.mms.wlm.Job;
import com.amazonaws.ml.mms.wlm.Model;
import com.amazonaws.ml.mms.wlm.ModelManager;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.QueryStringDecoder;
import io.netty.handler.codec.http.multipart.DefaultHttpDataFactory;
import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class InferenceRequestHandler
extends HttpRequestHandler {
    private static final Logger logger = LoggerFactory.getLogger(InferenceRequestHandler.class);

    @Override
    protected void handleRequest(ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder, String[] segments) throws ModelNotFoundException {
        switch (segments[1]) {
            case "ping": {
                ModelManager.getInstance().workerStatus(ctx);
                break;
            }
            case "api-description": {
                this.handleApiDescription(ctx);
                break;
            }
            case "invocations": {
                this.handleInvocations(ctx, req, decoder);
                break;
            }
            case "predictions": {
                this.handlePredictions(ctx, req, segments);
                break;
            }
            default: {
                this.handleLegacyPredict(ctx, req, decoder, segments);
            }
        }
    }

    @Override
    protected void handleApiDescription(ChannelHandlerContext ctx) {
        NettyUtils.sendJsonResponse(ctx, OpenApiUtils.listInferenceApis());
    }

    private void handlePredictions(ChannelHandlerContext ctx, FullHttpRequest req, String[] segments) throws ModelNotFoundException {
        if (segments.length < 3) {
            throw new ResourceNotFoundException();
        }
        this.predict(ctx, req, null, segments[2]);
    }

    private void handleInvocations(ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder) throws ModelNotFoundException {
        String modelName = NettyUtils.getParameter(decoder, "model_name", null);
        this.predict(ctx, req, decoder, modelName);
    }

    private void handleLegacyPredict(ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder, String[] segments) throws ModelNotFoundException {
        if (segments.length < 3 || !"predict".equals(segments[2])) {
            throw new ResourceNotFoundException();
        }
        this.predict(ctx, req, decoder, segments[1]);
    }

    private void predict(ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder, String modelName) throws ModelNotFoundException {
        RequestInput input = InferenceRequestHandler.parseRequest(ctx, req, decoder);
        if (modelName == null && (modelName = input.getStringParameter("model_name")) == null) {
            throw new BadRequestException("Parameter model_name is required.");
        }
        if (HttpMethod.OPTIONS.equals(req.method())) {
            ModelManager modelManager = ModelManager.getInstance();
            Model model = modelManager.getModels().get(modelName);
            if (model == null) {
                throw new ModelNotFoundException("Model not found: " + modelName);
            }
            String resp = OpenApiUtils.getModelApi(model);
            NettyUtils.sendJsonResponse(ctx, resp);
            return;
        }
        Job job = new Job(ctx, modelName, WorkerCommands.PREDICT, input);
        if (!ModelManager.getInstance().addJob(job)) {
            throw new ServiceUnavailableException("No worker is available to serve request: " + modelName);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static RequestInput parseRequest(ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder) {
        String requestId = NettyUtils.getRequestId(ctx.channel());
        RequestInput inputData = new RequestInput(requestId);
        if (decoder != null) {
            for (Map.Entry<String, List<String>> entry : decoder.parameters().entrySet()) {
                String key = entry.getKey();
                for (String value : entry.getValue()) {
                    inputData.addParameter(new InputParameter(key, value));
                }
            }
        }
        CharSequence contentType = HttpUtil.getMimeType(req);
        if (HttpPostRequestDecoder.isMultipart(req) || HttpHeaderValues.APPLICATION_X_WWW_FORM_URLENCODED.contentEqualsIgnoreCase(contentType)) {
            DefaultHttpDataFactory factory = new DefaultHttpDataFactory(6553500L);
            HttpPostRequestDecoder form = new HttpPostRequestDecoder(factory, req);
            try {
                while (form.hasNext()) {
                    inputData.addParameter(NettyUtils.getFormData(form.next()));
                }
            }
            catch (HttpPostRequestDecoder.EndOfDataDecoderException ignore) {
                logger.trace("End of multipart items.");
            }
            finally {
                form.cleanFiles();
                form.destroy();
            }
        } else {
            byte[] content = NettyUtils.getBytes(req.content());
            inputData.addParameter(new InputParameter("body", content, contentType));
        }
        return inputData;
    }
}

