package org.apache.sysml.runtime.matrix.data;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.util.CommonThreadPool;
import org.apache.sysml.runtime.util.ConvolutionUtils;

/* loaded from: input_file:org/apache/sysml/runtime/matrix/data/LibMatrixDNN.class */
public class LibMatrixDNN {
    protected static final Log LOG = LogFactory.getLog(LibMatrixDNN.class.getName());
    private static AtomicLong conv2dSparseCount = new AtomicLong(0);
    private static AtomicLong conv2dDenseCount = new AtomicLong(0);
    private static AtomicLong conv2dBwdFilterSparseCount = new AtomicLong(0);
    private static AtomicLong conv2dBwdFilterDenseCount = new AtomicLong(0);
    private static AtomicLong conv2dBwdDataSparseCount = new AtomicLong(0);
    private static AtomicLong conv2dBwdDataDenseCount = new AtomicLong(0);
    private static AtomicLong im2colSparseCount = new AtomicLong(0);
    private static AtomicLong im2colDenseCount = new AtomicLong(0);
    private static AtomicLong maxPoolBwdSparseCount = new AtomicLong(0);
    private static AtomicLong maxPoolBwdDenseCount = new AtomicLong(0);
    static AtomicLong loopedConvMatMultTime = new AtomicLong(0);
    static AtomicLong loopedConvIm2ColTime = new AtomicLong(0);
    static AtomicLong loopedConvBwdFilterMatMultTime = new AtomicLong(0);
    static AtomicLong loopedConvBwdFilterIm2ColTime = new AtomicLong(0);
    static AtomicLong loopedConvBwdDataMatMultTime = new AtomicLong(0);
    static AtomicLong loopedConvBwdDataCol2ImTime = new AtomicLong(0);

    /* loaded from: input_file:org/apache/sysml/runtime/matrix/data/LibMatrixDNN$PoolingType.class */
    public enum PoolingType {
        MAX,
        AVG
    }

    public static void appendStatistics(StringBuilder sb) {
        if (DMLScript.FINEGRAINED_STATISTICS) {
            sb.append("LibMatrixDNN dense count (conv/bwdF/bwdD/im2col/maxBwd):\t" + conv2dDenseCount.get() + "/" + conv2dBwdFilterDenseCount.get() + "/" + conv2dBwdDataDenseCount.get() + "/" + im2colDenseCount.get() + "/" + maxPoolBwdDenseCount.get() + ".\n");
            sb.append("LibMatrixDNN sparse count (conv/bwdF/bwdD/im2col/maxBwd):\t" + conv2dSparseCount.get() + "/" + conv2dBwdFilterSparseCount.get() + "/" + conv2dBwdDataSparseCount.get() + "/" + im2colSparseCount.get() + "/" + maxPoolBwdSparseCount.get() + ".\n");
            sb.append("LibMatrixDNN conv(im2col/matmult), bwdF (im2col/matmult), bwdD (col2im/matmult) time:\t" + String.format("%.3f", Double.valueOf(loopedConvIm2ColTime.get() * 1.0E-9d)) + "/" + String.format("%.3f", Double.valueOf(loopedConvMatMultTime.get() * 1.0E-9d)) + "/" + String.format("%.3f", Double.valueOf(loopedConvBwdFilterIm2ColTime.get() * 1.0E-9d)) + "/" + String.format("%.3f", Double.valueOf(loopedConvBwdFilterMatMultTime.get() * 1.0E-9d)) + "/" + String.format("%.3f", Double.valueOf(loopedConvBwdDataCol2ImTime.get() * 1.0E-9d)) + "/" + String.format("%.3f", Double.valueOf(loopedConvBwdDataMatMultTime.get() * 1.0E-9d)) + " sec.\n");
        }
    }

