/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.model;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.io.Files;
import java.io.File;
import java.nio.file.Path;
import java.time.Instant;
import java.util.Arrays;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.delete.DeleteRequest;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.support.IndicesOptions;
import org.opensearch.action.support.ThreadedActionListener;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.common.xcontent.XContentParserUtils;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.reindex.DeleteByQueryAction;
import org.opensearch.index.reindex.DeleteByQueryRequest;
import org.opensearch.ml.breaker.MLCircuitBreakerService;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.load.MLLoadModelAction;
import org.opensearch.ml.common.transport.load.MLLoadModelRequest;
import org.opensearch.ml.common.transport.upload.MLUploadInput;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.engine.utils.FileUtils;
import org.opensearch.ml.indices.MLIndicesHandler;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.profile.MLModelProfile;
import org.opensearch.ml.settings.MLCommonsSettings;
import org.opensearch.ml.stats.ActionName;
import org.opensearch.ml.stats.MLActionLevelStat;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.MLExceptionUtils;
import org.opensearch.ml.utils.MLNodeUtils;
import org.opensearch.rest.RestStatus;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.threadpool.ThreadPool;

public class MLModelManager {
    @Generated
    private static final Logger log = LogManager.getLogger(MLModelManager.class);
    public static final int TIMEOUT_IN_MILLIS = 5000;
    public static final long MODEL_FILE_SIZE_LIMIT = 0x100000000L;
    private final Client client;
    private final ClusterService clusterService;
    private ThreadPool threadPool;
    private NamedXContentRegistry xContentRegistry;
    private ModelHelper modelHelper;
    private final MLModelCacheHelper modelCacheHelper;
    private final MLStats mlStats;
    private final MLCircuitBreakerService mlCircuitBreakerService;
    private final MLIndicesHandler mlIndicesHandler;
    private final MLTaskManager mlTaskManager;
    private final MLEngine mlEngine;
    private volatile Integer maxModelPerNode;
    private volatile Integer maxUploadTasksPerNode;
    private volatile Integer maxLoadTasksPerNode;
    public static final ImmutableSet MODEL_DONE_STATES = ImmutableSet.of((Object)MLModelState.TRAINED, (Object)MLModelState.UPLOADED, (Object)MLModelState.LOADED, (Object)MLModelState.PARTIALLY_LOADED, (Object)MLModelState.LOAD_FAILED, (Object)MLModelState.UNLOADED, (Object[])new MLModelState[0]);

