package org.apache.sysml.runtime.instructions.gpu.context;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.atomic.LongAdder;
import jcuda.CudaException;
import jcuda.Pointer;
import jcuda.runtime.JCuda;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.conf.DMLConfig;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysml.utils.GPUStatistics;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.class */
public class GPUMemoryManager {
    protected static final Log LOG = LogFactory.getLog(GPUMemoryManager.class.getName());
    private static final double WARN_UTILIZATION_FACTOR = 0.7d;
    public double GPU_MEMORY_UTILIZATION_FACTOR = ConfigurationManager.getDMLConfig().getDoubleValue(DMLConfig.GPU_MEMORY_UTILIZATION_FACTOR);
    private HashMap<Long, Set<Pointer>> rmvarGPUPointers = new HashMap<>();
    private ArrayList<GPUObject> allocatedGPUObjects = new ArrayList<>();
    private HashMap<Pointer, Long> allocatedGPUPointers = new HashMap<>();

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager$GPUComparator.class */
    public static class GPUComparator implements Comparator<GPUObject> {
        private long neededSize;

        public GPUComparator(long j) {
            this.neededSize = j;
        }

        @Override // java.util.Comparator
        public int compare(GPUObject gPUObject, GPUObject gPUObject2) {
            if (gPUObject.isLocked() && gPUObject2.isLocked()) {
                return 0;
            }
            if (gPUObject.isLocked()) {
                return -1;
            }
            if (gPUObject2.isLocked()) {
                return 1;
            }
            if (DMLScript.GPU_EVICTION_POLICY != DMLScript.EvictionPolicy.MIN_EVICT) {
                return Long.compare(gPUObject2.timestamp.get(), gPUObject.timestamp.get());
            }
            try {
                long sizeOnDevice = gPUObject.getSizeOnDevice() - this.neededSize;
                long sizeOnDevice2 = gPUObject2.getSizeOnDevice() - this.neededSize;
                return (sizeOnDevice < 0 || sizeOnDevice2 < 0) ? Long.compare(sizeOnDevice, sizeOnDevice2) : Long.compare(sizeOnDevice2, sizeOnDevice);
            } catch (DMLRuntimeException e) {
                throw new RuntimeException(e);
            }
        }
    }

    public void addGPUObject(GPUObject gPUObject) {
        this.allocatedGPUObjects.add(gPUObject);
    }

    public void removeGPUObject(GPUObject gPUObject) {
        if (LOG.isDebugEnabled()) {
            LOG.debug("Removing the GPU object: " + gPUObject);
        }
        this.allocatedGPUObjects.removeIf(gPUObject2 -> {
            return gPUObject2.equals(gPUObject);
        });
    }

    public long getSizeAllocatedGPUPointer(Pointer pointer) {
        if (this.allocatedGPUPointers.containsKey(pointer)) {
            return this.allocatedGPUPointers.get(pointer).longValue();
        }
        return -1L;
    }

    public GPUMemoryManager(GPUContext gPUContext) {
        JCuda.cudaMemGetInfo(new long[]{0}, new long[]{0});
        if (r0[0] < WARN_UTILIZATION_FACTOR * r0[0]) {
            LOG.warn("Potential under-utilization: GPU memory - Total: " + (r0[0] * 1.0E-6d) + " MB, Available: " + (r0[0] * 1.0E-6d) + " MB on " + gPUContext + ". This can happen if there are other processes running on the GPU at the same time.");
        } else {
            LOG.info("GPU memory - Total: " + (r0[0] * 1.0E-6d) + " MB, Available: " + (r0[0] * 1.0E-6d) + " MB on " + gPUContext);
        }
        if (GPUContextPool.initialGPUMemBudget() > OptimizerUtils.getLocalMemBudget()) {
            LOG.warn("Potential under-utilization: GPU memory (" + GPUContextPool.initialGPUMemBudget() + ") > driver memory budget (" + OptimizerUtils.getLocalMemBudget() + "). Consider increasing the driver memory budget.");
        }
    }

