package org.apache.sysml.hops.ipa;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.DataGenOp;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.FunctionOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.DMLTranslator;
import org.apache.sysml.parser.DataIdentifier;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.parser.ExternalFunctionStatement;
import org.apache.sysml.parser.ForStatement;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.FunctionStatement;
import org.apache.sysml.parser.FunctionStatementBlock;
import org.apache.sysml.parser.IfStatement;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.LanguageException;
import org.apache.sysml.parser.ParseException;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.WhileStatement;
import org.apache.sysml.parser.WhileStatementBlock;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.MatrixFormatMetaData;
import org.apache.sysml.udf.lib.DeNaNWrapper;
import org.apache.sysml.udf.lib.DeNegInfinityWrapper;
import org.apache.sysml.udf.lib.DynamicReadMatrixCP;
import org.apache.sysml.udf.lib.DynamicReadMatrixRcCP;
import org.apache.sysml.udf.lib.OrderWrapper;

/* loaded from: input_file:org/apache/sysml/hops/ipa/InterProceduralAnalysis.class */
public class InterProceduralAnalysis {
    private static final boolean LDEBUG = false;
    private static final Log LOG = LogFactory.getLog(InterProceduralAnalysis.class.getName());
    private static final boolean INTRA_PROCEDURAL_ANALYSIS = true;
    private static final boolean PROPAGATE_KNOWN_UDF_STATISTICS = true;
    private static final boolean ALLOW_MULTIPLE_FUNCTION_CALLS = true;
    private static final boolean REMOVE_UNUSED_FUNCTIONS = true;
    private static final boolean FLAG_FUNCTION_RECOMPILE_ONCE = true;
    private static final boolean REMOVE_UNNECESSARY_CHECKPOINTS = true;
    private static final boolean REMOVE_CONSTANT_BINARY_OPS = true;

    public void analyzeProgram(DMLProgram dMLProgram) throws HopsException, ParseException, LanguageException {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        HashMap hashMap3 = new HashMap();
        HashSet hashSet = new HashSet();
        if (dMLProgram.getFunctionStatementBlocks().size() > 0) {
            Iterator<StatementBlock> it = dMLProgram.getStatementBlocks().iterator();
            while (it.hasNext()) {
                getFunctionCandidatesForStatisticPropagation(it.next(), hashMap, hashMap2);
            }
            hashSet.addAll(hashMap.keySet());
            pruneFunctionCandidatesForStatisticPropagation(hashMap, hashMap2);
            determineFunctionCandidatesNNZPropagation(hashMap2, hashMap3);
            DMLTranslator.resetHopsDAGVisitStatus(dMLProgram);
        }
        if (hashMap.isEmpty()) {
        }
        LocalVariableMap localVariableMap = new LocalVariableMap();
        Iterator<StatementBlock> it2 = dMLProgram.getStatementBlocks().iterator();
        while (it2.hasNext()) {
            propagateStatisticsAcrossBlock(it2.next(), hashMap.keySet(), localVariableMap, hashMap3, new HashSet());
        }
        removeUnusedFunctions(dMLProgram, hashSet);
        flagFunctionsForRecompileOnce(dMLProgram);
        if (OptimizerUtils.isSparkExecutionMode()) {
            removeUnnecessaryCheckpoints(dMLProgram);
        }
        removeConstantBinaryOps(dMLProgram);
    }

    public Set<String> analyzeSubProgram(StatementBlock statementBlock) throws HopsException, ParseException {
        DMLTranslator.resetHopsDAGVisitStatus(statementBlock);
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        HashMap hashMap3 = new HashMap();
        HashSet hashSet = new HashSet();
        getFunctionCandidatesForStatisticPropagation(statementBlock, hashMap, hashMap2);
        hashSet.addAll(hashMap.keySet());
        pruneFunctionCandidatesForStatisticPropagation(hashMap, hashMap2);
        determineFunctionCandidatesNNZPropagation(hashMap2, hashMap3);
        DMLTranslator.resetHopsDAGVisitStatus(statementBlock);
        if (!hashMap.isEmpty()) {
            propagateStatisticsAcrossBlock(statementBlock, hashMap.keySet(), new LocalVariableMap(), hashMap3, new HashSet());
        }
        return hashMap.keySet();
    }

