package org.apache.sysml.hops.codegen;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.sysml.api.DMLException;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.conf.DMLConfig;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.codegen.cplan.CNode;
import org.apache.sysml.hops.codegen.cplan.CNodeCell;
import org.apache.sysml.hops.codegen.cplan.CNodeData;
import org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg;
import org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct;
import org.apache.sysml.hops.codegen.cplan.CNodeRow;
import org.apache.sysml.hops.codegen.cplan.CNodeTernary;
import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
import org.apache.sysml.hops.codegen.opt.PlanSelection;
import org.apache.sysml.hops.codegen.opt.PlanSelectionFuseAll;
import org.apache.sysml.hops.codegen.opt.PlanSelectionFuseCostBased;
import org.apache.sysml.hops.codegen.opt.PlanSelectionFuseCostBasedV2;
import org.apache.sysml.hops.codegen.opt.PlanSelectionFuseNoRedundancy;
import org.apache.sysml.hops.codegen.template.CPlanCSERewriter;
import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
import org.apache.sysml.hops.codegen.template.CPlanOpRewriter;
import org.apache.sysml.hops.codegen.template.TemplateBase;
import org.apache.sysml.hops.codegen.template.TemplateUtils;
import org.apache.sysml.hops.recompile.RecompileStatus;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.hops.rewrite.ProgramRewriteStatus;
import org.apache.sysml.hops.rewrite.ProgramRewriter;
import org.apache.sysml.hops.rewrite.RewriteCommonSubexpressionElimination;
import org.apache.sysml.hops.rewrite.RewriteRemoveUnnecessaryCasts;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.parser.ForStatement;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.FunctionStatement;
import org.apache.sysml.parser.FunctionStatementBlock;
import org.apache.sysml.parser.IfStatement;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.LanguageException;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.WhileStatement;
import org.apache.sysml.parser.WhileStatementBlock;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.codegen.CodegenUtils;
import org.apache.sysml.runtime.codegen.SpoofCellwise;
import org.apache.sysml.runtime.codegen.SpoofRowwise;
import org.apache.sysml.runtime.controlprogram.ForProgramBlock;
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.IfProgramBlock;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.matrix.data.Pair;
import org.apache.sysml.utils.Explain;
import org.apache.sysml.utils.Statistics;

/* loaded from: input_file:org/apache/sysml/hops/codegen/SpoofCompiler.class */
public class SpoofCompiler {
    public static final boolean RECOMPILE_CODEGEN = true;
    public static final boolean PRUNE_REDUNDANT_PLANS = true;
    public static final int PLAN_CACHE_SIZE = 1024;
    private static final PlanCache planCache;
    private static ProgramRewriter rewriteCSE;
    private static final Log LOG = LogFactory.getLog(SpoofCompiler.class.getName());
    public static boolean LDEBUG = false;
    public static CompilerType JAVA_COMPILER = CompilerType.JANINO;
    public static PlanSelector PLAN_SEL_POLICY = PlanSelector.FUSE_COST_BASED_V2;
    public static IntegrationType INTEGRATION = IntegrationType.RUNTIME;
    public static PlanCachePolicy PLAN_CACHE_POLICY = PlanCachePolicy.CSLH;

    /* loaded from: input_file:org/apache/sysml/hops/codegen/SpoofCompiler$CompilerType.class */
    public enum CompilerType {
        AUTO,
        JAVAC,
        JANINO
    }

    /* loaded from: input_file:org/apache/sysml/hops/codegen/SpoofCompiler$IntegrationType.class */
    public enum IntegrationType {
        HOPS,
        RUNTIME
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/hops/codegen/SpoofCompiler$PlanCache.class */
    public static class PlanCache {
        private final LinkedHashMap<CNode, Class<?>> _plans = new LinkedHashMap<>();
        private final int _maxSize;

        public PlanCache(int i) {
            this._maxSize = i;
        }

        public synchronized Class<?> getPlan(CNode cNode) {
            Class<?> remove = this._plans.remove(cNode);
            if (remove != null) {
                this._plans.put(cNode, remove);
            }
            return remove;
        }

        public synchronized void putPlan(CNode cNode, Class<?> cls) {
            if (this._plans.size() >= this._maxSize) {
                Iterator<Map.Entry<CNode, Class<?>>> it = this._plans.entrySet().iterator();
                CodegenUtils.clearClassCache(it.next().getValue());
                it.remove();
            }
            this._plans.put(cNode, cls);
        }

        public synchronized void clear() {
            this._plans.clear();
        }
    }