    private Pointer cudaMallocWarnIfFails(Pointer pointer, long j) {
        try {
            JCuda.cudaMalloc(pointer, j);
            this.allocatedGPUPointers.put(pointer, Long.valueOf(j));
            return pointer;
        } catch (CudaException e) {
            LOG.warn("cudaMalloc failed immediately after cudaMemGetInfo reported that memory of size " + j + " is available. This usually happens if there are external programs trying to grab on to memory in parallel.");
            return null;
        }
    }

    public Pointer malloc(String str, long j) throws DMLRuntimeException {
        if (j < 0) {
            throw new DMLRuntimeException("Cannot allocate memory of size " + j);
        }
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        Pointer rmvarPointer = getRmvarPointer(str, j);
        if (rmvarPointer == null && j <= getAvailableMemory()) {
            rmvarPointer = cudaMallocWarnIfFails(new Pointer(), j);
            if (LOG.isTraceEnabled()) {
                if (rmvarPointer == null) {
                    LOG.trace("Couldnot allocate a new pointer in the GPU memory:" + j);
                } else {
                    LOG.trace("Allocated a new pointer in the GPU memory:" + j);
                }
            }
        }
        if (rmvarPointer == null) {
            long j2 = Long.MAX_VALUE;
            for (Long l : this.rmvarGPUPointers.keySet()) {
                j2 = l.longValue() > j ? Math.min(j2, l.longValue()) : j2;
            }
            if (j2 != Long.MAX_VALUE) {
                guardedCudaFree(getRmvarPointer(str, j2));
                rmvarPointer = cudaMallocWarnIfFails(new Pointer(), j);
                if (LOG.isTraceEnabled()) {
                    if (rmvarPointer == null) {
                        LOG.trace("Couldnot reuse non-exact match of rmvarGPUPointers:" + j);
                    } else {
                        LOG.trace("Reuses a non-exact match from rmvarGPUPointers:" + j);
                    }
                }
            }
        }
        if (rmvarPointer == null) {
            HashSet hashSet = new HashSet();
            Iterator<Set<Pointer>> it = this.rmvarGPUPointers.values().iterator();
            while (it.hasNext()) {
                hashSet.addAll(it.next());
            }
            Iterator it2 = hashSet.iterator();
            while (it2.hasNext()) {
                guardedCudaFree((Pointer) it2.next());
            }
            if (j <= getAvailableMemory()) {
                rmvarPointer = cudaMallocWarnIfFails(new Pointer(), j);
                if (LOG.isTraceEnabled()) {
                    if (rmvarPointer == null) {
                        LOG.trace("Couldnot allocate a new pointer in the GPU memory after eager free:" + j);
                    } else {
                        LOG.trace("Allocated a new pointer in the GPU memory after eager free:" + j);
                    }
                }
            }
        }
        addMiscTime(str, GPUStatistics.cudaAllocTime, GPUStatistics.cudaAllocCount, GPUInstruction.MISC_TIMER_ALLOCATE, nanoTime);
        if (rmvarPointer == null) {
            long nanoTime2 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            Collections.sort(this.allocatedGPUObjects, new GPUComparator(j));
            while (j > getAvailableMemory() && this.allocatedGPUObjects.size() > 0) {
                GPUObject gPUObject = this.allocatedGPUObjects.get(this.allocatedGPUObjects.size() - 1);
                if (gPUObject.isLocked()) {
                    break;
                }
                if (gPUObject.dirty) {
                    gPUObject.copyFromDeviceToHost(str, true);
                }
                gPUObject.clearData(true);
            }
            addMiscTime(str, GPUStatistics.cudaEvictionCount, GPUStatistics.cudaEvictTime, GPUInstruction.MISC_TIMER_EVICT, nanoTime2);
            if (j <= getAvailableMemory()) {
                rmvarPointer = cudaMallocWarnIfFails(new Pointer(), j);
                if (LOG.isTraceEnabled()) {
                    if (rmvarPointer == null) {
                        LOG.trace("Couldnot allocate a new pointer in the GPU memory after eviction:" + j);
                    } else {
                        LOG.trace("Allocated a new pointer in the GPU memory after eviction:" + j);
                    }
                }
            }
        }
        if (rmvarPointer == null) {
            throw new DMLRuntimeException("There is not enough memory on device for this matrix, request (" + j + "). " + toString());
        }
        long nanoTime3 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        JCuda.cudaMemset(rmvarPointer, 0, j);
        addMiscTime(str, GPUStatistics.cudaMemSet0Time, GPUStatistics.cudaMemSet0Count, GPUInstruction.MISC_TIMER_SET_ZERO, nanoTime3);
        return rmvarPointer;
    }