    private void getFunctionCandidatesForStatisticPropagation(StatementBlock statementBlock, Map<String, Integer> map, Map<String, FunctionOp> map2) throws HopsException, ParseException {
        if (statementBlock instanceof FunctionStatementBlock) {
            Iterator<StatementBlock> it = ((FunctionStatement) ((FunctionStatementBlock) statementBlock).getStatement(0)).getBody().iterator();
            while (it.hasNext()) {
                getFunctionCandidatesForStatisticPropagation(it.next(), map, map2);
            }
            return;
        }
        if (statementBlock instanceof WhileStatementBlock) {
            Iterator<StatementBlock> it2 = ((WhileStatement) ((WhileStatementBlock) statementBlock).getStatement(0)).getBody().iterator();
            while (it2.hasNext()) {
                getFunctionCandidatesForStatisticPropagation(it2.next(), map, map2);
            }
            return;
        }
        if (statementBlock instanceof IfStatementBlock) {
            IfStatement ifStatement = (IfStatement) ((IfStatementBlock) statementBlock).getStatement(0);
            Iterator<StatementBlock> it3 = ifStatement.getIfBody().iterator();
            while (it3.hasNext()) {
                getFunctionCandidatesForStatisticPropagation(it3.next(), map, map2);
            }
            Iterator<StatementBlock> it4 = ifStatement.getElseBody().iterator();
            while (it4.hasNext()) {
                getFunctionCandidatesForStatisticPropagation(it4.next(), map, map2);
            }
            return;
        }
        if (statementBlock instanceof ForStatementBlock) {
            Iterator<StatementBlock> it5 = ((ForStatement) ((ForStatementBlock) statementBlock).getStatement(0)).getBody().iterator();
            while (it5.hasNext()) {
                getFunctionCandidatesForStatisticPropagation(it5.next(), map, map2);
            }
            return;
        }
        ArrayList<Hop> arrayList = statementBlock.get_hops();
        if (arrayList != null) {
            Iterator<Hop> it6 = arrayList.iterator();
            while (it6.hasNext()) {
                getFunctionCandidatesForStatisticPropagation(statementBlock.getDMLProg(), it6.next(), map, map2);
            }
        }
    }

    private void getFunctionCandidatesForStatisticPropagation(DMLProgram dMLProgram, Hop hop, Map<String, Integer> map, Map<String, FunctionOp> map2) throws HopsException, ParseException {
        if (hop.getVisited() == Hop.VisitStatus.DONE) {
            return;
        }
        if ((hop instanceof FunctionOp) && !((FunctionOp) hop).getFunctionNamespace().equals(DMLProgram.INTERNAL_NAMESPACE)) {
            FunctionOp functionOp = (FunctionOp) hop;
            String constructFunctionKey = DMLProgram.constructFunctionKey(functionOp.getFunctionNamespace(), functionOp.getFunctionName());
            if (map.containsKey(constructFunctionKey)) {
                boolean z = true;
                FunctionOp functionOp2 = map2.get(constructFunctionKey);
                int size = functionOp2.getInput().size();
                for (int i = 0; i < size; i++) {
                    Hop hop2 = functionOp2.getInput().get(i);
                    Hop hop3 = functionOp.getInput().get(i);
                    z &= hop2.dimsKnown() && hop3.dimsKnown() && hop2.getDim1() == hop3.getDim1() && hop2.getDim2() == hop3.getDim2() && hop2.getNnz() == hop3.getNnz();
                    if (hop2 instanceof LiteralOp) {
                        z &= (hop3 instanceof LiteralOp) && HopRewriteUtils.isEqualValue((LiteralOp) hop2, (LiteralOp) hop3);
                    }
                }
                if (!z) {
                    map.put(constructFunctionKey, Integer.valueOf(map.get(constructFunctionKey).intValue() + 1));
                }
            } else {
                map.put(constructFunctionKey, 1);
                map2.put(constructFunctionKey, functionOp);
                getFunctionCandidatesForStatisticPropagation(dMLProgram.getFunctionStatementBlock(functionOp.getFunctionNamespace(), functionOp.getFunctionName()), map, map2);
            }
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            getFunctionCandidatesForStatisticPropagation(dMLProgram, it.next(), map, map2);
        }
        hop.setVisited(Hop.VisitStatus.DONE);
    }

