package org.apache.sysml.hops.rewrite;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import org.apache.commons.cli.HelpFormatter;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.ParameterizedBuiltinOp;
import org.apache.sysml.hops.ReorgOp;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.parser.DataIdentifier;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.VariableSet;
import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysml.runtime.matrix.data.Pair;

/* loaded from: input_file:org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.class */
public class RewriteSplitDagDataDependentOperators extends StatementBlockRewriteRule {
    private static String _varnamePredix = "_sbcvar";
    private static IDSequence _seq = new IDSequence();

    @Override // org.apache.sysml.hops.rewrite.StatementBlockRewriteRule
    public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock statementBlock, ProgramRewriteStatus programRewriteStatus) throws HopsException {
        String str;
        ArrayList<StatementBlock> arrayList = new ArrayList<>();
        ArrayList<Hop> arrayList2 = new ArrayList<>();
        collectDataDependentOperators(statementBlock.get_hops(), arrayList2);
        Hop.resetVisitStatus(statementBlock.get_hops());
        if (arrayList2.isEmpty()) {
            arrayList.add(statementBlock);
        } else {
            HashSet<Hop> hashSet = new HashSet<>();
            collectCandidateChildOperators(arrayList2, hashSet);
            try {
                StatementBlock statementBlock2 = new StatementBlock();
                statementBlock2.setDMLProg(statementBlock.getDMLProg());
                statementBlock2.setAllPositions(statementBlock.getFilename(), statementBlock.getBeginLine(), statementBlock.getBeginColumn(), statementBlock.getEndLine(), statementBlock.getEndColumn());
                statementBlock2.setLiveIn(new VariableSet());
                statementBlock2.setLiveOut(new VariableSet());
                ArrayList<Hop> arrayList3 = new ArrayList<>();
                Iterator<Hop> it = arrayList2.iterator();
                while (it.hasNext()) {
                    Hop next = it.next();
                    boolean hasTransientWriteParents = hasTransientWriteParents(next);
                    boolean rHasSimpleReadChain = hasTransientWriteParents ? HopRewriteUtils.rHasSimpleReadChain(next, getFirstTransientWriteParent(next).getName()) : false;
                    long dim1 = next.getDim1();
                    long dim2 = next.getDim2();
                    long nnz = next.getNnz();
                    long rowsInBlock = next.getRowsInBlock();
                    long colsInBlock = next.getColsInBlock();
                    if (hasTransientWriteParents && rHasSimpleReadChain) {
                        Hop firstTransientWriteParent = getFirstTransientWriteParent(next);
                        str = firstTransientWriteParent.getName();
                        DataOp dataOp = new DataOp(str, next.getDataType(), next.getValueType(), Hop.DataOpTypes.TRANSIENTREAD, null, dim1, dim2, nnz, rowsInBlock, colsInBlock);
                        dataOp.setVisited(Hop.VisitStatus.DONE);
                        HopRewriteUtils.copyLineNumbers(next, dataOp);
                        ArrayList arrayList4 = new ArrayList(next.getParent());
                        for (int i = 0; i < arrayList4.size(); i++) {
                            Hop hop = (Hop) arrayList4.get(i);
                            if (!hashSet.contains(hop)) {
                                if (hop != firstTransientWriteParent) {
                                    int childReferencePos = HopRewriteUtils.getChildReferencePos(hop, next);
                                    HopRewriteUtils.removeChildReferenceByPos(hop, next, childReferencePos);
                                    HopRewriteUtils.addChildReference(hop, dataOp, childReferencePos);
                                } else {
                                    statementBlock.get_hops().remove(hop);
                                }
                            }
                        }
                        arrayList3.add(firstTransientWriteParent);
                    } else {
                        str = _varnamePredix + _seq.getNextID();
                        DataOp dataOp2 = new DataOp(str, next.getDataType(), next.getValueType(), Hop.DataOpTypes.TRANSIENTREAD, null, dim1, dim2, nnz, rowsInBlock, colsInBlock);
                        dataOp2.setVisited(Hop.VisitStatus.DONE);
                        HopRewriteUtils.copyLineNumbers(next, dataOp2);
                        ArrayList arrayList5 = new ArrayList(next.getParent());
                        for (int i2 = 0; i2 < arrayList5.size(); i2++) {
                            Hop hop2 = (Hop) arrayList5.get(i2);
                            if (!hashSet.contains(hop2)) {
                                int childReferencePos2 = HopRewriteUtils.getChildReferencePos(hop2, next);
                                HopRewriteUtils.removeChildReferenceByPos(hop2, next, childReferencePos2);
                                HopRewriteUtils.addChildReference(hop2, dataOp2, childReferencePos2);
                            }
                        }
                        DataOp dataOp3 = new DataOp(str, next.getDataType(), next.getValueType(), next, Hop.DataOpTypes.TRANSIENTWRITE, (String) null);
                        dataOp3.setVisited(Hop.VisitStatus.DONE);
                        dataOp3.setOutputParams(dim1, dim2, nnz, rowsInBlock, colsInBlock);
                        HopRewriteUtils.copyLineNumbers(next, dataOp3);
                        arrayList3.add(dataOp3);
                    }
                    DataIdentifier dataIdentifier = new DataIdentifier(str);
                    dataIdentifier.setDimensions(dim1, dim2);
                    dataIdentifier.setBlockDimensions(rowsInBlock, colsInBlock);
                    dataIdentifier.setDataType(next.getDataType());
                    dataIdentifier.setValueType(next.getValueType());
                    statementBlock2.liveOut().addVariable(str, new DataIdentifier(dataIdentifier));
                    statementBlock.liveIn().addVariable(str, new DataIdentifier(dataIdentifier));
                }
                handleReplicatedOperators(arrayList3, statementBlock.get_hops(), statementBlock2.liveOut(), statementBlock.liveIn());
                statementBlock2.set_hops(Recompiler.deepCopyHopsDag(arrayList3));
                statementBlock2.updateRecompilationFlag();
                arrayList.addAll(rewriteStatementBlock(statementBlock2, programRewriteStatus));
                arrayList.add(statementBlock);
                LOG.debug("Applied splitDagDataDependentOperators (lines " + statementBlock.getBeginLine() + HelpFormatter.DEFAULT_OPT_PREFIX + statementBlock.getEndLine() + ").");
            } catch (Exception e) {
                throw new HopsException("Failed to split hops dag for data dependent operators with unknown size.", e);
            }
        }
        return arrayList;
    }

    private void collectDataDependentOperators(ArrayList<Hop> arrayList, ArrayList<Hop> arrayList2) {
        if (arrayList == null) {
            return;
        }
        Hop.resetVisitStatus(arrayList);
        Iterator<Hop> it = arrayList.iterator();
        while (it.hasNext()) {
            rCollectDataDependentOperators(it.next(), arrayList2);
        }
    }

    private void rCollectDataDependentOperators(Hop hop, ArrayList<Hop> arrayList) {
        if (hop.getVisited() == Hop.VisitStatus.DONE) {
            return;
        }
        boolean z = hop.dimsKnown() || HopRewriteUtils.hasOnlyWriteParents(hop, true, true);
        boolean z2 = true;
        if ((hop instanceof ParameterizedBuiltinOp) && ((ParameterizedBuiltinOp) hop).getOp() == Hop.ParamBuiltinOp.RMEMPTY && !z && (hop.getParent().size() != 1 || !(hop.getParent().get(0) instanceof TernaryOp) || !((TernaryOp) hop.getParent().get(0)).isMatrixIgnoreZeroRewriteApplicable())) {
            ParameterizedBuiltinOp parameterizedBuiltinOp = (ParameterizedBuiltinOp) hop;
            arrayList.add(parameterizedBuiltinOp);
            z2 = false;
            boolean z3 = true;
            boolean z4 = true;
            boolean isTargetDiagInput = parameterizedBuiltinOp.isTargetDiagInput();
            Iterator<Hop> it = hop.getParent().iterator();
            while (it.hasNext()) {
                Hop next = it.next();
                z3 &= ((next instanceof AggBinaryOp) && hop == next.getInput().get(0)) || ((next instanceof UnaryOp) && ((UnaryOp) next).getOp() == Hop.OpOp1.NROW);
                z4 &= (next instanceof AggBinaryOp) && hop == next.getInput().get(0);
            }
            parameterizedBuiltinOp.setOutputEmptyBlocks(!z3);
            if (z4 && isTargetDiagInput) {
                if (OptimizerUtils.ALLOW_DYN_RECOMPILATION) {
                    parameterizedBuiltinOp.setOutputPermutationMatrix(true);
                }
                Iterator<Hop> it2 = hop.getParent().iterator();
                while (it2.hasNext()) {
                    ((AggBinaryOp) it2.next()).setHasLeftPMInput(true);
                }
            }
        }
        if ((hop instanceof TernaryOp) && ((TernaryOp) hop).getOp() == Hop.OpOp3.CTABLE && hop.getInput().size() < 4 && !z) {
            arrayList.add(hop);
            z2 = false;
            boolean z5 = true;
            Iterator<Hop> it3 = hop.getParent().iterator();
            while (it3.hasNext()) {
                Hop next2 = it3.next();
                z5 &= (next2 instanceof AggBinaryOp) && hop == next2.getInput().get(0);
            }
            if (z5 && HopRewriteUtils.isBasic1NSequence(hop.getInput().get(0))) {
                hop.setOutputEmptyBlocks(false);
            }
        }
        if ((hop instanceof ReorgOp) && ((ReorgOp) hop).getOp() == Hop.ReOrgOp.SORT) {
            for (int i = 2; i <= 3; i++) {
                Hop hop2 = hop.getInput().get(i);
                if (!(hop2 instanceof LiteralOp) && !(hop2 instanceof DataOp)) {
                    arrayList.add(hop2);
                    hop2.setVisited(Hop.VisitStatus.DONE);
                    z2 = false;
                }
            }
        }
        if (z2 && hop.getInput() != null) {
            Iterator<Hop> it4 = hop.getInput().iterator();
            while (it4.hasNext()) {
                rCollectDataDependentOperators(it4.next(), arrayList);
            }
        }
        hop.setVisited(Hop.VisitStatus.DONE);
    }

    private boolean hasTransientWriteParents(Hop hop) {
        Iterator<Hop> it = hop.getParent().iterator();
        while (it.hasNext()) {
            Hop next = it.next();
            if ((next instanceof DataOp) && ((DataOp) next).getDataOpType() == Hop.DataOpTypes.TRANSIENTWRITE) {
                return true;
            }
        }
        return false;
    }

    private Hop getFirstTransientWriteParent(Hop hop) {
        Iterator<Hop> it = hop.getParent().iterator();
        while (it.hasNext()) {
            Hop next = it.next();
            if ((next instanceof DataOp) && ((DataOp) next).getDataOpType() == Hop.DataOpTypes.TRANSIENTWRITE) {
                return next;
            }
        }
        return null;
    }

    private void handleReplicatedOperators(ArrayList<Hop> arrayList, ArrayList<Hop> arrayList2, VariableSet variableSet, VariableSet variableSet2) {
        HashSet<Hop> hashSet = new HashSet<>();
        Hop.resetVisitStatus(arrayList);
        Iterator<Hop> it = arrayList.iterator();
        while (it.hasNext()) {
            rAddHopsToProbeSet(it.next(), hashSet);
        }
        HashSet<Pair<Hop, Hop>> hashSet2 = new HashSet<>();
        Hop.resetVisitStatus(arrayList2);
        Iterator<Hop> it2 = arrayList2.iterator();
        while (it2.hasNext()) {
            rProbeAndAddHopsToCandidateSet(it2.next(), hashSet, hashSet2);
        }
        Iterator<Pair<Hop, Hop>> it3 = hashSet2.iterator();
        while (it3.hasNext()) {
            Pair<Hop, Hop> next = it3.next();
            String str = _varnamePredix + _seq.getNextID();
            Hop key = next.getKey();
            Hop value = next.getValue();
            DataOp dataOp = new DataOp(str, value.getDataType(), value.getValueType(), Hop.DataOpTypes.TRANSIENTREAD, null, value.getDim1(), value.getDim2(), value.getNnz(), value.getRowsInBlock(), value.getColsInBlock());
            dataOp.setVisited(Hop.VisitStatus.DONE);
            HopRewriteUtils.copyLineNumbers(value, dataOp);
            DataOp dataOp2 = new DataOp(str, value.getDataType(), value.getValueType(), value, Hop.DataOpTypes.TRANSIENTWRITE, (String) null);
            dataOp2.setVisited(Hop.VisitStatus.DONE);
            dataOp2.setOutputParams(value.getDim1(), value.getDim2(), value.getNnz(), value.getRowsInBlock(), value.getColsInBlock());
            HopRewriteUtils.copyLineNumbers(value, dataOp2);
            int childReferencePos = HopRewriteUtils.getChildReferencePos(key, value);
            HopRewriteUtils.removeChildReferenceByPos(key, value, childReferencePos);
            HopRewriteUtils.addChildReference(key, dataOp, childReferencePos);
            DataIdentifier dataIdentifier = new DataIdentifier(str);
            dataIdentifier.setDimensions(value.getDim1(), value.getDim2());
            dataIdentifier.setBlockDimensions(value.getRowsInBlock(), value.getColsInBlock());
            dataIdentifier.setDataType(value.getDataType());
            dataIdentifier.setValueType(value.getValueType());
            variableSet.addVariable(str, new DataIdentifier(dataIdentifier));
            variableSet2.addVariable(str, new DataIdentifier(dataIdentifier));
            arrayList.add(dataOp2);
        }
    }

    private void rAddHopsToProbeSet(Hop hop, HashSet<Hop> hashSet) {
        if (hop.getVisited() == Hop.VisitStatus.DONE) {
            return;
        }
        if ((!(hop instanceof DataOp) || ((DataOp) hop).isPersistentReadWrite()) && !(hop instanceof LiteralOp)) {
            hashSet.add(hop);
        }
        if (hop.getInput() != null) {
            Iterator<Hop> it = hop.getInput().iterator();
            while (it.hasNext()) {
                rAddHopsToProbeSet(it.next(), hashSet);
            }
        }
        hop.setVisited(Hop.VisitStatus.DONE);
    }

    private void rProbeAndAddHopsToCandidateSet(Hop hop, HashSet<Hop> hashSet, HashSet<Pair<Hop, Hop>> hashSet2) {
        if (hop.getVisited() == Hop.VisitStatus.DONE) {
            return;
        }
        if (hop.getInput() != null) {
            Iterator<Hop> it = hop.getInput().iterator();
            while (it.hasNext()) {
                Hop next = it.next();
                if (hashSet.contains(next)) {
                    hashSet2.add(new Pair<>(hop, next));
                } else {
                    rProbeAndAddHopsToCandidateSet(next, hashSet, hashSet2);
                }
            }
        }
        hop.setVisited(Hop.VisitStatus.DONE);
    }

    private void collectCandidateChildOperators(ArrayList<Hop> arrayList, HashSet<Hop> hashSet) {
        Hop.resetVisitStatus(arrayList);
        if (arrayList != null) {
            Iterator<Hop> it = arrayList.iterator();
            while (it.hasNext()) {
                rCollectCandidateChildOperators(it.next(), arrayList, hashSet, false);
            }
        }
        Hop.resetVisitStatus(arrayList);
    }

    private void rCollectCandidateChildOperators(Hop hop, ArrayList<Hop> arrayList, HashSet<Hop> hashSet, boolean z) {
        if (hop.getVisited() == Hop.VisitStatus.DONE) {
            return;
        }
        if (z) {
            hashSet.add(hop);
        }
        boolean z2 = z;
        if (arrayList.contains(hop)) {
            z2 = true;
        }
        if (hop.getInput() != null) {
            Iterator<Hop> it = hop.getInput().iterator();
            while (it.hasNext()) {
                rCollectCandidateChildOperators(it.next(), arrayList, hashSet, z2);
            }
        }
        hop.setVisited(Hop.VisitStatus.DONE);
    }
}