    public static void resetStatistics() {
        conv2dDenseCount.set(0L);
        conv2dBwdFilterDenseCount.set(0L);
        conv2dBwdDataDenseCount.set(0L);
        im2colDenseCount.set(0L);
        maxPoolBwdDenseCount.set(0L);
        conv2dSparseCount.set(0L);
        conv2dBwdFilterSparseCount.set(0L);
        conv2dBwdDataSparseCount.set(0L);
        im2colSparseCount.set(0L);
        maxPoolBwdSparseCount.set(0L);
        loopedConvIm2ColTime.set(0L);
        loopedConvMatMultTime.set(0L);
        loopedConvBwdFilterMatMultTime.set(0L);
        loopedConvBwdFilterIm2ColTime.set(0L);
        loopedConvBwdDataMatMultTime.set(0L);
        loopedConvBwdDataCol2ImTime.set(0L);
    }

    public static void conv2d(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, ConvolutionParameters convolutionParameters) throws DMLRuntimeException {
        checkInputsConv2d(matrixBlock, matrixBlock2, matrixBlock3, convolutionParameters);
        if (convolutionParameters.bias != null && convolutionParameters.bias.isInSparseFormat()) {
            convolutionParameters.bias.sparseToDense();
        }
        matrixBlock3.setNonZeros(execute(LibMatrixDNNConv2d.getConv2dWorkers(convolutionParameters), convolutionParameters));
        matrixBlock3.examSparsity();
    }

    public static void conv2dBackwardData(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, ConvolutionParameters convolutionParameters) throws DMLRuntimeException {
        checkInputsConv2dBackwardData(matrixBlock, matrixBlock2, matrixBlock3, convolutionParameters);
        matrixBlock3.setNonZeros(execute(LibMatrixDNNConv2d.getConv2dBackwardDataWorkers(convolutionParameters), convolutionParameters));
        matrixBlock3.examSparsity();
    }

    public static void conv2dBackwardFilter(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, ConvolutionParameters convolutionParameters) throws DMLRuntimeException {
        checkInputsConv2dBackwardFilter(matrixBlock, matrixBlock2, matrixBlock3, convolutionParameters);
        execute(LibMatrixDNNConv2d.getConv2dBackwardFilterWorkers(convolutionParameters), convolutionParameters);
        matrixBlock3.recomputeNonZeros();
        matrixBlock3.examSparsity();
    }

    public static void pooling(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, ConvolutionParameters convolutionParameters, PoolingType poolingType) throws DMLRuntimeException {
        convolutionParameters.input1 = matrixBlock;
        convolutionParameters.output = matrixBlock2;
        if (matrixBlock.getNumColumns() != convolutionParameters.C * convolutionParameters.H * convolutionParameters.W || matrixBlock.getNumRows() != convolutionParameters.N) {
            throw new DMLRuntimeException("Incorrect input dimensions in maxpooling:" + matrixBlock.getNumRows() + " " + matrixBlock.getNumColumns() + " " + convolutionParameters.N + " " + (convolutionParameters.C * convolutionParameters.H * convolutionParameters.W));
        }
        if (!convolutionParameters.isStride1Pad0() || matrixBlock.sparse) {
            fillIndexesArray(convolutionParameters);
        }
        matrixBlock2.setNonZeros(execute(LibMatrixDNNPooling.getPoolingWorkers(convolutionParameters, poolingType), convolutionParameters));
        matrixBlock2.examSparsity();
    }

