package org.apache.sysml.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.parser.DataIdentifier;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.IndexedIdentifier;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.VariableSet;
import org.apache.sysml.parser.WhileStatementBlock;

/* loaded from: input_file:org/apache/sysml/hops/rewrite/RewriteInjectSparkLoopCheckpointing.class */
public class RewriteInjectSparkLoopCheckpointing extends StatementBlockRewriteRule {
    private boolean _checkCtx;

    public RewriteInjectSparkLoopCheckpointing(boolean z) {
        this._checkCtx = false;
        this._checkCtx = z;
    }

    @Override // org.apache.sysml.hops.rewrite.StatementBlockRewriteRule
    public boolean createsSplitDag() {
        return true;
    }

    @Override // org.apache.sysml.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlock(StatementBlock statementBlock, ProgramRewriteStatus programRewriteStatus) throws HopsException {
        if (!OptimizerUtils.isSparkExecutionMode()) {
            return Arrays.asList(statementBlock);
        }
        ArrayList arrayList = new ArrayList();
        int blocksize = programRewriteStatus.getBlocksize();
        if (((statementBlock instanceof WhileStatementBlock) || (statementBlock instanceof ForStatementBlock)) && (!this._checkCtx || !programRewriteStatus.isInParforContext())) {
            ArrayList arrayList2 = new ArrayList();
            VariableSet variablesRead = statementBlock.variablesRead();
            VariableSet variablesUpdated = statementBlock.variablesUpdated();
            for (String str : variablesRead.getVariableNames()) {
                if (!variablesUpdated.containsVariable(str) && variablesRead.getVariable(str).getDataType() == Expression.DataType.MATRIX) {
                    arrayList2.add(str);
                }
            }
            if (!arrayList2.isEmpty()) {
                StatementBlock statementBlock2 = new StatementBlock();
                statementBlock2.setDMLProg(statementBlock.getDMLProg());
                statementBlock2.setParseInfo(statementBlock);
                ArrayList<Hop> arrayList3 = new ArrayList<>();
                VariableSet variableSet = new VariableSet();
                VariableSet variableSet2 = new VariableSet();
                Iterator it = arrayList2.iterator();
                while (it.hasNext()) {
                    String str2 = (String) it.next();
                    DataIdentifier variable = variablesRead.getVariable(str2);
                    long origDim1 = variable instanceof IndexedIdentifier ? ((IndexedIdentifier) variable).getOrigDim1() : variable.getDim1();
                    long origDim2 = variable instanceof IndexedIdentifier ? ((IndexedIdentifier) variable).getOrigDim2() : variable.getDim2();
                    DataOp dataOp = new DataOp(str2, Expression.DataType.MATRIX, Expression.ValueType.DOUBLE, Hop.DataOpTypes.TRANSIENTREAD, variable.getFilename(), origDim1, origDim2, variable.getNnz(), blocksize, blocksize);
                    dataOp.setRequiresCheckpoint(true);
                    DataOp dataOp2 = new DataOp(str2, Expression.DataType.MATRIX, Expression.ValueType.DOUBLE, dataOp, Hop.DataOpTypes.TRANSIENTWRITE, (String) null);
                    HopRewriteUtils.setOutputParameters(dataOp2, origDim1, origDim2, blocksize, blocksize, variable.getNnz());
                    arrayList3.add(dataOp2);
                    variableSet.addVariable(str2, variablesRead.getVariable(str2));
                    variableSet2.addVariable(str2, variablesRead.getVariable(str2));
                }
                statementBlock2.setHops(arrayList3);
                statementBlock2.setLiveIn(variableSet);
                statementBlock2.setLiveOut(variableSet2);
                statementBlock2.setSplitDag(true);
                arrayList.add(statementBlock2);
                programRewriteStatus.setInjectedCheckpoints();
            }
        }
        arrayList.add(statementBlock);
        return arrayList;
    }

    @Override // org.apache.sysml.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> list, ProgramRewriteStatus programRewriteStatus) throws HopsException {
        return list;
    }
}
