/*
 * Decompiled with CFR 0.152.
 */
package org.pytorch.serve.grpcimpl;

import com.google.gson.Gson;
import com.google.protobuf.ByteString;
import io.grpc.Status;
import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.StreamObserver;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import org.pytorch.serve.archive.model.ModelNotFoundException;
import org.pytorch.serve.archive.model.ModelVersionNotFoundException;
import org.pytorch.serve.grpc.openinference.GRPCInferenceServiceGrpc;
import org.pytorch.serve.grpc.openinference.OpenInferenceGrpc;
import org.pytorch.serve.http.BadRequestException;
import org.pytorch.serve.http.InternalServerException;
import org.pytorch.serve.job.GRPCJob;
import org.pytorch.serve.util.ApiUtils;
import org.pytorch.serve.util.messages.InputParameter;
import org.pytorch.serve.util.messages.RequestInput;
import org.pytorch.serve.util.messages.WorkerCommands;
import org.pytorch.serve.wlm.Model;
import org.pytorch.serve.wlm.ModelManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OpenInferenceProtocolImpl
extends GRPCInferenceServiceGrpc.GRPCInferenceServiceImplBase {
    private static final Logger logger = LoggerFactory.getLogger(OpenInferenceProtocolImpl.class);

    @Override
    public void serverLive(OpenInferenceGrpc.ServerLiveRequest request, StreamObserver<OpenInferenceGrpc.ServerLiveResponse> responseObserver) {
        ((ServerCallStreamObserver)responseObserver).setOnCancelHandler(() -> {
            logger.warn("grpc client call already cancelled");
            responseObserver.onError(Status.CANCELLED.withDescription("call already cancelled").asRuntimeException());
        });
        OpenInferenceGrpc.ServerLiveResponse readyResponse = OpenInferenceGrpc.ServerLiveResponse.newBuilder().setLive(true).build();
        responseObserver.onNext(readyResponse);
        responseObserver.onCompleted();
    }

    @Override
    public void serverReady(OpenInferenceGrpc.ServerReadyRequest request, StreamObserver<OpenInferenceGrpc.ServerReadyResponse> responseObserver) {
        ((ServerCallStreamObserver)responseObserver).setOnCancelHandler(() -> {
            logger.warn("grpc client call already cancelled");
            responseObserver.onError(Status.CANCELLED.withDescription("call already cancelled").asRuntimeException());
        });
        OpenInferenceGrpc.ServerReadyResponse readyResponse = OpenInferenceGrpc.ServerReadyResponse.newBuilder().setReady(true).build();
        responseObserver.onNext(readyResponse);
        responseObserver.onCompleted();
    }

    private void sendErrorResponse(StreamObserver<?> responseObserver, Status internal, Exception e, String string) {
        responseObserver.onError(internal.withDescription(e.getMessage()).augmentDescription(string == null ? e.getClass().getCanonicalName() : string).withCause(e).asRuntimeException());
    }

    @Override
    public void modelReady(OpenInferenceGrpc.ModelReadyRequest request, StreamObserver<OpenInferenceGrpc.ModelReadyResponse> responseObserver) {
        ((ServerCallStreamObserver)responseObserver).setOnCancelHandler(() -> {
            logger.warn("grpc client call already cancelled");
            responseObserver.onError(Status.CANCELLED.withDescription("call already cancelled").asRuntimeException());
        });
        String modelName = request.getName();
        String modelVersion = request.getVersion();
        ModelManager modelManager = ModelManager.getInstance();
        boolean isModelReady = false;
        if (modelName == null || "".equals(modelName)) {
            BadRequestException e = new BadRequestException("Parameter name is required.");
            this.sendErrorResponse(responseObserver, Status.INTERNAL, e, "BadRequestException.()");
            return;
        }
        if (modelVersion == null || "".equals(modelVersion)) {
            modelVersion = null;
        }
        try {
            Model model = modelManager.getModel(modelName, modelVersion);
            if (model == null) {
                throw new ModelNotFoundException("Model not found: " + modelName);
            }
            int numScaled = model.getMinWorkers();
            int numHealthy = modelManager.getNumHealthyWorkers(model.getModelVersionName());
            isModelReady = numHealthy >= numScaled;
            OpenInferenceGrpc.ModelReadyResponse modelReadyResponse = OpenInferenceGrpc.ModelReadyResponse.newBuilder().setReady(isModelReady).build();
            responseObserver.onNext(modelReadyResponse);
            responseObserver.onCompleted();
        }
        catch (ModelNotFoundException | ModelVersionNotFoundException e) {
            this.sendErrorResponse(responseObserver, Status.NOT_FOUND, e, null);
        }
    }

    @Override
    public void modelMetadata(OpenInferenceGrpc.ModelMetadataRequest request, StreamObserver<OpenInferenceGrpc.ModelMetadataResponse> responseObserver) {
        ((ServerCallStreamObserver)responseObserver).setOnCancelHandler(() -> {
            logger.warn("grpc client call already cancelled");
            responseObserver.onError(Status.CANCELLED.withDescription("call already cancelled").asRuntimeException());
        });
        String modelName = request.getName();
        String modelVersion = request.getVersion();
        ModelManager modelManager = ModelManager.getInstance();
        OpenInferenceGrpc.ModelMetadataResponse.Builder response = OpenInferenceGrpc.ModelMetadataResponse.newBuilder();
        ArrayList inputs = new ArrayList();
        ArrayList outputs = new ArrayList();
        ArrayList<String> versions = new ArrayList<String>();
        if (modelName == null || "".equals(modelName)) {
            BadRequestException e = new BadRequestException("Parameter model_name is required.");
            this.sendErrorResponse(responseObserver, Status.INTERNAL, e, "BadRequestException.()");
            return;
        }
        if (modelVersion == null || "".equals(modelVersion)) {
            modelVersion = null;
        }
        try {
            Model model = modelManager.getModel(modelName, modelVersion);
            if (model == null) {
                throw new ModelNotFoundException("Model not found: " + modelName);
            }
            modelManager.getAllModelVersions(modelName).forEach(entry -> versions.add((String)entry.getKey()));
            response.setName(modelName);
            response.addAllVersions(versions);
            response.setPlatform("");
            response.addAllInputs(inputs);
            response.addAllOutputs(outputs);
            responseObserver.onNext(response.build());
            responseObserver.onCompleted();
        }
        catch (ModelNotFoundException | ModelVersionNotFoundException e) {
            this.sendErrorResponse(responseObserver, Status.NOT_FOUND, e, null);
        }
    }

    @Override
    public void modelInfer(OpenInferenceGrpc.ModelInferRequest request, StreamObserver<OpenInferenceGrpc.ModelInferResponse> responseObserver) {
        ((ServerCallStreamObserver)responseObserver).setOnCancelHandler(() -> {
            logger.warn("grpc client call already cancelled");
            responseObserver.onError(Status.CANCELLED.withDescription("call already cancelled").asRuntimeException());
        });
        String modelName = request.getModelName();
        String modelVersion = request.getModelVersion();
        String contentsType = "application/json";
        Gson gson = new Gson();
        HashMap<String, Object> modelInferMap = new HashMap<String, Object>();
        ArrayList<HashMap<String, Object>> inferInputs = new ArrayList<HashMap<String, Object>>();
        String requestId = UUID.randomUUID().toString();
        RequestInput inputData = new RequestInput(requestId);
        modelInferMap.put("id", request.getId());
        modelInferMap.put("model_name", request.getModelName());
        for (OpenInferenceGrpc.ModelInferRequest.InferInputTensor entry : request.getInputsList()) {
            HashMap<String, Object> inferInputMap = new HashMap<String, Object>();
            inferInputMap.put("name", entry.getName());
            inferInputMap.put("shape", entry.getShapeList());
            inferInputMap.put("datatype", entry.getDatatype());
            OpenInferenceProtocolImpl.setInputContents(entry, inferInputMap);
            inferInputs.add(inferInputMap);
        }
        modelInferMap.put("inputs", inferInputs);
        String jsonString = gson.toJson(modelInferMap);
        byte[] byteArray = jsonString.getBytes(StandardCharsets.UTF_8);
        if (modelName == null || "".equals(modelName)) {
            BadRequestException e = new BadRequestException("Parameter model_name is required.");
            this.sendErrorResponse(responseObserver, Status.INTERNAL, e, "BadRequestException.()");
            return;
        }
        if (modelVersion == null || "".equals(modelVersion)) {
            modelVersion = null;
        }
        try {
            ModelManager modelManager = ModelManager.getInstance();
            inputData.addParameter(new InputParameter("body", byteArray, contentsType));
            GRPCJob job = new GRPCJob(responseObserver, modelName, modelVersion, inputData, WorkerCommands.OIPPREDICT);
            if (!modelManager.addJob(job)) {
                String responseMessage = ApiUtils.getStreamingInferenceErrorResponseMessage(modelName, modelVersion);
                InternalServerException e = new InternalServerException(responseMessage);
                this.sendErrorResponse(responseObserver, Status.INTERNAL, e, "InternalServerException.()");
            }
        }
        catch (ModelNotFoundException | ModelVersionNotFoundException e) {
            this.sendErrorResponse(responseObserver, Status.INTERNAL, e, null);
        }
    }

    private static void setInputContents(OpenInferenceGrpc.ModelInferRequest.InferInputTensor inferInputTensor, Map<String, Object> inferInputMap) {
        switch (inferInputTensor.getDatatype()) {
            case "BYTES": {
                List<ByteString> byteStrings = inferInputTensor.getContents().getBytesContentsList();
                ArrayList<String> base64Strings = new ArrayList<String>();
                for (ByteString byteString : byteStrings) {
                    String base64String = Base64.getEncoder().encodeToString(byteString.toByteArray());
                    base64Strings.add(base64String);
                }
                inferInputMap.put("data", base64Strings);
                break;
            }
            case "FP32": {
                List<Float> fp32Contents = inferInputTensor.getContents().getFp32ContentsList();
                inferInputMap.put("data", fp32Contents);
                break;
            }
            case "FP64": {
                List<Double> fp64ContentList = inferInputTensor.getContents().getFp64ContentsList();
                inferInputMap.put("data", fp64ContentList);
                break;
            }
            case "INT8": 
            case "INT16": 
            case "INT32": {
                List<Integer> int32Contents = inferInputTensor.getContents().getIntContentsList();
                inferInputMap.put("data", int32Contents);
                break;
            }
            case "INT64": {
                List<Long> int64Contents = inferInputTensor.getContents().getInt64ContentsList();
                inferInputMap.put("data", int64Contents);
                break;
            }
            case "UINT8": 
            case "UINT16": 
            case "UINT32": {
                List<Integer> uint32Contents = inferInputTensor.getContents().getUintContentsList();
                inferInputMap.put("data", uint32Contents);
                break;
            }
            case "UINT64": {
                List<Long> uint64Contents = inferInputTensor.getContents().getUint64ContentsList();
                inferInputMap.put("data", uint64Contents);
                break;
            }
            case "BOOL": {
                List<Boolean> boolContents = inferInputTensor.getContents().getBoolContentsList();
                inferInputMap.put("data", boolContents);
                break;
            }
        }
    }
}

