/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.load;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.nodes.TransportNodesAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.ml.breaker.MLCircuitBreakerService;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.transport.forward.MLForwardInput;
import org.opensearch.ml.common.transport.forward.MLForwardRequest;
import org.opensearch.ml.common.transport.forward.MLForwardRequestType;
import org.opensearch.ml.common.transport.forward.MLForwardResponse;
import org.opensearch.ml.common.transport.load.LoadModelInput;
import org.opensearch.ml.common.transport.load.LoadModelNodeRequest;
import org.opensearch.ml.common.transport.load.LoadModelNodeResponse;
import org.opensearch.ml.common.transport.load.LoadModelNodesRequest;
import org.opensearch.ml.common.transport.load.LoadModelNodesResponse;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.MLExceptionUtils;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportRequest;
import org.opensearch.transport.TransportResponseHandler;
import org.opensearch.transport.TransportService;

public class TransportLoadModelOnNodeAction
extends TransportNodesAction<LoadModelNodesRequest, LoadModelNodesResponse, LoadModelNodeRequest, LoadModelNodeResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportLoadModelOnNodeAction.class);
    TransportService transportService;
    ModelHelper modelHelper;
    MLTaskManager mlTaskManager;
    MLModelManager mlModelManager;
    ClusterService clusterService;
    ThreadPool threadPool;
    Client client;
    NamedXContentRegistry xContentRegistry;
    MLCircuitBreakerService mlCircuitBreakerService;
    MLStats mlStats;

    @Inject
    public TransportLoadModelOnNodeAction(TransportService transportService, ActionFilters actionFilters, ModelHelper modelHelper, MLTaskManager mlTaskManager, MLModelManager mlModelManager, ClusterService clusterService, ThreadPool threadPool, Client client, NamedXContentRegistry xContentRegistry, MLCircuitBreakerService mlCircuitBreakerService, MLStats mlStats, Settings settings) {
        super("cluster:admin/opensearch/ml/load_model_on_nodes", threadPool, clusterService, transportService, actionFilters, LoadModelNodesRequest::new, LoadModelNodeRequest::new, "management", LoadModelNodeResponse.class);
        this.transportService = transportService;
        this.modelHelper = modelHelper;
        this.mlTaskManager = mlTaskManager;
        this.mlModelManager = mlModelManager;
        this.clusterService = clusterService;
        this.threadPool = threadPool;
        this.client = client;
        this.xContentRegistry = xContentRegistry;
        this.mlCircuitBreakerService = mlCircuitBreakerService;
        this.mlStats = mlStats;
    }

    protected LoadModelNodesResponse newResponse(LoadModelNodesRequest nodesRequest, List<LoadModelNodeResponse> responses, List<FailedNodeException> failures) {
        return new LoadModelNodesResponse(this.clusterService.getClusterName(), responses, failures);
    }

    protected LoadModelNodeRequest newNodeRequest(LoadModelNodesRequest request) {
        return new LoadModelNodeRequest(request);
    }

    protected LoadModelNodeResponse newNodeResponse(StreamInput in) throws IOException {
        return new LoadModelNodeResponse(in);
    }

    protected LoadModelNodeResponse nodeOperation(LoadModelNodeRequest request) {
        return this.createLoadModelNodeResponse(request.getLoadModelNodesRequest());
    }

    private LoadModelNodeResponse createLoadModelNodeResponse(LoadModelNodesRequest loadModelNodesRequest) {
        LoadModelInput loadModelInput = loadModelNodesRequest.getLoadModelInput();
        String modelId = loadModelInput.getModelId();
        String taskId = loadModelInput.getTaskId();
        Integer nodeCount = loadModelInput.getNodeCount();
        String coordinatingNodeId = loadModelInput.getCoordinatingNodeId();
        MLTask mlTask = loadModelInput.getMlTask();
        String modelContentHash = loadModelInput.getModelContentHash();
        HashMap<String, String> modelLoadStatus = new HashMap<String, String>();
        modelLoadStatus.put(modelId, "received");
        String localNodeId = this.clusterService.localNode().getId();
        ActionListener taskDoneListener = ActionListener.wrap(res -> log.info("load model task done " + taskId), ex -> MLExceptionUtils.logException("Load model task failed: " + taskId, ex, log));
        this.loadModel(modelId, modelContentHash, mlTask.getFunctionName(), localNodeId, coordinatingNodeId, mlTask, (ActionListener<String>)ActionListener.wrap(r -> {
            MLForwardInput mlForwardInput = MLForwardInput.builder().requestType(MLForwardRequestType.LOAD_MODEL_DONE).taskId(taskId).modelId(modelId).workerNodeId(this.clusterService.localNode().getId()).build();
            MLForwardRequest loadModelDoneMessage = new MLForwardRequest(mlForwardInput);
            this.transportService.sendRequest(this.getNodeById(coordinatingNodeId), "cluster:admin/opensearch/mlinternal/forward", (TransportRequest)loadModelDoneMessage, (TransportResponseHandler)new ActionListenerResponseHandler(taskDoneListener, MLForwardResponse::new));
        }, e -> {
            MLForwardInput mlForwardInput = MLForwardInput.builder().requestType(MLForwardRequestType.LOAD_MODEL_DONE).taskId(taskId).modelId(modelId).workerNodeId(this.clusterService.localNode().getId()).error(MLExceptionUtils.getRootCauseMessage(e)).build();
            MLForwardRequest loadModelDoneMessage = new MLForwardRequest(mlForwardInput);
            this.transportService.sendRequest(this.getNodeById(coordinatingNodeId), "cluster:admin/opensearch/mlinternal/forward", (TransportRequest)loadModelDoneMessage, (TransportResponseHandler)new ActionListenerResponseHandler(taskDoneListener, MLForwardResponse::new));
        }));
        return new LoadModelNodeResponse(this.clusterService.localNode(), modelLoadStatus);
    }

    private DiscoveryNode getNodeById(String nodeId) {
        DiscoveryNodes nodes = this.clusterService.state().getNodes();
        for (DiscoveryNode node : nodes) {
            if (!node.getId().equals(nodeId)) continue;
            return node;
        }
        return null;
    }

    private void loadModel(String modelId, String modelContentHash, FunctionName functionName, String localNodeId, String coordinatingNodeId, MLTask mlTask, ActionListener<String> listener) {
        try {
            log.debug("start loading model {}", (Object)modelId);
            this.mlModelManager.loadModel(modelId, modelContentHash, functionName, mlTask, (ActionListener<String>)ActionListener.runBefore(listener, () -> {
                if (!coordinatingNodeId.equals(localNodeId)) {
                    this.mlTaskManager.remove(mlTask.getTaskId());
                }
            }));
        }
        catch (Exception e) {
            MLExceptionUtils.logException("Failed to load model " + modelId, e, log);
            listener.onFailure(e);
        }
    }
}