    public static void poolingBackward(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, ConvolutionParameters convolutionParameters, boolean z, PoolingType poolingType) throws DMLRuntimeException {
        convolutionParameters.input1 = matrixBlock;
        convolutionParameters.input2 = matrixBlock2;
        convolutionParameters.output = matrixBlock3;
        if (poolingType == PoolingType.MAX && (matrixBlock.getNumColumns() != convolutionParameters.C * convolutionParameters.H * convolutionParameters.W || matrixBlock.getNumRows() != convolutionParameters.N)) {
            throw new DMLRuntimeException("Incorrect input dimensions in maxpooling_backward:" + matrixBlock.getNumRows() + " " + matrixBlock.getNumColumns() + " " + convolutionParameters.N + " " + (convolutionParameters.K * convolutionParameters.P * convolutionParameters.Q));
        }
        if (matrixBlock2.getNumColumns() != convolutionParameters.C * convolutionParameters.P * convolutionParameters.Q || matrixBlock2.getNumRows() != convolutionParameters.N) {
            throw new DMLRuntimeException("Incorrect dout dimensions in pooling_backward:" + matrixBlock.getNumRows() + " " + matrixBlock.getNumColumns() + " " + convolutionParameters.N + " " + (convolutionParameters.K * convolutionParameters.P * convolutionParameters.Q));
        }
        if (DMLScript.FINEGRAINED_STATISTICS) {
            if (poolingType == PoolingType.MAX ? matrixBlock.isInSparseFormat() || matrixBlock2.isInSparseFormat() : matrixBlock2.isInSparseFormat()) {
                maxPoolBwdSparseCount.addAndGet(1L);
            } else {
                maxPoolBwdDenseCount.addAndGet(1L);
            }
        }
        if (convolutionParameters.output.isInSparseFormat()) {
            throw new DMLRuntimeException("Sparse pooling_backward is not supported");
        }
        if (poolingType == PoolingType.AVG) {
            fillIndexesArray(convolutionParameters);
        } else if (!convolutionParameters.input1.isInSparseFormat() || convolutionParameters.input2.isInSparseFormat()) {
            fillIndexesArray(convolutionParameters);
        }
        matrixBlock3.setNonZeros(execute(LibMatrixDNNPooling.getPoolingBackwardWorkers(convolutionParameters, z, poolingType), convolutionParameters));
        matrixBlock3.examSparsity();
    }

    public static void reluBackward(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, int i) throws DMLRuntimeException {
        ConvolutionParameters convolutionParameters = new ConvolutionParameters(matrixBlock.getNumRows(), -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, i);
        convolutionParameters.input1 = matrixBlock;
        convolutionParameters.input2 = matrixBlock2;
        convolutionParameters.output = matrixBlock3;
        if (matrixBlock.getNumRows() != matrixBlock2.getNumRows() || matrixBlock.getNumColumns() != matrixBlock2.getNumColumns()) {
            throw new DMLRuntimeException("Incorrect dimensions for relu_backward:" + matrixBlock.getNumRows() + " != " + matrixBlock2.getNumRows() + " || " + matrixBlock.getNumColumns() + " != " + matrixBlock2.getNumColumns());
        }
        matrixBlock3.setNonZeros(execute(LibMatrixDNNRelu.getReluBackwardWorkers(convolutionParameters), convolutionParameters));
        matrixBlock3.examSparsity();
    }

    public static void biasAdd(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, int i) throws DMLRuntimeException {
        int numRows = matrixBlock.getNumRows();
        int numRows2 = matrixBlock2.getNumRows();
        int numColumns = matrixBlock.getNumColumns() / numRows2;
        if (matrixBlock2.getNumColumns() != 1 || matrixBlock.getNumColumns() % numRows2 != 0) {
            throw new DMLRuntimeException("Incorrect inputs for bias_add: input[" + numRows + " X " + matrixBlock.getNumColumns() + "] and bias[" + numRows2 + " X " + matrixBlock2.getNumColumns() + "]");
        }
        double[] denseBlockValues = matrixBlock3.getDenseBlockValues();
        if (matrixBlock.isEmptyBlock()) {
            for (int i2 = 0; i2 < numRows; i2++) {
                ConvolutionUtils.fillBias(matrixBlock2, denseBlockValues, i2, i2 + 1, numRows, numRows2, numColumns);
            }
        } else {
            matrixBlock3.copy(matrixBlock);
            int i3 = 0;
            if (matrixBlock2.isInSparseFormat()) {
                matrixBlock2.sparseToDense();
            }
            double[] denseBlockValues2 = matrixBlock2.getDenseBlockValues();
            for (int i4 = 0; i4 < numRows; i4++) {
                for (int i5 = 0; i5 < numRows2; i5++) {
                    double d = denseBlockValues2[i5];
                    int i6 = 0;
                    while (i6 < numColumns) {
                        int i7 = i3;
                        denseBlockValues[i7] = denseBlockValues[i7] + d;
                        i6++;
                        i3++;
                    }
                }
            }
        }
        matrixBlock3.recomputeNonZeros();
        matrixBlock3.examSparsity();
    }