    /* loaded from: input_file:org/apache/sysml/hops/codegen/SpoofCompiler$PlanCachePolicy.class */
    public enum PlanCachePolicy {
        CONSTANT,
        CSLH,
        NONE;

        public static PlanCachePolicy get(boolean z, boolean z2) {
            return !z ? NONE : z2 ? CONSTANT : CSLH;
        }
    }

    /* loaded from: input_file:org/apache/sysml/hops/codegen/SpoofCompiler$PlanSelector.class */
    public enum PlanSelector {
        FUSE_ALL,
        FUSE_NO_REDUNDANCY,
        FUSE_COST_BASED,
        FUSE_COST_BASED_V2;

        public boolean isHeuristic() {
            return this == FUSE_ALL || this == FUSE_NO_REDUNDANCY;
        }

        public boolean isCostBased() {
            return this == FUSE_COST_BASED_V2 || this == FUSE_COST_BASED;
        }
    }

    public static void generateCode(DMLProgram dMLProgram) throws LanguageException, HopsException, DMLRuntimeException {
        for (String str : dMLProgram.getNamespaces().keySet()) {
            Iterator<String> it = dMLProgram.getFunctionStatementBlocks(str).keySet().iterator();
            while (it.hasNext()) {
                generateCodeFromStatementBlock(dMLProgram.getFunctionStatementBlock(str, it.next()));
            }
        }
        for (int i = 0; i < dMLProgram.getNumStatementBlocks(); i++) {
            generateCodeFromStatementBlock(dMLProgram.getStatementBlock(i));
        }
    }

    public static void generateCode(Program program) throws LanguageException, HopsException, DMLRuntimeException, LopsException, IOException {
        Iterator<FunctionProgramBlock> it = program.getFunctionProgramBlocks().values().iterator();
        while (it.hasNext()) {
            generateCodeFromProgramBlock(it.next());
        }
        Iterator<ProgramBlock> it2 = program.getProgramBlocks().iterator();
        while (it2.hasNext()) {
            generateCodeFromProgramBlock(it2.next());
        }
    }

    public static void generateCodeFromStatementBlock(StatementBlock statementBlock) throws HopsException, DMLRuntimeException {
        if (statementBlock instanceof FunctionStatementBlock) {
            Iterator<StatementBlock> it = ((FunctionStatement) ((FunctionStatementBlock) statementBlock).getStatement(0)).getBody().iterator();
            while (it.hasNext()) {
                generateCodeFromStatementBlock(it.next());
            }
            return;
        }
        if (statementBlock instanceof WhileStatementBlock) {
            WhileStatementBlock whileStatementBlock = (WhileStatementBlock) statementBlock;
            WhileStatement whileStatement = (WhileStatement) whileStatementBlock.getStatement(0);
            whileStatementBlock.setPredicateHops(optimize(whileStatementBlock.getPredicateHops(), false));
            Iterator<StatementBlock> it2 = whileStatement.getBody().iterator();
            while (it2.hasNext()) {
                generateCodeFromStatementBlock(it2.next());
            }
            return;
        }
        if (statementBlock instanceof IfStatementBlock) {
            IfStatementBlock ifStatementBlock = (IfStatementBlock) statementBlock;
            IfStatement ifStatement = (IfStatement) ifStatementBlock.getStatement(0);
            ifStatementBlock.setPredicateHops(optimize(ifStatementBlock.getPredicateHops(), false));
            Iterator<StatementBlock> it3 = ifStatement.getIfBody().iterator();
            while (it3.hasNext()) {
                generateCodeFromStatementBlock(it3.next());
            }
            Iterator<StatementBlock> it4 = ifStatement.getElseBody().iterator();
            while (it4.hasNext()) {
                generateCodeFromStatementBlock(it4.next());
            }
            return;
        }
        if (!(statementBlock instanceof ForStatementBlock)) {
            statementBlock.setHops(generateCodeFromHopDAGs(statementBlock.getHops()));
            statementBlock.updateRecompilationFlag();
            return;
        }
        ForStatementBlock forStatementBlock = (ForStatementBlock) statementBlock;
        ForStatement forStatement = (ForStatement) forStatementBlock.getStatement(0);
        forStatementBlock.setFromHops(optimize(forStatementBlock.getFromHops(), false));
        forStatementBlock.setToHops(optimize(forStatementBlock.getToHops(), false));
        forStatementBlock.setIncrementHops(optimize(forStatementBlock.getIncrementHops(), false));
        Iterator<StatementBlock> it5 = forStatement.getBody().iterator();
        while (it5.hasNext()) {
            generateCodeFromStatementBlock(it5.next());
        }
    }

