package org.apache.sysml.runtime.controlprogram.parfor.opt;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Set;
import org.apache.sysml.hops.FunctionOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.ForStatement;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.FunctionStatement;
import org.apache.sysml.parser.FunctionStatementBlock;
import org.apache.sysml.parser.IfStatement;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.WhileStatement;
import org.apache.sysml.parser.WhileStatementBlock;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.ForProgramBlock;
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.IfProgramBlock;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;

/* loaded from: input_file:org/apache/sysml/runtime/controlprogram/parfor/opt/OptTreePlanChecker.class */
public class OptTreePlanChecker {
    public static void checkProgramCorrectness(ProgramBlock programBlock, StatementBlock statementBlock, Set<String> set) throws HopsException, DMLRuntimeException {
        Program program = programBlock.getProgram();
        DMLProgram dMLProg = statementBlock.getDMLProg();
        if ((programBlock instanceof FunctionProgramBlock) && (statementBlock instanceof FunctionStatementBlock)) {
            FunctionProgramBlock functionProgramBlock = (FunctionProgramBlock) programBlock;
            FunctionStatement functionStatement = (FunctionStatement) ((FunctionStatementBlock) statementBlock).getStatement(0);
            for (int i = 0; i < functionProgramBlock.getChildBlocks().size(); i++) {
                checkProgramCorrectness(functionProgramBlock.getChildBlocks().get(i), functionStatement.getBody().get(i), set);
            }
            return;
        }
        if ((programBlock instanceof WhileProgramBlock) && (statementBlock instanceof WhileStatementBlock)) {
            WhileProgramBlock whileProgramBlock = (WhileProgramBlock) programBlock;
            WhileStatementBlock whileStatementBlock = (WhileStatementBlock) statementBlock;
            WhileStatement whileStatement = (WhileStatement) whileStatementBlock.getStatement(0);
            checkHopDagCorrectness(program, dMLProg, whileStatementBlock.getPredicateHops(), whileProgramBlock.getPredicate(), set);
            for (int i2 = 0; i2 < whileProgramBlock.getChildBlocks().size(); i2++) {
                checkProgramCorrectness(whileProgramBlock.getChildBlocks().get(i2), whileStatement.getBody().get(i2), set);
            }
            checkLinksProgramStatementBlock(whileProgramBlock, whileStatementBlock);
            return;
        }
        if ((programBlock instanceof IfProgramBlock) && (statementBlock instanceof IfStatementBlock)) {
            IfProgramBlock ifProgramBlock = (IfProgramBlock) programBlock;
            IfStatementBlock ifStatementBlock = (IfStatementBlock) statementBlock;
            IfStatement ifStatement = (IfStatement) ifStatementBlock.getStatement(0);
            checkHopDagCorrectness(program, dMLProg, ifStatementBlock.getPredicateHops(), ifProgramBlock.getPredicate(), set);
            for (int i3 = 0; i3 < ifProgramBlock.getChildBlocksIfBody().size(); i3++) {
                checkProgramCorrectness(ifProgramBlock.getChildBlocksIfBody().get(i3), ifStatement.getIfBody().get(i3), set);
            }
            for (int i4 = 0; i4 < ifProgramBlock.getChildBlocksElseBody().size(); i4++) {
                checkProgramCorrectness(ifProgramBlock.getChildBlocksElseBody().get(i4), ifStatement.getElseBody().get(i4), set);
            }
            checkLinksProgramStatementBlock(ifProgramBlock, ifStatementBlock);
            return;
        }
        if (!(programBlock instanceof ForProgramBlock) || !(statementBlock instanceof ForStatementBlock)) {
            checkHopDagCorrectness(program, dMLProg, statementBlock.getHops(), programBlock.getInstructions(), set);
            return;
        }
        ForProgramBlock forProgramBlock = (ForProgramBlock) programBlock;
        ForStatementBlock forStatementBlock = (ForStatementBlock) statementBlock;
        ForStatement forStatement = (ForStatement) statementBlock.getStatement(0);
        checkHopDagCorrectness(program, dMLProg, forStatementBlock.getFromHops(), forProgramBlock.getFromInstructions(), set);
        checkHopDagCorrectness(program, dMLProg, forStatementBlock.getToHops(), forProgramBlock.getToInstructions(), set);
        checkHopDagCorrectness(program, dMLProg, forStatementBlock.getIncrementHops(), forProgramBlock.getIncrementInstructions(), set);
        for (int i5 = 0; i5 < forProgramBlock.getChildBlocks().size(); i5++) {
            checkProgramCorrectness(forProgramBlock.getChildBlocks().get(i5), forStatement.getBody().get(i5), set);
        }
        checkLinksProgramStatementBlock(forProgramBlock, forStatementBlock);
    }

