package org.apache.sysml.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.LeftIndexingOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.parser.ForStatement;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.IfStatement;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.VariableSet;
import org.apache.sysml.parser.WhileStatement;
import org.apache.sysml.parser.WhileStatementBlock;

/* loaded from: input_file:org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.class */
public class RewriteMarkLoopVariablesUpdateInPlace extends StatementBlockRewriteRule {
    @Override // org.apache.sysml.hops.rewrite.StatementBlockRewriteRule
    public boolean createsSplitDag() {
        return false;
    }

    @Override // org.apache.sysml.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlock(StatementBlock statementBlock, ProgramRewriteStatus programRewriteStatus) throws HopsException {
        if (DMLScript.rtplatform == DMLScript.RUNTIME_PLATFORM.HADOOP || DMLScript.rtplatform == DMLScript.RUNTIME_PLATFORM.SPARK) {
            return Arrays.asList(statementBlock);
        }
        if ((statementBlock instanceof WhileStatementBlock) || (statementBlock instanceof ForStatementBlock)) {
            ArrayList<String> arrayList = new ArrayList<>();
            VariableSet variablesUpdated = statementBlock.variablesUpdated();
            VariableSet liveOut = statementBlock.liveOut();
            for (String str : variablesUpdated.getVariableNames()) {
                if (variablesUpdated.getVariable(str).getDataType() == Expression.DataType.MATRIX && liveOut.containsVariable(str)) {
                    if (statementBlock instanceof WhileStatementBlock) {
                        if (rIsApplicableForUpdateInPlace(((WhileStatement) statementBlock.getStatement(0)).getBody(), str)) {
                            arrayList.add(str);
                        }
                    } else if ((statementBlock instanceof ForStatementBlock) && rIsApplicableForUpdateInPlace(((ForStatement) statementBlock.getStatement(0)).getBody(), str)) {
                        arrayList.add(str);
                    }
                }
            }
            statementBlock.setUpdateInPlaceVars(arrayList);
        }
        return Arrays.asList(statementBlock);
    }

    private boolean rIsApplicableForUpdateInPlace(ArrayList<StatementBlock> arrayList, String str) throws HopsException {
        boolean z = true;
        Iterator<StatementBlock> it = arrayList.iterator();
        while (it.hasNext()) {
            StatementBlock next = it.next();
            if (next.variablesRead().containsVariable(str) || next.variablesUpdated().containsVariable(str)) {
                if ((next instanceof WhileStatementBlock) || (next instanceof ForStatementBlock)) {
                    z &= next.getUpdateInPlaceVars().contains(str);
                } else if (next instanceof IfStatementBlock) {
                    IfStatement ifStatement = (IfStatement) ((IfStatementBlock) next).getStatement(0);
                    z &= rIsApplicableForUpdateInPlace(ifStatement.getIfBody(), str);
                    if (z && ifStatement.getElseBody() != null) {
                        z &= rIsApplicableForUpdateInPlace(ifStatement.getElseBody(), str);
                    }
                } else if (next.getHops() != null) {
                    Iterator<Hop> it2 = next.getHops().iterator();
                    while (it2.hasNext()) {
                        z &= isApplicableForUpdateInPlace(it2.next(), str);
                    }
                }
                if (!z) {
                    break;
                }
            }
        }
        return z;
    }

    private static boolean isApplicableForUpdateInPlace(Hop hop, String str) {
        if (!hop.getName().equals(str)) {
            return true;
        }
        boolean z = (hop instanceof DataOp) && hop.isMatrix() && hop.getInput().get(0).isMatrix() && (hop.getInput().get(0) instanceof LeftIndexingOp) && (hop.getInput().get(0).getInput().get(0) instanceof DataOp) && hop.getInput().get(0).getInput().get(0).getName().equals(str);
        if (z) {
            Iterator<Hop> it = hop.getInput().get(0).getInput().get(0).getParent().iterator();
            while (it.hasNext()) {
                Hop next = it.next();
                z &= next == hop.getInput().get(0) || ((next instanceof UnaryOp) && ((UnaryOp) next).getOp() == Hop.OpOp1.NROW) || ((next instanceof UnaryOp) && ((UnaryOp) next).getOp() == Hop.OpOp1.NCOL);
            }
        }
        return z;
    }

    @Override // org.apache.sysml.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> list, ProgramRewriteStatus programRewriteStatus) throws HopsException {
        return list;
    }
}
