package org.apache.sysml.runtime.matrix.data;

import jcuda.CudaException;
import jcuda.Pointer;
import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnActivationDescriptor;
import jcuda.jcudnn.cudnnHandle;
import jcuda.jcudnn.cudnnStatus;
import jcuda.jcudnn.cudnnTensorDescriptor;
import jcuda.runtime.JCuda;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysml.runtime.instructions.gpu.context.CSRPointer;
import org.apache.sysml.runtime.instructions.gpu.context.ExecutionConfig;
import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
import org.apache.sysml.utils.GPUStatistics;

/* loaded from: input_file:org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.class */
public class LibMatrixCuDNN extends LibMatrixCUDA {
    private static final boolean RECOMPUTE_DENSE_NNZ = false;
    protected static int CONVOLUTION_PREFERENCE = 0;
    private static final Log LOG = LogFactory.getLog(LibMatrixCuDNN.class.getName());

    /* JADX INFO: Access modifiers changed from: protected */
    public static cudnnHandle getCudnnHandle(GPUContext gPUContext) throws DMLRuntimeException {
        return gPUContext.getCudnnHandle();
    }

    public static void conv2dBiasAdd(GPUContext gPUContext, String str, MatrixObject matrixObject, MatrixObject matrixObject2, MatrixObject matrixObject3, MatrixObject matrixObject4, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, double d) throws DMLRuntimeException {
        conv2d(gPUContext, str, matrixObject, matrixObject3, matrixObject4, i, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, d);
        biasAdd(gPUContext, str, matrixObject4, matrixObject2, matrixObject4);
    }

    private static Pointer denseIm2col(GPUContext gPUContext, String str, MatrixObject matrixObject, boolean z, long j, long j2, long j3, long j4, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8) throws DMLRuntimeException {
        Pointer allocate;
        long nanoTime = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
        if (z) {
            CSRPointer sparsePointer = getSparsePointer(gPUContext, matrixObject, str);
            if (sparsePointer.nnz < 0) {
                throw new DMLRuntimeException("Unknown number of nonzeroes in denseIm2col");
            }
            if (sparsePointer.nnz <= 0) {
                return null;
            }
            allocate = gPUContext.allocate(str, j2 * i * i2 * j * i7 * i8 * sizeOfDataType);
            getCudaKernels(gPUContext).launchKernel("sparse_dense_im2col", ExecutionConfig.getConfigForSimpleVectorOperations(toInt(sparsePointer.nnz)), sparsePointer.val, sparsePointer.rowPtr, sparsePointer.colInd, allocate, Long.valueOf(sparsePointer.nnz), Long.valueOf(j), Long.valueOf(j2 * j3 * j4), Long.valueOf(j3 * j4), Long.valueOf(j4), Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(i7), Integer.valueOf(i8), Integer.valueOf(i7 * i8), Integer.valueOf(i * i2), Long.valueOf(j * i7 * i8), Integer.valueOf(i5), Integer.valueOf(i6), Integer.valueOf(i3), Integer.valueOf(i4));
            if (DMLScript.FINEGRAINED_STATISTICS) {
                GPUStatistics.maintainCPMiscTimes(str, GPUInstruction.MISC_TIMER_SPARSE_IM2COL_KERNEL, System.nanoTime() - nanoTime);
            }
        } else {
            allocate = gPUContext.allocate(str, j2 * i * i2 * j * i7 * i8 * sizeOfDataType);
            getCudaKernels(gPUContext).launchKernel("dense_dense_im2col", ExecutionConfig.getConfigForSimpleVectorOperations(toInt(j * j2 * j3 * j4)), getDensePointerForCuDNN(gPUContext, matrixObject, str), allocate, Long.valueOf(j * j2 * j3 * j4), Long.valueOf(j2 * j3 * j4), Long.valueOf(j3 * j4), Long.valueOf(j4), Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(i7), Integer.valueOf(i8), Integer.valueOf(i7 * i8), Integer.valueOf(i * i2), Long.valueOf(j * i7 * i8), Integer.valueOf(i5), Integer.valueOf(i6), Integer.valueOf(i3), Integer.valueOf(i4));
            if (DMLScript.FINEGRAINED_STATISTICS) {
                GPUStatistics.maintainCPMiscTimes(str, GPUInstruction.MISC_TIMER_DENSE_IM2COL_KERNEL, System.nanoTime() - nanoTime);
            }
        }
        return allocate;
    }

