/*
 * Decompiled with CFR 0.152.
 */
package org.ray.runtime.functionmanager;

import com.google.common.base.Strings;
import java.io.File;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.Method;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.WeakHashMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.stream.Collectors;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.filefilter.DirectoryFileFilter;
import org.apache.commons.io.filefilter.RegexFileFilter;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.objectweb.asm.Type;
import org.ray.api.function.RayFunc;
import org.ray.api.id.JobId;
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
import org.ray.runtime.functionmanager.RayFunction;
import org.ray.runtime.util.LambdaUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FunctionManager {
    private static final Logger LOGGER = LoggerFactory.getLogger(FunctionManager.class);
    static final String CONSTRUCTOR_NAME = "<init>";
    private static final ThreadLocal<WeakHashMap<Class<? extends RayFunc>, JavaFunctionDescriptor>> RAY_FUNC_CACHE = ThreadLocal.withInitial(WeakHashMap::new);
    private ConcurrentMap<JobId, JobFunctionTable> jobFunctionTables = new ConcurrentHashMap<JobId, JobFunctionTable>();
    private final String jobResourcePath;

    public FunctionManager(String jobResourcePath) {
        this.jobResourcePath = jobResourcePath;
    }

    public RayFunction getFunction(JobId jobId, RayFunc func) {
        JavaFunctionDescriptor functionDescriptor = RAY_FUNC_CACHE.get().get(func.getClass());
        if (functionDescriptor == null) {
            SerializedLambda serializedLambda = LambdaUtils.getSerializedLambda(func);
            String className = serializedLambda.getImplClass().replace('/', '.');
            String methodName = serializedLambda.getImplMethodName();
            String signature = serializedLambda.getImplMethodSignature();
            functionDescriptor = new JavaFunctionDescriptor(className, methodName, signature);
            RAY_FUNC_CACHE.get().put(func.getClass(), functionDescriptor);
        }
        return this.getFunction(jobId, functionDescriptor);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public RayFunction getFunction(JobId jobId, JavaFunctionDescriptor functionDescriptor) {
        JobFunctionTable jobFunctionTable = (JobFunctionTable)this.jobFunctionTables.get(jobId);
        if (jobFunctionTable == null) {
            FunctionManager functionManager = this;
            synchronized (functionManager) {
                jobFunctionTable = (JobFunctionTable)this.jobFunctionTables.get(jobId);
                if (jobFunctionTable == null) {
                    jobFunctionTable = this.createJobFunctionTable(jobId);
                    this.jobFunctionTables.put(jobId, jobFunctionTable);
                }
            }
        }
        return jobFunctionTable.getFunction(functionDescriptor);
    }

    private JobFunctionTable createJobFunctionTable(JobId jobId) {
        ClassLoader classLoader;
        if (Strings.isNullOrEmpty(this.jobResourcePath)) {
            classLoader = this.getClass().getClassLoader();
        } else {
            File resourceDir = new File(this.jobResourcePath + "/" + jobId.toString() + "/");
            Collection<File> files = FileUtils.listFiles(resourceDir, new RegexFileFilter(".*\\.jar"), DirectoryFileFilter.DIRECTORY);
            files.add(resourceDir);
            List<URL> urlList = files.stream().map(file -> {
                try {
                    return file.toURI().toURL();
                }
                catch (MalformedURLException e) {
                    throw new RuntimeException(e);
                }
            }).collect(Collectors.toList());
            classLoader = new URLClassLoader(urlList.toArray(new URL[urlList.size()]));
            LOGGER.debug("Resource loaded for job {} from path {}.", (Object)jobId, (Object)resourceDir.getAbsolutePath());
        }
        return new JobFunctionTable(classLoader);
    }

    static class JobFunctionTable {
        final ClassLoader classLoader;
        ConcurrentMap<String, Map<Pair<String, String>, RayFunction>> functions;

        JobFunctionTable(ClassLoader classLoader) {
            this.classLoader = classLoader;
            this.functions = new ConcurrentHashMap<String, Map<Pair<String, String>, RayFunction>>();
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        RayFunction getFunction(JavaFunctionDescriptor descriptor) {
            ImmutablePair<String, String> key;
            RayFunction func;
            Map<Pair<String, String>, RayFunction> classFunctions = (Map<Pair<String, String>, RayFunction>)this.functions.get(descriptor.className);
            if (classFunctions == null) {
                JobFunctionTable jobFunctionTable = this;
                synchronized (jobFunctionTable) {
                    classFunctions = (Map)this.functions.get(descriptor.className);
                    if (classFunctions == null) {
                        classFunctions = this.loadFunctionsForClass(descriptor.className);
                        this.functions.put(descriptor.className, classFunctions);
                    }
                }
            }
            if ((func = (RayFunction)classFunctions.get(key = ImmutablePair.of(descriptor.name, descriptor.signature))) == null) {
                if (classFunctions.containsKey(key)) {
                    throw new RuntimeException(String.format("RayFunction %s is overloaded, the signature can't be empty.", descriptor.toString()));
                }
                throw new RuntimeException(String.format("RayFunction %s not found", descriptor.toString()));
            }
            return func;
        }

        Map<Pair<String, String>, RayFunction> loadFunctionsForClass(String className) {
            HashMap<Pair<String, String>, RayFunction> map = new HashMap<Pair<String, String>, RayFunction>();
            try {
                Class<?> clazz = Class.forName(className, true, this.classLoader);
                ArrayList<Executable> executables = new ArrayList<Executable>();
                executables.addAll(Arrays.asList(clazz.getDeclaredMethods()));
                executables.addAll(Arrays.asList(clazz.getConstructors()));
                for (Executable e : executables) {
                    e.setAccessible(true);
                    String methodName = e instanceof Method ? e.getName() : FunctionManager.CONSTRUCTOR_NAME;
                    Type type = e instanceof Method ? Type.getType((Method)e) : Type.getType((Constructor)e);
                    String signature = type.getDescriptor();
                    RayFunction rayFunction = new RayFunction(e, this.classLoader, new JavaFunctionDescriptor(className, methodName, signature));
                    map.put(ImmutablePair.of(methodName, signature), rayFunction);
                    ImmutablePair<String, String> emptyDescriptor = ImmutablePair.of(methodName, "");
                    if (map.containsKey(emptyDescriptor)) {
                        map.put(emptyDescriptor, null);
                        continue;
                    }
                    map.put(emptyDescriptor, rayFunction);
                }
            }
            catch (Exception e) {
                throw new RuntimeException("Failed to load functions from class " + className, e);
            }
            return map;
        }
    }
}

