/*
 * Decompiled with CFR 0.152.
 */
package org.ray.runtime.task;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import java.nio.ByteBuffer;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;
import org.ray.api.RayActor;
import org.ray.api.id.ActorId;
import org.ray.api.id.ObjectId;
import org.ray.api.id.TaskId;
import org.ray.api.options.ActorCreationOptions;
import org.ray.api.options.CallOptions;
import org.ray.runtime.RayDevRuntime;
import org.ray.runtime.actor.LocalModeRayActor;
import org.ray.runtime.context.LocalModeWorkerContext;
import org.ray.runtime.functionmanager.FunctionDescriptor;
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
import org.ray.runtime.generated.Common;
import org.ray.runtime.object.LocalModeObjectStore;
import org.ray.runtime.object.NativeRayObject;
import org.ray.runtime.task.FunctionArg;
import org.ray.runtime.task.LocalModeTaskExecutor;
import org.ray.runtime.task.TaskExecutor;
import org.ray.runtime.task.TaskSubmitter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LocalModeTaskSubmitter
implements TaskSubmitter {
    private static final Logger LOGGER = LoggerFactory.getLogger(LocalModeTaskSubmitter.class);
    private final Map<ObjectId, Set<Common.TaskSpec>> waitingTasks = new HashMap<ObjectId, Set<Common.TaskSpec>>();
    private final Object taskAndObjectLock = new Object();
    private final RayDevRuntime runtime;
    private final LocalModeObjectStore objectStore;
    private final Map<ActorId, ExecutorService> actorTaskExecutorServices;
    private final ExecutorService normalTaskExecutorService;
    private final Deque<TaskExecutor> idleTaskExecutors = new ArrayDeque<TaskExecutor>();
    private final Map<ActorId, TaskExecutor> actorTaskExecutors = new HashMap<ActorId, TaskExecutor>();
    private final Object taskExecutorLock = new Object();
    private final ThreadLocal<TaskExecutor> currentTaskExecutor = new ThreadLocal();

    public LocalModeTaskSubmitter(RayDevRuntime runtime, LocalModeObjectStore objectStore, int numberThreads) {
        this.runtime = runtime;
        this.objectStore = objectStore;
        this.normalTaskExecutorService = Executors.newFixedThreadPool(numberThreads);
        this.actorTaskExecutorServices = new HashMap<ActorId, ExecutorService>();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void onObjectPut(ObjectId id) {
        Object object = this.taskAndObjectLock;
        synchronized (object) {
            Set<Common.TaskSpec> tasks = this.waitingTasks.remove(id);
            if (tasks != null) {
                for (Common.TaskSpec task : tasks) {
                    Set<ObjectId> unreadyObjects = this.getUnreadyObjects(task);
                    if (!unreadyObjects.isEmpty()) continue;
                    this.submitTaskSpec(task);
                }
            }
        }
    }

    public TaskExecutor getCurrentTaskExecutor() {
        return this.currentTaskExecutor.get();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private TaskExecutor getTaskExecutor(Common.TaskSpec task) {
        TaskExecutor taskExecutor;
        Object object = this.taskExecutorLock;
        synchronized (object) {
            if (task.getType() == Common.TaskType.ACTOR_TASK) {
                taskExecutor = this.actorTaskExecutors.get(LocalModeTaskSubmitter.getActorId(task));
            } else if (task.getType() == Common.TaskType.ACTOR_CREATION_TASK) {
                taskExecutor = new LocalModeTaskExecutor(this.runtime);
                this.actorTaskExecutors.put(LocalModeTaskSubmitter.getActorId(task), taskExecutor);
            } else {
                taskExecutor = this.idleTaskExecutors.size() > 0 ? this.idleTaskExecutors.pop() : new LocalModeTaskExecutor(this.runtime);
            }
        }
        this.currentTaskExecutor.set(taskExecutor);
        return taskExecutor;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void returnTaskExecutor(TaskExecutor worker, Common.TaskSpec taskSpec) {
        this.currentTaskExecutor.remove();
        Object object = this.taskExecutorLock;
        synchronized (object) {
            if (taskSpec.getType() == Common.TaskType.NORMAL_TASK) {
                this.idleTaskExecutors.push(worker);
            }
        }
    }

    private Set<ObjectId> getUnreadyObjects(Common.TaskSpec taskSpec) {
        ObjectId dummyObjectId;
        HashSet<ObjectId> unreadyObjects = new HashSet<ObjectId>();
        for (Common.TaskArg arg : taskSpec.getArgsList()) {
            for (ByteString idByteString : arg.getObjectIdsList()) {
                ObjectId id = new ObjectId(idByteString.toByteArray());
                if (this.objectStore.isObjectReady(id)) continue;
                unreadyObjects.add(id);
            }
        }
        if (taskSpec.getType() == Common.TaskType.ACTOR_TASK && !this.objectStore.isObjectReady(dummyObjectId = new ObjectId(taskSpec.getActorTaskSpec().getPreviousActorTaskDummyObjectId().toByteArray()))) {
            unreadyObjects.add(dummyObjectId);
        }
        return unreadyObjects;
    }

    private Common.TaskSpec.Builder getTaskSpecBuilder(Common.TaskType taskType, FunctionDescriptor functionDescriptor, List<FunctionArg> args) {
        byte[] taskIdBytes = new byte[14];
        new Random().nextBytes(taskIdBytes);
        List<String> functionDescriptorList = functionDescriptor.toList();
        Preconditions.checkState(functionDescriptorList.size() >= 3);
        return Common.TaskSpec.newBuilder().setType(taskType).setLanguage(Common.Language.JAVA).setJobId(ByteString.copyFrom(this.runtime.getRayConfig().getJobId().getBytes())).setTaskId(ByteString.copyFrom(taskIdBytes)).setFunctionDescriptor(Common.FunctionDescriptor.newBuilder().setJavaFunctionDescriptor(Common.JavaFunctionDescriptor.newBuilder().setClassName(functionDescriptorList.get(0)).setFunctionName(functionDescriptorList.get(1)).setSignature(functionDescriptorList.get(2)))).addAllArgs(args.stream().map(arg -> arg.id != null ? Common.TaskArg.newBuilder().addObjectIds(ByteString.copyFrom(arg.id.getBytes())).build() : Common.TaskArg.newBuilder().setData(ByteString.copyFrom(arg.value.data)).setMetadata(arg.value.metadata != null ? ByteString.copyFrom(arg.value.metadata) : ByteString.EMPTY).build()).collect(Collectors.toList()));
    }

    @Override
    public List<ObjectId> submitTask(FunctionDescriptor functionDescriptor, List<FunctionArg> args, int numReturns, CallOptions options) {
        Preconditions.checkState(numReturns <= 1);
        Common.TaskSpec taskSpec = this.getTaskSpecBuilder(Common.TaskType.NORMAL_TASK, functionDescriptor, args).setNumReturns(numReturns).build();
        this.submitTaskSpec(taskSpec);
        return LocalModeTaskSubmitter.getReturnIds(taskSpec);
    }

    @Override
    public RayActor createActor(FunctionDescriptor functionDescriptor, List<FunctionArg> args, ActorCreationOptions options) {
        ActorId actorId = ActorId.fromRandom();
        Common.TaskSpec taskSpec = this.getTaskSpecBuilder(Common.TaskType.ACTOR_CREATION_TASK, functionDescriptor, args).setNumReturns(1L).setActorCreationTaskSpec(Common.ActorCreationTaskSpec.newBuilder().setActorId(ByteString.copyFrom(actorId.toByteBuffer())).build()).build();
        this.submitTaskSpec(taskSpec);
        return new LocalModeRayActor(actorId, LocalModeTaskSubmitter.getReturnIds(taskSpec).get(0));
    }

    @Override
    public List<ObjectId> submitActorTask(RayActor actor, FunctionDescriptor functionDescriptor, List<FunctionArg> args, int numReturns, CallOptions options) {
        Preconditions.checkState(numReturns <= 1);
        Common.TaskSpec.Builder builder = this.getTaskSpecBuilder(Common.TaskType.ACTOR_TASK, functionDescriptor, args);
        List<ObjectId> returnIds = LocalModeTaskSubmitter.getReturnIds(TaskId.fromBytes(builder.getTaskId().toByteArray()), numReturns + 1);
        Common.TaskSpec taskSpec = builder.setNumReturns(numReturns + 1).setActorTaskSpec(Common.ActorTaskSpec.newBuilder().setActorId(ByteString.copyFrom(actor.getId().getBytes())).setPreviousActorTaskDummyObjectId(ByteString.copyFrom(((LocalModeRayActor)actor).exchangePreviousActorTaskDummyObjectId(returnIds.get(returnIds.size() - 1)).getBytes())).build()).build();
        this.submitTaskSpec(taskSpec);
        if (numReturns == 0) {
            return ImmutableList.of();
        }
        return ImmutableList.of(returnIds.get(0));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void shutdown() {
        Map<ActorId, ExecutorService> map = this.actorTaskExecutorServices;
        synchronized (map) {
            for (Map.Entry<ActorId, ExecutorService> item : this.actorTaskExecutorServices.entrySet()) {
                item.getValue().shutdown();
            }
        }
        this.normalTaskExecutorService.shutdown();
    }

    public static ActorId getActorId(Common.TaskSpec taskSpec) {
        ByteString actorId = null;
        if (taskSpec.getType() == Common.TaskType.ACTOR_CREATION_TASK) {
            actorId = taskSpec.getActorCreationTaskSpec().getActorId();
        } else if (taskSpec.getType() == Common.TaskType.ACTOR_TASK) {
            actorId = taskSpec.getActorTaskSpec().getActorId();
        }
        if (actorId == null) {
            return null;
        }
        return ActorId.fromBytes(actorId.toByteArray());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void submitTaskSpec(Common.TaskSpec taskSpec) {
        LOGGER.debug("Submitting task: {}.", (Object)taskSpec);
        Object object = this.taskAndObjectLock;
        synchronized (object) {
            Set<ObjectId> unreadyObjects = this.getUnreadyObjects(taskSpec);
            Runnable runnable = () -> {
                TaskExecutor taskExecutor = this.getTaskExecutor(taskSpec);
                try {
                    List<NativeRayObject> args = LocalModeTaskSubmitter.getFunctionArgs(taskSpec).stream().map(arg -> arg.id != null ? this.objectStore.getRaw(Collections.singletonList(arg.id), -1L).get(0) : arg.value).collect(Collectors.toList());
                    ((LocalModeWorkerContext)this.runtime.getWorkerContext()).setCurrentTask(taskSpec);
                    List<NativeRayObject> returnObjects = taskExecutor.execute(LocalModeTaskSubmitter.getJavaFunctionDescriptor(taskSpec).toList(), args);
                    ((LocalModeWorkerContext)this.runtime.getWorkerContext()).setCurrentTask(null);
                    List<ObjectId> returnIds = LocalModeTaskSubmitter.getReturnIds(taskSpec);
                    for (int i = 0; i < returnIds.size(); ++i) {
                        NativeRayObject putObject = i >= returnObjects.size() ? new NativeRayObject(new byte[]{1}, null) : returnObjects.get(i);
                        this.objectStore.putRaw(putObject, returnIds.get(i));
                    }
                }
                finally {
                    this.returnTaskExecutor(taskExecutor, taskSpec);
                }
            };
            if (unreadyObjects.isEmpty()) {
                if (taskSpec.getType() == Common.TaskType.ACTOR_CREATION_TASK) {
                    ExecutorService actorExecutorService = Executors.newSingleThreadExecutor();
                    Map<ActorId, ExecutorService> map = this.actorTaskExecutorServices;
                    synchronized (map) {
                        this.actorTaskExecutorServices.put(LocalModeTaskSubmitter.getActorId(taskSpec), actorExecutorService);
                    }
                    actorExecutorService.submit(runnable);
                } else if (taskSpec.getType() == Common.TaskType.ACTOR_TASK) {
                    Map<ActorId, ExecutorService> map = this.actorTaskExecutorServices;
                    synchronized (map) {
                        ExecutorService actorExecutorService = this.actorTaskExecutorServices.get(LocalModeTaskSubmitter.getActorId(taskSpec));
                        actorExecutorService.submit(runnable);
                    }
                } else {
                    this.normalTaskExecutorService.submit(runnable);
                }
            } else {
                for (ObjectId id : unreadyObjects) {
                    this.waitingTasks.computeIfAbsent(id, k -> new HashSet()).add(taskSpec);
                }
            }
        }
    }

    private static JavaFunctionDescriptor getJavaFunctionDescriptor(Common.TaskSpec taskSpec) {
        Common.FunctionDescriptor functionDescriptor = taskSpec.getFunctionDescriptor();
        if (functionDescriptor.getFunctionDescriptorCase() == Common.FunctionDescriptor.FunctionDescriptorCase.JAVA_FUNCTION_DESCRIPTOR) {
            return new JavaFunctionDescriptor(functionDescriptor.getJavaFunctionDescriptor().getClassName(), functionDescriptor.getJavaFunctionDescriptor().getFunctionName(), functionDescriptor.getJavaFunctionDescriptor().getSignature());
        }
        throw new RuntimeException("Can't build non java function descriptor");
    }

    private static List<FunctionArg> getFunctionArgs(Common.TaskSpec taskSpec) {
        ArrayList<FunctionArg> functionArgs = new ArrayList<FunctionArg>();
        for (int i = 0; i < taskSpec.getArgsCount(); ++i) {
            Common.TaskArg arg = taskSpec.getArgs(i);
            if (arg.getObjectIdsCount() > 0) {
                functionArgs.add(FunctionArg.passByReference(new ObjectId(arg.getObjectIds(0).toByteArray())));
                continue;
            }
            functionArgs.add(FunctionArg.passByValue(new NativeRayObject(arg.getData().toByteArray(), arg.getMetadata().toByteArray())));
        }
        return functionArgs;
    }

    private static List<ObjectId> getReturnIds(Common.TaskSpec taskSpec) {
        return LocalModeTaskSubmitter.getReturnIds(TaskId.fromBytes(taskSpec.getTaskId().toByteArray()), taskSpec.getNumReturns());
    }

    private static List<ObjectId> getReturnIds(TaskId taskId, long numReturns) {
        ArrayList<ObjectId> returnIds = new ArrayList<ObjectId>();
        int i = 0;
        while ((long)i < numReturns) {
            returnIds.add(ObjectId.fromByteBuffer(ByteBuffer.allocate(20).put(taskId.getBytes()).putInt(14, i + 1).position(0)));
            ++i;
        }
        return returnIds;
    }
}