    public static void conv2d(GPUContext gPUContext, String str, MatrixObject matrixObject, MatrixObject matrixObject2, MatrixObject matrixObject3, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, double d) throws DMLRuntimeException {
        long j = i2 * i3 * i4;
        long j2 = i5 * i12 * i13;
        long j3 = i2 * i6 * i7;
        long j4 = i * j;
        long j5 = i * j2;
        long j6 = i5 * j3;
        long j7 = i * i12 * i13;
        boolean isInSparseFormat = isInSparseFormat(gPUContext, matrixObject2);
        if (getNnz(gPUContext, str, matrixObject2, false) == 0) {
            return;
        }
        boolean isInSparseFormat2 = isInSparseFormat(gPUContext, matrixObject);
        if (getNnz(gPUContext, str, matrixObject, false) == 0) {
            return;
        }
        Pointer densePointerForCuDNN = getDensePointerForCuDNN(gPUContext, matrixObject3, str);
        if (j4 >= maxNumElementsOfCuDNNTensor || j5 >= maxNumElementsOfCuDNNTensor || j6 >= maxNumElementsOfCuDNNTensor) {
            throwCuDNNDimensionError(i, j, i5, j3, i, j2);
            return;
        }
        if (isInSparseFormat && OptimizerUtils.estimateSizeExactSparsity(j3, j7, 1.0d) + OptimizerUtils.estimateSizeExactSparsity(i5, j7, 1.0d) < Math.min(LibMatrixCuDNNConvolutionAlgorithm.MAX_WORKSPACE_LIMIT_BYTES, d)) {
            Pointer denseIm2col = denseIm2col(gPUContext, str, matrixObject, isInSparseFormat2, i, i2, i3, i4, i6, i7, i8, i9, i10, i11, i12, i13);
            CSRPointer jcudaSparseMatrixPtr = matrixObject2.getGPUObject(gPUContext).getJcudaSparseMatrixPtr();
            Pointer allocate = gPUContext.allocate(str, j5 * sizeOfDataType);
            LibMatrixCuMatMult.sparseDenseMatMult(gPUContext, str, allocate, jcudaSparseMatrixPtr, denseIm2col, i5, j3, j3, j7, i5, j7, false, false);
            gPUContext.cudaFreeHelper(str, denseIm2col);
            long nanoTime = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
            getCudaKernels(gPUContext).launchKernel("reorg_knpq", ExecutionConfig.getConfigForSimpleVectorOperations(toInt(j5)), allocate, densePointerForCuDNN, Long.valueOf(j5), Long.valueOf(j7), Long.valueOf(j2), Integer.valueOf(i12 * i13));
            if (DMLScript.FINEGRAINED_STATISTICS) {
                GPUStatistics.maintainCPMiscTimes(str, GPUInstruction.MISC_TIMER_DENSE_REORG_KNPQ_KERNEL, System.nanoTime() - nanoTime);
            }
            gPUContext.cudaFreeHelper(str, allocate);
            return;
        }
        double estimateSizeExactSparsity = (isInSparseFormat ? OptimizerUtils.estimateSizeExactSparsity(i5, j3, 1.0d) : 0.0d) + (isInSparseFormat2 ? OptimizerUtils.estimateSizeExactSparsity(i, j, 1.0d) : 0.0d);
        Pointer densePointerForCuDNN2 = getDensePointerForCuDNN(gPUContext, matrixObject2, str);
        long j8 = (long) (d - estimateSizeExactSparsity);
        int i14 = estimateSizeExactSparsity <= d ? i : 1;
        LibMatrixCuDNNConvolutionAlgorithm cudnnGetConvolutionForwardAlgorithm = LibMatrixCuDNNConvolutionAlgorithm.cudnnGetConvolutionForwardAlgorithm(gPUContext, str, i14, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, j8);
        Throwable th = null;
        try {
            if (i14 == i) {
                cudnnConv2d(gPUContext, str, getDensePointerForCuDNN(gPUContext, matrixObject, str), densePointerForCuDNN2, densePointerForCuDNN, cudnnGetConvolutionForwardAlgorithm);
            } else {
                LibMatrixCuDNNInputRowFetcher libMatrixCuDNNInputRowFetcher = new LibMatrixCuDNNInputRowFetcher(gPUContext, str, matrixObject);
                Throwable th2 = null;
                for (int i15 = 0; i15 < i; i15++) {
                    try {
                        try {
                            cudnnConv2d(gPUContext, str, libMatrixCuDNNInputRowFetcher.getNthRow(i15), densePointerForCuDNN2, densePointerForCuDNN.withByteOffset(i15 * j2 * sizeOfDataType), cudnnGetConvolutionForwardAlgorithm);
                        } catch (Throwable th3) {
                            th2 = th3;
                            throw th3;
                        }
                    } catch (Throwable th4) {
                        if (libMatrixCuDNNInputRowFetcher != null) {
                            if (th2 != null) {
                                try {
                                    libMatrixCuDNNInputRowFetcher.close();
                                } catch (Throwable th5) {
                                    th2.addSuppressed(th5);
                                }
                            } else {
                                libMatrixCuDNNInputRowFetcher.close();
                            }
                        }
                        throw th4;
                    }
                }
                if (libMatrixCuDNNInputRowFetcher != null) {
                    if (0 != 0) {
                        try {
                            libMatrixCuDNNInputRowFetcher.close();
                        } catch (Throwable th6) {
                            th2.addSuppressed(th6);
                        }
                    } else {
                        libMatrixCuDNNInputRowFetcher.close();
                    }
                }
            }
            if (cudnnGetConvolutionForwardAlgorithm != null) {
                if (0 == 0) {
                    cudnnGetConvolutionForwardAlgorithm.close();
                    return;
                }
                try {
                    cudnnGetConvolutionForwardAlgorithm.close();
                } catch (Throwable th7) {
                    th.addSuppressed(th7);
                }
            }
        } catch (Throwable th8) {
            if (cudnnGetConvolutionForwardAlgorithm != null) {
                if (0 != 0) {
                    try {
                        cudnnGetConvolutionForwardAlgorithm.close();
                    } catch (Throwable th9) {
                        th.addSuppressed(th9);
                    }
                } else {
                    cudnnGetConvolutionForwardAlgorithm.close();
                }
            }
            throw th8;
        }
    }

    public static void softmax(ExecutionContext executionContext, GPUContext gPUContext, String str, MatrixObject matrixObject, String str2) throws DMLRuntimeException {
        if (LOG.isTraceEnabled()) {
            LOG.trace("GPU : softmax, GPUContext=" + gPUContext);
        }
        cudnnTensorDescriptor allocateTensorDescriptor = allocateTensorDescriptor(toInt(matrixObject.getNumRows()), toInt(matrixObject.getNumColumns()), 1, 1);
        Pointer densePointerForCuDNN = getDensePointerForCuDNN(gPUContext, matrixObject, str);
        MatrixObject matrixObject2 = executionContext.getMatrixObject(str2);
        executionContext.allocateGPUMatrixObject(str2, matrixObject.getNumRows(), matrixObject.getNumColumns());
        matrixObject2.getGPUObject(gPUContext).allocateAndFillDense(0.0d);
        JCudnn.cudnnSoftmaxForward(gPUContext.getCudnnHandle(), 1, 1, one(), allocateTensorDescriptor, densePointerForCuDNN, zero(), allocateTensorDescriptor, getDensePointerForCuDNN(gPUContext, matrixObject2, str));
        JCudnn.cudnnDestroyTensorDescriptor(allocateTensorDescriptor);
    }