    private void pruneFunctionCandidatesForStatisticPropagation(Map<String, Integer> map, Map<String, FunctionOp> map2) {
        if (LOG.isDebugEnabled()) {
            for (Map.Entry<String, Integer> entry : map.entrySet()) {
                LOG.debug("IPA: FUNC statistic propagation candidate: " + entry.getKey() + ", callCount=" + entry.getValue());
            }
        }
        for (String str : new HashSet(map.keySet())) {
            Integer num = map.get(str);
            if (num != null && num.intValue() > 1) {
                map.remove(str);
            }
        }
        if (LOG.isDebugEnabled()) {
            Iterator<String> it = map.keySet().iterator();
            while (it.hasNext()) {
                LOG.debug("IPA: FUNC statistic propagation candidate (after pruning): " + it.next());
            }
        }
    }

    private void determineFunctionCandidatesNNZPropagation(Map<String, FunctionOp> map, Map<String, Set<Long>> map2) {
        for (Map.Entry<String, FunctionOp> entry : map.entrySet()) {
            String key = entry.getKey();
            FunctionOp value = entry.getValue();
            HashSet hashSet = new HashSet();
            Iterator<Hop> it = value.getInput().iterator();
            while (it.hasNext()) {
                Hop next = it.next();
                if (next.getNnz() >= 0) {
                    hashSet.add(Long.valueOf(next.getHopID()));
                }
            }
            map2.put(key, hashSet);
        }
    }

    private void propagateStatisticsAcrossBlock(StatementBlock statementBlock, Set<String> set, LocalVariableMap localVariableMap, Map<String, Set<Long>> map, Set<String> set2) throws HopsException, ParseException {
        if (statementBlock instanceof FunctionStatementBlock) {
            Iterator<StatementBlock> it = ((FunctionStatement) ((FunctionStatementBlock) statementBlock).getStatement(0)).getBody().iterator();
            while (it.hasNext()) {
                propagateStatisticsAcrossBlock(it.next(), set, localVariableMap, map, set2);
            }
            return;
        }
        if (statementBlock instanceof WhileStatementBlock) {
            WhileStatementBlock whileStatementBlock = (WhileStatementBlock) statementBlock;
            WhileStatement whileStatement = (WhileStatement) whileStatementBlock.getStatement(0);
            propagateStatisticsAcrossPredicateDAG(whileStatementBlock.getPredicateHops(), localVariableMap);
            Recompiler.removeUpdatedScalars(localVariableMap, whileStatementBlock);
            LocalVariableMap localVariableMap2 = (LocalVariableMap) localVariableMap.clone();
            Iterator<StatementBlock> it2 = whileStatement.getBody().iterator();
            while (it2.hasNext()) {
                propagateStatisticsAcrossBlock(it2.next(), set, localVariableMap, map, set2);
            }
            if (Recompiler.reconcileUpdatedCallVarsLoops(localVariableMap2, localVariableMap, whileStatementBlock)) {
                propagateStatisticsAcrossPredicateDAG(whileStatementBlock.getPredicateHops(), localVariableMap);
                Iterator<StatementBlock> it3 = whileStatement.getBody().iterator();
                while (it3.hasNext()) {
                    propagateStatisticsAcrossBlock(it3.next(), set, localVariableMap, map, set2);
                }
            }
            Recompiler.removeUpdatedScalars(localVariableMap, statementBlock);
            return;
        }
        if (statementBlock instanceof IfStatementBlock) {
            IfStatementBlock ifStatementBlock = (IfStatementBlock) statementBlock;
            IfStatement ifStatement = (IfStatement) ifStatementBlock.getStatement(0);
            propagateStatisticsAcrossPredicateDAG(ifStatementBlock.getPredicateHops(), localVariableMap);
            LocalVariableMap localVariableMap3 = (LocalVariableMap) localVariableMap.clone();
            LocalVariableMap localVariableMap4 = (LocalVariableMap) localVariableMap.clone();
            Iterator<StatementBlock> it4 = ifStatement.getIfBody().iterator();
            while (it4.hasNext()) {
                propagateStatisticsAcrossBlock(it4.next(), set, localVariableMap, map, set2);
            }
            Iterator<StatementBlock> it5 = ifStatement.getElseBody().iterator();
            while (it5.hasNext()) {
                propagateStatisticsAcrossBlock(it5.next(), set, localVariableMap4, map, set2);
            }
            Recompiler.removeUpdatedScalars(Recompiler.reconcileUpdatedCallVarsIf(localVariableMap3, localVariableMap, localVariableMap4, ifStatementBlock), statementBlock);
            return;
        }
        if (!(statementBlock instanceof ForStatementBlock)) {
            Recompiler.removeUpdatedScalars(localVariableMap, statementBlock);
            ArrayList<Hop> arrayList = statementBlock.get_hops();
            DMLProgram dMLProg = statementBlock.getDMLProg();
            Hop.resetVisitStatus(arrayList);
            propagateStatisticsAcrossDAG(arrayList, localVariableMap);
            Hop.resetVisitStatus(arrayList);
            propagateStatisticsIntoFunctions(dMLProg, arrayList, set, localVariableMap, map, set2);
            return;
        }
        ForStatementBlock forStatementBlock = (ForStatementBlock) statementBlock;
        ForStatement forStatement = (ForStatement) forStatementBlock.getStatement(0);
        propagateStatisticsAcrossPredicateDAG(forStatementBlock.getFromHops(), localVariableMap);
        propagateStatisticsAcrossPredicateDAG(forStatementBlock.getToHops(), localVariableMap);
        propagateStatisticsAcrossPredicateDAG(forStatementBlock.getIncrementHops(), localVariableMap);
        Recompiler.removeUpdatedScalars(localVariableMap, forStatementBlock);
        LocalVariableMap localVariableMap5 = (LocalVariableMap) localVariableMap.clone();
        Iterator<StatementBlock> it6 = forStatement.getBody().iterator();
        while (it6.hasNext()) {
            propagateStatisticsAcrossBlock(it6.next(), set, localVariableMap, map, set2);
        }
        if (Recompiler.reconcileUpdatedCallVarsLoops(localVariableMap5, localVariableMap, forStatementBlock)) {
            Iterator<StatementBlock> it7 = forStatement.getBody().iterator();
            while (it7.hasNext()) {
                propagateStatisticsAcrossBlock(it7.next(), set, localVariableMap, map, set2);
            }
        }
        Recompiler.removeUpdatedScalars(localVariableMap, statementBlock);
    }

