/*
 * Decompiled with CFR 0.152.
 */
package org.ray.streaming.runtime.worker.tasks;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.id.ActorId;
import org.ray.streaming.api.collector.Collector;
import org.ray.streaming.api.partition.Partition;
import org.ray.streaming.runtime.core.collector.OutputCollector;
import org.ray.streaming.runtime.core.graph.ExecutionEdge;
import org.ray.streaming.runtime.core.graph.ExecutionGraph;
import org.ray.streaming.runtime.core.graph.ExecutionNode;
import org.ray.streaming.runtime.core.processor.Processor;
import org.ray.streaming.runtime.transfer.ChannelID;
import org.ray.streaming.runtime.transfer.DataReader;
import org.ray.streaming.runtime.transfer.DataWriter;
import org.ray.streaming.runtime.worker.JobWorker;
import org.ray.streaming.runtime.worker.context.RayRuntimeContext;
import org.ray.streaming.util.Config;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class StreamTask
implements Runnable {
    private static final Logger LOG = LoggerFactory.getLogger(StreamTask.class);
    protected int taskId;
    protected Processor processor;
    protected JobWorker worker;
    protected DataReader reader;
    private Map<ExecutionEdge, DataWriter> writers;
    private Thread thread;

    public StreamTask(int taskId, Processor processor, JobWorker worker) {
        this.taskId = taskId;
        this.processor = processor;
        this.worker = worker;
        this.prepareTask();
        this.thread = new Thread(Ray.wrapRunnable(this), this.getClass().getName() + "-" + System.currentTimeMillis());
        this.thread.setDaemon(true);
    }

    private void prepareTask() {
        HashMap<String, String> queueConf = new HashMap<String, String>();
        this.worker.getConfig().forEach((k, v) -> queueConf.put((String)k, String.valueOf(v)));
        String queueSize = this.worker.getConfig().getOrDefault("channel_size", Config.CHANNEL_SIZE_DEFAULT);
        queueConf.put("channel_size", queueSize);
        queueConf.put("streaming.task_job_id", Ray.getRuntimeContext().getCurrentJobId().toString());
        String channelType = this.worker.getConfig().getOrDefault("channel_type", "memory_channel");
        queueConf.put("channel_type", channelType);
        ExecutionGraph executionGraph = this.worker.getExecutionGraph();
        ExecutionNode executionNode = this.worker.getExecutionNode();
        this.writers = new HashMap<ExecutionEdge, DataWriter>();
        List<ExecutionEdge> outputEdges = executionNode.getOutputEdges();
        ArrayList<Collector> collectors = new ArrayList<Collector>();
        for (ExecutionEdge edge : outputEdges) {
            HashMap outputActorIds = new HashMap();
            Map<Integer, RayActor<JobWorker>> taskId2Worker = executionGraph.getTaskId2WorkerByNodeId(edge.getTargetNodeId());
            taskId2Worker.forEach((targetTaskId, targetActor) -> {
                String queueName = ChannelID.genIdStr(this.taskId, targetTaskId, executionGraph.getBuildTime());
                outputActorIds.put(queueName, targetActor.getId());
            });
            if (outputActorIds.isEmpty()) continue;
            ArrayList<String> channelIDs = new ArrayList<String>();
            ArrayList<ActorId> toActorIds = new ArrayList<ActorId>();
            outputActorIds.forEach((k, v) -> {
                channelIDs.add((String)k);
                toActorIds.add((ActorId)v);
            });
            DataWriter writer = new DataWriter(channelIDs, toActorIds, queueConf);
            LOG.info("Create DataWriter succeed.");
            this.writers.put(edge, writer);
            Partition partition = edge.getPartition();
            collectors.add(new OutputCollector(channelIDs, writer, partition));
        }
        List<ExecutionEdge> inputEdges = executionNode.getInputsEdges();
        HashMap<String, ActorId> inputActorIds = new HashMap<String, ActorId>();
        for (ExecutionEdge edge : inputEdges) {
            Map<Integer, RayActor<JobWorker>> taskId2Worker = executionGraph.getTaskId2WorkerByNodeId(edge.getSrcNodeId());
            taskId2Worker.forEach((srcTaskId, srcActor) -> {
                String queueName = ChannelID.genIdStr(srcTaskId, this.taskId, executionGraph.getBuildTime());
                inputActorIds.put(queueName, srcActor.getId());
            });
        }
        if (!inputActorIds.isEmpty()) {
            ArrayList<String> channelIDs = new ArrayList<String>();
            ArrayList<ActorId> fromActorIds = new ArrayList<ActorId>();
            inputActorIds.forEach((k, v) -> {
                channelIDs.add((String)k);
                fromActorIds.add((ActorId)v);
            });
            LOG.info("Register queue consumer, queues {}.", (Object)channelIDs);
            this.reader = new DataReader(channelIDs, fromActorIds, queueConf);
        }
        RayRuntimeContext runtimeContext = new RayRuntimeContext(this.worker.getExecutionTask(), this.worker.getConfig(), executionNode.getParallelism());
        this.processor.open(collectors, runtimeContext);
        Runtime.getRuntime().addShutdownHook(new Thread(() -> {
            try {
                this.cancelTask();
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        }));
    }

    protected abstract void init() throws Exception;

    protected abstract void cancelTask() throws Exception;

    public void start() {
        this.thread.start();
        LOG.info("started {}-{}", (Object)this.getClass().getSimpleName(), (Object)this.taskId);
    }
}

