package org.apache.sysml.yarn.ropt;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.StringTokenizer;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapreduce.MRJobConfig;
import org.apache.hadoop.yarn.api.records.NodeReport;
import org.apache.hadoop.yarn.api.records.NodeState;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.client.api.YarnClient;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.matrix.mapred.MRConfigurationNames;

/* loaded from: input_file:org/apache/sysml/yarn/ropt/YarnClusterAnalyzer.class */
public class YarnClusterAnalyzer {
    public static final long DEFAULT_JVM_SIZE = 536870912;
    public static final int CPU_HYPER_FACTOR = 1;
    public static int _localPar = -1;
    public static long _localJVMMaxMem = -1;
    public static int _remotePar = -1;
    public static long _remoteJVMMaxMemMap = -1;
    public static long _remoteJVMMaxMemReduce = -1;
    public static long _remoteMRSortMem = -1;
    public static boolean _localJT = false;
    public static long _blocksize = -1;
    public static HashMap<Long, Long> remoteJVMMaxMemPlan = new HashMap<>();
    public static HashSet<Long> probedSb = new HashSet<>();
    public static List<Long> nodesMaxPhySorted = null;
    public static List<Double> nodesMaxBudgetSorted = null;
    public static int minimumMRContainerPhyMB = -1;
    public static long mrAMPhy = -1;
    public static long clusterTotalMem = -1;
    public static int clusterTotalNodes = -1;
    public static int clusterTotalCores = -1;
    public static long minimalPhyAllocate = -1;
    public static long maximumPhyAllocate = -1;
    private static YarnClient _client = null;

    public static List<Long> getNodesMaxPhySorted() {
        if (nodesMaxPhySorted == null) {
            analyzeYarnCluster(true);
        }
        return nodesMaxPhySorted;
    }

    public static List<Double> getNodesMaxBudgetSorted() {
        if (nodesMaxBudgetSorted == null) {
            analyzeYarnCluster(true);
        }
        return nodesMaxBudgetSorted;
    }

    public static long getMRARPhy() {
        if (mrAMPhy == -1) {
            analyzeYarnCluster(true);
        }
        return mrAMPhy;
    }

    public static long getClusterTotalMem() {
        if (clusterTotalMem == -1) {
            analyzeYarnCluster(true);
        }
        return clusterTotalMem;
    }

    public static long getMaxPhyAllocate() {
        if (maximumPhyAllocate == -1) {
            analyzeYarnCluster(true);
        }
        return maximumPhyAllocate;
    }

    public static int getMinMRContarinerPhyMB() {
        if (minimumMRContainerPhyMB == -1) {
            analyzeYarnCluster(true);
        }
        return minimumMRContainerPhyMB;
    }

    public static int getLocalParallelism() {
        return _localPar;
    }

    public static int getRemoteParallelNodes() {
        if (_remotePar == -1) {
            analyzeYarnCluster(true);
        }
        return _remotePar;
    }

    public static int getRemoteParallelMapTasks(long j) {
        if (clusterTotalCores == -1) {
            analyzeYarnCluster(true);
        }
        int remoteParallelTasksGivenMem = getRemoteParallelTasksGivenMem(getRemoteMaxMemoryMap(j));
        if (remoteParallelTasksGivenMem >= clusterTotalCores * 1) {
            remoteParallelTasksGivenMem = clusterTotalCores * 1;
        }
        return remoteParallelTasksGivenMem;
    }

    public static int getRemoteParallelReduceTasks(long j) {
        if (clusterTotalCores == -1) {
            analyzeYarnCluster(true);
        }
        int remoteParallelTasksGivenMem = getRemoteParallelTasksGivenMem(getRemoteMaxMemoryReduce(j));
        if (remoteParallelTasksGivenMem >= clusterTotalCores * 1) {
            remoteParallelTasksGivenMem = clusterTotalCores * 1;
        }
        return remoteParallelTasksGivenMem;
    }

    public static long getYarnPhyAllocate(long j) {
        if (minimalPhyAllocate == -1) {
            analyzeYarnCluster(true);
        }
        if (j > maximumPhyAllocate) {
            throw new RuntimeException("Requested " + OptimizerUtils.toMB(j) + "MB, while the maximum yarn allocate is " + OptimizerUtils.toMB(maximumPhyAllocate) + "MB");
        }
        long ceil = ((long) Math.ceil(j / minimalPhyAllocate)) * minimalPhyAllocate;
        if (ceil > maximumPhyAllocate) {
            ceil = maximumPhyAllocate;
        }
        return ceil;
    }