    private static void checkHopDagCorrectness(Program program, DMLProgram dMLProgram, ArrayList<Hop> arrayList, ArrayList<Instruction> arrayList2, Set<String> set) throws DMLRuntimeException, HopsException {
        if (arrayList != null) {
            Iterator<Hop> it = arrayList.iterator();
            while (it.hasNext()) {
                checkHopDagCorrectness(program, dMLProgram, it.next(), arrayList2, set);
            }
        }
    }

    private static void checkHopDagCorrectness(Program program, DMLProgram dMLProgram, Hop hop, ArrayList<Instruction> arrayList, Set<String> set) throws DMLRuntimeException, HopsException {
        checkFunctionNames(program, dMLProgram, hop, arrayList, set);
    }

    private static void checkLinksProgramStatementBlock(ProgramBlock programBlock, StatementBlock statementBlock) throws DMLRuntimeException {
        if (programBlock.getStatementBlock() != statementBlock) {
            throw new DMLRuntimeException("Links between programblocks and statementblocks are incorrect (" + programBlock + ").");
        }
    }

    private static void checkFunctionNames(Program program, DMLProgram dMLProgram, Hop hop, ArrayList<Instruction> arrayList, Set<String> set) throws DMLRuntimeException, HopsException {
        hop.resetVisitStatus();
        HashMap hashMap = new HashMap();
        getAllFunctionOps(hop, hashMap);
        Iterator<Instruction> it = arrayList.iterator();
        while (it.hasNext()) {
            Instruction next = it.next();
            if (next instanceof FunctionCallCPInstruction) {
                FunctionCallCPInstruction functionCallCPInstruction = (FunctionCallCPInstruction) next;
                String namespace = functionCallCPInstruction.getNamespace();
                String functionName = functionCallCPInstruction.getFunctionName();
                String constructFunctionKey = DMLProgram.constructFunctionKey(namespace, functionName);
                if (!hashMap.containsKey(constructFunctionKey)) {
                    throw new DMLRuntimeException("Function Check: instruction and hop names differ (" + constructFunctionKey + ", " + hashMap.keySet() + ")");
                }
                if (!program.getFunctionProgramBlocks().containsKey(constructFunctionKey)) {
                    throw new DMLRuntimeException("Function Check: function does not exits (" + constructFunctionKey + ")");
                }
                FunctionProgramBlock functionProgramBlock = program.getFunctionProgramBlock(namespace, functionName);
                FunctionStatementBlock functionStatementBlock = dMLProgram.getFunctionStatementBlock(namespace, functionName);
                if (!set.contains(constructFunctionKey)) {
                    set.add(constructFunctionKey);
                    checkProgramCorrectness(functionProgramBlock, functionStatementBlock, set);
                    set.remove(constructFunctionKey);
                }
            }
        }
    }

    private static void getAllFunctionOps(Hop hop, HashMap<String, FunctionOp> hashMap) {
        if (hop.isVisited()) {
            return;
        }
        if (hop instanceof FunctionOp) {
            FunctionOp functionOp = (FunctionOp) hop;
            hashMap.put(functionOp.getFunctionKey(), functionOp);
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            getAllFunctionOps(it.next(), hashMap);
        }
        hop.setVisited();
    }
}