    public MLModelManager(ClusterService clusterService, Client client, ThreadPool threadPool, NamedXContentRegistry xContentRegistry, ModelHelper modelHelper, Settings settings, MLStats mlStats, MLCircuitBreakerService mlCircuitBreakerService, MLIndicesHandler mlIndicesHandler, MLTaskManager mlTaskManager, MLModelCacheHelper modelCacheHelper, MLEngine mlEngine) {
        this.client = client;
        this.threadPool = threadPool;
        this.xContentRegistry = xContentRegistry;
        this.modelHelper = modelHelper;
        this.clusterService = clusterService;
        this.modelCacheHelper = modelCacheHelper;
        this.mlStats = mlStats;
        this.mlCircuitBreakerService = mlCircuitBreakerService;
        this.mlIndicesHandler = mlIndicesHandler;
        this.mlTaskManager = mlTaskManager;
        this.mlEngine = mlEngine;
        this.maxModelPerNode = (Integer)MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE, it -> {
            this.maxModelPerNode = it;
        });
        this.maxUploadTasksPerNode = (Integer)MLCommonsSettings.ML_COMMONS_MAX_UPLOAD_TASKS_PER_NODE.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_MAX_UPLOAD_TASKS_PER_NODE, it -> {
            this.maxUploadTasksPerNode = it;
        });
        this.maxLoadTasksPerNode = (Integer)MLCommonsSettings.ML_COMMONS_MAX_LOAD_MODEL_TASKS_PER_NODE.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_MAX_LOAD_MODEL_TASKS_PER_NODE, it -> {
            this.maxLoadTasksPerNode = it;
        });
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void uploadMLModel(MLUploadInput uploadInput, MLTask mlTask) {
        this.mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
        this.checkAndAddRunningTask(mlTask, this.maxUploadTasksPerNode);
        try {
            this.mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
            this.mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.UPLOAD, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment();
            if (uploadInput.getUrl() != null) {
                this.uploadModelFromUrl(uploadInput, mlTask);
            } else {
                this.uploadPrebuiltModel(uploadInput, mlTask);
            }
        }
        catch (Exception e) {
            this.mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.UPLOAD, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment();
            this.handleException(uploadInput.getFunctionName(), mlTask.getTaskId(), e);
        }
        finally {
            this.mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void uploadModelFromUrl(MLUploadInput uploadInput, MLTask mlTask) {
        String taskId = mlTask.getTaskId();
        FunctionName functionName = mlTask.getFunctionName();
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            this.mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
            this.mlStats.createCounterStatIfAbsent(functionName, ActionName.UPLOAD, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment();
            this.mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
            String modelName = uploadInput.getModelName();
            String version = uploadInput.getVersion();
            Instant now = Instant.now();
            this.mlIndicesHandler.initModelIndexIfAbsent((ActionListener<Boolean>)ActionListener.wrap(res -> {
                MLModel mlModelMeta = MLModel.builder().name(modelName).algorithm(functionName).version(version).description(uploadInput.getDescription()).modelFormat(uploadInput.getModelFormat()).modelState(MLModelState.UPLOADING).modelConfig(uploadInput.getModelConfig()).createdTime(now).lastUpdateTime(now).build();
                IndexRequest indexModelMetaRequest = new IndexRequest(".plugins-ml-model");
                indexModelMetaRequest.source(mlModelMeta.toXContent(XContentBuilder.builder((XContent)XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS));
                indexModelMetaRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
                ActionListener listener = ActionListener.wrap(modelMetaRes -> {
                    String modelId = modelMetaRes.getId();
                    mlTask.setModelId(modelId);
                    log.info("create new model meta doc {} for upload task {}", (Object)modelId, (Object)taskId);
                    this.uploadModel(uploadInput, taskId, functionName, modelName, version, modelId);
                }, e -> {
                    log.error("Failed to index model meta doc", (Throwable)e);
                    this.handleException(functionName, taskId, (Exception)e);
                });
                this.client.index(indexModelMetaRequest, this.threadedActionListener("opensearch_ml_upload", listener));
            }, e -> {
                log.error("Failed to init model index", (Throwable)e);
                this.handleException(functionName, taskId, (Exception)e);
            }));
        }
        catch (Exception e2) {
            MLExceptionUtils.logException("Failed to upload model", e2, log);
            this.handleException(functionName, taskId, e2);
        }
        finally {
            this.mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
        }
    }

    private void uploadModel(MLUploadInput uploadInput, String taskId, FunctionName functionName, String modelName, String version, String modelId) {
        this.modelHelper.downloadAndSplit(modelId, modelName, version, uploadInput.getUrl(), ActionListener.wrap(result -> {
            Long modelSizeInBytes = (Long)result.get("model_size_in_bytes");
            if (modelSizeInBytes >= 0x100000000L) {
                throw new MLException("Model file size exceeds the limit of 4GB: " + modelSizeInBytes);
            }
            List chunkFiles = (List)result.get("chunk_files");
            String hashValue = (String)result.get("model_file_hash");
            Semaphore semaphore = new Semaphore(1);
            AtomicInteger uploaded = new AtomicInteger(0);
            AtomicBoolean failedToUploadChunk = new AtomicBoolean(false);
            for (String name : chunkFiles) {
                semaphore.tryAcquire(10L, TimeUnit.SECONDS);
                if (failedToUploadChunk.get()) {
                    throw new MLException("Failed to save model chunk");
                }
                File file = new File(name);
                byte[] bytes = Files.toByteArray((File)file);
                int chunkNum = Integer.parseInt(file.getName());
                Instant now = Instant.now();
                MLModel mlModel = MLModel.builder().modelId(modelId).name(modelName).algorithm(functionName).version(version).modelFormat(uploadInput.getModelFormat()).chunkNumber(Integer.valueOf(chunkNum)).totalChunks(Integer.valueOf(chunkFiles.size())).content(Base64.getEncoder().encodeToString(bytes)).createdTime(now).lastUpdateTime(now).build();
                IndexRequest indexRequest = new IndexRequest(".plugins-ml-model");
                String chunkId = this.getModelChunkId(modelId, chunkNum);
                indexRequest.id(chunkId);
                indexRequest.source(mlModel.toXContent(XContentBuilder.builder((XContent)XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS));
                indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
                this.client.index(indexRequest, ActionListener.wrap(r -> {
                    uploaded.getAndIncrement();
                    if (uploaded.get() == chunkFiles.size()) {
                        this.updateModelUploadStateAsDone(uploadInput, taskId, modelId, modelSizeInBytes, chunkFiles, hashValue);
                    } else {
                        FileUtils.deleteFileQuietly((File)file);
                    }
                    semaphore.release();
                }, e -> {
                    log.error("Failed to index model chunk " + chunkId, (Throwable)e);
                    failedToUploadChunk.set(true);
                    this.handleException(functionName, taskId, (Exception)e);
                    FileUtils.deleteFileQuietly((File)file);
                    this.deleteModel(modelId);
                    semaphore.release();
                    FileUtils.deleteFileQuietly((Path)this.mlEngine.getUploadModelPath(modelId));
                }));
            }
        }, e -> {
            log.error("Failed to index chunk file", (Throwable)e);
            FileUtils.deleteFileQuietly((Path)this.mlEngine.getUploadModelPath(modelId));
            this.deleteModel(modelId);
            this.handleException(functionName, taskId, (Exception)e);
        }));
    }

    private void uploadPrebuiltModel(MLUploadInput uploadInput, MLTask mlTask) {
        String taskId = mlTask.getTaskId();
        this.modelHelper.downloadPrebuiltModelConfig(taskId, uploadInput, ActionListener.wrap(mlUploadInput -> this.uploadModelFromUrl((MLUploadInput)mlUploadInput, mlTask), e -> {
            log.error("Failed to upload prebuilt model", (Throwable)e);
            this.handleException(uploadInput.getFunctionName(), taskId, (Exception)e);
        }));
    }

    private <T> ThreadedActionListener<T> threadedActionListener(String threadPoolName, ActionListener<T> listener) {
        return new ThreadedActionListener(log, this.threadPool, threadPoolName, listener, false);
    }

    public void checkAndAddRunningTask(MLTask mlTask, Integer runningTaskLimit) {
        MLNodeUtils.checkOpenCircuitBreaker(this.mlCircuitBreakerService, this.mlStats);
        this.mlTaskManager.checkLimitAndAddRunningTask(mlTask, runningTaskLimit);
    }

    private void updateModelUploadStateAsDone(MLUploadInput uploadInput, String taskId, String modelId, Long modelSizeInBytes, List<String> chunkFiles, String hashValue) {
        FunctionName functionName = uploadInput.getFunctionName();
        FileUtils.deleteFileQuietly((Path)this.mlEngine.getUploadModelPath(modelId));
        ImmutableMap updatedFields = ImmutableMap.of((Object)"model_state", (Object)MLModelState.UPLOADED, (Object)"last_uploaded_time", (Object)Instant.now().toEpochMilli(), (Object)"total_chunks", (Object)chunkFiles.size(), (Object)"model_content_hash_value", (Object)hashValue, (Object)"model_content_size_in_bytes", (Object)modelSizeInBytes);
        log.info("Model uploaded successfully, model id: {}, task id: {}", (Object)modelId, (Object)taskId);
        this.updateModel(modelId, (Map<String, Object>)updatedFields, (ActionListener<UpdateResponse>)ActionListener.wrap(updateResponse -> {
            this.mlTaskManager.updateMLTask(taskId, (Map<String, Object>)ImmutableMap.of((Object)"state", (Object)MLTaskState.COMPLETED, (Object)"model_id", (Object)modelId), 5000L, true);
            if (uploadInput.isLoadModel()) {
                this.loadModelAfterUploading(uploadInput, modelId);
            }
        }, e -> {
            log.error("Failed to update model", (Throwable)e);
            this.handleException(functionName, taskId, (Exception)e);
            this.deleteModel(modelId);
        }));
    }

    private void loadModelAfterUploading(MLUploadInput uploadInput, String modelId) {
        Object[] modelNodeIds = uploadInput.getModelNodeIds();
        log.debug("start loading model after uploading {} on nodes: {}", (Object)modelId, (Object)Arrays.toString(modelNodeIds));
        MLLoadModelRequest request = new MLLoadModelRequest(modelId, (String[])modelNodeIds, false, true);
        ActionListener listener = ActionListener.wrap(r -> log.debug("model loaded, response {}", r), e -> log.error("Failed to load model", (Throwable)e));
        this.client.execute((ActionType)MLLoadModelAction.INSTANCE, (ActionRequest)request, listener);
    }

    private void deleteModel(String modelId) {
        DeleteRequest deleteRequest = new DeleteRequest();
        ((DeleteRequest)deleteRequest.index(".plugins-ml-model")).id(modelId).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
        this.client.delete(deleteRequest);
        DeleteByQueryRequest deleteChunksRequest = (DeleteByQueryRequest)new DeleteByQueryRequest(new String[]{".plugins-ml-model"}).setQuery((QueryBuilder)new TermQueryBuilder("model_id", modelId)).setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN).setAbortOnVersionConflict(false);
        this.client.execute((ActionType)DeleteByQueryAction.INSTANCE, (ActionRequest)deleteChunksRequest);
    }

    private void handleException(FunctionName functionName, String taskId, Exception e) {
        this.mlStats.createCounterStatIfAbsent(functionName, ActionName.UPLOAD, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment();
        ImmutableMap updated = ImmutableMap.of((Object)"error", (Object)MLExceptionUtils.getRootCauseMessage(e), (Object)"state", (Object)MLTaskState.FAILED);
        this.mlTaskManager.updateMLTask(taskId, (Map<String, Object>)updated, 5000L, true);
    }

    public void loadModel(String modelId, String modelContentHash, FunctionName functionName, MLTask mlTask, ActionListener<String> listener) {
        this.mlStats.createCounterStatIfAbsent(functionName, ActionName.LOAD, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment();
        List workerNodes = mlTask.getWorkerNodes();
        if (this.modelCacheHelper.isModelLoaded(modelId)) {
            if (workerNodes != null && workerNodes.size() > 0) {
                log.info("Set new target node ids {} for model {}", (Object)Arrays.toString(workerNodes.toArray(new String[0])), (Object)modelId);
                this.modelCacheHelper.setTargetWorkerNodes(modelId, workerNodes);
            }
            listener.onResponse((Object)"successful");
            return;
        }
        if (this.modelCacheHelper.getLoadedModels().length >= this.maxModelPerNode) {
            listener.onFailure((Exception)new IllegalArgumentException("Exceed max model per node limit"));
            return;
        }
        this.modelCacheHelper.initModelState(modelId, MLModelState.LOADING, functionName, workerNodes);
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            this.checkAndAddRunningTask(mlTask, this.maxLoadTasksPerNode);
            this.getModel(modelId, (ActionListener<MLModel>)this.threadedActionListener("opensearch_ml_load", ActionListener.wrap(mlModel -> {
                if (!FunctionName.isDLModel((FunctionName)mlModel.getAlgorithm())) {
                    Predictable predictable = this.mlEngine.load(mlModel, null);
                    this.modelCacheHelper.setPredictor(modelId, predictable);
                    this.mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment();
                    this.modelCacheHelper.setModelState(modelId, MLModelState.LOADED);
                    listener.onResponse((Object)"successful");
                    return;
                }
                MLNodeUtils.checkOpenCircuitBreaker(this.mlCircuitBreakerService, this.mlStats);
                this.retrieveModelChunks((MLModel)mlModel, (ActionListener<File>)ActionListener.wrap(modelZipFile -> {
                    String hash = FileUtils.calculateFileHash((File)modelZipFile);
                    if (modelContentHash != null && !modelContentHash.equals(hash)) {
                        log.error("Model content hash can't match original hash value");
                        this.removeModel(modelId);
                        listener.onFailure((Exception)new IllegalArgumentException("model content changed"));
                        return;
                    }
                    log.debug("Model content matches original hash value, continue loading");
                    ImmutableMap params = ImmutableMap.of((Object)"model_zip_file", (Object)modelZipFile, (Object)"model_helper", (Object)this.modelHelper, (Object)"ml_engine", (Object)this.mlEngine);
                    Predictable predictable = this.mlEngine.load(mlModel, (Map)params);
                    try {
                        this.modelCacheHelper.setPredictor(modelId, predictable);
                        this.mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment();
                        this.modelCacheHelper.setModelState(modelId, MLModelState.LOADED);
                        listener.onResponse((Object)"successful");
                    }
                    catch (Exception e) {
                        log.error("Failed to add predictor to cache", (Throwable)e);
                        predictable.close();
                        listener.onFailure(e);
                    }
                }, e -> {
                    log.error("Failed to retrieve model " + modelId, (Throwable)e);
                    this.handleLoadModelException(modelId, functionName, listener, (Exception)e);
                }));
            }, e -> {
                log.error("Failed to load model " + modelId, (Throwable)e);
                this.handleLoadModelException(modelId, functionName, listener, (Exception)e);
            })));
        }
        catch (Exception e2) {
            this.handleLoadModelException(modelId, functionName, listener, e2);
        }
    }

    private void handleLoadModelException(String modelId, FunctionName functionName, ActionListener<String> listener, Exception e) {
        this.mlStats.createCounterStatIfAbsent(functionName, ActionName.LOAD, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment();
        this.removeModel(modelId);
        listener.onFailure(e);
    }

    public void getModel(String modelId, ActionListener<MLModel> listener) {
        this.getModel(modelId, null, null, listener);
    }

    public void getModel(String modelId, String[] includes, String[] excludes, ActionListener<MLModel> listener) {
        GetRequest getRequest = new GetRequest();
        FetchSourceContext featchContext = new FetchSourceContext(true, includes, excludes);
        ((GetRequest)getRequest.index(".plugins-ml-model")).id(modelId).fetchSourceContext(featchContext);
        this.client.get(getRequest, ActionListener.wrap(r -> {
            if (r != null && r.isExists()) {
                try (XContentParser parser = MLNodeUtils.createXContentParserFromRegistry(this.xContentRegistry, r.getSourceAsBytesRef());){
                    XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.nextToken(), (XContentParser)parser);
                    MLModel mlModel = MLModel.parse((XContentParser)parser);
                    mlModel.setModelId(modelId);
                    listener.onResponse((Object)mlModel);
                }
                catch (Exception e) {
                    log.error("Failed to parse ml task" + r.getId(), (Throwable)e);
                    listener.onFailure(e);
                }
            } else {
                listener.onFailure((Exception)new MLResourceNotFoundException("Fail to find model"));
            }
        }, e -> listener.onFailure(e)));
    }

    private void retrieveModelChunks(MLModel mlModelMeta, ActionListener<File> listener) throws InterruptedException {
        String modelId = mlModelMeta.getModelId();
        String modelName = mlModelMeta.getName();
        Integer totalChunks = mlModelMeta.getTotalChunks();
        GetRequest getRequest = new GetRequest();
        getRequest.index(".plugins-ml-model");
        getRequest.id();
        Semaphore semaphore = new Semaphore(1);
        AtomicBoolean stopNow = new AtomicBoolean(false);
        String modelZip = this.mlEngine.getLoadModelZipPath(modelId, modelName);
        ConcurrentLinkedDeque chunkFiles = new ConcurrentLinkedDeque();
        AtomicInteger retrievedChunks = new AtomicInteger(0);
        int i = 0;
        while (i < totalChunks) {
            semaphore.tryAcquire(10L, TimeUnit.SECONDS);
            if (stopNow.get()) {
                throw new MLException("Failed to load model");
            }
            String modelChunkId = this.getModelChunkId(modelId, i);
            int currentChunk = i++;
            this.getModel(modelChunkId, (ActionListener<MLModel>)this.threadedActionListener("opensearch_ml_load", ActionListener.wrap(model -> {
                Path chunkPath = this.mlEngine.getLoadModelChunkPath(modelId, Integer.valueOf(currentChunk));
                FileUtils.write((byte[])Base64.getDecoder().decode(model.getContent()), (String)chunkPath.toString());
                chunkFiles.add(new File(chunkPath.toUri()));
                retrievedChunks.getAndIncrement();
                if (retrievedChunks.get() == totalChunks.intValue()) {
                    File modelZipFile = new File(modelZip);
                    FileUtils.mergeFiles((Queue)chunkFiles, (File)modelZipFile);
                    listener.onResponse((Object)modelZipFile);
                }
                semaphore.release();
            }, e -> {
                stopNow.set(true);
                semaphore.release();
                log.error("Failed to retrieve model chunk " + modelChunkId, (Throwable)e);
                if (retrievedChunks.get() == totalChunks - 1) {
                    listener.onFailure((Exception)new MLResourceNotFoundException("Fail to find model chunk " + modelChunkId));
                }
            })));
        }
    }

    public void updateModel(String modelId, Map<String, Object> updatedFields) {
        this.updateModel(modelId, updatedFields, (ActionListener<UpdateResponse>)ActionListener.wrap(response -> {
            if (response.status() == RestStatus.OK) {
                log.debug("Updated ML model successfully: {}, model id: {}", (Object)response.status(), (Object)modelId);
            } else {
                log.error("Failed to update ML model {}, status: {}", (Object)modelId, (Object)response.status());
            }
        }, e -> log.error("Failed to update ML model: " + modelId, (Throwable)e)));
    }

    public void updateModel(String modelId, Map<String, Object> updatedFields, ActionListener<UpdateResponse> listener) {
        if (updatedFields == null || updatedFields.size() == 0) {
            listener.onFailure((Exception)new IllegalArgumentException("Updated fields is null or empty"));
            return;
        }
        HashMap<String, Object> newUpdatedFields = new HashMap<String, Object>();
        newUpdatedFields.putAll(updatedFields);
        newUpdatedFields.put("last_updated_time", Instant.now().toEpochMilli());
        UpdateRequest updateRequest = new UpdateRequest(".plugins-ml-model", modelId);
        updateRequest.doc(newUpdatedFields);
        updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
        if (newUpdatedFields.containsKey("model_state") && MODEL_DONE_STATES.contains(newUpdatedFields.get("model_state"))) {
            updateRequest.retryOnConflict(3);
        }
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            this.client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore()));
        }
        catch (Exception e) {
            listener.onFailure(e);
        }
    }

    public String getModelChunkId(String modelId, Integer chunkNumber) {
        return modelId + "_" + chunkNumber;
    }

    public void addModelWorkerNode(String modelId, String ... nodeIds) {
        if (nodeIds != null) {
            for (String nodeId : nodeIds) {
                this.modelCacheHelper.addWorkerNode(modelId, nodeId);
            }
        }
    }

    public void removeModelWorkerNode(String modelId, String ... nodeIds) {
        if (nodeIds != null) {
            for (String nodeId : nodeIds) {
                this.modelCacheHelper.removeWorkerNode(modelId, nodeId);
            }
        }
    }

    public void removeWorkerNodes(Set<String> removedNodes) {
        this.modelCacheHelper.removeWorkerNodes(removedNodes);
    }

    public synchronized Map<String, String> unloadModel(String[] modelIds) {
        HashMap<String, String> modelUnloadStatus = new HashMap<String, String>();
        if (modelIds != null && modelIds.length > 0) {
            log.debug("unload models {}", (Object)Arrays.toString(modelIds));
            for (String modelId : modelIds) {
                if (this.modelCacheHelper.isModelLoaded(modelId)) {
                    modelUnloadStatus.put(modelId, "unloaded");
                    this.mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).decrement();
                    this.mlStats.createCounterStatIfAbsent(this.getModelFunctionName(modelId), ActionName.UNLOAD, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment();
                } else {
                    modelUnloadStatus.put(modelId, "not_found");
                }
                this.removeModel(modelId);
            }
        } else {
            log.debug("unload all models {}", (Object)Arrays.toString(this.getLocalLoadedModels()));
            for (String modelId : this.getLocalLoadedModels()) {
                modelUnloadStatus.put(modelId, "unloaded");
                this.mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).decrement();
                this.mlStats.createCounterStatIfAbsent(this.getModelFunctionName(modelId), ActionName.UNLOAD, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment();
                this.removeModel(modelId);
            }
        }
        return modelUnloadStatus;
    }

    private void removeModel(String modelId) {
        this.modelCacheHelper.removeModel(modelId);
        this.modelHelper.deleteFileCache(modelId);
    }

    public String[] getWorkerNodes(String modelId) {
        return this.modelCacheHelper.getWorkerNodes(modelId);
    }

    public Predictable getPredictor(String modelId) {
        return this.modelCacheHelper.getPredictor(modelId);
    }

    public String[] getAllModelIds() {
        return this.modelCacheHelper.getAllModels();
    }

    public String[] getLocalLoadedModels() {
        return this.modelCacheHelper.getLoadedModels();
    }

    public synchronized void syncModelWorkerNodes(Map<String, Set<String>> modelWorkerNodes) {
        this.modelCacheHelper.syncWorkerNodes(modelWorkerNodes);
    }

    public void clearRoutingTable() {
        this.modelCacheHelper.clearWorkerNodes();
    }

    public MLModelProfile getModelProfile(String modelId) {
        return this.modelCacheHelper.getModelProfile(modelId);
    }

    public <T> T trackPredictDuration(String modelId, Supplier<T> supplier) {
        long start = System.nanoTime();
        T t = supplier.get();
        long end = System.nanoTime();
        double durationInMs = (double)(end - start) / 1000000.0;
        this.modelCacheHelper.addModelInferenceDuration(modelId, durationInMs);
        return t;
    }

    public FunctionName getModelFunctionName(String modelId) {
        return this.modelCacheHelper.getFunctionName(modelId);
    }

    public boolean isModelRunningOnNode(String modelId) {
        return this.modelCacheHelper.isModelRunningOnNode(modelId);
    }
}