    private void guardedCudaFree(Pointer pointer) {
        if (pointer != new Pointer()) {
            if (!this.allocatedGPUPointers.containsKey(pointer)) {
                throw new RuntimeException("Attempting to free an unaccounted pointer:" + pointer);
            }
            Long remove = this.allocatedGPUPointers.remove(pointer);
            if (this.rmvarGPUPointers.containsKey(remove) && this.rmvarGPUPointers.get(remove).contains(pointer)) {
                remove(this.rmvarGPUPointers, remove.longValue(), pointer);
            }
            if (LOG.isDebugEnabled()) {
                LOG.debug("Free-ing up the pointer: " + pointer);
            }
            JCuda.cudaFree(pointer);
        }
    }

    public void free(String str, Pointer pointer, boolean z) throws DMLRuntimeException {
        if (pointer == new Pointer()) {
            return;
        }
        if (z) {
            long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            guardedCudaFree(pointer);
            addMiscTime(str, GPUStatistics.cudaDeAllocTime, GPUStatistics.cudaDeAllocCount, GPUInstruction.MISC_TIMER_CUDA_FREE, nanoTime);
        } else {
            if (!this.allocatedGPUPointers.containsKey(pointer)) {
                throw new RuntimeException("ERROR : Internal state corrupted, cache block size map is not aware of a block it trying to free up");
            }
            long longValue = this.allocatedGPUPointers.get(pointer).longValue();
            Set<Pointer> set = this.rmvarGPUPointers.get(Long.valueOf(longValue));
            if (set == null) {
                set = new HashSet();
                this.rmvarGPUPointers.put(Long.valueOf(longValue), set);
            }
            if (set.contains(pointer)) {
                throw new RuntimeException("GPU : Internal state corrupted, double free");
            }
            set.add(pointer);
        }
    }

    public void clearMemory() throws DMLRuntimeException {
        Iterator<GPUObject> it = this.allocatedGPUObjects.iterator();
        while (it.hasNext()) {
            GPUObject next = it.next();
            if (next.isDirty()) {
                LOG.debug("Attempted to free GPU Memory when a block[" + next + "] is still on GPU memory, copying it back to host.");
                next.acquireHostRead(null);
            }
            next.clearData(true);
        }
        this.allocatedGPUObjects.clear();
        Iterator it2 = new HashSet(this.allocatedGPUPointers.keySet()).iterator();
        while (it2.hasNext()) {
            guardedCudaFree((Pointer) it2.next());
        }
    }

    private HashSet<Pointer> getDirtyPointers() {
        HashSet<Pointer> hashSet = new HashSet<>();
        Iterator<GPUObject> it = this.allocatedGPUObjects.iterator();
        while (it.hasNext()) {
            GPUObject next = it.next();
            if (next.isDirty()) {
                if (next.isSparse()) {
                    CSRPointer sparseMatrixCudaPointer = next.getSparseMatrixCudaPointer();
                    if (sparseMatrixCudaPointer == null) {
                        throw new RuntimeException("CSRPointer is null in clearTemporaryMemory");
                    }
                    if (sparseMatrixCudaPointer.rowPtr != null) {
                        hashSet.add(sparseMatrixCudaPointer.rowPtr);
                    }
                    if (sparseMatrixCudaPointer.colInd != null) {
                        hashSet.add(sparseMatrixCudaPointer.colInd);
                    }
                    if (sparseMatrixCudaPointer.val != null) {
                        hashSet.add(sparseMatrixCudaPointer.val);
                    }
                } else {
                    Pointer jcudaDenseMatrixPtr = next.getJcudaDenseMatrixPtr();
                    if (jcudaDenseMatrixPtr == null) {
                        throw new RuntimeException("Pointer is null in clearTemporaryMemory");
                    }
                    hashSet.add(jcudaDenseMatrixPtr);
                }
            }
        }
        return hashSet;
    }

