/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.cluster;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Semaphore;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.sync.MLSyncUpAction;
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest;
import org.opensearch.ml.indices.MLIndicesHandler;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

public class MLSyncUpCron
implements Runnable {
    @Generated
    private static final Logger log = LogManager.getLogger(MLSyncUpCron.class);
    public static final int LOAD_MODEL_TASK_GRACE_TIME_IN_MS = 20000;
    private Client client;
    private ClusterService clusterService;
    private DiscoveryNodeHelper nodeHelper;
    private MLIndicesHandler mlIndicesHandler;
    @VisibleForTesting
    Semaphore updateModelStateSemaphore;

    public MLSyncUpCron(Client client, ClusterService clusterService, DiscoveryNodeHelper nodeHelper, MLIndicesHandler mlIndicesHandler) {
        this.client = client;
        this.clusterService = clusterService;
        this.nodeHelper = nodeHelper;
        this.mlIndicesHandler = mlIndicesHandler;
        this.updateModelStateSemaphore = new Semaphore(1);
    }

    @Override
    public void run() {
        if (!this.clusterService.state().metadata().indices().containsKey((Object)".plugins-ml-model")) {
            return;
        }
        log.debug("ML sync job starts");
        DiscoveryNode[] allNodes = this.nodeHelper.getAllNodes();
        MLSyncUpInput gatherInfoInput = MLSyncUpInput.builder().getLoadedModels(true).build();
        MLSyncUpNodesRequest gatherInfoRequest = new MLSyncUpNodesRequest(allNodes, gatherInfoInput);
        this.client.execute((ActionType)MLSyncUpAction.INSTANCE, (ActionRequest)gatherInfoRequest, ActionListener.wrap(r -> {
            List responses = r.getNodes();
            HashMap<String, Set> modelWorkerNodes = new HashMap<String, Set>();
            HashMap<String, Set> runningLoadModelTasks = new HashMap<String, Set>();
            HashMap<String, Set> loadingModels = new HashMap<String, Set>();
            for (MLSyncUpNodeResponse mLSyncUpNodeResponse : responses) {
                String[] runningLoadModelTaskIds;
                String[] runningModelIds;
                String nodeId = mLSyncUpNodeResponse.getNode().getId();
                String[] loadedModelIds = mLSyncUpNodeResponse.getLoadedModelIds();
                if (loadedModelIds != null && loadedModelIds.length > 0) {
                    for (String modelId : loadedModelIds) {
                        Set workerNodes = modelWorkerNodes.computeIfAbsent(modelId, it -> new HashSet());
                        workerNodes.add(nodeId);
                    }
                }
                if ((runningModelIds = mLSyncUpNodeResponse.getRunningLoadModelIds()) != null && runningModelIds.length > 0) {
                    for (String modelId : runningModelIds) {
                        Set workerNodes = loadingModels.computeIfAbsent(modelId, it -> new HashSet());
                        workerNodes.add(nodeId);
                    }
                }
                if ((runningLoadModelTaskIds = mLSyncUpNodeResponse.getRunningLoadModelTaskIds()) == null || runningLoadModelTaskIds.length <= 0) continue;
                for (String taskId : runningLoadModelTaskIds) {
                    Set workerNodes = runningLoadModelTasks.computeIfAbsent(taskId, it -> new HashSet());
                    workerNodes.add(nodeId);
                }
            }
            for (Map.Entry entry : modelWorkerNodes.entrySet()) {
                String modelId = (String)entry.getKey();
                log.debug("will sync model worker nodes for model: {}: {}", (Object)modelId, (Object)((Set)entry.getValue()).toArray(new String[0]));
            }
            for (Map.Entry entry : runningLoadModelTasks.entrySet()) {
                log.debug("will sync running task: {}: {}", entry.getKey(), (Object)((Set)entry.getValue()).toArray(new String[0]));
            }
            MLSyncUpInput.MLSyncUpInputBuilder inputBuilder = MLSyncUpInput.builder().syncRunningLoadModelTasks(true).runningLoadModelTasks(runningLoadModelTasks);
            if (modelWorkerNodes.size() == 0) {
                log.debug("No loaded model found. Will clear model routing on all nodes");
                inputBuilder.clearRoutingTable(true);
            } else {
                inputBuilder.modelRoutingTable(modelWorkerNodes);
            }
            MLSyncUpInput mLSyncUpInput = inputBuilder.build();
            MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(allNodes, mLSyncUpInput);
            this.client.execute((ActionType)MLSyncUpAction.INSTANCE, (ActionRequest)syncUpRequest, ActionListener.wrap(re -> log.debug("sync model routing job finished"), ex -> log.error("Failed to sync model routing", (Throwable)ex)));
            this.mlIndicesHandler.initModelIndexIfAbsent((ActionListener<Boolean>)ActionListener.wrap(res -> this.refreshModelState(modelWorkerNodes, loadingModels), e -> log.error("Failed to init model index", (Throwable)e)));
        }, e -> log.error("Failed to sync model routing", (Throwable)e)));
    }

    @VisibleForTesting
    void refreshModelState(Map<String, Set<String>> modelWorkerNodes, Map<String, Set<String>> loadingModels) {
        if (!this.updateModelStateSemaphore.tryAcquire()) {
            return;
        }
        try {
            SearchRequest searchRequest = new SearchRequest(new String[]{".plugins-ml-model"});
            BoolQueryBuilder queryBuilder = new BoolQueryBuilder();
            queryBuilder.filter((QueryBuilder)new TermsQueryBuilder("model_state", Arrays.asList(MLModelState.LOADING.name(), MLModelState.PARTIALLY_LOADED.name(), MLModelState.LOADED.name(), MLModelState.LOAD_FAILED.name())));
            SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
            sourceBuilder.query((QueryBuilder)queryBuilder);
            sourceBuilder.size(10000);
            sourceBuilder.fetchSource(new String[]{"model_state", "planning_worker_node_count", "last_updated_time", "current_worker_node_count"}, null);
            searchRequest.source(sourceBuilder);
            this.client.search(searchRequest, ActionListener.wrap(res -> {
                SearchHit[] hits = res.getHits().getHits();
                HashMap<String, MLModelState> newModelStates = new HashMap<String, MLModelState>();
                for (SearchHit hit : hits) {
                    int currentWorkerNodeCountInIndex;
                    int planningWorkerNodeCount;
                    Long lastUpdateTime;
                    Map sourceAsMap;
                    MLModelState state;
                    String modelId = hit.getId();
                    MLModelState mlModelState = this.getNewModelState(loadingModels, modelWorkerNodes, modelId, state = MLModelState.from((String)((String)(sourceAsMap = hit.getSourceAsMap()).get("model_state"))), lastUpdateTime = sourceAsMap.containsKey("last_updated_time") ? (Long)sourceAsMap.get("last_updated_time") : null, planningWorkerNodeCount = sourceAsMap.containsKey("planning_worker_node_count") ? (Integer)sourceAsMap.get("planning_worker_node_count") : 0, currentWorkerNodeCountInIndex = sourceAsMap.containsKey("current_worker_node_count") ? (Integer)sourceAsMap.get("current_worker_node_count") : 0);
                    if (mlModelState == null) continue;
                    newModelStates.put(modelId, mlModelState);
                }
                this.bulkUpdateModelState(modelWorkerNodes, newModelStates);
            }, e -> {
                this.updateModelStateSemaphore.release();
                log.error("Failed to search models", (Throwable)e);
            }));
        }
        catch (Exception e2) {
            this.updateModelStateSemaphore.release();
            log.error("Failed to refresh model state", (Throwable)e2);
        }
    }

    private MLModelState getNewModelState(Map<String, Set<String>> loadingModels, Map<String, Set<String>> modelWorkerNodes, String modelId, MLModelState state, Long lastUpdateTime, int planningWorkerNodeCount, int currentWorkerNodeCountInIndex) {
        int currentWorkerNodeCount;
        Set<String> loadTaskNodes = loadingModels.get(modelId);
        if (loadTaskNodes != null && loadTaskNodes.size() > 0 && state != MLModelState.LOADING) {
            return MLModelState.LOADING;
        }
        int n = currentWorkerNodeCount = modelWorkerNodes.containsKey(modelId) ? modelWorkerNodes.get(modelId).size() : 0;
        if (currentWorkerNodeCount == 0 && state != MLModelState.LOAD_FAILED && (state != MLModelState.LOADING || lastUpdateTime == null || lastUpdateTime + 20000L <= Instant.now().toEpochMilli())) {
            return MLModelState.LOAD_FAILED;
        }
        if (currentWorkerNodeCount > 0) {
            if (currentWorkerNodeCount < planningWorkerNodeCount && (state != MLModelState.PARTIALLY_LOADED || currentWorkerNodeCountInIndex != currentWorkerNodeCount)) {
                return MLModelState.PARTIALLY_LOADED;
            }
            if (planningWorkerNodeCount > 0 && currentWorkerNodeCount >= planningWorkerNodeCount && state != MLModelState.LOADED) {
                if (currentWorkerNodeCount > planningWorkerNodeCount) {
                    log.warn("Model {} loaded on more nodes [{}] than planning worker node [{}]", (Object)modelId, (Object)currentWorkerNodeCount, (Object)planningWorkerNodeCount);
                }
                return MLModelState.LOADED;
            }
        }
        return null;
    }

    private void bulkUpdateModelState(Map<String, Set<String>> modelWorkerNodes, Map<String, MLModelState> newModelStates) {
        if (newModelStates.size() > 0) {
            BulkRequest bulkUpdateRequest = new BulkRequest();
            for (String modelId : newModelStates.keySet()) {
                UpdateRequest updateRequest = new UpdateRequest();
                Instant now = Instant.now();
                ImmutableMap.Builder builder = ImmutableMap.builder();
                builder.put((Object)"model_state", (Object)newModelStates.get(modelId).name()).put((Object)"last_updated_time", (Object)now.toEpochMilli());
                Set<String> workerNodes = modelWorkerNodes.get(modelId);
                int currentWorkNodeCount = workerNodes == null ? 0 : workerNodes.size();
                builder.put((Object)"current_worker_node_count", (Object)currentWorkNodeCount);
                ((UpdateRequest)updateRequest.index(".plugins-ml-model")).id(modelId).doc((Map)builder.build());
                bulkUpdateRequest.add(updateRequest);
            }
            log.info("Refresh model state: {}", newModelStates);
            this.client.bulk(bulkUpdateRequest, ActionListener.wrap(br -> {
                this.updateModelStateSemaphore.release();
                log.debug("Refresh model state successfully");
            }, e -> {
                this.updateModelStateSemaphore.release();
                log.error("Failed to bulk update model state", (Throwable)e);
            }));
        } else {
            this.updateModelStateSemaphore.release();
        }
    }
}