    public static void generateCodeFromProgramBlock(ProgramBlock programBlock) throws HopsException, DMLRuntimeException, LopsException, IOException {
        if (programBlock instanceof FunctionProgramBlock) {
            Iterator<ProgramBlock> it = ((FunctionProgramBlock) programBlock).getChildBlocks().iterator();
            while (it.hasNext()) {
                generateCodeFromProgramBlock(it.next());
            }
            return;
        }
        if (programBlock instanceof WhileProgramBlock) {
            WhileProgramBlock whileProgramBlock = (WhileProgramBlock) programBlock;
            WhileStatementBlock whileStatementBlock = (WhileStatementBlock) whileProgramBlock.getStatementBlock();
            if (whileStatementBlock != null && whileStatementBlock.getPredicateHops() != null) {
                whileProgramBlock.setPredicate(generateCodeFromHopDAGsToInst(whileStatementBlock.getPredicateHops()));
            }
            Iterator<ProgramBlock> it2 = whileProgramBlock.getChildBlocks().iterator();
            while (it2.hasNext()) {
                generateCodeFromProgramBlock(it2.next());
            }
            return;
        }
        if (programBlock instanceof IfProgramBlock) {
            IfProgramBlock ifProgramBlock = (IfProgramBlock) programBlock;
            IfStatementBlock ifStatementBlock = (IfStatementBlock) ifProgramBlock.getStatementBlock();
            if (ifStatementBlock != null && ifStatementBlock.getPredicateHops() != null) {
                ifProgramBlock.setPredicate(generateCodeFromHopDAGsToInst(ifStatementBlock.getPredicateHops()));
            }
            Iterator<ProgramBlock> it3 = ifProgramBlock.getChildBlocksIfBody().iterator();
            while (it3.hasNext()) {
                generateCodeFromProgramBlock(it3.next());
            }
            Iterator<ProgramBlock> it4 = ifProgramBlock.getChildBlocksElseBody().iterator();
            while (it4.hasNext()) {
                generateCodeFromProgramBlock(it4.next());
            }
            return;
        }
        if (!(programBlock instanceof ForProgramBlock)) {
            StatementBlock statementBlock = programBlock.getStatementBlock();
            programBlock.setInstructions(generateCodeFromHopDAGsToInst(statementBlock, statementBlock.getHops()));
            return;
        }
        ForProgramBlock forProgramBlock = (ForProgramBlock) programBlock;
        ForStatementBlock forStatementBlock = (ForStatementBlock) forProgramBlock.getStatementBlock();
        if (forStatementBlock != null && forStatementBlock.getFromHops() != null) {
            forProgramBlock.setFromInstructions(generateCodeFromHopDAGsToInst(forStatementBlock.getFromHops()));
        }
        if (forStatementBlock != null && forStatementBlock.getToHops() != null) {
            forProgramBlock.setToInstructions(generateCodeFromHopDAGsToInst(forStatementBlock.getToHops()));
        }
        if (forStatementBlock != null && forStatementBlock.getIncrementHops() != null) {
            forProgramBlock.setIncrementInstructions(generateCodeFromHopDAGsToInst(forStatementBlock.getIncrementHops()));
        }
        Iterator<ProgramBlock> it5 = forProgramBlock.getChildBlocks().iterator();
        while (it5.hasNext()) {
            generateCodeFromProgramBlock(it5.next());
        }
    }

    public static ArrayList<Hop> generateCodeFromHopDAGs(ArrayList<Hop> arrayList) throws HopsException, DMLRuntimeException {
        if (arrayList == null) {
            return arrayList;
        }
        ArrayList<Hop> optimize = optimize(arrayList, false);
        Hop.resetVisitStatus(arrayList);
        Hop.resetVisitStatus(optimize);
        return optimize;
    }

    public static ArrayList<Instruction> generateCodeFromHopDAGsToInst(StatementBlock statementBlock, ArrayList<Hop> arrayList) throws DMLRuntimeException, HopsException, LopsException, IOException {
        return Recompiler.recompileHopsDag(statementBlock, arrayList, new LocalVariableMap(), new RecompileStatus(true), false, false, 0L);
    }

    public static ArrayList<Instruction> generateCodeFromHopDAGsToInst(Hop hop) throws DMLRuntimeException, HopsException, LopsException, IOException {
        return Recompiler.recompileHopsDag(hop, new LocalVariableMap(), new RecompileStatus(true), false, false, 0L);
    }

    public static Hop optimize(Hop hop, boolean z) throws DMLRuntimeException {
        return hop == null ? hop : optimize((ArrayList<Hop>) new ArrayList(Collections.singleton(hop)), z).get(0);
    }

