package org.apache.sysml.runtime.matrix.data;

import jcuda.Pointer;
import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnConvolutionDescriptor;
import jcuda.jcudnn.cudnnFilterDescriptor;
import jcuda.jcudnn.cudnnTensorDescriptor;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysml.utils.GPUStatistics;

/* loaded from: input_file:org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.class */
public class LibMatrixCuDNNConvolutionAlgorithm implements AutoCloseable {
    static long MAX_WORKSPACE_LIMIT_BYTES = 1000000000;
    public int algo = -1;
    public Pointer workSpace = new Pointer();
    public long sizeInBytes = 0;
    cudnnTensorDescriptor nchwTensorDesc;
    cudnnTensorDescriptor nkpqTensorDesc;
    cudnnFilterDescriptor filterDesc;
    cudnnConvolutionDescriptor convDesc;
    GPUContext gCtx;
    String instName;

    private LibMatrixCuDNNConvolutionAlgorithm(GPUContext gPUContext, String str, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13) throws DMLRuntimeException {
        this.nchwTensorDesc = null;
        this.nkpqTensorDesc = null;
        this.filterDesc = null;
        this.convDesc = null;
        this.gCtx = null;
        this.instName = null;
        this.convDesc = allocateConvolutionDescriptor(new int[]{i8, i9}, new int[]{i10, i11});
        this.gCtx = gPUContext;
        this.instName = str;
        this.nchwTensorDesc = allocateTensorDescriptor(i, i2, i3, i4);
        this.nkpqTensorDesc = allocateTensorDescriptor(i, i5, i12, i13);
        this.filterDesc = allocateFilterDescriptor(i5, i2, i6, i7);
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        long j = 0;
        if (DMLScript.FINEGRAINED_STATISTICS) {
            j = System.nanoTime();
        }
        if (this.nchwTensorDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor(this.nchwTensorDesc);
        }
        if (this.nkpqTensorDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor(this.nkpqTensorDesc);
        }
        if (this.filterDesc != null) {
            JCudnn.cudnnDestroyFilterDescriptor(this.filterDesc);
        }
        if (this.convDesc != null) {
            JCudnn.cudnnDestroyConvolutionDescriptor(this.convDesc);
        }
        if (this.sizeInBytes != 0) {
            try {
                this.gCtx.cudaFreeHelper(this.instName, this.workSpace);
            } catch (DMLRuntimeException e) {
                throw new RuntimeException(e);
            }
        }
        if (DMLScript.FINEGRAINED_STATISTICS) {
            GPUStatistics.maintainCPMiscTimes(this.instName, GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - j);
        }
    }

    public static LibMatrixCuDNNConvolutionAlgorithm cudnnGetConvolutionForwardAlgorithm(GPUContext gPUContext, String str, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, long j) throws DMLRuntimeException {
        long nanoTime = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
        LibMatrixCuDNNConvolutionAlgorithm libMatrixCuDNNConvolutionAlgorithm = new LibMatrixCuDNNConvolutionAlgorithm(gPUContext, str, i, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13);
        int[] iArr = {-1};
        long[] jArr = {Math.min(j, MAX_WORKSPACE_LIMIT_BYTES)};
        JCudnn.cudnnGetConvolutionForwardAlgorithm(LibMatrixCuDNN.getCudnnHandle(gPUContext), libMatrixCuDNNConvolutionAlgorithm.nchwTensorDesc, libMatrixCuDNNConvolutionAlgorithm.filterDesc, libMatrixCuDNNConvolutionAlgorithm.convDesc, libMatrixCuDNNConvolutionAlgorithm.nkpqTensorDesc, 2, jArr[0], iArr);
        JCudnn.cudnnGetConvolutionForwardWorkspaceSize(LibMatrixCuDNN.getCudnnHandle(gPUContext), libMatrixCuDNNConvolutionAlgorithm.nchwTensorDesc, libMatrixCuDNNConvolutionAlgorithm.filterDesc, libMatrixCuDNNConvolutionAlgorithm.convDesc, libMatrixCuDNNConvolutionAlgorithm.nkpqTensorDesc, iArr[0], jArr);
        if (jArr[0] != 0) {
            libMatrixCuDNNConvolutionAlgorithm.workSpace = gPUContext.allocate(jArr[0]);
        }
        libMatrixCuDNNConvolutionAlgorithm.sizeInBytes = jArr[0];
        libMatrixCuDNNConvolutionAlgorithm.algo = iArr[0];
        if (DMLScript.FINEGRAINED_STATISTICS) {
            GPUStatistics.maintainCPMiscTimes(str, GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - nanoTime);
        }
        return libMatrixCuDNNConvolutionAlgorithm;
    }