    public static int getRemoteParallelTasksGivenMem(long j) {
        long yarnPhyAllocate = getYarnPhyAllocate(ResourceOptimizer.jvmToPhy(j, false));
        long yarnPhyAllocate2 = getYarnPhyAllocate(ResourceOptimizer.jvmToPhy(getLocalMaxMemory(), false));
        long yarnPhyAllocate3 = getYarnPhyAllocate(getMRARPhy());
        if (nodesMaxPhySorted == null) {
            analyzeYarnCluster(true);
        }
        if (nodesMaxPhySorted.isEmpty()) {
            return -1;
        }
        if (nodesMaxPhySorted.size() == 1) {
            long longValue = (nodesMaxPhySorted.get(0).longValue() - yarnPhyAllocate2) - yarnPhyAllocate3;
            if (longValue < 0) {
                return -1;
            }
            return (int) (longValue / yarnPhyAllocate);
        }
        long longValue2 = nodesMaxPhySorted.get(0).longValue() - yarnPhyAllocate2;
        long longValue3 = nodesMaxPhySorted.get(1).longValue();
        if (longValue2 >= longValue3) {
            longValue2 -= yarnPhyAllocate3;
        } else {
            longValue3 -= yarnPhyAllocate3;
        }
        if (longValue2 < 0 || longValue3 < 0) {
            return -1;
        }
        long j2 = (longValue2 / yarnPhyAllocate) + (longValue3 / yarnPhyAllocate);
        int i = 0;
        for (Long l : nodesMaxPhySorted) {
            int i2 = i;
            i++;
            if (i2 >= 2) {
                j2 += l.longValue() / yarnPhyAllocate;
            }
        }
        return (int) j2;
    }

    public static boolean checkValidMemPlan(boolean z) {
        if (nodesMaxPhySorted == null) {
            analyzeYarnCluster(true);
        }
        return !z ? nodesMaxPhySorted.get(0).longValue() >= getYarnPhyAllocate(ResourceOptimizer.jvmToPhy(getLocalMaxMemory(), false)) : getRemoteParallelTasksGivenMem(getMaximumRemoteMaxMemory()) > 0;
    }

    public static long getLocalMaxMemory() {
        return _localJVMMaxMem;
    }

    public static void setLocalMaxMemory(long j) {
        _localJVMMaxMem = j;
    }

    public static long getMaximumRemoteMaxMemory() {
        if (_remoteJVMMaxMemMap == -1) {
            analyzeYarnCluster(true);
        }
        long j = _remoteJVMMaxMemMap > _remoteJVMMaxMemReduce ? _remoteJVMMaxMemMap : _remoteJVMMaxMemReduce;
        for (Map.Entry<Long, Long> entry : remoteJVMMaxMemPlan.entrySet()) {
            if (j < entry.getValue().longValue()) {
                j = entry.getValue().longValue();
            }
        }
        return j;
    }

    public static long getRemoteMaxMemoryMap(long j) {
        if (_remoteJVMMaxMemMap == -1) {
            analyzeYarnCluster(true);
        }
        long specifiedRemoteMaxMemory = getSpecifiedRemoteMaxMemory(j);
        if (specifiedRemoteMaxMemory == -1) {
            specifiedRemoteMaxMemory = _remoteJVMMaxMemMap;
        }
        return specifiedRemoteMaxMemory;
    }

    public static long getRemoteMaxMemoryReduce(long j) {
        if (_remoteJVMMaxMemReduce == -1) {
            analyzeYarnCluster(true);
        }
        long specifiedRemoteMaxMemory = getSpecifiedRemoteMaxMemory(j);
        if (specifiedRemoteMaxMemory == -1) {
            specifiedRemoteMaxMemory = _remoteJVMMaxMemReduce;
        }
        return specifiedRemoteMaxMemory;
    }

    public static long getSpecifiedRemoteMaxMemory(long j) {
        probedSb.add(Long.valueOf(j));
        Long l = remoteJVMMaxMemPlan.get(Long.valueOf(j));
        if (l != null) {
            return l.longValue();
        }
        Long l2 = remoteJVMMaxMemPlan.get(-1L);
        if (l2 != null) {
            return l2.longValue();
        }
        return -1L;
    }

    public static void setRemoteMaxMemPlan(HashMap<Long, Double> hashMap) {
        remoteJVMMaxMemPlan.clear();
        for (Map.Entry<Long, Double> entry : hashMap.entrySet()) {
            remoteJVMMaxMemPlan.put(entry.getKey(), Long.valueOf(ResourceOptimizer.budgetToJvm(entry.getValue().doubleValue())));
        }
    }

    public static void resetSBProbedSet() {
        probedSb.clear();
    }

    public static HashSet<Long> getSBProbedSet() {
        return probedSb;
    }