    private void propagateStatisticsAcrossPredicateDAG(Hop hop, LocalVariableMap localVariableMap) throws HopsException {
        if (hop == null) {
            return;
        }
        hop.resetVisitStatus();
        try {
            Recompiler.rUpdateStatistics(hop, localVariableMap);
        } catch (Exception e) {
            throw new HopsException("Failed to update Hop DAG statistics.", e);
        }
    }

    private void propagateStatisticsAcrossDAG(ArrayList<Hop> arrayList, LocalVariableMap localVariableMap) throws HopsException {
        if (arrayList == null) {
            return;
        }
        try {
            Iterator<Hop> it = arrayList.iterator();
            while (it.hasNext()) {
                Recompiler.rUpdateStatistics(it.next(), localVariableMap);
            }
            Recompiler.extractDAGOutputStatistics(arrayList, localVariableMap, true);
        } catch (Exception e) {
            throw new HopsException("Failed to update Hop DAG statistics.", e);
        }
    }

    private void propagateStatisticsIntoFunctions(DMLProgram dMLProgram, ArrayList<Hop> arrayList, Set<String> set, LocalVariableMap localVariableMap, Map<String, Set<Long>> map, Set<String> set2) throws HopsException, ParseException {
        Iterator<Hop> it = arrayList.iterator();
        while (it.hasNext()) {
            propagateStatisticsIntoFunctions(dMLProgram, it.next(), set, localVariableMap, map, set2);
        }
    }