    public static ArrayList<Hop> optimize(ArrayList<Hop> arrayList, boolean z) throws DMLRuntimeException {
        if (arrayList == null || arrayList.isEmpty()) {
            return arrayList;
        }
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        ArrayList<Hop> arrayList2 = arrayList;
        try {
            boolean z2 = PLAN_CACHE_POLICY == PlanCachePolicy.CONSTANT || !z;
            CPlanMemoTable cPlanMemoTable = new CPlanMemoTable();
            Iterator<Hop> it = arrayList.iterator();
            while (it.hasNext()) {
                rExploreCPlans(it.next(), cPlanMemoTable, z2);
            }
            cPlanMemoTable.pruneSuboptimal(arrayList);
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            HashSet hashSet = new HashSet();
            Iterator<Hop> it2 = arrayList.iterator();
            while (it2.hasNext()) {
                rConstructCPlans(it2.next(), cPlanMemoTable, linkedHashMap, z2, hashSet);
            }
            HashMap<Long, Pair<Hop[], CNodeTpl>> cleanupCPlans = cleanupCPlans(cPlanMemoTable, linkedHashMap);
            if (LOG.isTraceEnabled() && !cleanupCPlans.isEmpty()) {
                LOG.trace("Codegen EXPLAIN (before optimize): \n" + Explain.explainHops(arrayList));
            }
            HashMap hashMap = new HashMap();
            for (Map.Entry<Long, Pair<Hop[], CNodeTpl>> entry : cleanupCPlans.entrySet()) {
                Pair<Hop[], CNodeTpl> value = entry.getValue();
                Class<?> plan = planCache.getPlan(value.getValue());
                if (plan == null) {
                    String codegen = value.getValue().codegen(false);
                    if (LOG.isTraceEnabled() || DMLScript.EXPLAIN.isHopsType(z)) {
                        LOG.info("Codegen EXPLAIN (generated cplan for HopID: " + entry.getKey() + ", line " + value.getValue().getBeginLine() + ", hash=" + value.getValue().hashCode() + "):");
                        LOG.info(value.getValue().getClassname() + Explain.explainCPlan(entry.getValue().getValue()));
                    }
                    if (LOG.isTraceEnabled() || DMLScript.EXPLAIN.isRuntimeType(z)) {
                        LOG.info("Codegen EXPLAIN (generated code for HopID: " + entry.getKey() + ", line " + value.getValue().getBeginLine() + ", hash=" + value.getValue().hashCode() + "):");
                        LOG.info(codegen);
                    }
                    plan = CodegenUtils.compileClass("codegen." + value.getValue().getClassname(), codegen);
                    if (PLAN_CACHE_POLICY != PlanCachePolicy.NONE) {
                        planCache.putPlan(value.getValue(), plan);
                    }
                } else if (DMLScript.STATISTICS) {
                    Statistics.incrementCodegenOpCacheHits();
                }
                if (plan != null) {
                    hashMap.put(entry.getKey(), new Pair(value.getKey(), plan));
                }
                if (DMLScript.STATISTICS) {
                    Statistics.incrementCodegenOpCacheTotal();
                }
            }
            if (!cleanupCPlans.isEmpty()) {
                arrayList2 = rewriteCSE.rewriteHopDAG(constructModifiedHopDag(arrayList, cleanupCPlans, hashMap), new ProgramRewriteStatus());
                if (LOG.isTraceEnabled()) {
                    LOG.trace("Codegen EXPLAIN (after optimize): \n" + Explain.explainHops(arrayList));
                }
            }
            if (DMLScript.STATISTICS) {
                Statistics.incrementCodegenDAGCompile();
                Statistics.incrementCodegenCompileTime(System.nanoTime() - nanoTime);
            }
            Hop.resetVisitStatus(arrayList);
            return arrayList2;
        } catch (Exception e) {
            LOG.error("Codegen failed to optimize the following HOP DAG: \n" + Explain.explainHops(arrayList));
            throw new DMLRuntimeException(e);
        }
    }

    public static void cleanupCodeGenerator() {
        if (PLAN_CACHE_POLICY != PlanCachePolicy.NONE) {
            CodegenUtils.clearClassCache();
            planCache.clear();
        }
    }