    public static void printProbedSet(String str) {
        ArrayList arrayList = new ArrayList(probedSb);
        Collections.sort(arrayList);
        System.out.print(str);
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            System.out.print(((Long) it.next()) + ",");
        }
        System.out.println();
    }

    public static long getRemoteMaxMemorySortBuffer() {
        if (_remoteMRSortMem == -1) {
            analyzeYarnCluster(true);
        }
        return _remoteMRSortMem;
    }

    public static int getCkMaxCP() {
        return getLocalParallelism();
    }

    public static int getCkMaxMR(long j) {
        return getRemoteParallelMapTasks(j);
    }

    public static long getCmMax(long j) {
        return Math.min(getLocalMaxMemory(), getRemoteMaxMemoryMap(j));
    }

    public static long getHDFSBlockSize() {
        if (_blocksize == -1) {
            analyzeYarnCluster(true);
        }
        return _blocksize;
    }

    public static long extractMaxMemoryOpt(String str) {
        long j = -1;
        try {
            StringTokenizer stringTokenizer = new StringTokenizer(str, " ");
            while (stringTokenizer.hasMoreTokens()) {
                String nextToken = stringTokenizer.nextToken();
                if (nextToken.startsWith("-Xmx")) {
                    String substring = nextToken.substring(4);
                    j = (substring.endsWith("g") || substring.endsWith("G")) ? Long.parseLong(substring.substring(0, substring.length() - 1)) * 1024 * 1024 * 1024 : (substring.endsWith("m") || substring.endsWith("M")) ? Long.parseLong(substring.substring(0, substring.length() - 1)) * 1024 * 1024 : (substring.endsWith("k") || substring.endsWith("K")) ? Long.parseLong(substring.substring(0, substring.length() - 1)) * 1024 : Long.parseLong(substring.substring(0, substring.length() - 2));
                }
            }
            if (j < 0) {
                j = 536870912;
            }
        } catch (Exception e) {
            j = 536870912;
        }
        return j;
    }

    public static void setMaxMemoryOpt(JobConf jobConf, String str, long j) {
        String[] split = jobConf.get(str).split(" ");
        StringBuilder sb = new StringBuilder();
        for (String str2 : split) {
            if (str2.startsWith("-Xmx")) {
                sb.append("-Xmx");
                sb.append(j / 1048576);
                sb.append("M");
            } else {
                sb.append(str2);
            }
            sb.append(" ");
        }
        jobConf.set(str, sb.toString().trim());
    }

    private static void analyzeLocalMachine() {
        _localPar = Runtime.getRuntime().availableProcessors();
        _localJVMMaxMem = Runtime.getRuntime().maxMemory();
    }

    public static void analyzeYarnCluster(boolean z) {
        YarnConfiguration yarnConfiguration = new YarnConfiguration();
        YarnClient createYarnClient = YarnClient.createYarnClient();
        createYarnClient.init(yarnConfiguration);
        createYarnClient.start();
        analyzeYarnCluster(createYarnClient, yarnConfiguration, z);
    }

    public static long getMinAllocationBytes() {
        if (minimalPhyAllocate < 0) {
            analyzeYarnCluster(false);
        }
        return minimalPhyAllocate;
    }

    public static long getMaxAllocationBytes() {
        if (maximumPhyAllocate < 0) {
            analyzeYarnCluster(false);
        }
        return maximumPhyAllocate;
    }

    public static long getNumCores() {
        if (clusterTotalCores < 0) {
            analyzeYarnCluster(false);
        }
        return clusterTotalCores;
    }

    public static long getNumNodes() {
        if (clusterTotalNodes < 0) {
            analyzeYarnCluster(false);
        }
        return clusterTotalNodes;
    }

    public static YarnClusterConfig getClusterConfig() {
        YarnClusterConfig yarnClusterConfig = new YarnClusterConfig();
        yarnClusterConfig.setMinAllocationMB(getMinAllocationBytes() / 1048576);
        yarnClusterConfig.setMaxAllocationMB(getMaxAllocationBytes() / 1048576);
        yarnClusterConfig.setNumNodes(getNumNodes());
        yarnClusterConfig.setNumCores(getNumCores() * 1);
        return yarnClusterConfig;
    }

    public static double getClusterUtilization() throws IOException {
        try {
            if (_client == null) {
                _client = createYarnClient();
            }
            double d = 0.0d;
            double d2 = 0.0d;
            long j = 0;
            long j2 = 0;
            for (NodeReport nodeReport : _client.getNodeReports(new NodeState[0])) {
                Resource capability = nodeReport.getCapability();
                Resource used = nodeReport.getUsed();
                d += capability.getMemory();
                d2 += used.getMemory();
                j += capability.getVirtualCores();
                j2 += used.getVirtualCores();
            }
            return Math.max(Math.min(1.0d, d2 / d), Math.min(1.0d, j2 / j));
        } catch (Exception e) {
            throw new IOException(e);
        }
    }

    public static void analyzeYarnCluster(YarnClient yarnClient, YarnConfiguration yarnConfiguration, boolean z) {
        try {
            List<NodeReport> nodeReports = yarnClient.getNodeReports(new NodeState[0]);
            if (z) {
                System.out.println("There are " + nodeReports.size() + " nodes in the cluster");
            }
            if (nodeReports.isEmpty()) {
                throw new YarnException("There are zero available nodes in the yarn cluster");
            }
            nodesMaxPhySorted = new ArrayList(nodeReports.size());
            clusterTotalMem = 0L;
            clusterTotalCores = 0;
            clusterTotalNodes = 0;
            minimumMRContainerPhyMB = -1;
            for (NodeReport nodeReport : nodeReports) {
                Resource capability = nodeReport.getCapability();
                Resource used = nodeReport.getUsed();
                if (used == null) {
                    used = Resource.newInstance(0, 0);
                }
                int memory = capability.getMemory();
                int virtualCores = capability.getVirtualCores();
                if (memory <= 0) {
                    throw new YarnException("A node has non-positive memory " + memory);
                }
                int i = (memory / virtualCores) / 1;
                if (minimumMRContainerPhyMB < i) {
                    minimumMRContainerPhyMB = i;
                }
                clusterTotalMem += memory * 1024 * 1024;
                nodesMaxPhySorted.add(Long.valueOf(memory * 1024 * 1024));
                clusterTotalCores += virtualCores;
                clusterTotalNodes++;
                if (z) {
                    System.out.println("\t" + nodeReport.getNodeId() + " has " + memory + " MB (" + used.getMemory() + " MB used) memory and " + capability.getVirtualCores() + " (" + used.getVirtualCores() + " used) cores");
                }
            }
            Collections.sort(nodesMaxPhySorted, Collections.reverseOrder());
            nodesMaxBudgetSorted = new ArrayList(nodesMaxPhySorted.size());
            for (int i2 = 0; i2 < nodesMaxPhySorted.size(); i2++) {
                nodesMaxBudgetSorted.add(Double.valueOf(ResourceOptimizer.phyToBudget(nodesMaxPhySorted.get(i2).longValue())));
            }
            _remotePar = nodeReports.size();
            if (_remotePar == 0) {
                throw new YarnException("There are no available nodes in the yarn cluster");
            }
            _remoteMRSortMem = 1048576 * yarnConfiguration.getLong(CommonConfigurationKeysPublic.IO_SORT_MB_KEY, 100L);
            String str = yarnConfiguration.get(JobConf.MAPRED_TASK_JAVA_OPTS);
            String str2 = yarnConfiguration.get("mapreduce.map.java.opts", null);
            String str3 = yarnConfiguration.get("mapreduce.reduce.java.opts", null);
            if (str2 != null) {
                _remoteJVMMaxMemMap = extractMaxMemoryOpt(str2);
            } else {
                _remoteJVMMaxMemMap = extractMaxMemoryOpt(str);
            }
            if (str3 != null) {
                _remoteJVMMaxMemReduce = extractMaxMemoryOpt(str3);
            } else {
                _remoteJVMMaxMemReduce = extractMaxMemoryOpt(str);
            }
            _blocksize = Long.parseLong(yarnConfiguration.get(MRConfigurationNames.DFS_BLOCK_SIZE, "134217728"));
            minimalPhyAllocate = 1048576 * yarnConfiguration.getInt(YarnConfiguration.RM_SCHEDULER_MINIMUM_ALLOCATION_MB, 1024);
            maximumPhyAllocate = 1048576 * yarnConfiguration.getInt(YarnConfiguration.RM_SCHEDULER_MAXIMUM_ALLOCATION_MB, 8192);
            mrAMPhy = yarnConfiguration.getInt(MRJobConfig.MR_AM_VMEM_MB, MRJobConfig.DEFAULT_MR_AM_VMEM_MB) * 1024 * 1024;
        } catch (Exception e) {
            throw new RuntimeException("Unable to analyze yarn cluster ", e);
        }
    }

    private static YarnClient createYarnClient() {
        YarnConfiguration yarnConfiguration = new YarnConfiguration();
        YarnClient createYarnClient = YarnClient.createYarnClient();
        createYarnClient.init(yarnConfiguration);
        createYarnClient.start();
        return createYarnClient;
    }

    static {
        analyzeLocalMachine();
    }
}
