/*
 * Decompiled with CFR 0.152.
 */
package io.ray.serve.router;

import io.ray.api.ActorHandle;
import io.ray.api.BaseActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.PyActorHandle;
import io.ray.api.Ray;
import io.ray.api.function.PyActorMethod;
import io.ray.runtime.metric.Gauge;
import io.ray.runtime.metric.Metrics;
import io.ray.runtime.metric.TagKey;
import io.ray.serve.exception.RayServeException;
import io.ray.serve.generated.ActorNameList;
import io.ray.serve.metrics.RayServeMetrics;
import io.ray.serve.replica.RayServeWrappedReplica;
import io.ray.serve.router.Query;
import io.ray.serve.util.CollectionUtil;
import io.ray.shaded.com.google.common.collect.ImmutableMap;
import io.ray.shaded.com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.lang3.RandomUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ReplicaSet {
    private static final Logger LOGGER = LoggerFactory.getLogger(ReplicaSet.class);
    private final Map<String, Set<ObjectRef<Object>>> inFlightQueries;
    private final Map<String, BaseActorHandle> allActorHandles;
    private AtomicInteger numQueuedQueries = new AtomicInteger();
    private Gauge numQueuedQueriesGauge;
    private boolean hasPullReplica = false;

    public ReplicaSet(String deploymentName) {
        this.inFlightQueries = new ConcurrentHashMap<String, Set<ObjectRef<Object>>>();
        this.allActorHandles = new ConcurrentHashMap<String, BaseActorHandle>();
        RayServeMetrics.execute(() -> {
            this.numQueuedQueriesGauge = (Gauge)((Metrics.GaugeBuilder)((Metrics.GaugeBuilder)((Metrics.GaugeBuilder)((Metrics.GaugeBuilder)Metrics.gauge().name(RayServeMetrics.SERVE_DEPLOYMENT_QUEUED_QUERIES.getName())).description(RayServeMetrics.SERVE_DEPLOYMENT_QUEUED_QUERIES.getDescription())).unit("")).tags(ImmutableMap.of("deployment", deploymentName))).register();
        });
    }

    public synchronized void updateWorkerReplicas(Object actorSet) {
        if (null != actorSet) {
            HashSet<String> actorNameSet = new HashSet<String>(((ActorNameList)actorSet).getNamesList());
            HashSet<String> added = new HashSet<String>(Sets.difference(actorNameSet, this.inFlightQueries.keySet()));
            HashSet<String> removed = new HashSet<String>(Sets.difference(this.inFlightQueries.keySet(), actorNameSet));
            added.forEach(name -> {
                Optional handleOptional = Ray.getActor(name, "serve");
                if (handleOptional.isPresent()) {
                    this.allActorHandles.put((String)name, (BaseActorHandle)handleOptional.get());
                    this.inFlightQueries.put((String)name, Sets.newConcurrentHashSet());
                } else {
                    LOGGER.warn("Can not get actor handle. actor name is {}", name);
                }
            });
            removed.forEach(this.inFlightQueries::remove);
            removed.forEach(this.allActorHandles::remove);
            if (added.size() > 0 || removed.size() > 0) {
                LOGGER.info("ReplicaSet: +{}, -{} replicas.", (Object)added.size(), (Object)removed.size());
            }
        }
        this.hasPullReplica = true;
    }

    public ObjectRef<Object> assignReplica(Query query) {
        String endpoint = query.getMetadata().getEndpoint();
        this.numQueuedQueries.incrementAndGet();
        RayServeMetrics.execute(() -> this.numQueuedQueriesGauge.update(this.numQueuedQueries.get(), ImmutableMap.of(new TagKey("endpoint"), endpoint)));
        ObjectRef<Object> assignedRef = this.tryAssignReplica(query);
        this.numQueuedQueries.decrementAndGet();
        RayServeMetrics.execute(() -> this.numQueuedQueriesGauge.update(this.numQueuedQueries.get(), ImmutableMap.of(new TagKey("endpoint"), endpoint)));
        return assignedRef;
    }

    private ObjectRef<Object> tryAssignReplica(Query query) {
        for (int loopCount = 0; !this.hasPullReplica && loopCount < 50; ++loopCount) {
            try {
                TimeUnit.MICROSECONDS.sleep(20L);
                continue;
            }
            catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
        ArrayList<BaseActorHandle> handles = new ArrayList<BaseActorHandle>(this.allActorHandles.values());
        if (CollectionUtil.isEmpty(handles)) {
            throw new RayServeException("ReplicaSet found no replica.");
        }
        int randomIndex = RandomUtils.nextInt(0, handles.size());
        BaseActorHandle replica = (BaseActorHandle)handles.get(randomIndex);
        LOGGER.debug("Assigned query {} to replica {}.", (Object)query.getMetadata().getRequestId(), (Object)replica);
        if (replica instanceof PyActorHandle) {
            return ((PyActorHandle)replica).task(PyActorMethod.of("handle_request_from_java"), query.getMetadata().toByteArray(), query.getArgs()).remote();
        }
        return ((ActorHandle)replica).task(RayServeWrappedReplica::handleRequest, query.getMetadata().toByteArray(), query.getArgs()).remote();
    }

    public Map<String, Set<ObjectRef<Object>>> getInFlightQueries() {
        return this.inFlightQueries;
    }
}

