package org.apache.sysml.runtime.instructions.gpu.context;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import jcuda.driver.JCudaDriver;
import jcuda.jcublas.JCublas2;
import jcuda.jcudnn.JCudnn;
import jcuda.jcusparse.JCusparse;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaDeviceProp;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.utils.GPUStatistics;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/gpu/context/GPUContextPool.class */
public class GPUContextPool {
    public static String AVAILABLE_GPUS;
    static cudaDeviceProp[] deviceProperties;
    protected static final Log LOG = LogFactory.getLog(GPUContextPool.class.getName());
    private static long INITIAL_GPU_MEMORY_BUDGET = -1;
    static boolean initialized = false;
    static int deviceCount = -1;
    static List<GPUContext> pool = new LinkedList();
    static boolean reserved = false;

    public static synchronized void initializeGPU() throws DMLRuntimeException {
        initialized = true;
        GPUContext.LOG.info("Initializing CUDA");
        long nanoTime = System.nanoTime();
        JCuda.setExceptionsEnabled(true);
        JCudnn.setExceptionsEnabled(true);
        JCublas2.setExceptionsEnabled(true);
        JCusparse.setExceptionsEnabled(true);
        JCudaDriver.setExceptionsEnabled(true);
        JCudaDriver.cuInit(0);
        int[] iArr = {0};
        JCudaDriver.cuDeviceGetCount(iArr);
        deviceCount = iArr[0];
        deviceProperties = new cudaDeviceProp[deviceCount];
        try {
            Iterator<Integer> it = parseListString(AVAILABLE_GPUS, deviceCount).iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                cudaDeviceProp cudadeviceprop = new cudaDeviceProp();
                JCuda.cudaGetDeviceProperties(cudadeviceprop, intValue);
                deviceProperties[intValue] = cudadeviceprop;
                pool.add(new GPUContext(intValue));
            }
        } catch (IllegalArgumentException e) {
            LOG.warn("Invalid setting for setting systemml.gpu.availableGPUs, defaulting to use ALL GPUs");
            for (int i = 0; i < deviceCount; i++) {
                cudaDeviceProp cudadeviceprop2 = new cudaDeviceProp();
                JCuda.cudaGetDeviceProperties(cudadeviceprop2, i);
                deviceProperties[i] = cudadeviceprop2;
                pool.add(new GPUContext(i));
            }
        }
        long j = Long.MAX_VALUE;
        for (GPUContext gPUContext : pool) {
            gPUContext.initializeThread();
            j = Math.min(j, gPUContext.getAvailableMemory());
        }
        INITIAL_GPU_MEMORY_BUDGET = j;
        GPUContext.LOG.info("Total number of GPUs on the machine: " + deviceCount);
        GPUContext.LOG.info("GPUs being used: " + AVAILABLE_GPUS);
        GPUContext.LOG.info("Initial GPU memory: " + initialGPUMemBudget());
        GPUStatistics.cudaInitTime = System.nanoTime() - nanoTime;
    }

    public static ArrayList<Integer> parseListString(String str, int i) {
        ArrayList<Integer> arrayList = new ArrayList<>();
        String trim = str.trim();
        if (trim.equalsIgnoreCase("-1")) {
            for (int i2 = 0; i2 < i; i2++) {
                arrayList.add(Integer.valueOf(i2));
            }
        } else if (trim.contains(HelpFormatter.DEFAULT_OPT_PREFIX)) {
            String[] split = trim.split(HelpFormatter.DEFAULT_OPT_PREFIX);
            if (split.length != 2) {
                throw new IllegalArgumentException("Invalid string to parse to a list of numbers : " + trim);
            }
            String str2 = split[0];
            String str3 = split[1];
            int parseInt = Integer.parseInt(str2);
            int parseInt2 = Integer.parseInt(str3);
            for (int i3 = parseInt; i3 <= parseInt2; i3++) {
                arrayList.add(Integer.valueOf(i3));
            }
        } else if (trim.contains(",")) {
            for (String str4 : trim.split(",")) {
                arrayList.add(Integer.valueOf(Integer.parseInt(str4.trim())));
            }
        } else {
            arrayList.add(Integer.valueOf(Integer.parseInt(trim)));
        }
        Iterator<Integer> it = arrayList.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (intValue < 0 || intValue >= i) {
                throw new IllegalArgumentException("Invalid string (" + trim + ") parsed to a list of numbers (" + arrayList + ") which exceeds the maximum range : ");
            }
        }
        return arrayList;
    }

    public static synchronized List<GPUContext> reserveAllGPUContexts() throws DMLRuntimeException {
        if (reserved) {
            throw new DMLRuntimeException("Trying to re-reserve GPUs");
        }
        if (!initialized) {
            initializeGPU();
        }
        reserved = true;
        LOG.trace("GPU : Reserved all GPUs");
        return pool;
    }

    public static synchronized int getAvailableCount() {
        return pool.size();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static cudaDeviceProp getGPUProperties(int i) throws DMLRuntimeException {
        if (!initialized) {
            initializeGPU();
        }
        return deviceProperties[i];
    }

    public static int getDeviceCount() throws DMLRuntimeException {
        if (!initialized) {
            initializeGPU();
        }
        return deviceCount;
    }

    public static synchronized void freeAllGPUContexts() throws DMLRuntimeException {
        if (!reserved) {
            throw new DMLRuntimeException("Trying to free unreserved GPUs");
        }
        reserved = false;
        LOG.trace("GPU : Unreserved all GPUs");
    }

    public static synchronized long initialGPUMemBudget() throws RuntimeException {
        try {
            if (!initialized) {
                initializeGPU();
            }
            return INITIAL_GPU_MEMORY_BUDGET;
        } catch (DMLRuntimeException e) {
            throw new RuntimeException(e);
        }
    }
}
