/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.pytorch.engine;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.index.NDArrayIndexer;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.index.dim.NDIndexBooleans;
import ai.djl.ndarray.index.full.NDIndexFullPick;
import ai.djl.ndarray.index.full.NDIndexFullSlice;
import ai.djl.ndarray.index.full.NDIndexFullTake;
import ai.djl.ndarray.types.Shape;
import ai.djl.pytorch.engine.PtNDArray;
import ai.djl.pytorch.engine.PtNDManager;
import ai.djl.pytorch.jni.JniUtils;
import java.util.Stack;

public class PtNDArrayIndexer
extends NDArrayIndexer {
    private PtNDManager manager;

    PtNDArrayIndexer(PtNDManager manager) {
        this.manager = manager;
    }

    public NDArray get(NDArray array, NDIndexFullPick fullPick) {
        return JniUtils.pick(this.manager.from(array), this.manager.from(fullPick.getIndices()), fullPick.getAxis());
    }

    public NDArray get(NDArray array, NDIndexFullTake fullTake) {
        return JniUtils.take(this.manager.from(array), this.manager.from(fullTake.getIndices()), this.manager);
    }

    public NDArray get(NDArray array, NDIndexFullSlice fullSlice) {
        long[] min = fullSlice.getMin();
        long[] max = fullSlice.getMax();
        long[] step = fullSlice.getStep();
        try (PtNDArray res = JniUtils.index(this.manager.from(array), min, max, step, this.manager);){
            PtNDArray ptNDArray = res.squeeze(fullSlice.getToSqueeze());
            return ptNDArray;
        }
    }

    public NDArray get(NDArray array, NDIndex index) {
        if (index.getRank() == 0) {
            if (array.getShape().isScalar()) {
                return array.getManager() == this.manager ? array.duplicate() : this.manager.create(array.toByteBuffer(), array.getShape(), array.getDataType());
            }
            index.addAllDim();
        }
        if (array == null || array instanceof PtNDArray) {
            return JniUtils.indexAdv((PtNDArray)array, index, this.manager);
        }
        PtNDArray arrayNew = this.manager.create(array.toByteBuffer(), array.getShape(), array.getDataType());
        return JniUtils.indexAdv(arrayNew, index, this.manager);
    }

    public void set(NDArray array, NDIndex index, Object data) {
        PtNDArray ptArray;
        PtNDArray ptNDArray = ptArray = array instanceof PtNDArray ? (PtNDArray)array : this.manager.create(array.toByteBuffer(), array.getShape(), array.getDataType());
        if (data instanceof Number) {
            JniUtils.indexAdvPut(ptArray, index, (PtNDArray)this.manager.create((Number)data));
        } else if (data instanceof NDArray) {
            JniUtils.indexAdvPut(ptArray, index, (PtNDArray)((Object)data));
        } else {
            throw new IllegalArgumentException("The type of value to assign cannot be other than NDArray and Number.");
        }
    }

    public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) {
        Stack<NDArray> prepareValue = new Stack<NDArray>();
        prepareValue.add(value);
        prepareValue.add(((NDArray)prepareValue.peek()).toDevice(array.getDevice(), false));
        Shape targetShape = fullSlice.getShape();
        while (targetShape.size() > value.size()) {
            targetShape = targetShape.slice(1);
        }
        prepareValue.add(((NDArray)prepareValue.peek()).reshape(targetShape));
        prepareValue.add(((NDArray)prepareValue.peek()).broadcast(fullSlice.getShape()));
        JniUtils.indexSet(this.manager.from(array), this.manager.from((NDArray)prepareValue.peek()), fullSlice.getMin(), fullSlice.getMax(), fullSlice.getStep());
        for (NDArray toClean : prepareValue) {
            if (toClean == value) continue;
            toClean.close();
        }
    }

    public void set(NDArray array, NDIndexBooleans indices, NDArray value) {
        try (NDArray mask = indices.getIndex();){
            JniUtils.booleanMaskSet(this.manager.from(array), this.manager.from(value), this.manager.from(mask));
        }
    }

    public void set(NDArray array, NDIndexFullSlice fullSlice, Number value) {
        this.set(array, fullSlice, array.getManager().create(value));
    }
}