    public static void biasMultiply(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, int i) throws DMLRuntimeException {
        int numRows = matrixBlock.getNumRows();
        int numRows2 = matrixBlock2.getNumRows();
        int numColumns = matrixBlock.getNumColumns() / numRows2;
        ConvolutionParameters convolutionParameters = new ConvolutionParameters(numRows, numColumns, -1, -1, numRows2, -1, -1, -1, -1, -1, -1, i);
        convolutionParameters.input1 = matrixBlock;
        convolutionParameters.input2 = matrixBlock2;
        convolutionParameters.output = matrixBlock3;
        if (matrixBlock2.getNumColumns() != 1 || matrixBlock.getNumColumns() % numRows2 != 0) {
            throw new DMLRuntimeException("Incorrect inputs for bias_multiply: input[" + numRows + " X " + matrixBlock.getNumColumns() + "] and bias[" + numRows2 + " X " + matrixBlock2.getNumColumns() + "]");
        }
        if (matrixBlock.isEmptyBlock() || matrixBlock2.isEmptyBlock()) {
            convolutionParameters.output.setNonZeros(0L);
            return;
        }
        matrixBlock3.copy(matrixBlock);
        if (matrixBlock2.isInSparseFormat()) {
            matrixBlock2.sparseToDense();
        }
        double[] denseBlockValues = matrixBlock2.getDenseBlockValues();
        if (matrixBlock.isInSparseFormat()) {
            for (int i2 = 0; i2 < numRows2; i2++) {
                if (denseBlockValues[i2] == 0.0d) {
                    for (int i3 = 0; i3 < numRows; i3++) {
                        matrixBlock3.sparseBlock.deleteIndexRange(i3, i2 * numColumns, (i2 + 1) * numColumns);
                    }
                }
            }
            for (int i4 = 0; i4 < numRows; i4++) {
                if (!matrixBlock3.sparseBlock.isEmpty(i4)) {
                    int pos = matrixBlock3.sparseBlock.pos(i4);
                    int size = matrixBlock3.sparseBlock.size(i4);
                    int[] indexes = matrixBlock3.sparseBlock.indexes(i4);
                    double[] values = matrixBlock3.sparseBlock.values(i4);
                    for (int i5 = pos; i5 < pos + size; i5++) {
                        int i6 = indexes[i5] % numColumns;
                        if (denseBlockValues[i6] != 0.0d) {
                            int i7 = i5;
                            values[i7] = values[i7] * denseBlockValues[i6];
                        }
                    }
                }
            }
        } else {
            double[] denseBlockValues2 = matrixBlock3.getDenseBlockValues();
            int i8 = 0;
            for (int i9 = 0; i9 < numRows; i9++) {
                for (int i10 = 0; i10 < numRows2; i10++) {
                    double d = denseBlockValues[i10];
                    int i11 = 0;
                    while (i11 < numColumns) {
                        int i12 = i8;
                        denseBlockValues2[i12] = denseBlockValues2[i12] * d;
                        i11++;
                        i8++;
                    }
                }
            }
        }
        convolutionParameters.output.recomputeNonZeros();
        convolutionParameters.output.examSparsity();
    }

