package org.apache.sysml.udf.lib;

import java.io.IOException;
import java.util.Iterator;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.caching.CacheException;
import org.apache.sysml.runtime.matrix.data.IJV;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.udf.FunctionParameter;
import org.apache.sysml.udf.Matrix;
import org.apache.sysml.udf.PackageFunction;
import org.apache.sysml.udf.Scalar;

/* loaded from: input_file:org/apache/sysml/udf/lib/CumSumProd.class */
public class CumSumProd extends PackageFunction {
    private static final long serialVersionUID = -7883258699548686065L;
    private Matrix ret;
    private MatrixBlock retMB;
    private MatrixBlock X;
    private MatrixBlock C;
    private double start;
    private boolean isReverse;
    int numRetRows;
    int numRetCols;
    double[] denseBlock;

    @Override // org.apache.sysml.udf.PackageFunction
    public int getNumFunctionOutputs() {
        return 1;
    }

    @Override // org.apache.sysml.udf.PackageFunction
    public FunctionParameter getFunctionOutput(int i) {
        if (i == 0) {
            return this.ret;
        }
        throw new RuntimeException("CumSumProd produces only one output");
    }

    @Override // org.apache.sysml.udf.PackageFunction
    public void execute() {
        try {
            this.X = ((Matrix) getFunctionInput(0)).getMatrixObject().acquireRead();
            this.C = ((Matrix) getFunctionInput(1)).getMatrixObject().acquireRead();
            if (this.X.getNumRows() != this.C.getNumRows()) {
                throw new RuntimeException("Number of rows of X and C should match");
            }
            if (this.X.getNumColumns() != this.C.getNumColumns() && this.C.getNumColumns() != 1) {
                throw new RuntimeException("Incorrect Number of columns of X and C (Expected C to be of same dimension or a vector)");
            }
            this.start = Double.parseDouble(((Scalar) getFunctionInput(2)).getValue());
            this.isReverse = Boolean.parseBoolean(((Scalar) getFunctionInput(3)).getValue());
            this.numRetRows = this.X.getNumRows();
            this.numRetCols = this.X.getNumColumns();
            allocateOutput();
            this.denseBlock = this.retMB.getDenseBlockValues();
            if (this.X.isInSparseFormat()) {
                Iterator<IJV> sparseBlockIterator = this.X.getSparseBlockIterator();
                while (sparseBlockIterator.hasNext()) {
                    IJV next = sparseBlockIterator.next();
                    this.denseBlock[(next.getI() * this.numRetCols) + next.getJ()] = next.getV();
                }
            } else if (this.X.getDenseBlock() != null) {
                System.arraycopy(this.X.getDenseBlockValues(), 0, this.denseBlock, 0, this.denseBlock.length);
            }
            if (this.isReverse) {
                addCNConstant(this.numRetRows - 1, this.start);
                for (int i = this.numRetRows - 2; i >= 0; i--) {
                    addC(i, false);
                }
            } else {
                addCNConstant(0, this.start);
                for (int i2 = 1; i2 < this.numRetRows; i2++) {
                    addC(i2, true);
                }
            }
            ((Matrix) getFunctionInput(1)).getMatrixObject().release();
            ((Matrix) getFunctionInput(0)).getMatrixObject().release();
            this.retMB.recomputeNonZeros();
            try {
                this.retMB.examSparsity();
                this.ret.setMatrixDoubleArray(this.retMB, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
            } catch (IOException e) {
                throw new RuntimeException("Error while executing CumSumProd", e);
            } catch (DMLRuntimeException e2) {
                throw new RuntimeException("Error while executing CumSumProd", e2);
            }
        } catch (CacheException e3) {
            throw new RuntimeException("Error while executing CumSumProd", e3);
        }
    }

    private void addCNConstant(int i, double d) {
        boolean z = ((long) this.C.getNumColumns()) != this.ret.getNumCols();
        if (this.C.isInSparseFormat()) {
            Iterator<IJV> sparseBlockIterator = this.C.getSparseBlockIterator(i, i + 1);
            while (sparseBlockIterator.hasNext()) {
                IJV next = sparseBlockIterator.next();
                if (z) {
                    double v = next.getV();
                    for (int i2 = next.getI() * this.numRetCols; i2 < (next.getI() + 1) * this.numRetCols; i2++) {
                        double[] dArr = this.denseBlock;
                        int i3 = i2;
                        dArr[i3] = dArr[i3] + (v * d);
                    }
                } else {
                    double[] dArr2 = this.denseBlock;
                    int i4 = (next.getI() * this.numRetCols) + next.getJ();
                    dArr2[i4] = dArr2[i4] + (next.getV() * d);
                }
            }
            return;
        }
        double[] denseBlockValues = this.C.getDenseBlockValues();
        if (denseBlockValues != null) {
            if (z) {
                for (int i5 = i * this.numRetCols; i5 < (i + 1) * this.numRetCols; i5++) {
                    double[] dArr3 = this.denseBlock;
                    int i6 = i5;
                    dArr3[i6] = dArr3[i6] + (denseBlockValues[i] * d);
                }
                return;
            }
            for (int i7 = i * this.numRetCols; i7 < (i + 1) * this.numRetCols; i7++) {
                double[] dArr4 = this.denseBlock;
                int i8 = i7;
                dArr4[i8] = dArr4[i8] + (denseBlockValues[i7] * d);
            }
        }
    }

    private void addC(int i, boolean z) {
        boolean z2 = ((long) this.C.getNumColumns()) != this.ret.getNumCols();
        if (this.C.isInSparseFormat()) {
            Iterator<IJV> sparseBlockIterator = this.C.getSparseBlockIterator(i, i + 1);
            while (sparseBlockIterator.hasNext()) {
                IJV next = sparseBlockIterator.next();
                if (z2) {
                    double v = next.getV();
                    for (int i2 = next.getI() * this.numRetCols; i2 < (next.getI() + 1) * this.numRetCols; i2++) {
                        double d = z ? this.denseBlock[((next.getI() - 1) * this.numRetCols) + next.getJ()] : this.denseBlock[((next.getI() + 1) * this.numRetCols) + next.getJ()];
                        double[] dArr = this.denseBlock;
                        int i3 = i2;
                        dArr[i3] = dArr[i3] + (v * d);
                    }
                } else if (z) {
                    double[] dArr2 = this.denseBlock;
                    int i4 = (next.getI() * this.numRetCols) + next.getJ();
                    dArr2[i4] = dArr2[i4] + (next.getV() * this.denseBlock[((next.getI() - 1) * this.numRetCols) + next.getJ()]);
                } else {
                    double[] dArr3 = this.denseBlock;
                    int i5 = (next.getI() * this.numRetCols) + next.getJ();
                    dArr3[i5] = dArr3[i5] + (next.getV() * this.denseBlock[((next.getI() + 1) * this.numRetCols) + next.getJ()]);
                }
            }
            return;
        }
        double[] denseBlockValues = this.C.getDenseBlockValues();
        if (denseBlockValues != null) {
            if (z2) {
                for (int i6 = i * this.numRetCols; i6 < (i + 1) * this.numRetCols; i6++) {
                    double d2 = z ? this.denseBlock[i6 - this.numRetCols] : this.denseBlock[i6 + this.numRetCols];
                    double[] dArr4 = this.denseBlock;
                    int i7 = i6;
                    dArr4[i7] = dArr4[i7] + (denseBlockValues[i] * d2);
                }
                return;
            }
            for (int i8 = i * this.numRetCols; i8 < (i + 1) * this.numRetCols; i8++) {
                double d3 = z ? this.denseBlock[i8 - this.numRetCols] : this.denseBlock[i8 + this.numRetCols];
                double[] dArr5 = this.denseBlock;
                int i9 = i8;
                dArr5[i9] = dArr5[i9] + (denseBlockValues[i8] * d3);
            }
        }
    }

    private void allocateOutput() {
        this.ret = new Matrix(createOutputFilePathAndName("TMP"), this.numRetRows, this.numRetCols, Matrix.ValueType.Double);
        this.retMB = new MatrixBlock(this.numRetRows, this.numRetCols, false);
        this.retMB.allocateDenseBlock();
    }
}