    public static PlanSelection createPlanSelector() {
        switch (PLAN_SEL_POLICY) {
            case FUSE_ALL:
                return new PlanSelectionFuseAll();
            case FUSE_NO_REDUNDANCY:
                return new PlanSelectionFuseNoRedundancy();
            case FUSE_COST_BASED:
                return new PlanSelectionFuseCostBased();
            case FUSE_COST_BASED_V2:
                return new PlanSelectionFuseCostBasedV2();
            default:
                throw new RuntimeException("Unsupported plan selector: " + PLAN_SEL_POLICY);
        }
    }

    public static void setConfiguredPlanSelector() {
        PLAN_SEL_POLICY = PlanSelector.valueOf(ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.CODEGEN_OPTIMIZER).toUpperCase());
    }

    public static void setExecTypeSpecificJavaCompiler() {
        CompilerType valueOf = CompilerType.valueOf(ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.CODEGEN_COMPILER).toUpperCase());
        JAVA_COMPILER = valueOf != CompilerType.AUTO ? valueOf : OptimizerUtils.isSparkExecutionMode() ? CompilerType.JANINO : CompilerType.JAVAC;
    }

    private static void rExploreCPlans(Hop hop, CPlanMemoTable cPlanMemoTable, boolean z) throws DMLException {
        if (cPlanMemoTable.contains(hop.getHopID()) || cPlanMemoTable.containsHop(hop)) {
            return;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            rExploreCPlans(it.next(), cPlanMemoTable, z);
        }
        for (TemplateBase templateBase : TemplateUtils.TEMPLATES) {
            if (templateBase.open(hop)) {
                cPlanMemoTable.addAll(hop, enumPlans(hop, null, templateBase, cPlanMemoTable));
            }
        }
        Iterator<Hop> it2 = hop.getInput().iterator();
        while (it2.hasNext()) {
            Hop next = it2.next();
            for (TemplateBase templateBase2 : cPlanMemoTable.getDistinctTemplates(next.getHopID())) {
                if (templateBase2.fuse(hop, next)) {
                    cPlanMemoTable.addAll(hop, enumPlans(hop, next, templateBase2, cPlanMemoTable));
                }
            }
        }
        if (cPlanMemoTable.contains(hop.getHopID())) {
            Iterator<CPlanMemoTable.MemoTableEntry> it3 = cPlanMemoTable.get(hop.getHopID()).iterator();
            while (it3.hasNext()) {
                CPlanMemoTable.MemoTableEntry next2 = it3.next();
                TemplateBase.CloseType close = TemplateUtils.createTemplate(next2.type).close(hop);
                if (close == TemplateBase.CloseType.CLOSED_INVALID) {
                    it3.remove();
                }
                next2.ctype = close;
            }
        }
        cPlanMemoTable.pruneRedundant(hop.getHopID(), PLAN_SEL_POLICY.isHeuristic(), null);
        cPlanMemoTable.addHop(hop);
    }

    private static CPlanMemoTable.MemoTableEntrySet enumPlans(Hop hop, Hop hop2, TemplateBase templateBase, CPlanMemoTable cPlanMemoTable) {
        CPlanMemoTable.MemoTableEntrySet memoTableEntrySet = new CPlanMemoTable.MemoTableEntrySet(hop, hop2, templateBase);
        for (int i = 0; i < hop.getInput().size(); i++) {
            Hop hop3 = hop.getInput().get(i);
            if (hop3 != hop2 && templateBase.merge(hop, hop3) && cPlanMemoTable.contains(hop3.getHopID(), true, templateBase.getType(), TemplateBase.TemplateType.CELL)) {
                memoTableEntrySet.crossProduct(i, -1L, Long.valueOf(hop3.getHopID()));
            }
        }
        return memoTableEntrySet;
    }

    private static void rConstructCPlans(Hop hop, CPlanMemoTable cPlanMemoTable, HashMap<Long, Pair<Hop[], CNodeTpl>> hashMap, boolean z, HashSet<Long> hashSet) throws DMLException {
        if (hop == null || hashSet.contains(Long.valueOf(hop.getHopID()))) {
            return;
        }
        if (cPlanMemoTable.containsTopLevel(hop.getHopID())) {
            hashMap.put(Long.valueOf(hop.getHopID()), TemplateUtils.createTemplate(cPlanMemoTable.getBest(hop.getHopID()).type).constructCplan(hop, cPlanMemoTable, z));
            if (DMLScript.STATISTICS) {
                Statistics.incrementCodegenCPlanCompile(1L);
            }
        }
        if (hashMap.containsKey(Long.valueOf(hop.getHopID()))) {
            for (Hop hop2 : hashMap.get(Long.valueOf(hop.getHopID())).getKey()) {
                rConstructCPlans(hop2, cPlanMemoTable, hashMap, z, hashSet);
            }
        } else {
            Iterator<Hop> it = hop.getInput().iterator();
            while (it.hasNext()) {
                rConstructCPlans(it.next(), cPlanMemoTable, hashMap, z, hashSet);
            }
        }
        hashSet.add(Long.valueOf(hop.getHopID()));
    }

    private static ArrayList<Hop> constructModifiedHopDag(ArrayList<Hop> arrayList, HashMap<Long, Pair<Hop[], CNodeTpl>> hashMap, HashMap<Long, Pair<Hop[], Class<?>>> hashMap2) {
        HashSet hashSet = new HashSet();
        for (int i = 0; i < arrayList.size(); i++) {
            rConstructModifiedHopDag(arrayList.get(i), hashMap, hashMap2, hashSet);
        }
        return arrayList;
    }

    private static void rConstructModifiedHopDag(Hop hop, HashMap<Long, Pair<Hop[], CNodeTpl>> hashMap, HashMap<Long, Pair<Hop[], Class<?>>> hashMap2, HashSet<Long> hashSet) {
        if (hashSet.contains(Long.valueOf(hop.getHopID()))) {
            return;
        }
        Hop hop2 = hop;
        if (hashMap2.containsKey(Long.valueOf(hop.getHopID()))) {
            Pair<Hop[], Class<?>> pair = hashMap2.get(Long.valueOf(hop.getHopID()));
            CNodeTpl value = hashMap.get(Long.valueOf(hop.getHopID())).getValue();
            hop2 = new SpoofFusedOp(hop.getName(), hop.getDataType(), hop.getValueType(), pair.getValue(), false, value.getOutputDimType());
            Hop[] key = pair.getKey();
            for (int i = 0; i < key.length; i++) {
                if ((value instanceof CNodeOuterProduct) && key[i].getHopID() == ((CNodeData) value.getInput().get(2)).getHopID() && !TemplateUtils.hasTransposeParentUnderOuterProduct(key[i])) {
                    hop2.addInput(HopRewriteUtils.createTranspose(key[i]));
                } else {
                    hop2.addInput(key[i]);
                }
            }
            HopRewriteUtils.setOutputParameters(hop2, hop.getDim1(), hop.getDim2(), hop.getRowsInBlock(), hop.getColsInBlock(), hop.getNnz());
            if ((value instanceof CNodeOuterProduct) && ((CNodeOuterProduct) value).isTransposeOutput()) {
                hop2 = HopRewriteUtils.createTranspose(hop2);
            } else if (value instanceof CNodeMultiAgg) {
                ArrayList<Hop> rootNodes = ((CNodeMultiAgg) value).getRootNodes();
                hop2.setDataType(Expression.DataType.MATRIX);
                HopRewriteUtils.setOutputParameters(hop2, 1L, rootNodes.size(), key[0].getRowsInBlock(), key[0].getColsInBlock(), -1L);
                for (int i2 = 0; i2 < rootNodes.size(); i2++) {
                    HopRewriteUtils.rewireAllParentChildReferences(rootNodes.get(i2), rootNodes.get(i2) instanceof AggUnaryOp ? HopRewriteUtils.createScalarIndexing(hop2, 1L, i2 + 1) : HopRewriteUtils.createIndexingOp(hop2, 1L, i2 + 1));
                }
            } else if ((value instanceof CNodeCell) && ((CNodeCell) value).requiredCastDtm()) {
                HopRewriteUtils.setOutputParametersForScalar(hop2);
                hop2 = HopRewriteUtils.createUnary(hop2, Hop.OpOp1.CAST_AS_MATRIX);
            } else if ((value instanceof CNodeRow) && (((CNodeRow) value).getRowType() == SpoofRowwise.RowType.NO_AGG_CONST || ((CNodeRow) value).getRowType() == SpoofRowwise.RowType.COL_AGG_CONST)) {
                ((SpoofFusedOp) hop2).setConstDim2(((CNodeRow) value).getConstDim2());
            }
            if (!(value instanceof CNodeMultiAgg)) {
                HopRewriteUtils.rewireAllParentChildReferences(hop, hop2);
            }
            hashSet.add(Long.valueOf(hop2.getHopID()));
        }
        for (int i3 = 0; i3 < hop2.getInput().size(); i3++) {
            rConstructModifiedHopDag(hop2.getInput().get(i3), hashMap, hashMap2, hashSet);
        }
        hashSet.add(Long.valueOf(hop2.getHopID()));
    }

    private static HashMap<Long, Pair<Hop[], CNodeTpl>> cleanupCPlans(CPlanMemoTable cPlanMemoTable, HashMap<Long, Pair<Hop[], CNodeTpl>> hashMap) {
        HashMap<Long, Pair<Hop[], CNodeTpl>> hashMap2 = new HashMap<>();
        CPlanOpRewriter cPlanOpRewriter = new CPlanOpRewriter();
        CPlanCSERewriter cPlanCSERewriter = new CPlanCSERewriter();
        for (Map.Entry<Long, Pair<Hop[], CNodeTpl>> entry : hashMap.entrySet()) {
            CNodeTpl value = entry.getValue().getValue();
            Hop[] key = entry.getValue().getKey();
            if (!Arrays.stream(key).anyMatch(hop -> {
                return hop == null;
            })) {
                CNodeTpl eliminateCommonSubexpressions = cPlanCSERewriter.eliminateCommonSubexpressions(cPlanOpRewriter.simplifyCPlan(value));
                HashSet<Long> inputHopIDs = eliminateCommonSubexpressions.getInputHopIDs(false);
                Hop[] hopArr = (Hop[]) Arrays.stream(key).filter(hop2 -> {
                    return hop2 != null && inputHopIDs.contains(Long.valueOf(hop2.getHopID()));
                }).toArray(i -> {
                    return new Hop[i];
                });
                hashMap2.put(entry.getKey(), new Pair<>(hopArr, eliminateCommonSubexpressions));
                if ((eliminateCommonSubexpressions instanceof CNodeCell) || (eliminateCommonSubexpressions instanceof CNodeRow)) {
                    CNodeData cNodeData = (CNodeData) eliminateCommonSubexpressions.getInput().get(0);
                    boolean z = !(eliminateCommonSubexpressions instanceof CNodeRow);
                    if (rHasLookupRC1(eliminateCommonSubexpressions.getOutput(), cNodeData, z) || isLookupRC1(eliminateCommonSubexpressions.getOutput(), cNodeData, z)) {
                        hashMap2.remove(entry.getKey());
                        if (LOG.isTraceEnabled()) {
                            LOG.trace("Removed cplan due to invalid rc1 indexing on main input.");
                        }
                    }
                } else if (eliminateCommonSubexpressions instanceof CNodeMultiAgg) {
                    CNodeData cNodeData2 = (CNodeData) eliminateCommonSubexpressions.getInput().get(0);
                    Iterator<CNode> it = ((CNodeMultiAgg) eliminateCommonSubexpressions).getOutputs().iterator();
                    while (it.hasNext()) {
                        CNode next = it.next();
                        if (rHasLookupRC1(next, cNodeData2, true) || isLookupRC1(next, cNodeData2, true)) {
                            hashMap2.remove(entry.getKey());
                            if (LOG.isTraceEnabled()) {
                                LOG.trace("Removed cplan due to invalid rc1 indexing on main input.");
                            }
                        }
                    }
                }
                CNodeData cNodeData3 = (CNodeData) eliminateCommonSubexpressions.getInput().get(0);
                if (eliminateCommonSubexpressions instanceof CNodeMultiAgg) {
                    rFindAndRemoveLookupMultiAgg((CNodeMultiAgg) eliminateCommonSubexpressions, cNodeData3);
                } else {
                    rFindAndRemoveLookup(eliminateCommonSubexpressions.getOutput(), cNodeData3, !(eliminateCommonSubexpressions instanceof CNodeRow));
                }
                if (eliminateCommonSubexpressions instanceof CNodeRow) {
                    if (((CNodeRow) eliminateCommonSubexpressions).getRowType() == SpoofRowwise.RowType.NO_AGG && eliminateCommonSubexpressions.getOutput().getDataType().isScalar()) {
                        hashMap2.remove(entry.getKey());
                        if (LOG.isTraceEnabled()) {
                            LOG.trace("Removed invalid row cplan w/o agg on column vector.");
                        }
                    } else if (OptimizerUtils.isSparkExecutionMode()) {
                        Hop hop3 = cPlanMemoTable.getHopRefs().get(entry.getKey());
                        boolean z2 = DMLScript.rtplatform == DMLScript.RUNTIME_PLATFORM.SPARK || OptimizerUtils.getTotalMemEstimate(hopArr, hop3, true) > OptimizerUtils.getLocalMemBudget();
                        boolean z3 = hop3.getDataType().isMatrix() && (!HopRewriteUtils.isTransposeOperation(hop3) ? hop3.getDim2() <= hop3.getColsInBlock() : hop3.getDim1() <= hop3.getRowsInBlock());
                        for (Hop hop4 : hopArr) {
                            z3 |= hop4.getDataType().isMatrix() && hop4.getDim2() > hop4.getColsInBlock();
                        }
                        if (z2 && z3) {
                            hashMap2.remove(entry.getKey());
                            if (LOG.isTraceEnabled()) {
                                LOG.trace("Removed invalid row cplan w/ ncol>ncolpb.");
                            }
                        }
                    }
                }
                if (((eliminateCommonSubexpressions instanceof CNodeCell) && ((CNodeCell) eliminateCommonSubexpressions).getCellType() == SpoofCellwise.CellType.NO_AGG && TemplateUtils.hasSingleOperation(eliminateCommonSubexpressions)) || (((eliminateCommonSubexpressions instanceof CNodeRow) && ((((CNodeRow) eliminateCommonSubexpressions).getRowType() == SpoofRowwise.RowType.NO_AGG || ((CNodeRow) eliminateCommonSubexpressions).getRowType() == SpoofRowwise.RowType.NO_AGG_B1 || ((CNodeRow) eliminateCommonSubexpressions).getRowType() == SpoofRowwise.RowType.ROW_AGG) && TemplateUtils.hasSingleOperation(eliminateCommonSubexpressions))) || TemplateUtils.hasNoOperation(eliminateCommonSubexpressions))) {
                    hashMap2.remove(entry.getKey());
                    if (LOG.isTraceEnabled()) {
                        LOG.trace("Removed cplan with single operation.");
                    }
                }
                if (eliminateCommonSubexpressions.getOutput() instanceof CNodeData) {
                    hashMap2.remove(entry.getKey());
                    if (LOG.isTraceEnabled()) {
                        LOG.trace("Removed empty cplan.");
                    }
                }
                eliminateCommonSubexpressions.renameInputs();
            }
        }
        return hashMap2;
    }

    private static void rFindAndRemoveLookupMultiAgg(CNodeMultiAgg cNodeMultiAgg, CNodeData cNodeData) {
        Iterator<CNode> it = cNodeMultiAgg.getOutputs().iterator();
        while (it.hasNext()) {
            rFindAndRemoveLookup(it.next(), cNodeData, true);
        }
        for (int i = 0; i < cNodeMultiAgg.getOutputs().size(); i++) {
            CNode cNode = cNodeMultiAgg.getOutputs().get(i);
            if (TemplateUtils.isLookup(cNode, true) && (cNode.getInput().get(0) instanceof CNodeData) && ((CNodeData) cNode.getInput().get(0)).getHopID() == cNodeData.getHopID()) {
                cNodeMultiAgg.getOutputs().set(i, cNode.getInput().get(0));
            }
        }
    }

    private static void rFindAndRemoveLookup(CNode cNode, CNodeData cNodeData, boolean z) {
        for (int i = 0; i < cNode.getInput().size(); i++) {
            CNode cNode2 = cNode.getInput().get(i);
            if (TemplateUtils.isLookup(cNode2, z) && (cNode2.getInput().get(0) instanceof CNodeData) && ((CNodeData) cNode2.getInput().get(0)).getHopID() == cNodeData.getHopID()) {
                cNode.getInput().set(i, cNode2.getInput().get(0));
            } else {
                rFindAndRemoveLookup(cNode2, cNodeData, z);
            }
        }
    }

    private static boolean rHasLookupRC1(CNode cNode, CNodeData cNodeData, boolean z) {
        boolean z2 = false;
        for (int i = 0; i < cNode.getInput().size() && !z2; i++) {
            CNode cNode2 = cNode.getInput().get(i);
            z2 = isLookupRC1(cNode2, cNodeData, z) ? true : z2 | rHasLookupRC1(cNode2, cNodeData, z);
        }
        return z2;
    }

    private static boolean isLookupRC1(CNode cNode, CNodeData cNodeData, boolean z) {
        return (cNode instanceof CNodeTernary) && ((((CNodeTernary) cNode).getType() == CNodeTernary.TernaryType.LOOKUP_RC1 && z) || ((CNodeTernary) cNode).getType() == CNodeTernary.TernaryType.LOOKUP_RVECT1) && (cNode.getInput().get(0) instanceof CNodeData) && ((CNodeData) cNode.getInput().get(0)).getHopID() == cNodeData.getHopID();
    }

    static {
        if (LDEBUG) {
            Logger.getLogger("org.apache.sysml.hops.codegen").setLevel(Level.TRACE);
        }
        planCache = new PlanCache(1024);
        rewriteCSE = new ProgramRewriter(new RewriteCommonSubexpressionElimination(true), new RewriteRemoveUnnecessaryCasts());
    }
}
