/*
 * Decompiled with CFR 0.152.
 */
package io.ray.runtime.task;

import com.google.common.base.Preconditions;
import io.ray.api.exception.RayTaskException;
import io.ray.api.id.ActorId;
import io.ray.api.id.JobId;
import io.ray.api.id.TaskId;
import io.ray.api.id.UniqueId;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.functionmanager.JavaFunctionDescriptor;
import io.ray.runtime.functionmanager.RayFunction;
import io.ray.runtime.generated.Common;
import io.ray.runtime.object.NativeRayObject;
import io.ray.runtime.object.ObjectSerializer;
import io.ray.runtime.task.ArgumentsBuilder;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class TaskExecutor<T extends ActorContext> {
    private static final Logger LOGGER = LoggerFactory.getLogger(TaskExecutor.class);
    protected final RayRuntimeInternal runtime;
    private final ConcurrentHashMap<UniqueId, T> actorContextMap = new ConcurrentHashMap();

    TaskExecutor(RayRuntimeInternal runtime) {
        this.runtime = runtime;
    }

    protected abstract T createActorContext();

    T getActorContext() {
        return (T)((ActorContext)this.actorContextMap.get(this.runtime.getWorkerContext().getCurrentWorkerId()));
    }

    void setActorContext(T actorContext) {
        if (actorContext == null) {
            return;
        }
        this.actorContextMap.put(this.runtime.getWorkerContext().getCurrentWorkerId(), actorContext);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected List<NativeRayObject> execute(List<String> rayFunctionInfo, List<NativeRayObject> argsBytes) {
        this.runtime.setIsContextSet(true);
        JobId jobId = this.runtime.getWorkerContext().getCurrentJobId();
        Common.TaskType taskType = this.runtime.getWorkerContext().getCurrentTaskType();
        TaskId taskId = this.runtime.getWorkerContext().getCurrentTaskId();
        LOGGER.debug("Executing task {}", (Object)taskId);
        Object actorContext = null;
        if (taskType == Common.TaskType.ACTOR_CREATION_TASK) {
            actorContext = this.createActorContext();
            this.setActorContext(actorContext);
        } else if (taskType == Common.TaskType.ACTOR_TASK) {
            actorContext = this.getActorContext();
            Preconditions.checkNotNull(actorContext);
        }
        ArrayList<NativeRayObject> returnObjects = new ArrayList<NativeRayObject>();
        ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
        JavaFunctionDescriptor functionDescriptor = this.parseFunctionDescriptor(rayFunctionInfo);
        RayFunction rayFunction = null;
        try {
            Object result;
            rayFunction = this.runtime.getFunctionManager().getFunction(jobId, functionDescriptor);
            Thread.currentThread().setContextClassLoader(rayFunction.classLoader);
            this.runtime.getWorkerContext().setCurrentClassLoader(rayFunction.classLoader);
            Object actor = null;
            if (taskType == Common.TaskType.ACTOR_TASK) {
                if (((ActorContext)actorContext).actorCreationException != null) {
                    throw ((ActorContext)actorContext).actorCreationException;
                }
                actor = ((ActorContext)actorContext).currentActor;
            }
            Object[] args = ArgumentsBuilder.unwrap(argsBytes, rayFunction.executable.getParameterTypes());
            try {
                result = !rayFunction.isConstructor() ? rayFunction.getMethod().invoke(actor, args) : rayFunction.getConstructor().newInstance(args);
            }
            catch (InvocationTargetException e) {
                if (e.getCause() != null) {
                    throw e.getCause();
                }
                throw e;
            }
            if (taskType != Common.TaskType.ACTOR_CREATION_TASK) {
                if (taskType == Common.TaskType.ACTOR_TASK) {
                    this.maybeSaveCheckpoint(actor, this.runtime.getWorkerContext().getCurrentActorId());
                }
                if (rayFunction.hasReturn()) {
                    returnObjects.add(ObjectSerializer.serialize(result));
                }
            } else {
                this.maybeLoadCheckpoint(result, this.runtime.getWorkerContext().getCurrentActorId());
                ((ActorContext)actorContext).currentActor = result;
            }
            LOGGER.debug("Finished executing task {}", (Object)taskId);
        }
        catch (Throwable e) {
            LOGGER.error("Error executing task " + taskId, e);
            if (taskType != Common.TaskType.ACTOR_CREATION_TASK) {
                boolean hasReturn = rayFunction != null && rayFunction.hasReturn();
                boolean isCrossLanguage = functionDescriptor.signature.equals("");
                if (hasReturn || isCrossLanguage) {
                    returnObjects.add(ObjectSerializer.serialize(new RayTaskException("Error executing task " + taskId, e)));
                }
            } else {
                ((ActorContext)actorContext).actorCreationException = e;
            }
        }
        finally {
            Thread.currentThread().setContextClassLoader(oldLoader);
            this.runtime.getWorkerContext().setCurrentClassLoader(null);
            this.runtime.setIsContextSet(false);
        }
        return returnObjects;
    }

    private JavaFunctionDescriptor parseFunctionDescriptor(List<String> rayFunctionInfo) {
        Preconditions.checkState(rayFunctionInfo != null && rayFunctionInfo.size() == 3);
        return new JavaFunctionDescriptor(rayFunctionInfo.get(0), rayFunctionInfo.get(1), rayFunctionInfo.get(2));
    }

    protected abstract void maybeSaveCheckpoint(Object var1, ActorId var2);

    protected abstract void maybeLoadCheckpoint(Object var1, ActorId var2);

    static class ActorContext {
        Object currentActor = null;
        Throwable actorCreationException = null;

        ActorContext() {
        }
    }
}

