/*
 * Decompiled with CFR 0.152.
 */
package io.ray.runtime.task;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import io.ray.api.BaseActor;
import io.ray.api.id.ActorId;
import io.ray.api.id.ObjectId;
import io.ray.api.id.TaskId;
import io.ray.api.id.UniqueId;
import io.ray.api.options.ActorCreationOptions;
import io.ray.api.options.CallOptions;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.actor.LocalModeRayActor;
import io.ray.runtime.context.LocalModeWorkerContext;
import io.ray.runtime.functionmanager.FunctionDescriptor;
import io.ray.runtime.functionmanager.JavaFunctionDescriptor;
import io.ray.runtime.generated.Common;
import io.ray.runtime.object.LocalModeObjectStore;
import io.ray.runtime.object.NativeRayObject;
import io.ray.runtime.task.FunctionArg;
import io.ray.runtime.task.LocalModeTaskExecutor;
import io.ray.runtime.task.TaskExecutor;
import io.ray.runtime.task.TaskSubmitter;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.RejectedExecutionException;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LocalModeTaskSubmitter
implements TaskSubmitter {
    private static final Logger LOGGER = LoggerFactory.getLogger(LocalModeTaskSubmitter.class);
    private final Map<ObjectId, Set<Common.TaskSpec>> waitingTasks = new HashMap<ObjectId, Set<Common.TaskSpec>>();
    private final Object taskAndObjectLock = new Object();
    private final RayRuntimeInternal runtime;
    private final TaskExecutor taskExecutor;
    private final LocalModeObjectStore objectStore;
    private final Map<ActorId, ExecutorService> actorTaskExecutorServices;
    private final ExecutorService normalTaskExecutorService;
    private final Map<ActorId, TaskExecutor.ActorContext> actorContexts = new ConcurrentHashMap<ActorId, TaskExecutor.ActorContext>();

    public LocalModeTaskSubmitter(RayRuntimeInternal runtime, TaskExecutor taskExecutor, LocalModeObjectStore objectStore) {
        this.runtime = runtime;
        this.taskExecutor = taskExecutor;
        this.objectStore = objectStore;
        this.normalTaskExecutorService = Executors.newCachedThreadPool();
        this.actorTaskExecutorServices = new HashMap<ActorId, ExecutorService>();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void onObjectPut(ObjectId id) {
        Object object = this.taskAndObjectLock;
        synchronized (object) {
            Set<Common.TaskSpec> tasks = this.waitingTasks.remove(id);
            if (tasks != null) {
                for (Common.TaskSpec task : tasks) {
                    Set<ObjectId> unreadyObjects = this.getUnreadyObjects(task);
                    if (!unreadyObjects.isEmpty()) continue;
                    this.submitTaskSpec(task);
                }
            }
        }
    }

    private Set<ObjectId> getUnreadyObjects(Common.TaskSpec taskSpec) {
        ObjectId dummyObjectId;
        HashSet<ObjectId> unreadyObjects = new HashSet<ObjectId>();
        for (Common.TaskArg arg : taskSpec.getArgsList()) {
            for (ByteString idByteString : arg.getObjectIdsList()) {
                ObjectId id = new ObjectId(idByteString.toByteArray());
                if (this.objectStore.isObjectReady(id)) continue;
                unreadyObjects.add(id);
            }
        }
        if (taskSpec.getType() == Common.TaskType.ACTOR_TASK && !this.objectStore.isObjectReady(dummyObjectId = new ObjectId(taskSpec.getActorTaskSpec().getPreviousActorTaskDummyObjectId().toByteArray()))) {
            unreadyObjects.add(dummyObjectId);
        }
        return unreadyObjects;
    }

    private Common.TaskSpec.Builder getTaskSpecBuilder(Common.TaskType taskType, FunctionDescriptor functionDescriptor, List<FunctionArg> args) {
        byte[] taskIdBytes = new byte[14];
        new Random().nextBytes(taskIdBytes);
        List<String> functionDescriptorList = functionDescriptor.toList();
        Preconditions.checkState(functionDescriptorList.size() >= 3);
        return Common.TaskSpec.newBuilder().setType(taskType).setLanguage(Common.Language.JAVA).setJobId(ByteString.copyFrom(this.runtime.getRayConfig().getJobId().getBytes())).setTaskId(ByteString.copyFrom(taskIdBytes)).setFunctionDescriptor(Common.FunctionDescriptor.newBuilder().setJavaFunctionDescriptor(Common.JavaFunctionDescriptor.newBuilder().setClassName(functionDescriptorList.get(0)).setFunctionName(functionDescriptorList.get(1)).setSignature(functionDescriptorList.get(2)))).addAllArgs(args.stream().map(arg -> arg.id != null ? Common.TaskArg.newBuilder().addObjectIds(ByteString.copyFrom(arg.id.getBytes())).build() : Common.TaskArg.newBuilder().setData(ByteString.copyFrom(arg.value.data)).setMetadata(arg.value.metadata != null ? ByteString.copyFrom(arg.value.metadata) : ByteString.EMPTY).build()).collect(Collectors.toList()));
    }

    @Override
    public List<ObjectId> submitTask(FunctionDescriptor functionDescriptor, List<FunctionArg> args, int numReturns, CallOptions options) {
        Preconditions.checkState(numReturns <= 1);
        Common.TaskSpec taskSpec = this.getTaskSpecBuilder(Common.TaskType.NORMAL_TASK, functionDescriptor, args).setNumReturns(numReturns).build();
        this.submitTaskSpec(taskSpec);
        return LocalModeTaskSubmitter.getReturnIds(taskSpec);
    }

    @Override
    public BaseActor createActor(FunctionDescriptor functionDescriptor, List<FunctionArg> args, ActorCreationOptions options) {
        ActorId actorId = ActorId.fromRandom();
        Common.TaskSpec taskSpec = this.getTaskSpecBuilder(Common.TaskType.ACTOR_CREATION_TASK, functionDescriptor, args).setNumReturns(1L).setActorCreationTaskSpec(Common.ActorCreationTaskSpec.newBuilder().setActorId(ByteString.copyFrom(actorId.toByteBuffer())).build()).build();
        this.submitTaskSpec(taskSpec);
        return new LocalModeRayActor(actorId, LocalModeTaskSubmitter.getReturnIds(taskSpec).get(0));
    }

    @Override
    public List<ObjectId> submitActorTask(BaseActor actor, FunctionDescriptor functionDescriptor, List<FunctionArg> args, int numReturns, CallOptions options) {
        Preconditions.checkState(numReturns <= 1);
        Common.TaskSpec.Builder builder = this.getTaskSpecBuilder(Common.TaskType.ACTOR_TASK, functionDescriptor, args);
        List<ObjectId> returnIds = LocalModeTaskSubmitter.getReturnIds(TaskId.fromBytes(builder.getTaskId().toByteArray()), numReturns + 1);
        Common.TaskSpec taskSpec = builder.setNumReturns(numReturns + 1).setActorTaskSpec(Common.ActorTaskSpec.newBuilder().setActorId(ByteString.copyFrom(actor.getId().getBytes())).setPreviousActorTaskDummyObjectId(ByteString.copyFrom(((LocalModeRayActor)actor).exchangePreviousActorTaskDummyObjectId(returnIds.get(returnIds.size() - 1)).getBytes())).build()).build();
        this.submitTaskSpec(taskSpec);
        if (numReturns == 0) {
            return ImmutableList.of();
        }
        return ImmutableList.of(returnIds.get(0));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void shutdown() {
        Map<ActorId, ExecutorService> map = this.actorTaskExecutorServices;
        synchronized (map) {
            for (Map.Entry<ActorId, ExecutorService> item : this.actorTaskExecutorServices.entrySet()) {
                item.getValue().shutdown();
            }
        }
        this.normalTaskExecutorService.shutdown();
    }

    public static ActorId getActorId(Common.TaskSpec taskSpec) {
        ByteString actorId = null;
        if (taskSpec.getType() == Common.TaskType.ACTOR_CREATION_TASK) {
            actorId = taskSpec.getActorCreationTaskSpec().getActorId();
        } else if (taskSpec.getType() == Common.TaskType.ACTOR_TASK) {
            actorId = taskSpec.getActorTaskSpec().getActorId();
        }
        if (actorId == null) {
            return null;
        }
        return ActorId.fromBytes(actorId.toByteArray());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void submitTaskSpec(Common.TaskSpec taskSpec) {
        LOGGER.debug("Submitting task: {}.", (Object)taskSpec);
        Object object = this.taskAndObjectLock;
        synchronized (object) {
            block18: {
                Set<ObjectId> unreadyObjects = this.getUnreadyObjects(taskSpec);
                Runnable runnable = () -> {
                    try {
                        this.executeTask(taskSpec);
                    }
                    catch (Exception ex) {
                        LOGGER.error("Unexpected exception when executing a task.", ex);
                        System.exit(-1);
                    }
                };
                if (unreadyObjects.isEmpty()) {
                    ExecutorService executorService;
                    if (taskSpec.getType() == Common.TaskType.ACTOR_CREATION_TASK) {
                        executorService = Executors.newSingleThreadExecutor();
                        Map<ActorId, ExecutorService> map = this.actorTaskExecutorServices;
                        synchronized (map) {
                            this.actorTaskExecutorServices.put(LocalModeTaskSubmitter.getActorId(taskSpec), executorService);
                        }
                    } else if (taskSpec.getType() == Common.TaskType.ACTOR_TASK) {
                        Map<ActorId, ExecutorService> map = this.actorTaskExecutorServices;
                        synchronized (map) {
                            executorService = this.actorTaskExecutorServices.get(LocalModeTaskSubmitter.getActorId(taskSpec));
                        }
                    } else {
                        executorService = this.normalTaskExecutorService;
                    }
                    try {
                        executorService.submit(runnable);
                    }
                    catch (RejectedExecutionException e) {
                        if (!executorService.isShutdown()) break block18;
                        LOGGER.warn("Ignore task submission due to the ExecutorService is shutdown. Task: {}", (Object)taskSpec);
                    }
                } else {
                    for (ObjectId id : unreadyObjects) {
                        this.waitingTasks.computeIfAbsent(id, k -> new HashSet()).add(taskSpec);
                    }
                }
            }
        }
    }

    private void executeTask(Common.TaskSpec taskSpec) {
        TaskExecutor.ActorContext actorContext = null;
        if (taskSpec.getType() == Common.TaskType.ACTOR_TASK) {
            actorContext = this.actorContexts.get(LocalModeTaskSubmitter.getActorId(taskSpec));
            Preconditions.checkNotNull(actorContext);
        }
        this.taskExecutor.setActorContext(actorContext);
        List<NativeRayObject> args = LocalModeTaskSubmitter.getFunctionArgs(taskSpec).stream().map(arg -> arg.id != null ? this.objectStore.getRaw(Collections.singletonList(arg.id), -1L).get(0) : arg.value).collect(Collectors.toList());
        this.runtime.setIsContextSet(true);
        ((LocalModeWorkerContext)this.runtime.getWorkerContext()).setCurrentTask(taskSpec);
        UniqueId workerId = actorContext != null ? ((LocalModeTaskExecutor.LocalActorContext)actorContext).getWorkerId() : UniqueId.randomId();
        ((LocalModeWorkerContext)this.runtime.getWorkerContext()).setCurrentWorkerId(workerId);
        List<NativeRayObject> returnObjects = this.taskExecutor.execute(LocalModeTaskSubmitter.getJavaFunctionDescriptor(taskSpec).toList(), args);
        if (taskSpec.getType() == Common.TaskType.ACTOR_CREATION_TASK) {
            this.actorContexts.put(LocalModeTaskSubmitter.getActorId(taskSpec), (TaskExecutor.ActorContext)this.taskExecutor.getActorContext());
        }
        this.runtime.setIsContextSet(true);
        ((LocalModeWorkerContext)this.runtime.getWorkerContext()).setCurrentTask(null);
        this.runtime.setIsContextSet(false);
        List<ObjectId> returnIds = LocalModeTaskSubmitter.getReturnIds(taskSpec);
        for (int i = 0; i < returnIds.size(); ++i) {
            NativeRayObject putObject = i >= returnObjects.size() ? new NativeRayObject(new byte[]{1}, null) : returnObjects.get(i);
            this.objectStore.putRaw(putObject, returnIds.get(i));
        }
    }

    private static JavaFunctionDescriptor getJavaFunctionDescriptor(Common.TaskSpec taskSpec) {
        Common.FunctionDescriptor functionDescriptor = taskSpec.getFunctionDescriptor();
        if (functionDescriptor.getFunctionDescriptorCase() == Common.FunctionDescriptor.FunctionDescriptorCase.JAVA_FUNCTION_DESCRIPTOR) {
            return new JavaFunctionDescriptor(functionDescriptor.getJavaFunctionDescriptor().getClassName(), functionDescriptor.getJavaFunctionDescriptor().getFunctionName(), functionDescriptor.getJavaFunctionDescriptor().getSignature());
        }
        throw new RuntimeException("Can't build non java function descriptor");
    }

    private static List<FunctionArg> getFunctionArgs(Common.TaskSpec taskSpec) {
        ArrayList<FunctionArg> functionArgs = new ArrayList<FunctionArg>();
        for (int i = 0; i < taskSpec.getArgsCount(); ++i) {
            Common.TaskArg arg = taskSpec.getArgs(i);
            if (arg.getObjectIdsCount() > 0) {
                functionArgs.add(FunctionArg.passByReference(new ObjectId(arg.getObjectIds(0).toByteArray())));
                continue;
            }
            functionArgs.add(FunctionArg.passByValue(new NativeRayObject(arg.getData().toByteArray(), arg.getMetadata().toByteArray())));
        }
        return functionArgs;
    }

    private static List<ObjectId> getReturnIds(Common.TaskSpec taskSpec) {
        return LocalModeTaskSubmitter.getReturnIds(TaskId.fromBytes(taskSpec.getTaskId().toByteArray()), taskSpec.getNumReturns());
    }

    private static List<ObjectId> getReturnIds(TaskId taskId, long numReturns) {
        ArrayList<ObjectId> returnIds = new ArrayList<ObjectId>();
        int i = 0;
        while ((long)i < numReturns) {
            returnIds.add(ObjectId.fromByteBuffer(ByteBuffer.allocate(20).put(taskId.getBytes()).putInt(14, i + 1).position(0)));
            ++i;
        }
        return returnIds;
    }
}