    private static long execute(ArrayList<Callable<Long>> arrayList, ConvolutionParameters convolutionParameters) throws DMLRuntimeException {
        int constrainedNumThreads = OptimizerUtils.getConstrainedNumThreads(convolutionParameters.numThreads);
        long j = 0;
        try {
            if (constrainedNumThreads == 1) {
                Iterator<Callable<Long>> it = arrayList.iterator();
                while (it.hasNext()) {
                    j += it.next().call().longValue();
                }
            } else {
                ExecutorService executorService = CommonThreadPool.get(Math.min(constrainedNumThreads, convolutionParameters.N));
                List invokeAll = executorService.invokeAll(arrayList);
                executorService.shutdown();
                Iterator it2 = invokeAll.iterator();
                while (it2.hasNext()) {
                    j += ((Long) ((Future) it2.next()).get()).longValue();
                }
            }
            return j;
        } catch (Exception e) {
            throw new DMLRuntimeException("Error while executing multi-threaded tasks", e);
        }
    }

    private static void checkOrThrowException(String str, long j, long j2) throws DMLRuntimeException {
        if (j != j2) {
            throw new DMLRuntimeException(str + ":" + j + " != " + j2);
        }
    }

    private static void checkOrThrowException(String str, long j, long j2, long j3, long j4) throws DMLRuntimeException {
        if (j != j2 * j3 * j4) {
            throw new DMLRuntimeException(str + ":" + j + " != (" + j2 + " * " + j3 + " * " + j4);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void checkInputsConv2dBackwardData(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, ConvolutionParameters convolutionParameters) throws DMLRuntimeException {
        convolutionParameters.input1 = matrixBlock;
        convolutionParameters.input2 = matrixBlock2;
        convolutionParameters.output = matrixBlock3;
        checkOrThrowException("Incorrect input to conv2d_backward_data: Number of rows of input filter != number of filters in filter_shape", matrixBlock.getNumRows(), convolutionParameters.K);
        checkOrThrowException("Incorrect input to conv2d_backward_data: Number of columns of input filter != channels*filter_height*filter_height in filter_shape", matrixBlock.getNumColumns(), convolutionParameters.C, convolutionParameters.R, convolutionParameters.S);
        checkOrThrowException("Incorrect input to conv2d_backward_data: Number of rows of input errors != batch size in input_shape", matrixBlock2.getNumRows(), convolutionParameters.N);
        checkOrThrowException("Incorrect input to conv2d_backward_data: Number of columns of input errors != expected input error channels*height*width", matrixBlock2.getNumColumns(), convolutionParameters.K, convolutionParameters.P, convolutionParameters.Q);
        if (convolutionParameters.stride_h <= 0 || convolutionParameters.stride_w <= 0) {
            throw new DMLRuntimeException("Only positive strides supported:" + convolutionParameters.stride_h + ", " + convolutionParameters.stride_w);
        }
        if (DMLScript.FINEGRAINED_STATISTICS) {
            if (matrixBlock.isInSparseFormat() || matrixBlock2.isInSparseFormat()) {
                conv2dBwdDataSparseCount.addAndGet(1L);
            } else {
                conv2dBwdDataDenseCount.addAndGet(1L);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void checkInputsConv2dBackwardFilter(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, ConvolutionParameters convolutionParameters) throws DMLRuntimeException {
        convolutionParameters.input1 = matrixBlock;
        convolutionParameters.input2 = matrixBlock2;
        convolutionParameters.output = matrixBlock3;
        checkOrThrowException("Incorrect input to conv2d_backward_filter: Number of rows of input data != batch size in input_shape", matrixBlock.getNumRows(), convolutionParameters.N);
        checkOrThrowException("Incorrect input to conv2d_backward_filter: Number of columns of input data != channels*input_height*input_height in input_shape", matrixBlock.getNumColumns(), convolutionParameters.C, convolutionParameters.H, convolutionParameters.W);
        checkOrThrowException("Incorrect input to conv2d_backward_filter: Number of rows of input errors != batch size in input_shape", matrixBlock2.getNumRows(), convolutionParameters.N);
        checkOrThrowException("Incorrect input to conv2d_backward_filter: Number of columns of input errors != expected input error channels*height*width", matrixBlock2.getNumColumns(), convolutionParameters.K, convolutionParameters.P, convolutionParameters.Q);
        if (convolutionParameters.stride_h <= 0 || convolutionParameters.stride_w <= 0) {
            throw new DMLRuntimeException("Only positive strides supported:" + convolutionParameters.stride_h + ", " + convolutionParameters.stride_w);
        }
        if (DMLScript.FINEGRAINED_STATISTICS) {
            if (matrixBlock.isInSparseFormat() || matrixBlock2.isInSparseFormat()) {
                conv2dBwdFilterSparseCount.addAndGet(1L);
            } else {
                conv2dBwdFilterDenseCount.addAndGet(1L);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void checkInputsConv2d(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, ConvolutionParameters convolutionParameters) throws DMLRuntimeException {
        convolutionParameters.input1 = matrixBlock;
        convolutionParameters.input2 = matrixBlock2;
        convolutionParameters.output = matrixBlock3;
        checkOrThrowException("Incorrect input to conv2d: Number of rows of input filter != number of filters in filter_shape", matrixBlock2.getNumRows(), convolutionParameters.K);
        checkOrThrowException("Incorrect input to conv2d: Number of columns of input filter != channels*filter_height*filter_height in filter_shape", matrixBlock2.getNumColumns(), convolutionParameters.C, convolutionParameters.R, convolutionParameters.S);
        checkOrThrowException("Incorrect input to conv2d: Number of rows of input data != batch size in input_shape", matrixBlock.getNumRows(), convolutionParameters.N);
        checkOrThrowException("Incorrect input to conv2d: Number of columns of input data != channels*input_height*input_height in input_shape", matrixBlock.getNumColumns(), convolutionParameters.C, convolutionParameters.H, convolutionParameters.W);
        if (convolutionParameters.stride_h <= 0 || convolutionParameters.stride_w <= 0) {
            throw new DMLRuntimeException("Only positive strides supported:" + convolutionParameters.stride_h + ", " + convolutionParameters.stride_w);
        }
        if (DMLScript.FINEGRAINED_STATISTICS) {
            if (matrixBlock.isInSparseFormat() || matrixBlock2.isInSparseFormat()) {
                conv2dSparseCount.addAndGet(1L);
            } else {
                conv2dDenseCount.addAndGet(1L);
            }
        }
    }

    private static void fillIndexesArray(ConvolutionParameters convolutionParameters) {
        convolutionParameters.start_indexes_h = new int[convolutionParameters.P];
        convolutionParameters.end_indexes_h = new int[convolutionParameters.P];
        convolutionParameters.start_indexes_w = new int[convolutionParameters.Q];
        convolutionParameters.end_indexes_w = new int[convolutionParameters.Q];
        int i = 0;
        int i2 = -convolutionParameters.pad_h;
        while (true) {
            int i3 = i2;
            if (i >= convolutionParameters.P) {
                break;
            }
            convolutionParameters.start_indexes_h[i] = Math.max(i3, 0);
            convolutionParameters.end_indexes_h[i] = Math.min(i3 + convolutionParameters.R, convolutionParameters.H);
            i++;
            i2 = i3 + convolutionParameters.stride_h;
        }
        int i4 = 0;
        int i5 = -convolutionParameters.pad_w;
        while (true) {
            int i6 = i5;
            if (i4 >= convolutionParameters.Q) {
                return;
            }
            convolutionParameters.start_indexes_w[i4] = Math.max(i6, 0);
            convolutionParameters.end_indexes_w[i4] = Math.min(i6 + convolutionParameters.S, convolutionParameters.W);
            i4++;
            i5 = i6 + convolutionParameters.stride_w;
        }
    }
}
