package org.apache.sysml.hops.rewrite;

import java.io.IOException;
import java.util.ArrayList;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.lops.compile.Dag;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;

/* loaded from: input_file:org/apache/sysml/hops/rewrite/RewriteConstantFolding.class */
public class RewriteConstantFolding extends HopRewriteRule {
    private static final String TMP_VARNAME = "__cf_tmp";
    private static ProgramBlock _tmpPB = null;
    private static ExecutionContext _tmpEC = null;

    @Override // org.apache.sysml.hops.rewrite.HopRewriteRule
    public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> arrayList, ProgramRewriteStatus programRewriteStatus) throws HopsException {
        if (arrayList == null) {
            return null;
        }
        for (int i = 0; i < arrayList.size(); i++) {
            arrayList.set(i, rule_ConstantFolding(arrayList.get(i)));
        }
        return arrayList;
    }

    @Override // org.apache.sysml.hops.rewrite.HopRewriteRule
    public Hop rewriteHopDAG(Hop hop, ProgramRewriteStatus programRewriteStatus) throws HopsException {
        if (hop == null) {
            return null;
        }
        return rule_ConstantFolding(hop);
    }

    private Hop rule_ConstantFolding(Hop hop) throws HopsException {
        return rConstantFoldingExpression(hop);
    }

    private Hop rConstantFoldingExpression(Hop hop) throws HopsException {
        if (hop.getVisited() == Hop.VisitStatus.DONE) {
            return hop;
        }
        for (int i = 0; i < hop.getInput().size(); i++) {
            rConstantFoldingExpression(hop.getInput().get(i));
        }
        LiteralOp literalOp = null;
        if (hop.getDataType() == Expression.DataType.SCALAR && (isApplicableBinaryOp(hop) || isApplicableUnaryOp(hop))) {
            try {
                literalOp = evalScalarOperation(hop);
            } catch (Exception e) {
                LOG.error("Failed to execute constant folding instructions. No abort.", e);
            }
        } else if (isApplicableFalseConjunctivePredicate(hop)) {
            literalOp = new LiteralOp(false);
        } else if (isApplicableTrueDisjunctivePredicate(hop)) {
            literalOp = new LiteralOp(true);
        }
        if (literalOp != null) {
            if (hop.getParent().size() > 0) {
                for (int i2 = 0; i2 < hop.getParent().size(); i2++) {
                    Hop hop2 = hop.getParent().get(i2);
                    for (int i3 = 0; i3 < hop2.getInput().size(); i3++) {
                        if (hop == hop2.getInput().get(i3)) {
                            hop2.getInput().remove(i3);
                            HopRewriteUtils.addChildReference(hop2, literalOp, i3);
                        }
                    }
                }
                hop.getParent().clear();
            } else {
                hop = literalOp;
            }
        }
        hop.setVisited(Hop.VisitStatus.DONE);
        return hop;
    }

    private LiteralOp evalScalarOperation(Hop hop) throws LopsException, DMLRuntimeException, IOException, HopsException {
        LiteralOp literalOp;
        DataOp dataOp = new DataOp(TMP_VARNAME, hop.getDataType(), hop.getValueType(), hop, Hop.DataOpTypes.TRANSIENTWRITE, TMP_VARNAME);
        Dag<Lop> dag = new Dag<>();
        Recompiler.rClearLops(dataOp);
        dataOp.constructLops().addToDag(dag);
        ArrayList<Instruction> jobs = dag.getJobs(null, ConfigurationManager.getDMLConfig());
        ExecutionContext executionContext = getExecutionContext();
        ProgramBlock programBlock = getProgramBlock();
        programBlock.setInstructions(jobs);
        programBlock.execute(executionContext);
        ScalarObject scalarObject = (ScalarObject) executionContext.getVariable(TMP_VARNAME);
        switch (scalarObject.getValueType()) {
            case DOUBLE:
                literalOp = new LiteralOp(scalarObject.getDoubleValue());
                break;
            case INT:
                literalOp = new LiteralOp(scalarObject.getLongValue());
                break;
            case BOOLEAN:
                literalOp = new LiteralOp(scalarObject.getBooleanValue());
                break;
            case STRING:
                literalOp = new LiteralOp(scalarObject.getStringValue());
                break;
            default:
                throw new HopsException("Unsupported literal value type: " + hop.getValueType());
        }
        dataOp.getInput().clear();
        hop.getParent().remove(dataOp);
        programBlock.setInstructions(null);
        executionContext.getVariables().removeAll();
        literalOp.setDim1(0L);
        literalOp.setDim2(0L);
        literalOp.setRowsInBlock(-1L);
        literalOp.setColsInBlock(-1L);
        return literalOp;
    }

    private static ProgramBlock getProgramBlock() throws DMLRuntimeException {
        if (_tmpPB == null) {
            _tmpPB = new ProgramBlock(new Program());
        }
        return _tmpPB;
    }

    private static ExecutionContext getExecutionContext() {
        if (_tmpEC == null) {
            _tmpEC = ExecutionContextFactory.createContext();
        }
        return _tmpEC;
    }

    private boolean isApplicableBinaryOp(Hop hop) {
        ArrayList<Hop> input = hop.getInput();
        return (hop instanceof BinaryOp) && (input.get(0) instanceof LiteralOp) && (input.get(1) instanceof LiteralOp) && ((BinaryOp) hop).getOp() != Hop.OpOp2.CBIND && ((BinaryOp) hop).getOp() != Hop.OpOp2.RBIND;
    }

    private boolean isApplicableUnaryOp(Hop hop) {
        return (hop instanceof UnaryOp) && (hop.getInput().get(0) instanceof LiteralOp) && ((UnaryOp) hop).getOp() != Hop.OpOp1.PRINT && ((UnaryOp) hop).getOp() != Hop.OpOp1.STOP && hop.getDataType() == Expression.DataType.SCALAR;
    }

    private boolean isApplicableFalseConjunctivePredicate(Hop hop) throws HopsException {
        ArrayList<Hop> input = hop.getInput();
        return (hop instanceof BinaryOp) && ((BinaryOp) hop).getOp() == Hop.OpOp2.AND && (((input.get(0) instanceof LiteralOp) && !((LiteralOp) input.get(0)).getBooleanValue()) || ((input.get(1) instanceof LiteralOp) && !((LiteralOp) input.get(1)).getBooleanValue()));
    }

    private boolean isApplicableTrueDisjunctivePredicate(Hop hop) throws HopsException {
        ArrayList<Hop> input = hop.getInput();
        return (hop instanceof BinaryOp) && ((BinaryOp) hop).getOp() == Hop.OpOp2.OR && (((input.get(0) instanceof LiteralOp) && ((LiteralOp) input.get(0)).getBooleanValue()) || ((input.get(1) instanceof LiteralOp) && ((LiteralOp) input.get(1)).getBooleanValue()));
    }
}