    private void propagateStatisticsIntoFunctions(DMLProgram dMLProgram, Hop hop, Set<String> set, LocalVariableMap localVariableMap, Map<String, Set<Long>> map, Set<String> set2) throws HopsException, ParseException {
        if (hop.getVisited() == Hop.VisitStatus.DONE) {
            return;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            propagateStatisticsIntoFunctions(dMLProgram, it.next(), set, localVariableMap, map, set2);
        }
        if (hop instanceof FunctionOp) {
            FunctionOp functionOp = (FunctionOp) hop;
            String constructFunctionKey = DMLProgram.constructFunctionKey(functionOp.getFunctionNamespace(), functionOp.getFunctionName());
            if (functionOp.getFunctionType() == FunctionOp.FunctionType.DML) {
                FunctionStatementBlock functionStatementBlock = dMLProgram.getFunctionStatementBlock(functionOp.getFunctionNamespace(), functionOp.getFunctionName());
                FunctionStatement functionStatement = (FunctionStatement) functionStatementBlock.getStatement(0);
                if (!set.contains(constructFunctionKey) || set2.contains(constructFunctionKey)) {
                    extractFunctionCallUnknownReturnStatistics(functionStatement, functionOp, localVariableMap);
                } else {
                    set2.add(constructFunctionKey);
                    LocalVariableMap localVariableMap2 = new LocalVariableMap();
                    populateLocalVariableMapForFunctionCall(functionStatement, functionOp, localVariableMap2, map.get(constructFunctionKey));
                    propagateStatisticsAcrossBlock(functionStatementBlock, set, localVariableMap2, map, set2);
                    extractFunctionCallReturnStatistics(functionStatement, functionOp, localVariableMap2, localVariableMap, true);
                    set2.remove(constructFunctionKey);
                }
            } else if (functionOp.getFunctionType() == FunctionOp.FunctionType.EXTERNAL_FILE || functionOp.getFunctionType() == FunctionOp.FunctionType.EXTERNAL_MEM) {
                extractExternalFunctionCallReturnStatistics((ExternalFunctionStatement) dMLProgram.getFunctionStatementBlock(functionOp.getFunctionNamespace(), functionOp.getFunctionName()).getStatement(0), functionOp, localVariableMap);
            }
        }
        hop.setVisited(Hop.VisitStatus.DONE);
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    /* JADX WARN: Code restructure failed: missing block: B:23:0x012c, code lost:
    
        r15.put(r0.getName(), r23);
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    private void populateLocalVariableMapForFunctionCall(org.apache.sysml.parser.FunctionStatement r13, org.apache.sysml.hops.FunctionOp r14, org.apache.sysml.runtime.controlprogram.LocalVariableMap r15, java.util.Set<java.lang.Long> r16) throws org.apache.sysml.hops.HopsException {
        /*
            Method dump skipped, instructions count: 318
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.apache.sysml.hops.ipa.InterProceduralAnalysis.populateLocalVariableMapForFunctionCall(org.apache.sysml.parser.FunctionStatement, org.apache.sysml.hops.FunctionOp, org.apache.sysml.runtime.controlprogram.LocalVariableMap, java.util.Set):void");
    }

    private void extractFunctionCallReturnStatistics(FunctionStatement functionStatement, FunctionOp functionOp, LocalVariableMap localVariableMap, LocalVariableMap localVariableMap2, boolean z) throws HopsException {
        ArrayList<DataIdentifier> outputParams = functionStatement.getOutputParams();
        String[] outputVariableNames = functionOp.getOutputVariableNames();
        String constructFunctionKey = DMLProgram.constructFunctionKey(functionOp.getFunctionNamespace(), functionOp.getFunctionName());
        for (int i = 0; i < outputParams.size(); i++) {
            try {
                DataIdentifier dataIdentifier = outputParams.get(i);
                String name = dataIdentifier.getName();
                String str = outputVariableNames[i];
                if (dataIdentifier.getDataType() == Expression.DataType.MATRIX && localVariableMap.keySet().contains(name)) {
                    MatrixObject matrixObject = (MatrixObject) localVariableMap.get(name);
                    if (!localVariableMap2.keySet().contains(str) || z) {
                        localVariableMap2.put(str, createOutputMatrix(matrixObject.getNumRows(), matrixObject.getNumColumns(), matrixObject.getNnz()));
                    } else {
                        Data data = localVariableMap2.get(str);
                        if (data instanceof MatrixObject) {
                            MatrixCharacteristics matrixCharacteristics = ((MatrixObject) data).getMatrixCharacteristics();
                            if (OptimizerUtils.estimateSizeExactSparsity(matrixCharacteristics.getRows(), matrixCharacteristics.getCols(), matrixCharacteristics.getNonZeros() > 0 ? (matrixCharacteristics.getNonZeros() / matrixCharacteristics.getRows()) / matrixCharacteristics.getCols() : 1.0d) < OptimizerUtils.estimateSize(matrixObject.getNumRows(), matrixObject.getNumColumns())) {
                                matrixCharacteristics.setDimension(matrixObject.getNumRows(), matrixObject.getNumColumns());
                                matrixCharacteristics.setNonZeros(matrixObject.getNnz());
                            }
                        }
                    }
                }
            } catch (Exception e) {
                throw new HopsException("Failed to extract output statistics of function " + constructFunctionKey + ".", e);
            }
        }
    }

    private void extractFunctionCallUnknownReturnStatistics(FunctionStatement functionStatement, FunctionOp functionOp, LocalVariableMap localVariableMap) throws HopsException {
        ArrayList<DataIdentifier> outputParams = functionStatement.getOutputParams();
        String[] outputVariableNames = functionOp.getOutputVariableNames();
        String constructFunctionKey = DMLProgram.constructFunctionKey(functionOp.getFunctionNamespace(), functionOp.getFunctionName());
        for (int i = 0; i < outputParams.size(); i++) {
            try {
                DataIdentifier dataIdentifier = outputParams.get(i);
                String str = outputVariableNames[i];
                if (dataIdentifier.getDataType() == Expression.DataType.MATRIX) {
                    localVariableMap.put(str, createOutputMatrix(-1L, -1L, -1L));
                }
            } catch (Exception e) {
                throw new HopsException("Failed to extract output statistics of function " + constructFunctionKey + ".", e);
            }
        }
    }

    private void extractExternalFunctionCallReturnStatistics(ExternalFunctionStatement externalFunctionStatement, FunctionOp functionOp, LocalVariableMap localVariableMap) throws HopsException {
        String str = externalFunctionStatement.getOtherParams().get(ExternalFunctionStatement.CLASS_NAME);
        if (str.equals(OrderWrapper.class.getName()) || str.equals(DeNaNWrapper.class.getCanonicalName()) || str.equals(DeNegInfinityWrapper.class.getCanonicalName())) {
            Hop hop = functionOp.getInput().get(0);
            localVariableMap.put(functionOp.getOutputVariableNames()[0], createOutputMatrix(hop.getDim1(), hop.getDim2(), str.equals(OrderWrapper.class.getName()) ? hop.getNnz() : -1L));
            return;
        }
        if (str.equals("org.apache.sysml.udf.lib.EigenWrapper")) {
            Hop hop2 = functionOp.getInput().get(0);
            localVariableMap.put(functionOp.getOutputVariableNames()[0], createOutputMatrix(hop2.getDim1(), 1L, -1L));
            localVariableMap.put(functionOp.getOutputVariableNames()[1], createOutputMatrix(hop2.getDim1(), hop2.getDim1(), -1L));
        } else {
            if (str.equals("org.apache.sysml.udf.lib.LinearSolverWrapperCP")) {
                localVariableMap.put(functionOp.getOutputVariableNames()[0], createOutputMatrix(functionOp.getInput().get(1).getDim1(), 1L, -1L));
                return;
            }
            if (!str.equals(DynamicReadMatrixCP.class.getName()) && !str.equals(DynamicReadMatrixRcCP.class.getName())) {
                extractFunctionCallUnknownReturnStatistics(externalFunctionStatement, functionOp, localVariableMap);
                return;
            }
            Hop hop3 = functionOp.getInput().get(1);
            Hop hop4 = functionOp.getInput().get(2);
            if ((hop3 instanceof LiteralOp) && (hop4 instanceof LiteralOp)) {
                localVariableMap.put(functionOp.getOutputVariableNames()[0], createOutputMatrix(((LiteralOp) hop3).getLongValue(), ((LiteralOp) hop4).getLongValue(), -1L));
            }
        }
    }

    private MatrixObject createOutputMatrix(long j, long j2, long j3) {
        MatrixObject matrixObject = new MatrixObject(Expression.ValueType.DOUBLE, null);
        matrixObject.setMetaData(new MatrixFormatMetaData(new MatrixCharacteristics(j, j2, DMLTranslator.DMLBlockSize, DMLTranslator.DMLBlockSize, j3), null, null));
        return matrixObject;
    }

    public void removeUnusedFunctions(DMLProgram dMLProgram, Set<String> set) throws LanguageException {
        for (String str : dMLProgram.getNamespaces().keySet()) {
            Iterator<Map.Entry<String, FunctionStatementBlock>> it = dMLProgram.getFunctionStatementBlocks(str).entrySet().iterator();
            while (it.hasNext()) {
                if (!set.contains(DMLProgram.constructFunctionKey(str, it.next().getKey()))) {
                    it.remove();
                }
            }
        }
    }

    public void flagFunctionsForRecompileOnce(DMLProgram dMLProgram) throws LanguageException {
        for (String str : dMLProgram.getNamespaces().keySet()) {
            for (String str2 : dMLProgram.getFunctionStatementBlocks(str).keySet()) {
                FunctionStatementBlock functionStatementBlock = dMLProgram.getFunctionStatementBlock(str, str2);
                if (rFlagFunctionForRecompileOnce(functionStatementBlock, false)) {
                    functionStatementBlock.setRecompileOnce(true);
                    LOG.debug("IPA: FUNC flagged for recompile-once: " + DMLProgram.constructFunctionKey(str, str2));
                }
            }
        }
    }

    public boolean rFlagFunctionForRecompileOnce(StatementBlock statementBlock, boolean z) {
        boolean z2 = false;
        if (statementBlock instanceof FunctionStatementBlock) {
            Iterator<StatementBlock> it = ((FunctionStatement) ((FunctionStatementBlock) statementBlock).getStatement(0)).getBody().iterator();
            while (it.hasNext()) {
                z2 |= rFlagFunctionForRecompileOnce(it.next(), z);
            }
        } else if (statementBlock instanceof WhileStatementBlock) {
            z2 = true;
        } else if (statementBlock instanceof IfStatementBlock) {
            IfStatementBlock ifStatementBlock = (IfStatementBlock) statementBlock;
            IfStatement ifStatement = (IfStatement) ifStatementBlock.getStatement(0);
            z2 = false | (z && ifStatementBlock.requiresPredicateRecompilation());
            Iterator<StatementBlock> it2 = ifStatement.getIfBody().iterator();
            while (it2.hasNext()) {
                z2 |= rFlagFunctionForRecompileOnce(it2.next(), z);
            }
            Iterator<StatementBlock> it3 = ifStatement.getElseBody().iterator();
            while (it3.hasNext()) {
                z2 |= rFlagFunctionForRecompileOnce(it3.next(), z);
            }
        } else if (statementBlock instanceof ForStatementBlock) {
            z2 = true;
        } else {
            z2 = false | (z && statementBlock.requiresRecompilation());
        }
        return z2;
    }

    private void removeUnnecessaryCheckpoints(DMLProgram dMLProgram) throws HopsException {
        HashMap hashMap = new HashMap();
        Iterator<StatementBlock> it = dMLProgram.getStatementBlocks().iterator();
        while (it.hasNext()) {
            StatementBlock next = it.next();
            for (String str : new HashSet(hashMap.keySet())) {
                if (next.variablesRead().containsVariable(str) && !next.variablesUpdated().containsVariable(str)) {
                    boolean z = false;
                    if (next.get_hops() != null) {
                        Hop.resetVisitStatus(next.get_hops());
                        z = true;
                        Iterator<Hop> it2 = next.get_hops().iterator();
                        while (it2.hasNext()) {
                            z &= !HopRewriteUtils.rContainsRead(it2.next(), str, false);
                        }
                    }
                    if (!z) {
                        hashMap.remove(str);
                    }
                }
            }
            HashSet<String> hashSet = new HashSet(hashMap.keySet());
            if ((next instanceof IfStatementBlock) || (next instanceof WhileStatementBlock) || (next instanceof ForStatementBlock)) {
                for (String str2 : hashSet) {
                    if (next.variablesUpdated().containsVariable(str2)) {
                        hashMap.remove(str2);
                    }
                }
            } else {
                for (String str3 : hashSet) {
                    if (next.variablesUpdated().containsVariable(str3) && next.get_hops() != null) {
                        ArrayList<Hop> arrayList = next.get_hops();
                        Hop.resetVisitStatus(arrayList);
                        Iterator<Hop> it3 = arrayList.iterator();
                        while (it3.hasNext()) {
                            Hop next2 = it3.next();
                            if (next2.getName().equals(str3) && !HopRewriteUtils.rHasSimpleReadChain(next2, str3)) {
                                hashMap.remove(str3);
                            }
                        }
                    }
                }
            }
            Iterator<Hop> it4 = collectCheckpoints(next.get_hops()).iterator();
            while (it4.hasNext()) {
                Hop next3 = it4.next();
                if (hashMap.containsKey(next3.getName())) {
                    ((Hop) hashMap.get(next3.getName())).setRequiresCheckpoint(false);
                }
                hashMap.put(next3.getName(), next3);
            }
        }
    }

    private ArrayList<Hop> collectCheckpoints(ArrayList<Hop> arrayList) {
        ArrayList<Hop> arrayList2 = new ArrayList<>();
        if (arrayList != null) {
            Hop.resetVisitStatus(arrayList);
            Iterator<Hop> it = arrayList.iterator();
            while (it.hasNext()) {
                rCollectCheckpoints(it.next(), arrayList2);
            }
        }
        return arrayList2;
    }

    private void rCollectCheckpoints(Hop hop, ArrayList<Hop> arrayList) {
        if (hop.getVisited() == Hop.VisitStatus.DONE) {
            return;
        }
        if (hop.requiresCheckpoint() && hop.getParent().size() == 1 && (hop.getParent().get(0) instanceof DataOp) && ((DataOp) hop.getParent().get(0)).getDataOpType() == Hop.DataOpTypes.TRANSIENTWRITE) {
            arrayList.add(hop);
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            rCollectCheckpoints(it.next(), arrayList);
        }
        hop.setVisited(Hop.VisitStatus.DONE);
    }

    private void removeConstantBinaryOps(DMLProgram dMLProgram) throws HopsException {
        HashMap<String, Hop> hashMap = new HashMap<>();
        Iterator<StatementBlock> it = dMLProgram.getStatementBlocks().iterator();
        while (it.hasNext()) {
            StatementBlock next = it.next();
            for (String str : next.variablesUpdated().getVariableNames()) {
                if (hashMap.containsKey(str)) {
                    hashMap.remove(str);
                }
            }
            if (!hashMap.isEmpty()) {
                rRemoveConstantBinaryOp(next, hashMap);
            }
            if (!(next instanceof IfStatementBlock) && !(next instanceof WhileStatementBlock) && !(next instanceof ForStatementBlock)) {
                collectMatrixOfOnes(next.get_hops(), hashMap);
            }
        }
    }

    private void collectMatrixOfOnes(ArrayList<Hop> arrayList, HashMap<String, Hop> hashMap) {
        if (arrayList == null) {
            return;
        }
        Iterator<Hop> it = arrayList.iterator();
        while (it.hasNext()) {
            Hop next = it.next();
            if ((next instanceof DataOp) && ((DataOp) next).getDataOpType() == Hop.DataOpTypes.TRANSIENTWRITE && (next.getInput().get(0) instanceof DataGenOp) && ((DataGenOp) next.getInput().get(0)).hasConstantValue(1.0d)) {
                hashMap.put(next.getName(), next.getInput().get(0));
            }
        }
    }

    private void rRemoveConstantBinaryOp(StatementBlock statementBlock, HashMap<String, Hop> hashMap) throws HopsException {
        if (statementBlock instanceof IfStatementBlock) {
            IfStatement ifStatement = (IfStatement) ((IfStatementBlock) statementBlock).getStatement(0);
            Iterator<StatementBlock> it = ifStatement.getIfBody().iterator();
            while (it.hasNext()) {
                rRemoveConstantBinaryOp(it.next(), hashMap);
            }
            if (ifStatement.getElseBody() != null) {
                Iterator<StatementBlock> it2 = ifStatement.getElseBody().iterator();
                while (it2.hasNext()) {
                    rRemoveConstantBinaryOp(it2.next(), hashMap);
                }
                return;
            }
            return;
        }
        if (statementBlock instanceof WhileStatementBlock) {
            Iterator<StatementBlock> it3 = ((WhileStatement) ((WhileStatementBlock) statementBlock).getStatement(0)).getBody().iterator();
            while (it3.hasNext()) {
                rRemoveConstantBinaryOp(it3.next(), hashMap);
            }
        } else if (statementBlock instanceof ForStatementBlock) {
            Iterator<StatementBlock> it4 = ((ForStatement) ((ForStatementBlock) statementBlock).getStatement(0)).getBody().iterator();
            while (it4.hasNext()) {
                rRemoveConstantBinaryOp(it4.next(), hashMap);
            }
        } else if (statementBlock.get_hops() != null) {
            Hop.resetVisitStatus(statementBlock.get_hops());
            Iterator<Hop> it5 = statementBlock.get_hops().iterator();
            while (it5.hasNext()) {
                rRemoveConstantBinaryOp(it5.next(), hashMap);
            }
        }
    }

    private void rRemoveConstantBinaryOp(Hop hop, HashMap<String, Hop> hashMap) {
        if (hop.getVisited() == Hop.VisitStatus.DONE) {
            return;
        }
        if ((hop instanceof BinaryOp) && ((BinaryOp) hop).getOp() == Hop.OpOp2.MULT && !((BinaryOp) hop).isOuterVectorOperator() && hop.getInput().get(0).getDataType() == Expression.DataType.MATRIX && (hop.getInput().get(1) instanceof DataOp) && hashMap.containsKey(hop.getInput().get(1).getName())) {
            HopRewriteUtils.removeChildReferenceByPos(hop, hop.getInput().get(1), 1);
            HopRewriteUtils.addChildReference(hop, new LiteralOp(1L), 1);
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            rRemoveConstantBinaryOp(it.next(), hashMap);
        }
        hop.setVisited(Hop.VisitStatus.DONE);
    }
}
