/*
 * Decompiled with CFR 0.152.
 */
package org.pytorch.serve.job;

import com.google.gson.Gson;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.rpc.ErrorInfo;
import com.google.rpc.Status;
import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.StreamObserver;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import org.pytorch.serve.archive.model.ModelNotFoundException;
import org.pytorch.serve.archive.model.ModelVersionNotFoundException;
import org.pytorch.serve.grpc.inference.PredictionResponse;
import org.pytorch.serve.grpc.management.ManagementResponse;
import org.pytorch.serve.grpc.openinference.OpenInferenceGrpc;
import org.pytorch.serve.grpcimpl.ManagementImpl;
import org.pytorch.serve.http.messages.DescribeModelResponse;
import org.pytorch.serve.job.Job;
import org.pytorch.serve.metrics.IMetric;
import org.pytorch.serve.metrics.MetricCache;
import org.pytorch.serve.util.ApiUtils;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.GRPCUtils;
import org.pytorch.serve.util.JsonUtils;
import org.pytorch.serve.util.messages.RequestInput;
import org.pytorch.serve.util.messages.WorkerCommands;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GRPCJob
extends Job {
    private static final Logger logger = LoggerFactory.getLogger(GRPCJob.class);
    private final IMetric queueTimeMetric;
    private final List<String> queueTimeMetricDimensionValues;
    private StreamObserver<PredictionResponse> predictionResponseObserver;
    private StreamObserver<ManagementResponse> managementResponseObserver;
    private StreamObserver<OpenInferenceGrpc.ModelInferResponse> modelInferResponseObserver;

    public GRPCJob(StreamObserver<PredictionResponse> predictionResponseObserver, String modelName, String version, WorkerCommands cmd, RequestInput input) {
        super(modelName, version, cmd, input);
        this.predictionResponseObserver = predictionResponseObserver;
        this.queueTimeMetric = MetricCache.getInstance().getMetricFrontend("QueueTime");
        this.queueTimeMetricDimensionValues = Arrays.asList("Host", ConfigManager.getInstance().getHostName());
    }

    public GRPCJob(StreamObserver<OpenInferenceGrpc.ModelInferResponse> modelInferResponseObserver, String modelName, String version, RequestInput input, WorkerCommands cmd) {
        super(modelName, version, cmd, input);
        this.modelInferResponseObserver = modelInferResponseObserver;
        this.queueTimeMetric = MetricCache.getInstance().getMetricFrontend("QueueTime");
        this.queueTimeMetricDimensionValues = Arrays.asList("Host", ConfigManager.getInstance().getHostName());
    }

    public GRPCJob(StreamObserver<ManagementResponse> managementResponseObserver, String modelName, String version, RequestInput input) {
        super(modelName, version, WorkerCommands.DESCRIBE, input);
        this.managementResponseObserver = managementResponseObserver;
        this.queueTimeMetric = MetricCache.getInstance().getMetricFrontend("QueueTime");
        this.queueTimeMetricDimensionValues = Arrays.asList("Host", ConfigManager.getInstance().getHostName());
    }

    private void cancelHandler(ServerCallStreamObserver<PredictionResponse> responseObserver) {
        if (responseObserver.isCancelled()) {
            logger.warn("grpc client call already cancelled, not able to send this response for requestId: {}", (Object)this.getPayload().getRequestId());
        }
    }

    private void logQueueTime() {
        logger.debug("Waiting time ns: {}, Backend time ns: {}", (Object)(this.getScheduled() - this.getBegin()), (Object)(System.nanoTime() - this.getScheduled()));
        double queueTime = TimeUnit.MILLISECONDS.convert(this.getScheduled() - this.getBegin(), TimeUnit.NANOSECONDS);
        if (this.queueTimeMetric != null) {
            try {
                this.queueTimeMetric.addOrUpdate(this.queueTimeMetricDimensionValues, queueTime);
            }
            catch (Exception e) {
                logger.error("Failed to update frontend metric QueueTime: ", e);
            }
        }
    }

    @Override
    public void response(byte[] body, CharSequence contentType, int statusCode, String statusPhrase, Map<String, String> responseHeaders) {
        ByteString output = ByteString.copyFrom(body);
        WorkerCommands cmd = this.getCmd();
        switch (cmd) {
            case PREDICT: 
            case STREAMPREDICT: 
            case STREAMPREDICT2: {
                ServerCallStreamObserver responseObserver = (ServerCallStreamObserver)this.predictionResponseObserver;
                this.cancelHandler(responseObserver);
                PredictionResponse reply = PredictionResponse.newBuilder().setPrediction(output).build();
                responseObserver.onNext(reply);
                if (cmd == WorkerCommands.PREDICT || cmd == WorkerCommands.STREAMPREDICT && responseHeaders.get("ts_stream_next").equals("false")) {
                    responseObserver.onCompleted();
                    this.logQueueTime();
                    break;
                }
                if (cmd != WorkerCommands.STREAMPREDICT2 || responseHeaders.get("ts_stream_next") != null && !responseHeaders.get("ts_stream_next").equals("false")) break;
                this.logQueueTime();
                break;
            }
            case DESCRIBE: {
                try {
                    ArrayList<DescribeModelResponse> respList = ApiUtils.getModelDescription(this.getModelName(), this.getModelVersion());
                    if (!output.isEmpty() && respList != null && respList.size() == 1) {
                        respList.get(0).setCustomizedMetadata(body);
                    }
                    String resp = JsonUtils.GSON_PRETTY.toJson(respList);
                    ManagementResponse mgmtReply = ManagementResponse.newBuilder().setMsg(resp).build();
                    this.managementResponseObserver.onNext(mgmtReply);
                    this.managementResponseObserver.onCompleted();
                }
                catch (ModelNotFoundException | ModelVersionNotFoundException e) {
                    ManagementImpl.sendErrorResponse(this.managementResponseObserver, io.grpc.Status.NOT_FOUND, e);
                }
                break;
            }
            case OIPPREDICT: {
                Gson gson = new Gson();
                String jsonResponse = output.toStringUtf8();
                JsonObject jsonObject = gson.fromJson(jsonResponse, JsonObject.class);
                if (((ServerCallStreamObserver)this.modelInferResponseObserver).isCancelled()) {
                    logger.warn("grpc client call already cancelled, not able to send this response for requestId: {}", (Object)this.getPayload().getRequestId());
                    return;
                }
                OpenInferenceGrpc.ModelInferResponse.Builder responseBuilder = OpenInferenceGrpc.ModelInferResponse.newBuilder();
                responseBuilder.setId(jsonObject.get("id").getAsString());
                responseBuilder.setModelName(jsonObject.get("model_name").getAsString());
                responseBuilder.setModelVersion(jsonObject.get("model_version").getAsString());
                JsonArray jsonOutputs = jsonObject.get("outputs").getAsJsonArray();
                for (JsonElement element : jsonOutputs) {
                    OpenInferenceGrpc.ModelInferResponse.InferOutputTensor.Builder outputBuilder = OpenInferenceGrpc.ModelInferResponse.InferOutputTensor.newBuilder();
                    outputBuilder.setName(element.getAsJsonObject().get("name").getAsString());
                    outputBuilder.setDatatype(element.getAsJsonObject().get("datatype").getAsString());
                    JsonArray shapeArray = element.getAsJsonObject().get("shape").getAsJsonArray();
                    shapeArray.forEach(shapeElement -> outputBuilder.addShape(shapeElement.getAsLong()));
                    this.setOutputContents(element, outputBuilder);
                    responseBuilder.addOutputs(outputBuilder);
                }
                this.modelInferResponseObserver.onNext(responseBuilder.build());
                this.modelInferResponseObserver.onCompleted();
                break;
            }
        }
    }

    @Override
    public void sendError(int status, String error) {
        io.grpc.Status responseStatus = GRPCUtils.getGRPCStatusCode(status);
        WorkerCommands cmd = this.getCmd();
        switch (cmd) {
            case PREDICT: 
            case STREAMPREDICT: 
            case STREAMPREDICT2: {
                ServerCallStreamObserver responseObserver = (ServerCallStreamObserver)this.predictionResponseObserver;
                this.cancelHandler(responseObserver);
                if (cmd == WorkerCommands.PREDICT || cmd == WorkerCommands.STREAMPREDICT) {
                    responseObserver.onError(responseStatus.withDescription(error).augmentDescription("org.pytorch.serve.http.InternalServerException").asRuntimeException());
                    break;
                }
                if (cmd != WorkerCommands.STREAMPREDICT2) break;
                Status rpcStatus = Status.newBuilder().setCode(responseStatus.getCode().value()).setMessage(error).addDetails(Any.pack(ErrorInfo.newBuilder().setReason("org.pytorch.serve.http.InternalServerException").build())).build();
                responseObserver.onNext(PredictionResponse.newBuilder().setPrediction(null).setStatus(rpcStatus).build());
                break;
            }
            case DESCRIBE: {
                this.managementResponseObserver.onError(responseStatus.withDescription(error).augmentDescription("org.pytorch.serve.http.InternalServerException").asRuntimeException());
                break;
            }
            case OIPPREDICT: {
                this.modelInferResponseObserver.onError(responseStatus.withDescription(error).augmentDescription("org.pytorch.serve.http.InternalServerException").asRuntimeException());
                break;
            }
        }
    }

    @Override
    public boolean isOpen() {
        return ((ServerCallStreamObserver)this.predictionResponseObserver).isCancelled();
    }

    private void setOutputContents(JsonElement element, OpenInferenceGrpc.ModelInferResponse.InferOutputTensor.Builder outputBuilder) {
        String dataType = element.getAsJsonObject().get("datatype").getAsString();
        JsonArray jsonData = element.getAsJsonObject().get("data").getAsJsonArray();
        OpenInferenceGrpc.InferTensorContents.Builder inferTensorContents = OpenInferenceGrpc.InferTensorContents.newBuilder();
        switch (dataType) {
            case "INT8": 
            case "INT16": 
            case "INT32": {
                ArrayList int32Contents = new ArrayList();
                jsonData.forEach(data -> int32Contents.add(data.getAsInt()));
                inferTensorContents.addAllIntContents(int32Contents);
                break;
            }
            case "INT64": {
                ArrayList int64Contents = new ArrayList();
                jsonData.forEach(data -> int64Contents.add(data.getAsLong()));
                inferTensorContents.addAllInt64Contents(int64Contents);
                break;
            }
            case "BYTES": {
                ArrayList byteContents = new ArrayList();
                jsonData.forEach(data -> byteContents.add(ByteString.copyFromUtf8(data.toString())));
                inferTensorContents.addAllBytesContents(byteContents);
                break;
            }
            case "BOOL": {
                ArrayList boolContents = new ArrayList();
                jsonData.forEach(data -> boolContents.add(data.getAsBoolean()));
                inferTensorContents.addAllBoolContents(boolContents);
                break;
            }
            case "FP32": {
                ArrayList fp32Contents = new ArrayList();
                jsonData.forEach(data -> fp32Contents.add(Float.valueOf(data.getAsFloat())));
                inferTensorContents.addAllFp32Contents(fp32Contents);
                break;
            }
            case "FP64": {
                ArrayList fp64Contents = new ArrayList();
                jsonData.forEach(data -> fp64Contents.add(data.getAsDouble()));
                inferTensorContents.addAllFp64Contents(fp64Contents);
                break;
            }
            case "UINT8": 
            case "UINT16": 
            case "UINT32": {
                ArrayList uint32Contents = new ArrayList();
                jsonData.forEach(data -> uint32Contents.add(data.getAsInt()));
                inferTensorContents.addAllUintContents(uint32Contents);
                break;
            }
            case "UINT64": {
                ArrayList uint64Contents = new ArrayList();
                jsonData.forEach(data -> uint64Contents.add(data.getAsLong()));
                inferTensorContents.addAllUint64Contents(uint64Contents);
                break;
            }
        }
        outputBuilder.setContents(inferTensorContents);
    }
}

