/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.task;

import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import javax.naming.LimitExceededException;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.ml.action.stats.MLStatsNodeResponse;
import org.opensearch.ml.action.stats.MLStatsNodesAction;
import org.opensearch.ml.action.stats.MLStatsNodesRequest;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.settings.MLCommonsSettings;
import org.opensearch.ml.stats.MLNodeLevelStat;

public class MLTaskDispatcher {
    @Generated
    private static final Logger log = LogManager.getLogger(MLTaskDispatcher.class);
    private final short DEFAULT_JVM_HEAP_USAGE_THRESHOLD = (short)85;
    private final String ROUND_ROBIN = "round_robin";
    private final String LEAST_LOAD = "least_load";
    private final ClusterService clusterService;
    private final Client client;
    private AtomicInteger nextNode;
    private volatile Integer maxMLBatchTaskPerNode;
    private volatile String dispatchPolicy;
    private DiscoveryNodeHelper nodeHelper;

    public MLTaskDispatcher(ClusterService clusterService, Client client, Settings settings, DiscoveryNodeHelper nodeHelper) {
        this.clusterService = clusterService;
        this.client = client;
        this.nodeHelper = nodeHelper;
        this.maxMLBatchTaskPerNode = (Integer)MLCommonsSettings.ML_COMMONS_MAX_ML_TASK_PER_NODE.get(settings);
        this.nextNode = new AtomicInteger(0);
        this.dispatchPolicy = (String)MLCommonsSettings.ML_COMMONS_TASK_DISPATCH_POLICY.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_TASK_DISPATCH_POLICY, it -> {
            this.dispatchPolicy = it;
        });
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_MAX_ML_TASK_PER_NODE, it -> {
            this.maxMLBatchTaskPerNode = it;
        });
    }

    public void dispatch(ActionListener<DiscoveryNode> actionListener) {
        if ("round_robin".equals(this.dispatchPolicy)) {
            this.dispatchTaskWithRoundRobin(actionListener);
        } else if ("least_load".equals(this.dispatchPolicy)) {
            this.dispatchTaskWithLeastLoad(actionListener);
        } else {
            throw new IllegalArgumentException("Unknown policy");
        }
    }

    public void dispatchPredictTask(String[] nodeIds, ActionListener<DiscoveryNode> actionListener) {
        if (nodeIds == null || nodeIds.length == 0) {
            throw new IllegalArgumentException("Model not loaded yet");
        }
        if ("round_robin".equals(this.dispatchPolicy)) {
            this.dispatchTaskWithRoundRobin(nodeIds, ActionListener.wrap(nodeId -> actionListener.onResponse((Object)this.nodeHelper.getNode((String)nodeId)), e -> actionListener.onFailure(e)));
        } else if ("least_load".equals(this.dispatchPolicy)) {
            this.dispatchTaskWithLeastLoad(nodeIds, actionListener);
        } else {
            throw new IllegalArgumentException("Unknown policy");
        }
    }

    private <T> void dispatchTaskWithRoundRobin(T[] nodes, ActionListener<T> listener) {
        int currentNode = this.nextNode.getAndIncrement();
        if (currentNode > nodes.length - 1) {
            currentNode = 0;
            this.nextNode.set(currentNode + 1);
        }
        listener.onResponse(nodes[currentNode]);
    }

    private void dispatchTaskWithLeastLoad(String[] nodeIds, ActionListener<DiscoveryNode> listener) {
        DiscoveryNode[] nodes = this.nodeHelper.getNodes(nodeIds);
        this.dispatchTaskWithLeastLoad(nodes, listener);
    }

    private void dispatchTaskWithLeastLoad(DiscoveryNode[] nodes, ActionListener<DiscoveryNode> listener) {
        MLStatsNodesRequest MLStatsNodesRequest2 = new MLStatsNodesRequest(nodes);
        MLStatsNodesRequest2.addNodeLevelStats((Set<MLNodeLevelStat>)ImmutableSet.of((Object)((Object)MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT), (Object)((Object)MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE)));
        this.client.execute((ActionType)MLStatsNodesAction.INSTANCE, (ActionRequest)MLStatsNodesRequest2, ActionListener.wrap(mlStatsResponse -> {
            List candidateNodeResponse = mlStatsResponse.getNodes().stream().filter(stat -> (Long)stat.getNodeLevelStat(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE) < 85L).collect(Collectors.toList());
            if (candidateNodeResponse.size() == 0) {
                String errorMessage = "All nodes' memory usage exceeds limitation 85. No eligible node available to run ml jobs ";
                log.warn(errorMessage);
                listener.onFailure((Exception)new LimitExceededException(errorMessage));
                return;
            }
            if ((candidateNodeResponse = candidateNodeResponse.stream().filter(stat -> (Long)stat.getNodeLevelStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT) < (long)this.maxMLBatchTaskPerNode.intValue()).collect(Collectors.toList())).size() == 0) {
                String errorMessage = "All nodes' executing ML task count reach limitation.";
                log.warn(errorMessage);
                listener.onFailure((Exception)new LimitExceededException(errorMessage));
                return;
            }
            Optional targetNode = candidateNodeResponse.stream().sorted((r1, r2) -> {
                int result = ((Long)r1.getNodeLevelStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT)).compareTo((Long)r2.getNodeLevelStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT));
                if (result == 0) {
                    return ((Long)r1.getNodeLevelStat(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE)).compareTo((Long)r2.getNodeLevelStat(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE));
                }
                return result;
            }).findFirst();
            listener.onResponse((Object)((MLStatsNodeResponse)((Object)((Object)targetNode.get()))).getNode());
        }, exception -> {
            log.error("Failed to get node's task stats", (Throwable)exception);
            listener.onFailure(exception);
        }));
    }

    private void dispatchTaskWithLeastLoad(ActionListener<DiscoveryNode> listener) {
        DiscoveryNode[] eligibleNodes = this.nodeHelper.getEligibleNodes();
        this.dispatchTaskWithLeastLoad(eligibleNodes, listener);
    }

    private void dispatchTaskWithRoundRobin(ActionListener<DiscoveryNode> listener) {
        DiscoveryNode[] eligibleNodes = this.nodeHelper.getEligibleNodes();
        if (eligibleNodes == null || eligibleNodes.length == 0) {
            throw new MLResourceNotFoundException("No eligible node found to execute this request. It's best practice to provision ML nodes to serve your models. You can disable this setting to serve the model on your data node for development purposes by disabling the \"plugins.ml_commons.only_run_on_ml_node\" configuration using the _cluster/setting api");
        }
        this.dispatchTaskWithRoundRobin(eligibleNodes, listener);
    }
}