    private static cudnnTensorDescriptor allocateTensorDescriptor(int i, int i2, int i3, int i4) throws DMLRuntimeException {
        cudnnTensorDescriptor cudnntensordescriptor = new cudnnTensorDescriptor();
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor);
        JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor, 0, LibMatrixCUDA.CUDNN_DATA_TYPE, i, i2, i3, i4);
        return cudnntensordescriptor;
    }

    private static void throwCuDNNDimensionError(long j, long j2, long j3, long j4) throws DMLRuntimeException {
        throw new DMLRuntimeException("The dimensions of input/output matrices is too large to execute a CuDNN kernel. Max CuDNN matrix size:" + maxNumElementsOfCuDNNTensor + ". Given input matrix dimensions: [" + j + "," + j2 + "]. Output dimension:  [" + j3 + "," + j4 + "].");
    }

    private static void throwCuDNNDimensionError(long j, long j2, long j3, long j4, long j5, long j6) throws DMLRuntimeException {
        throw new DMLRuntimeException("The dimensions of input/output matrices is too large to execute a CuDNN kernel. Max CuDNN matrix size:" + maxNumElementsOfCuDNNTensor + ". Given input matrix dimensions: [" + j + "," + j2 + "], [" + j3 + "," + j4 + "]. Output dimension: [" + j5 + "," + j6 + "]");
    }

    private static void cudnnConv2d(GPUContext gPUContext, String str, Pointer pointer, Pointer pointer2, Pointer pointer3, LibMatrixCuDNNConvolutionAlgorithm libMatrixCuDNNConvolutionAlgorithm) throws DMLRuntimeException {
        if (LOG.isTraceEnabled()) {
            LOG.trace("GPU : conv2d, GPUContext=" + gPUContext);
        }
        try {
            long j = 0;
            if (DMLScript.FINEGRAINED_STATISTICS) {
                j = System.nanoTime();
            }
            int cudnnConvolutionForward = JCudnn.cudnnConvolutionForward(getCudnnHandle(gPUContext), one(), libMatrixCuDNNConvolutionAlgorithm.nchwTensorDesc, pointer, libMatrixCuDNNConvolutionAlgorithm.filterDesc, pointer2, libMatrixCuDNNConvolutionAlgorithm.convDesc, libMatrixCuDNNConvolutionAlgorithm.algo, libMatrixCuDNNConvolutionAlgorithm.workSpace, libMatrixCuDNNConvolutionAlgorithm.sizeInBytes, zero(), libMatrixCuDNNConvolutionAlgorithm.nkpqTensorDesc, pointer3);
            if (DMLScript.FINEGRAINED_STATISTICS) {
                GPUStatistics.maintainCPMiscTimes(str, GPUInstruction.MISC_TIMER_CONVOLUTION_FORWARD_LIB, System.nanoTime() - j);
            }
            if (cudnnConvolutionForward != 0) {
                throw new DMLRuntimeException("Could not executed cudnnConvolutionForward: " + cudnnStatus.stringFor(cudnnConvolutionForward));
            }
        } catch (CudaException e) {
            throw new DMLRuntimeException("Error in conv2d in GPUContext " + gPUContext.toString() + " from Thread " + Thread.currentThread().toString(), e);
        }
    }

    /* JADX WARN: Finally extract failed */
    public static void conv2dBackwardFilter(GPUContext gPUContext, String str, MatrixObject matrixObject, MatrixObject matrixObject2, MatrixObject matrixObject3, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, double d) throws DMLRuntimeException {
        long j = i2 * i3 * i4;
        long j2 = i5 * i12 * i13;
        long j3 = i2 * i6 * i7;
        long j4 = i * j;
        long j5 = i * j2;
        long j6 = i5 * j3;
        boolean isInSparseFormat = isInSparseFormat(gPUContext, matrixObject2);
        if (getNnz(gPUContext, str, matrixObject2, false) == 0) {
            return;
        }
        boolean isInSparseFormat2 = isInSparseFormat(gPUContext, matrixObject);
        if (getNnz(gPUContext, str, matrixObject, false) == 0) {
            return;
        }
        if (j4 >= maxNumElementsOfCuDNNTensor || j5 >= maxNumElementsOfCuDNNTensor || j6 >= maxNumElementsOfCuDNNTensor) {
            throwCuDNNDimensionError(i, j, i, j2, i5, j3);
            return;
        }
        Pointer densePointerForCuDNN = getDensePointerForCuDNN(gPUContext, matrixObject3, str);
        double estimateSizeExactSparsity = (isInSparseFormat2 ? OptimizerUtils.estimateSizeExactSparsity(i, j, 1.0d) : 0.0d) + (isInSparseFormat ? OptimizerUtils.estimateSizeExactSparsity(i, j2, 1.0d) : 0.0d);
        long j7 = (long) (d - estimateSizeExactSparsity);
        int i14 = estimateSizeExactSparsity <= d ? i : 1;
        LibMatrixCuDNNConvolutionAlgorithm cudnnGetConvolutionBackwardFilterAlgorithm = LibMatrixCuDNNConvolutionAlgorithm.cudnnGetConvolutionBackwardFilterAlgorithm(gPUContext, str, i14, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, j7);
        Throwable th = null;
        try {
            if (i14 == i) {
                cudnnConv2dBackwardFilter(gPUContext, str, getDensePointerForCuDNN(gPUContext, matrixObject, str), getDensePointerForCuDNN(gPUContext, matrixObject2, str), densePointerForCuDNN, cudnnGetConvolutionBackwardFilterAlgorithm);
            } else {
                LibMatrixCuDNNInputRowFetcher libMatrixCuDNNInputRowFetcher = new LibMatrixCuDNNInputRowFetcher(gPUContext, str, matrixObject);
                Throwable th2 = null;
                try {
                    LibMatrixCuDNNInputRowFetcher libMatrixCuDNNInputRowFetcher2 = new LibMatrixCuDNNInputRowFetcher(gPUContext, str, matrixObject2);
                    Throwable th3 = null;
                    try {
                        try {
                            Pointer allocate = gPUContext.allocate(j6 * sizeOfDataType);
                            for (int i15 = 0; i15 < i; i15++) {
                                long nanoTime = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
                                JCuda.cudaMemset(allocate, 0, j6 * sizeOfDataType);
                                if (DMLScript.FINEGRAINED_STATISTICS) {
                                    GPUStatistics.maintainCPMiscTimes(str, GPUInstruction.MISC_TIMER_SET_ZERO, System.nanoTime() - nanoTime);
                                }
                                cudnnConv2dBackwardFilter(gPUContext, str, libMatrixCuDNNInputRowFetcher.getNthRow(i15), libMatrixCuDNNInputRowFetcher2.getNthRow(i15), allocate, cudnnGetConvolutionBackwardFilterAlgorithm);
                                getCudaKernels(gPUContext).launchKernel("inplace_add", ExecutionConfig.getConfigForSimpleMatrixOperations(i5, toInt(j3)), allocate, densePointerForCuDNN, Integer.valueOf(i5), Integer.valueOf(toInt(j3)));
                            }
                            gPUContext.cudaFreeHelper(allocate, true);
                            if (libMatrixCuDNNInputRowFetcher2 != null) {
                                if (0 != 0) {
                                    try {
                                        libMatrixCuDNNInputRowFetcher2.close();
                                    } catch (Throwable th4) {
                                        th3.addSuppressed(th4);
                                    }
                                } else {
                                    libMatrixCuDNNInputRowFetcher2.close();
                                }
                            }
                            if (libMatrixCuDNNInputRowFetcher != null) {
                                if (0 != 0) {
                                    try {
                                        libMatrixCuDNNInputRowFetcher.close();
                                    } catch (Throwable th5) {
                                        th2.addSuppressed(th5);
                                    }
                                } else {
                                    libMatrixCuDNNInputRowFetcher.close();
                                }
                            }
                        } catch (Throwable th6) {
                            th3 = th6;
                            throw th6;
                        }
                    } catch (Throwable th7) {
                        if (libMatrixCuDNNInputRowFetcher2 != null) {
                            if (th3 != null) {
                                try {
                                    libMatrixCuDNNInputRowFetcher2.close();
                                } catch (Throwable th8) {
                                    th3.addSuppressed(th8);
                                }
                            } else {
                                libMatrixCuDNNInputRowFetcher2.close();
                            }
                        }
                        throw th7;
                    }
                } catch (Throwable th9) {
                    if (libMatrixCuDNNInputRowFetcher != null) {
                        if (0 != 0) {
                            try {
                                libMatrixCuDNNInputRowFetcher.close();
                            } catch (Throwable th10) {
                                th2.addSuppressed(th10);
                            }
                        } else {
                            libMatrixCuDNNInputRowFetcher.close();
                        }
                    }
                    throw th9;
                }
            }
            if (cudnnGetConvolutionBackwardFilterAlgorithm != null) {
                if (0 == 0) {
                    cudnnGetConvolutionBackwardFilterAlgorithm.close();
                    return;
                }
                try {
                    cudnnGetConvolutionBackwardFilterAlgorithm.close();
                } catch (Throwable th11) {
                    th.addSuppressed(th11);
                }
            }
        } catch (Throwable th12) {
            if (cudnnGetConvolutionBackwardFilterAlgorithm != null) {
                if (0 != 0) {
                    try {
                        cudnnGetConvolutionBackwardFilterAlgorithm.close();
                    } catch (Throwable th13) {
                        th.addSuppressed(th13);
                    }
                } else {
                    cudnnGetConvolutionBackwardFilterAlgorithm.close();
                }
            }
            throw th12;
        }
    }

    private static void cudnnConv2dBackwardFilter(GPUContext gPUContext, String str, Pointer pointer, Pointer pointer2, Pointer pointer3, LibMatrixCuDNNConvolutionAlgorithm libMatrixCuDNNConvolutionAlgorithm) throws DMLRuntimeException {
        if (LOG.isTraceEnabled()) {
            LOG.trace("GPU : conv2dBackwardFilter, GPUContext=" + gPUContext);
        }
        try {
            long nanoTime = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
            int cudnnConvolutionBackwardFilter = JCudnn.cudnnConvolutionBackwardFilter(getCudnnHandle(gPUContext), one(), libMatrixCuDNNConvolutionAlgorithm.nchwTensorDesc, pointer, libMatrixCuDNNConvolutionAlgorithm.nkpqTensorDesc, pointer2, libMatrixCuDNNConvolutionAlgorithm.convDesc, libMatrixCuDNNConvolutionAlgorithm.algo, libMatrixCuDNNConvolutionAlgorithm.workSpace, libMatrixCuDNNConvolutionAlgorithm.sizeInBytes, zero(), libMatrixCuDNNConvolutionAlgorithm.filterDesc, pointer3);
            if (DMLScript.FINEGRAINED_STATISTICS) {
                GPUStatistics.maintainCPMiscTimes(str, GPUInstruction.MISC_TIMER_CONVOLUTION_BACKWARD_FILTER_LIB, System.nanoTime() - nanoTime);
            }
            if (cudnnConvolutionBackwardFilter != 0) {
                throw new DMLRuntimeException("Could not executed cudnnConvolutionBackwardFilter: " + cudnnStatus.stringFor(cudnnConvolutionBackwardFilter));
            }
        } catch (CudaException e) {
            throw new DMLRuntimeException("Error in conv2d in GPUContext " + gPUContext.toString() + " from Thread " + Thread.currentThread().toString(), e);
        }
    }

    public static void conv2dBackwardData(GPUContext gPUContext, String str, MatrixObject matrixObject, MatrixObject matrixObject2, MatrixObject matrixObject3, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, double d) throws DMLRuntimeException {
        long j = i2 * i3 * i4;
        long j2 = i5 * i12 * i13;
        long j3 = i2 * i6 * i7;
        long j4 = i * j;
        long j5 = i * j2;
        long j6 = i5 * j3;
        boolean isInSparseFormat = isInSparseFormat(gPUContext, matrixObject);
        if (getNnz(gPUContext, str, matrixObject, false) == 0) {
            return;
        }
        boolean isInSparseFormat2 = isInSparseFormat(gPUContext, matrixObject2);
        if (getNnz(gPUContext, str, matrixObject2, false) == 0) {
            return;
        }
        if (j4 >= maxNumElementsOfCuDNNTensor || j5 >= maxNumElementsOfCuDNNTensor || j6 >= maxNumElementsOfCuDNNTensor) {
            throwCuDNNDimensionError(i, j, i, j2, i5, j3);
            return;
        }
        double estimateSizeExactSparsity = (isInSparseFormat ? OptimizerUtils.estimateSizeExactSparsity(i5, j3, 1.0d) : 0.0d) + (isInSparseFormat2 ? OptimizerUtils.estimateSizeExactSparsity(i, j2, 1.0d) : 0.0d);
        Pointer densePointerForCuDNN = getDensePointerForCuDNN(gPUContext, matrixObject, str);
        Pointer densePointerForCuDNN2 = getDensePointerForCuDNN(gPUContext, matrixObject3, str);
        long j7 = (long) (d - estimateSizeExactSparsity);
        int i14 = estimateSizeExactSparsity <= d ? i : 1;
        LibMatrixCuDNNConvolutionAlgorithm cudnnGetConvolutionBackwardDataAlgorithm = LibMatrixCuDNNConvolutionAlgorithm.cudnnGetConvolutionBackwardDataAlgorithm(gPUContext, str, i14, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, j7);
        Throwable th = null;
        try {
            if (i14 == i) {
                cudnnConv2dBackwardData(gPUContext, str, densePointerForCuDNN, getDensePointerForCuDNN(gPUContext, matrixObject2, str), densePointerForCuDNN2, cudnnGetConvolutionBackwardDataAlgorithm);
            } else {
                LibMatrixCuDNNInputRowFetcher libMatrixCuDNNInputRowFetcher = new LibMatrixCuDNNInputRowFetcher(gPUContext, str, matrixObject2);
                Throwable th2 = null;
                for (int i15 = 0; i15 < i; i15++) {
                    try {
                        try {
                            cudnnConv2dBackwardData(gPUContext, str, libMatrixCuDNNInputRowFetcher.getNthRow(i15), densePointerForCuDNN, densePointerForCuDNN2.withByteOffset(i15 * j * sizeOfDataType), cudnnGetConvolutionBackwardDataAlgorithm);
                        } catch (Throwable th3) {
                            th2 = th3;
                            throw th3;
                        }
                    } catch (Throwable th4) {
                        if (libMatrixCuDNNInputRowFetcher != null) {
                            if (th2 != null) {
                                try {
                                    libMatrixCuDNNInputRowFetcher.close();
                                } catch (Throwable th5) {
                                    th2.addSuppressed(th5);
                                }
                            } else {
                                libMatrixCuDNNInputRowFetcher.close();
                            }
                        }
                        throw th4;
                    }
                }
                if (libMatrixCuDNNInputRowFetcher != null) {
                    if (0 != 0) {
                        try {
                            libMatrixCuDNNInputRowFetcher.close();
                        } catch (Throwable th6) {
                            th2.addSuppressed(th6);
                        }
                    } else {
                        libMatrixCuDNNInputRowFetcher.close();
                    }
                }
            }
            if (cudnnGetConvolutionBackwardDataAlgorithm != null) {
                if (0 == 0) {
                    cudnnGetConvolutionBackwardDataAlgorithm.close();
                    return;
                }
                try {
                    cudnnGetConvolutionBackwardDataAlgorithm.close();
                } catch (Throwable th7) {
                    th.addSuppressed(th7);
                }
            }
        } catch (Throwable th8) {
            if (cudnnGetConvolutionBackwardDataAlgorithm != null) {
                if (0 != 0) {
                    try {
                        cudnnGetConvolutionBackwardDataAlgorithm.close();
                    } catch (Throwable th9) {
                        th.addSuppressed(th9);
                    }
                } else {
                    cudnnGetConvolutionBackwardDataAlgorithm.close();
                }
            }
            throw th8;
        }
    }

    private static void cudnnConv2dBackwardData(GPUContext gPUContext, String str, Pointer pointer, Pointer pointer2, Pointer pointer3, LibMatrixCuDNNConvolutionAlgorithm libMatrixCuDNNConvolutionAlgorithm) throws DMLRuntimeException {
        if (LOG.isTraceEnabled()) {
            LOG.trace("GPU : conv2dBackwardData, GPUContext=" + gPUContext);
        }
        try {
            long nanoTime = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
            int cudnnConvolutionBackwardData = JCudnn.cudnnConvolutionBackwardData(getCudnnHandle(gPUContext), one(), libMatrixCuDNNConvolutionAlgorithm.filterDesc, pointer, libMatrixCuDNNConvolutionAlgorithm.nkpqTensorDesc, pointer2, libMatrixCuDNNConvolutionAlgorithm.convDesc, libMatrixCuDNNConvolutionAlgorithm.algo, libMatrixCuDNNConvolutionAlgorithm.workSpace, libMatrixCuDNNConvolutionAlgorithm.sizeInBytes, zero(), libMatrixCuDNNConvolutionAlgorithm.nchwTensorDesc, pointer3);
            if (DMLScript.FINEGRAINED_STATISTICS) {
                GPUStatistics.maintainCPMiscTimes(str, GPUInstruction.MISC_TIMER_CONVOLUTION_BACKWARD_DATA_LIB, System.nanoTime() - nanoTime);
            }
            if (cudnnConvolutionBackwardData != 0) {
                throw new DMLRuntimeException("Could not executed cudnnConvolutionBackwardData: " + cudnnStatus.stringFor(cudnnConvolutionBackwardData));
            }
        } catch (CudaException e) {
            throw new DMLRuntimeException("Error in conv2d in GPUContext " + gPUContext.toString() + " from Thread " + Thread.currentThread().toString(), e);
        }
    }

    public static void pooling(GPUContext gPUContext, String str, MatrixObject matrixObject, MatrixObject matrixObject2, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, LibMatrixDNN.PoolingType poolingType, double d) throws DMLRuntimeException {
        long j = i2 * i3 * i4;
        long j2 = i2 * i12 * i13;
        long j3 = i * j;
        long j4 = i * j2;
        if (j3 >= maxNumElementsOfCuDNNTensor || j4 >= maxNumElementsOfCuDNNTensor) {
            throwCuDNNDimensionError(i, j, i, j2);
            return;
        }
        long estimateSizeExactSparsity = isInSparseFormat(gPUContext, matrixObject) ? OptimizerUtils.estimateSizeExactSparsity(i, j, 1.0d) : 0L;
        Pointer densePointerForCuDNN = getDensePointerForCuDNN(gPUContext, matrixObject2, str);
        if (estimateSizeExactSparsity <= d) {
            cudnnPoolingHelper(gPUContext, str, getDensePointerForCuDNN(gPUContext, matrixObject, str), densePointerForCuDNN, i, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, poolingType);
            return;
        }
        LibMatrixCuDNNInputRowFetcher libMatrixCuDNNInputRowFetcher = new LibMatrixCuDNNInputRowFetcher(gPUContext, str, matrixObject);
        for (int i14 = 0; i14 < i; i14++) {
            cudnnPoolingHelper(gPUContext, str, libMatrixCuDNNInputRowFetcher.getNthRow(i14), densePointerForCuDNN.withByteOffset(i14 * j2 * sizeOfDataType), 1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, poolingType);
        }
        libMatrixCuDNNInputRowFetcher.close();
    }

    private static void cudnnPoolingHelper(GPUContext gPUContext, String str, Pointer pointer, Pointer pointer2, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, LibMatrixDNN.PoolingType poolingType) throws DMLRuntimeException {
        if (LOG.isTraceEnabled()) {
            LOG.trace("GPU : perform pooling, GPUContext=" + gPUContext);
        }
        try {
            LibMatrixCuDNNPoolingDescriptors cudnnPoolingDescriptors = LibMatrixCuDNNPoolingDescriptors.cudnnPoolingDescriptors(gPUContext, str, i, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, poolingType);
            Throwable th = null;
            try {
                try {
                    long j = 0;
                    long j2 = 0;
                    if (DMLScript.FINEGRAINED_STATISTICS) {
                        j = System.nanoTime();
                    }
                    if (DMLScript.FINEGRAINED_STATISTICS) {
                        GPUStatistics.maintainCPMiscTimes(str, GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - j);
                    }
                    if (DMLScript.FINEGRAINED_STATISTICS) {
                        j2 = System.nanoTime();
                    }
                    int cudnnPoolingForward = JCudnn.cudnnPoolingForward(getCudnnHandle(gPUContext), cudnnPoolingDescriptors.poolingDesc, one(), cudnnPoolingDescriptors.xDesc, pointer, zero(), cudnnPoolingDescriptors.yDesc, pointer2);
                    if (DMLScript.FINEGRAINED_STATISTICS) {
                        GPUStatistics.maintainCPMiscTimes(str, GPUInstruction.MISC_TIMER_MAXPOOLING_FORWARD_LIB, System.nanoTime() - j2);
                    }
                    if (cudnnPoolingForward != 0) {
                        throw new DMLRuntimeException("Could not executed cudnnPoolingForward: " + cudnnStatus.stringFor(cudnnPoolingForward));
                    }
                    if (cudnnPoolingDescriptors != null) {
                        if (0 != 0) {
                            try {
                                cudnnPoolingDescriptors.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            cudnnPoolingDescriptors.close();
                        }
                    }
                } finally {
                }
            } finally {
            }
        } catch (CudaException e) {
            throw new DMLRuntimeException("Error in conv2d in GPUContext " + gPUContext.toString() + " from Thread " + Thread.currentThread().toString(), e);
        }
    }

    public static void poolingBackward(GPUContext gPUContext, String str, MatrixObject matrixObject, MatrixObject matrixObject2, MatrixObject matrixObject3, MatrixObject matrixObject4, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, LibMatrixDNN.PoolingType poolingType, double d) throws DMLRuntimeException {
        long j = i2 * i3 * i4;
        long j2 = i2 * i12 * i13;
        long j3 = i * j;
        long j4 = i * j2;
        boolean z = matrixObject3 != null;
        if (j3 >= maxNumElementsOfCuDNNTensor || j4 >= maxNumElementsOfCuDNNTensor) {
            throwCuDNNDimensionError(i, j, i, j2);
            return;
        }
        long estimateSizeExactSparsity = isInSparseFormat(gPUContext, matrixObject) ? OptimizerUtils.estimateSizeExactSparsity(i, j, 1.0d) : 0L;
        long estimateSizeExactSparsity2 = isInSparseFormat(gPUContext, matrixObject2) ? OptimizerUtils.estimateSizeExactSparsity(i, j2, 1.0d) : 0L;
        Pointer densePointerForCuDNN = getDensePointerForCuDNN(gPUContext, matrixObject4, str);
        if (estimateSizeExactSparsity + estimateSizeExactSparsity2 <= d) {
            cudnnPoolingBackwardHelper(gPUContext, str, getDensePointerForCuDNN(gPUContext, matrixObject, str), getDensePointerForCuDNN(gPUContext, matrixObject2, str), z ? getDensePointerForCuDNN(gPUContext, matrixObject3, str) : null, densePointerForCuDNN, i, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, poolingType);
            return;
        }
        LibMatrixCuDNNInputRowFetcher libMatrixCuDNNInputRowFetcher = new LibMatrixCuDNNInputRowFetcher(gPUContext, str, matrixObject);
        LibMatrixCuDNNInputRowFetcher libMatrixCuDNNInputRowFetcher2 = new LibMatrixCuDNNInputRowFetcher(gPUContext, str, matrixObject2);
        LibMatrixCuDNNInputRowFetcher libMatrixCuDNNInputRowFetcher3 = z ? new LibMatrixCuDNNInputRowFetcher(gPUContext, str, matrixObject3) : null;
        for (int i14 = 0; i14 < i; i14++) {
            cudnnPoolingBackwardHelper(gPUContext, str, libMatrixCuDNNInputRowFetcher.getNthRow(i14), libMatrixCuDNNInputRowFetcher2.getNthRow(i14), z ? libMatrixCuDNNInputRowFetcher3.getNthRow(i14) : null, densePointerForCuDNN.withByteOffset(i14 * j * sizeOfDataType), 1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, poolingType);
        }
        libMatrixCuDNNInputRowFetcher.close();
        libMatrixCuDNNInputRowFetcher2.close();
        if (z) {
            libMatrixCuDNNInputRowFetcher3.close();
        }
    }

    /* JADX WARN: Failed to calculate best type for var: r38v1 ??
    java.lang.NullPointerException
     */
    /* JADX WARN: Failed to calculate best type for var: r39v0 ??
    java.lang.NullPointerException
     */
    /* JADX WARN: Multi-variable type inference failed. Error: java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.RegisterArg.getSVar()" because the return value of "jadx.core.dex.nodes.InsnNode.getResult()" is null
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.collectRelatedVars(AbstractTypeConstraint.java:31)
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.<init>(AbstractTypeConstraint.java:19)
    	at jadx.core.dex.visitors.typeinference.TypeSearch$1.<init>(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeMoveConstraint(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeConstraint(TypeSearch.java:361)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.collectConstraints(TypeSearch.java:341)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.run(TypeSearch.java:60)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.runMultiVariableSearch(FixTypesVisitor.java:116)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Not initialized variable reg: 38, insn: 0x019f: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r38 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) A[TRY_LEAVE], block:B:79:0x019f */
    /* JADX WARN: Not initialized variable reg: 39, insn: 0x01a4: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r39 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]), block:B:81:0x01a4 */
    /* JADX WARN: Type inference failed for: r38v1, types: [org.apache.sysml.runtime.matrix.data.LibMatrixCuDNNPoolingDescriptors] */
    /* JADX WARN: Type inference failed for: r39v0, types: [java.lang.Throwable] */
    private static void cudnnPoolingBackwardHelper(GPUContext gPUContext, String str, Pointer pointer, Pointer pointer2, Pointer pointer3, Pointer pointer4, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, LibMatrixDNN.PoolingType poolingType) throws DMLRuntimeException {
        ?? r38;
        ?? r39;
        if (LOG.isTraceEnabled()) {
            LOG.trace("GPU : maxpoolingBackward, GPUContext=" + gPUContext);
        }
        boolean z = pointer3 != null;
        try {
            try {
                try {
                    LibMatrixCuDNNPoolingDescriptors cudnnPoolingBackwardDescriptors = LibMatrixCuDNNPoolingDescriptors.cudnnPoolingBackwardDescriptors(gPUContext, str, i, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, poolingType);
                    Throwable th = null;
                    long j = 0;
                    long j2 = 0;
                    long j3 = 0;
                    if (!z) {
                        if (DMLScript.FINEGRAINED_STATISTICS) {
                            j = System.nanoTime();
                        }
                        pointer3 = gPUContext.allocate(i * i2 * i12 * i13 * sizeOfDataType);
                        if (DMLScript.FINEGRAINED_STATISTICS) {
                            GPUStatistics.maintainCPMiscTimes(str, GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - j);
                        }
                        if (DMLScript.FINEGRAINED_STATISTICS) {
                            j2 = System.nanoTime();
                        }
                        int cudnnPoolingForward = JCudnn.cudnnPoolingForward(getCudnnHandle(gPUContext), cudnnPoolingBackwardDescriptors.poolingDesc, one(), cudnnPoolingBackwardDescriptors.xDesc, pointer, zero(), cudnnPoolingBackwardDescriptors.yDesc, pointer3);
                        if (DMLScript.FINEGRAINED_STATISTICS) {
                            GPUStatistics.maintainCPMiscTimes(str, GPUInstruction.MISC_TIMER_MAXPOOLING_FORWARD_LIB, System.nanoTime() - j2);
                        }
                        if (cudnnPoolingForward != 0) {
                            throw new DMLRuntimeException("Could not executed cudnnPoolingForward before cudnnPoolingBackward: " + cudnnStatus.stringFor(cudnnPoolingForward));
                        }
                    }
                    if (DMLScript.FINEGRAINED_STATISTICS) {
                        j3 = System.nanoTime();
                    }
                    int cudnnPoolingBackward = JCudnn.cudnnPoolingBackward(getCudnnHandle(gPUContext), cudnnPoolingBackwardDescriptors.poolingDesc, one(), cudnnPoolingBackwardDescriptors.yDesc, pointer3, cudnnPoolingBackwardDescriptors.dyDesc, pointer2, cudnnPoolingBackwardDescriptors.xDesc, pointer, zero(), cudnnPoolingBackwardDescriptors.dxDesc, pointer4);
                    if (DMLScript.FINEGRAINED_STATISTICS) {
                        GPUStatistics.maintainCPMiscTimes(str, GPUInstruction.MISC_TIMER_MAXPOOLING_BACKWARD_LIB, System.nanoTime() - j3);
                    }
                    if (cudnnPoolingBackward != 0) {
                        throw new DMLRuntimeException("Could not executed cudnnPoolingBackward: " + cudnnStatus.stringFor(cudnnPoolingBackward));
                    }
                    if (cudnnPoolingBackwardDescriptors != null) {
                        if (0 != 0) {
                            try {
                                cudnnPoolingBackwardDescriptors.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            cudnnPoolingBackwardDescriptors.close();
                        }
                    }
                    long j4 = 0;
                    if (DMLScript.FINEGRAINED_STATISTICS) {
                        j4 = System.nanoTime();
                    }
                    if (!z) {
                        gPUContext.cudaFreeHelper(str, pointer3);
                    }
                    if (DMLScript.FINEGRAINED_STATISTICS) {
                        GPUStatistics.maintainCPMiscTimes(str, GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - j4);
                    }
                } catch (Throwable th3) {
                    long j5 = 0;
                    if (DMLScript.FINEGRAINED_STATISTICS) {
                        j5 = System.nanoTime();
                    }
                    if (!z) {
                        gPUContext.cudaFreeHelper(str, pointer3);
                    }
                    if (DMLScript.FINEGRAINED_STATISTICS) {
                        GPUStatistics.maintainCPMiscTimes(str, GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - j5);
                    }
                    throw th3;
                }
            } catch (Throwable th4) {
                if (r38 != 0) {
                    if (r39 != 0) {
                        try {
                            r38.close();
                        } catch (Throwable th5) {
                            r39.addSuppressed(th5);
                        }
                    } else {
                        r38.close();
                    }
                }
                throw th4;
            }
        } catch (CudaException e) {
            throw new DMLRuntimeException("Error in conv2d in GPUContext " + gPUContext.toString() + " from Thread " + Thread.currentThread().toString(), e);
        }
    }

    private static void cudnnReLU(GPUContext gPUContext, String str, MatrixObject matrixObject, Pointer pointer, cudnnTensorDescriptor cudnntensordescriptor) throws DMLRuntimeException {
        long j = 0;
        try {
            try {
                if (LOG.isTraceEnabled()) {
                    LOG.trace("GPU : performCuDNNReLU, GPUContext=" + gPUContext);
                }
                Pointer densePointerForCuDNN = getDensePointerForCuDNN(gPUContext, matrixObject, str);
                cudnnActivationDescriptor cudnnactivationdescriptor = new cudnnActivationDescriptor();
                JCudnn.cudnnCreateActivationDescriptor(cudnnactivationdescriptor);
                JCudnn.cudnnSetActivationDescriptor(cudnnactivationdescriptor, 1, 1, -1.0d);
                if (DMLScript.FINEGRAINED_STATISTICS) {
                    j = System.nanoTime();
                }
                JCudnn.cudnnActivationForward(getCudnnHandle(gPUContext), cudnnactivationdescriptor, one(), cudnntensordescriptor, densePointerForCuDNN, zero(), cudnntensordescriptor, pointer);
                if (DMLScript.FINEGRAINED_STATISTICS) {
                    GPUStatistics.maintainCPMiscTimes(str, GPUInstruction.MISC_TIMER_ACTIVATION_FORWARD_LIB, System.nanoTime() - j);
                }
            } catch (CudaException e) {
                throw new DMLRuntimeException("Error in conv2d in GPUContext " + gPUContext.toString() + " from Thread " + Thread.currentThread().toString(), e);
            }
        } finally {
            long j2 = 0;
            if (DMLScript.FINEGRAINED_STATISTICS) {
                j2 = System.nanoTime();
            }
            if (DMLScript.FINEGRAINED_STATISTICS) {
                GPUStatistics.maintainCPMiscTimes(str, GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - j2);
            }
        }
    }

    public static void relu(ExecutionContext executionContext, GPUContext gPUContext, String str, MatrixObject matrixObject, String str2) throws DMLRuntimeException {
        if (executionContext.getGPUContext(0) != gPUContext) {
            throw new DMLRuntimeException("GPU : Invalid internal state, the GPUContext set with the ExecutionContext is not the same used to run this LibMatrixCUDA function");
        }
        long numRows = matrixObject.getNumRows();
        long numColumns = matrixObject.getNumColumns();
        MatrixObject matrixObject2 = executionContext.getMatrixObject(str2);
        getDenseMatrixOutputForGPUInstruction(executionContext, str, str2, matrixObject.getNumRows(), matrixObject.getNumColumns());
        long j = 0;
        if (numRows * numColumns < maxNumElementsOfCuDNNTensor) {
            cudnnTensorDescriptor cudnntensordescriptor = new cudnnTensorDescriptor();
            JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor);
            JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor, 0, CUDNN_DATA_TYPE, toInt(numRows), 1, 1, toInt(numColumns));
            cudnnReLU(gPUContext, str, matrixObject, getDensePointerForCuDNN(gPUContext, matrixObject2, str), cudnntensordescriptor);
            JCudnn.cudnnDestroyTensorDescriptor(cudnntensordescriptor);
            return;
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace("GPU : relu custom kernel, GPUContext=" + gPUContext);
        }
        if (DMLScript.FINEGRAINED_STATISTICS) {
            j = System.nanoTime();
        }
        getCudaKernels(gPUContext).launchKernel("relu", ExecutionConfig.getConfigForSimpleMatrixOperations(toInt(numRows), toInt(numColumns)), getDensePointerForCuDNN(gPUContext, matrixObject, str), getDensePointerForCuDNN(gPUContext, matrixObject2, str), Integer.valueOf(toInt(numRows)), Integer.valueOf(toInt(numColumns)));
        if (DMLScript.FINEGRAINED_STATISTICS) {
            GPUStatistics.maintainCPMiscTimes(str, GPUInstruction.MISC_TIMER_RELU_KERNEL, System.nanoTime() - j);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static Pointer getDensePointerForCuDNN(GPUContext gPUContext, MatrixObject matrixObject, String str) throws DMLRuntimeException {
        long numRows = matrixObject.getNumRows() * matrixObject.getNumColumns();
        if (numRows > maxNumElementsOfCuDNNTensor) {
            throw new DMLRuntimeException("CuDNN restriction: the size of input tensor cannot have greater than 2 giga-elements, but has " + numRows + " (i.e. [" + matrixObject.getNumRows() + " X " + matrixObject.getNumColumns() + "]). Hint: try reducing the mini-batch size.");
        }
        return getDensePointer(gPUContext, matrixObject, str);
    }

    protected static void checkStatus(int i) throws DMLRuntimeException {
        if (i != 0) {
            throw new DMLRuntimeException("Error status returned by CuDNN:" + cudnnStatus.stringFor(i));
        }
    }
}