    public static LibMatrixCuDNNConvolutionAlgorithm cudnnGetConvolutionBackwardFilterAlgorithm(GPUContext gPUContext, String str, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, long j) throws DMLRuntimeException {
        long nanoTime = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
        LibMatrixCuDNNConvolutionAlgorithm libMatrixCuDNNConvolutionAlgorithm = new LibMatrixCuDNNConvolutionAlgorithm(gPUContext, str, i, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13);
        int[] iArr = {-1};
        long[] jArr = {Math.min(j, MAX_WORKSPACE_LIMIT_BYTES)};
        JCudnn.cudnnGetConvolutionBackwardFilterAlgorithm(LibMatrixCuDNN.getCudnnHandle(gPUContext), libMatrixCuDNNConvolutionAlgorithm.nchwTensorDesc, libMatrixCuDNNConvolutionAlgorithm.nkpqTensorDesc, libMatrixCuDNNConvolutionAlgorithm.convDesc, libMatrixCuDNNConvolutionAlgorithm.filterDesc, 2, jArr[0], iArr);
        JCudnn.cudnnGetConvolutionBackwardFilterWorkspaceSize(LibMatrixCuDNN.getCudnnHandle(gPUContext), libMatrixCuDNNConvolutionAlgorithm.nchwTensorDesc, libMatrixCuDNNConvolutionAlgorithm.nkpqTensorDesc, libMatrixCuDNNConvolutionAlgorithm.convDesc, libMatrixCuDNNConvolutionAlgorithm.filterDesc, iArr[0], jArr);
        if (jArr[0] != 0) {
            libMatrixCuDNNConvolutionAlgorithm.workSpace = gPUContext.allocate(jArr[0]);
        }
        libMatrixCuDNNConvolutionAlgorithm.sizeInBytes = jArr[0];
        libMatrixCuDNNConvolutionAlgorithm.algo = iArr[0];
        if (DMLScript.FINEGRAINED_STATISTICS) {
            GPUStatistics.maintainCPMiscTimes(str, GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - nanoTime);
        }
        return libMatrixCuDNNConvolutionAlgorithm;
    }

    public static LibMatrixCuDNNConvolutionAlgorithm cudnnGetConvolutionBackwardDataAlgorithm(GPUContext gPUContext, String str, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, long j) throws DMLRuntimeException {
        LibMatrixCuDNNConvolutionAlgorithm libMatrixCuDNNConvolutionAlgorithm = new LibMatrixCuDNNConvolutionAlgorithm(gPUContext, str, i, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13);
        if (i3 == i6 || i4 == i7) {
            libMatrixCuDNNConvolutionAlgorithm.algo = 0;
        } else {
            long nanoTime = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
            int[] iArr = {-1};
            long[] jArr = {Math.min(j, MAX_WORKSPACE_LIMIT_BYTES)};
            JCudnn.cudnnGetConvolutionBackwardDataAlgorithm(LibMatrixCuDNN.getCudnnHandle(gPUContext), libMatrixCuDNNConvolutionAlgorithm.filterDesc, libMatrixCuDNNConvolutionAlgorithm.nkpqTensorDesc, libMatrixCuDNNConvolutionAlgorithm.convDesc, libMatrixCuDNNConvolutionAlgorithm.nchwTensorDesc, 2, jArr[0], iArr);
            JCudnn.cudnnGetConvolutionBackwardDataWorkspaceSize(LibMatrixCuDNN.getCudnnHandle(gPUContext), libMatrixCuDNNConvolutionAlgorithm.filterDesc, libMatrixCuDNNConvolutionAlgorithm.nkpqTensorDesc, libMatrixCuDNNConvolutionAlgorithm.convDesc, libMatrixCuDNNConvolutionAlgorithm.nchwTensorDesc, iArr[0], jArr);
            if (jArr[0] != 0) {
                libMatrixCuDNNConvolutionAlgorithm.workSpace = gPUContext.allocate(jArr[0]);
            }
            libMatrixCuDNNConvolutionAlgorithm.sizeInBytes = jArr[0];
            libMatrixCuDNNConvolutionAlgorithm.algo = iArr[0];
            if (DMLScript.FINEGRAINED_STATISTICS) {
                GPUStatistics.maintainCPMiscTimes(str, GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - nanoTime);
            }
        }
        return libMatrixCuDNNConvolutionAlgorithm;
    }

    private static cudnnTensorDescriptor allocateTensorDescriptor(int i, int i2, int i3, int i4) throws DMLRuntimeException {
        cudnnTensorDescriptor cudnntensordescriptor = new cudnnTensorDescriptor();
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor);
        JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor, 0, LibMatrixCUDA.CUDNN_DATA_TYPE, i, i2, i3, i4);
        return cudnntensordescriptor;
    }

    private static cudnnFilterDescriptor allocateFilterDescriptor(int i, int i2, int i3, int i4) {
        cudnnFilterDescriptor cudnnfilterdescriptor = new cudnnFilterDescriptor();
        JCudnn.cudnnCreateFilterDescriptor(cudnnfilterdescriptor);
        JCudnn.cudnnSetFilter4dDescriptor(cudnnfilterdescriptor, LibMatrixCUDA.CUDNN_DATA_TYPE, 0, i, i2, i3, i4);
        return cudnnfilterdescriptor;
    }

    private static cudnnConvolutionDescriptor allocateConvolutionDescriptor(int[] iArr, int[] iArr2) {
        cudnnConvolutionDescriptor cudnnconvolutiondescriptor = new cudnnConvolutionDescriptor();
        JCudnn.cudnnCreateConvolutionDescriptor(cudnnconvolutiondescriptor);
        JCudnn.cudnnSetConvolution2dDescriptor(cudnnconvolutiondescriptor, iArr[0], iArr[1], iArr2[0], iArr2[1], 1, 1, 1);
        return cudnnconvolutiondescriptor;
    }
}
