package org.apache.sysml.hops.globalopt.gdfgraph;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.globalopt.Summary;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.WhileStatementBlock;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.ForProgramBlock;
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.IfProgramBlock;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysml.utils.Explain;

/* loaded from: input_file:org/apache/sysml/hops/globalopt/gdfgraph/GraphBuilder.class */
public class GraphBuilder {
    private static final boolean IGNORE_UNBOUND_UPDATED_VARS = true;

    public static GDFGraph constructGlobalDataFlowGraph(Program program, Summary summary) throws DMLRuntimeException, HopsException {
        Timing timing = new Timing(true);
        HashMap hashMap = new HashMap();
        Iterator<ProgramBlock> it = program.getProgramBlocks().iterator();
        while (it.hasNext()) {
            constructGDFGraph(it.next(), hashMap);
        }
        ArrayList arrayList = new ArrayList();
        for (GDFNode gDFNode : hashMap.values()) {
            if (!(gDFNode instanceof GDFCrossBlockNode)) {
                arrayList.add(gDFNode);
            }
        }
        GDFGraph gDFGraph = new GDFGraph(program, arrayList);
        summary.setTimeGDFGraph(timing.stop());
        return gDFGraph;
    }

    private static void constructGDFGraph(ProgramBlock programBlock, HashMap<String, GDFNode> hashMap) throws DMLRuntimeException, HopsException {
        if (programBlock instanceof FunctionProgramBlock) {
            throw new DMLRuntimeException("FunctionProgramBlocks not implemented yet.");
        }
        if (programBlock instanceof WhileProgramBlock) {
            WhileProgramBlock whileProgramBlock = (WhileProgramBlock) programBlock;
            WhileStatementBlock whileStatementBlock = (WhileStatementBlock) programBlock.getStatementBlock();
            GDFNode constructGDFGraph = constructGDFGraph(whileStatementBlock.getPredicateHops(), whileProgramBlock, new HashMap(), hashMap);
            HashMap<String, GDFNode> constructLoopInputNodes = constructLoopInputNodes(whileProgramBlock, whileStatementBlock, hashMap);
            HashMap hashMap2 = (HashMap) constructLoopInputNodes.clone();
            Iterator<ProgramBlock> it = whileProgramBlock.getChildBlocks().iterator();
            while (it.hasNext()) {
                constructGDFGraph(it.next(), hashMap2);
            }
            HashMap<String, GDFNode> constructLoopOutputNodes = constructLoopOutputNodes(whileStatementBlock, hashMap2);
            constructLoopOutputCrossBlockNodes(whileStatementBlock, new GDFLoopNode(whileProgramBlock, constructGDFGraph, constructLoopInputNodes, constructLoopOutputNodes), constructLoopOutputNodes, hashMap, whileProgramBlock);
            return;
        }
        if (programBlock instanceof IfProgramBlock) {
            IfProgramBlock ifProgramBlock = (IfProgramBlock) programBlock;
            IfStatementBlock ifStatementBlock = (IfStatementBlock) programBlock.getStatementBlock();
            if (ifStatementBlock.getPredicateHops() != null) {
                Hop predicateHops = ifStatementBlock.getPredicateHops();
                hashMap.put(predicateHops.getName(), constructGDFGraph(predicateHops, ifProgramBlock, new HashMap(), hashMap));
            }
            HashMap hashMap3 = (HashMap) hashMap.clone();
            HashMap hashMap4 = (HashMap) hashMap.clone();
            Iterator<ProgramBlock> it2 = ifProgramBlock.getChildBlocksIfBody().iterator();
            while (it2.hasNext()) {
                constructGDFGraph(it2.next(), hashMap3);
            }
            if (ifProgramBlock.getChildBlocksElseBody() != null) {
                Iterator<ProgramBlock> it3 = ifProgramBlock.getChildBlocksElseBody().iterator();
                while (it3.hasNext()) {
                    constructGDFGraph(it3.next(), hashMap4);
                }
            }
            reconcileMergeIfProgramBlockOutputs(hashMap3, hashMap4, hashMap, ifProgramBlock);
            return;
        }
        if (programBlock instanceof ForProgramBlock) {
            ForProgramBlock forProgramBlock = (ForProgramBlock) programBlock;
            ForStatementBlock forStatementBlock = (ForStatementBlock) programBlock.getStatementBlock();
            GDFNode constructForPredicateNode = constructForPredicateNode(forProgramBlock, forStatementBlock, hashMap);
            HashMap<String, GDFNode> constructLoopInputNodes2 = constructLoopInputNodes(forProgramBlock, forStatementBlock, hashMap);
            HashMap hashMap5 = (HashMap) constructLoopInputNodes2.clone();
            Iterator<ProgramBlock> it4 = forProgramBlock.getChildBlocks().iterator();
            while (it4.hasNext()) {
                constructGDFGraph(it4.next(), hashMap5);
            }
            HashMap<String, GDFNode> constructLoopOutputNodes2 = constructLoopOutputNodes(forStatementBlock, hashMap5);
            constructLoopOutputCrossBlockNodes(forStatementBlock, new GDFLoopNode(forProgramBlock, constructForPredicateNode, constructLoopInputNodes2, constructLoopOutputNodes2), constructLoopOutputNodes2, hashMap, forProgramBlock);
            return;
        }
        ArrayList<Hop> hops = programBlock.getStatementBlock().getHops();
        if (hops != null) {
            HashMap hashMap6 = new HashMap();
            Iterator<Hop> it5 = hops.iterator();
            while (it5.hasNext()) {
                Hop next = it5.next();
                GDFNode constructGDFGraph2 = constructGDFGraph(next, programBlock, hashMap6, hashMap);
                if (constructGDFGraph2 == null) {
                    throw new HopsException("GDFGraphBuilder: failed to constuct dag root for: " + Explain.explain(next));
                }
                if ((next instanceof DataOp) && ((DataOp) next).getDataOpType() == Hop.DataOpTypes.TRANSIENTWRITE) {
                    constructGDFGraph2 = new GDFCrossBlockNode(next, programBlock, constructGDFGraph2, next.getName());
                }
                hashMap.put(next.getName(), constructGDFGraph2);
            }
        }
    }