    private Set<Pointer> nonIn(Set<Pointer> set, Set<Pointer> set2) {
        HashSet hashSet = new HashSet();
        for (Pointer pointer : set) {
            if (!set2.contains(pointer)) {
                hashSet.add(pointer);
            }
        }
        return hashSet;
    }

    public void clearTemporaryMemory() {
        Iterator<Pointer> it = nonIn(this.allocatedGPUPointers.keySet(), getDirtyPointers()).iterator();
        while (it.hasNext()) {
            guardedCudaFree(it.next());
        }
    }

    private void addMiscTime(String str, LongAdder longAdder, LongAdder longAdder2, String str2, long j) {
        if (DMLScript.STATISTICS) {
            long nanoTime = System.nanoTime() - j;
            longAdder.add(nanoTime);
            longAdder2.add(1L);
            if (str == null || !DMLScript.FINEGRAINED_STATISTICS) {
                return;
            }
            GPUStatistics.maintainCPMiscTimes(str, str2, nanoTime);
        }
    }

    private void addMiscTime(String str, String str2, long j) {
        if (str == null || !DMLScript.FINEGRAINED_STATISTICS) {
            return;
        }
        GPUStatistics.maintainCPMiscTimes(str, str2, System.nanoTime() - j);
    }

    private Pointer getRmvarPointer(String str, long j) {
        if (!this.rmvarGPUPointers.containsKey(Long.valueOf(j))) {
            return null;
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace("Getting rmvar-ed pointers for size:" + j);
        }
        long nanoTime = (str == null || !DMLScript.FINEGRAINED_STATISTICS) ? 0L : System.nanoTime();
        Pointer remove = remove(this.rmvarGPUPointers, j);
        addMiscTime(str, GPUInstruction.MISC_TIMER_REUSE, nanoTime);
        return remove;
    }

    private Pointer remove(HashMap<Long, Set<Pointer>> hashMap, long j) {
        Pointer next = hashMap.get(Long.valueOf(j)).iterator().next();
        remove(hashMap, j, next);
        return next;
    }

    private void remove(HashMap<Long, Set<Pointer>> hashMap, long j, Pointer pointer) {
        hashMap.get(Long.valueOf(j)).remove(pointer);
        if (hashMap.get(Long.valueOf(j)).isEmpty()) {
            hashMap.remove(Long.valueOf(j));
        }
    }

    public String toString() {
        long j = 0;
        long j2 = 0;
        long j3 = 0;
        long j4 = 0;
        Iterator<GPUObject> it = this.allocatedGPUObjects.iterator();
        while (it.hasNext()) {
            GPUObject next = it.next();
            try {
                if (next.isLocked()) {
                    j2++;
                    j += next.getSizeOnDevice();
                } else {
                    j4++;
                    j3 += next.getSizeOnDevice();
                }
            } catch (DMLRuntimeException e) {
                throw new RuntimeException(e);
            }
        }
        long j5 = 0;
        Iterator<Long> it2 = this.allocatedGPUPointers.values().iterator();
        while (it2.hasNext()) {
            j5 += it2.next().longValue();
        }
        return "Num of GPU objects: [unlocked:" + j4 + ", locked:" + j2 + "]. Size of GPU objects in bytes: [unlocked:" + j3 + ", locked:" + j + "]. Total memory allocated by the current GPU context in bytes:" + j5;
    }

    public long getAvailableMemory() {
        JCuda.cudaMemGetInfo(new long[]{0}, new long[]{0});
        return (long) (r0[0] * this.GPU_MEMORY_UTILIZATION_FACTOR);
    }
}