    private static GDFNode constructGDFGraph(Hop hop, ProgramBlock programBlock, HashMap<Long, GDFNode> hashMap, HashMap<String, GDFNode> hashMap2) {
        if (hashMap.containsKey(Long.valueOf(hop.getHopID()))) {
            return hashMap.get(Long.valueOf(hop.getHopID()));
        }
        ArrayList arrayList = new ArrayList();
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            arrayList.add(constructGDFGraph(it.next(), programBlock, hashMap, hashMap2));
        }
        if ((hop instanceof DataOp) && ((DataOp) hop).getDataOpType() == Hop.DataOpTypes.TRANSIENTREAD) {
            arrayList.add(hashMap2.get(hop.getName()));
        }
        GDFNode gDFNode = new GDFNode(hop, programBlock, arrayList);
        hashMap.put(Long.valueOf(hop.getHopID()), gDFNode);
        return gDFNode;
    }

    private static GDFNode constructForPredicateNode(ForProgramBlock forProgramBlock, ForStatementBlock forStatementBlock, HashMap<String, GDFNode> hashMap) {
        HashMap hashMap2 = new HashMap();
        GDFNode constructGDFGraph = forStatementBlock.getFromHops() != null ? constructGDFGraph(forStatementBlock.getFromHops(), forProgramBlock, hashMap2, hashMap) : null;
        GDFNode constructGDFGraph2 = forStatementBlock.getToHops() != null ? constructGDFGraph(forStatementBlock.getToHops(), forProgramBlock, hashMap2, hashMap) : null;
        GDFNode constructGDFGraph3 = forStatementBlock.getIncrementHops() != null ? constructGDFGraph(forStatementBlock.getIncrementHops(), forProgramBlock, hashMap2, hashMap) : null;
        ArrayList arrayList = new ArrayList();
        arrayList.add(constructGDFGraph);
        arrayList.add(constructGDFGraph2);
        arrayList.add(constructGDFGraph3);
        return new GDFNode(null, forProgramBlock, arrayList);
    }

    private static HashMap<String, GDFNode> constructLoopInputNodes(ProgramBlock programBlock, StatementBlock statementBlock, HashMap<String, GDFNode> hashMap) throws DMLRuntimeException {
        HashMap<String, GDFNode> hashMap2 = new HashMap<>();
        for (String str : statementBlock.variablesRead().getVariableNames()) {
            if (statementBlock.liveIn().containsVariable(str)) {
                GDFNode gDFNode = hashMap.get(str);
                if (gDFNode == null) {
                    throw new DMLRuntimeException("GDFGraphBuilder: Non-existing input node for variable: " + str);
                }
                hashMap2.put(str, gDFNode);
            }
        }
        return hashMap2;
    }

    private static HashMap<String, GDFNode> constructLoopOutputNodes(StatementBlock statementBlock, HashMap<String, GDFNode> hashMap) throws HopsException {
        HashMap<String, GDFNode> hashMap2 = new HashMap<>();
        for (String str : statementBlock.variablesUpdated().getVariableNames()) {
            GDFNode gDFNode = hashMap.get(str);
            if (gDFNode != null) {
                hashMap2.put(str, gDFNode);
            }
        }
        return hashMap2;
    }

    private static void reconcileMergeIfProgramBlockOutputs(HashMap<String, GDFNode> hashMap, HashMap<String, GDFNode> hashMap2, HashMap<String, GDFNode> hashMap3, IfProgramBlock ifProgramBlock) {
        for (Map.Entry<String, GDFNode> entry : hashMap.entrySet()) {
            GDFNode value = entry.getValue();
            GDFNode gDFNode = hashMap2.get(entry.getKey());
            if (value != gDFNode) {
                value = new GDFCrossBlockNode(null, ifProgramBlock, value, gDFNode, entry.getKey());
            }
            hashMap3.put(entry.getKey(), value);
        }
        for (Map.Entry<String, GDFNode> entry2 : hashMap2.entrySet()) {
            if (!hashMap.containsKey(entry2.getKey())) {
                hashMap3.put(entry2.getKey(), entry2.getValue());
            }
        }
    }

    private static void constructLoopOutputCrossBlockNodes(StatementBlock statementBlock, GDFLoopNode gDFLoopNode, HashMap<String, GDFNode> hashMap, HashMap<String, GDFNode> hashMap2, ProgramBlock programBlock) {
        for (Map.Entry<String, GDFNode> entry : hashMap.entrySet()) {
            if (statementBlock.liveOut().containsVariable(entry.getKey())) {
                hashMap2.put(entry.getKey(), hashMap2.containsKey(entry.getKey()) ? new GDFCrossBlockNode(null, programBlock, hashMap2.get(entry.getKey()), gDFLoopNode, entry.getKey()) : new GDFCrossBlockNode(null, programBlock, gDFLoopNode, entry.getKey()));
            }
        }
    }
}
