package org.apache.sysml.hops.rewrite;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.LeftIndexingOp;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.parser.Expression;

/* loaded from: input_file:org/apache/sysml/hops/rewrite/RewriteIndexingVectorization.class */
public class RewriteIndexingVectorization extends HopRewriteRule {
    private static final Log LOG = LogFactory.getLog(RewriteIndexingVectorization.class.getName());

    @Override // org.apache.sysml.hops.rewrite.HopRewriteRule
    public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> arrayList, ProgramRewriteStatus programRewriteStatus) throws HopsException {
        if (arrayList == null) {
            return arrayList;
        }
        Iterator<Hop> it = arrayList.iterator();
        while (it.hasNext()) {
            rule_IndexingVectorization(it.next());
        }
        return arrayList;
    }

    @Override // org.apache.sysml.hops.rewrite.HopRewriteRule
    public Hop rewriteHopDAG(Hop hop, ProgramRewriteStatus programRewriteStatus) throws HopsException {
        if (hop == null) {
            return hop;
        }
        rule_IndexingVectorization(hop);
        return hop;
    }

    private void rule_IndexingVectorization(Hop hop) throws HopsException {
        if (hop.getVisited() == Hop.VisitStatus.DONE) {
            return;
        }
        for (int i = 0; i < hop.getInput().size(); i++) {
            Hop hop2 = hop.getInput().get(i);
            vectorizeLeftIndexing(hop2);
            rule_IndexingVectorization(hop2);
        }
        hop.setVisited(Hop.VisitStatus.DONE);
    }

    private void vectorizeRightIndexing(Hop hop) throws HopsException {
        if (hop instanceof IndexingOp) {
            IndexingOp indexingOp = (IndexingOp) hop;
            boolean rowLowerEqualsUpper = indexingOp.getRowLowerEqualsUpper();
            boolean colLowerEqualsUpper = indexingOp.getColLowerEqualsUpper();
            boolean z = false;
            if (rowLowerEqualsUpper && colLowerEqualsUpper) {
                Hop hop2 = indexingOp.getInput().get(0);
                ArrayList arrayList = new ArrayList();
                arrayList.add(indexingOp);
                Iterator<Hop> it = hop2.getParent().iterator();
                while (it.hasNext()) {
                    Hop next = it.next();
                    if (next != indexingOp && (next instanceof IndexingOp) && next.getInput().get(0) == hop2 && ((IndexingOp) next).getRowLowerEqualsUpper() && next.getInput().get(1) == indexingOp.getInput().get(1)) {
                        arrayList.add(next);
                    }
                }
                if (arrayList.size() > 1) {
                    IndexingOp indexingOp2 = new IndexingOp("tmp", Expression.DataType.MATRIX, Expression.ValueType.DOUBLE, hop2, indexingOp.getInput().get(1), indexingOp.getInput().get(1), new LiteralOp(1L), HopRewriteUtils.createValueHop(hop2, false), true, false);
                    HopRewriteUtils.setOutputParameters(indexingOp2, -1L, -1L, hop2.getRowsInBlock(), hop2.getColsInBlock(), -1L);
                    indexingOp2.refreshSizeInformation();
                    Iterator it2 = arrayList.iterator();
                    while (it2.hasNext()) {
                        Hop hop3 = (Hop) it2.next();
                        HopRewriteUtils.removeChildReference(hop3, hop2);
                        HopRewriteUtils.addChildReference(hop3, indexingOp2, 0);
                        HopRewriteUtils.removeChildReferenceByPos(hop3, hop3.getInput().get(1), 1);
                        HopRewriteUtils.addChildReference(hop3, new LiteralOp(1L), 1);
                        HopRewriteUtils.removeChildReferenceByPos(hop3, hop3.getInput().get(2), 2);
                        HopRewriteUtils.addChildReference(hop3, new LiteralOp(1L), 2);
                        hop3.refreshSizeInformation();
                    }
                    z = true;
                    LOG.debug("Applied vectorizeRightIndexingRow");
                }
            }
            if (rowLowerEqualsUpper && colLowerEqualsUpper && !z) {
                Hop hop4 = indexingOp.getInput().get(0);
                ArrayList arrayList2 = new ArrayList();
                arrayList2.add(indexingOp);
                Iterator<Hop> it3 = hop4.getParent().iterator();
                while (it3.hasNext()) {
                    Hop next2 = it3.next();
                    if (next2 != indexingOp && (next2 instanceof IndexingOp) && next2.getInput().get(0) == hop4 && ((IndexingOp) next2).getColLowerEqualsUpper() && next2.getInput().get(3) == indexingOp.getInput().get(3)) {
                        arrayList2.add(next2);
                    }
                }
                if (arrayList2.size() > 1) {
                    IndexingOp indexingOp3 = new IndexingOp("tmp", Expression.DataType.MATRIX, Expression.ValueType.DOUBLE, hop4, new LiteralOp(1L), HopRewriteUtils.createValueHop(hop4, true), indexingOp.getInput().get(3), indexingOp.getInput().get(3), false, true);
                    HopRewriteUtils.setOutputParameters(indexingOp3, -1L, -1L, hop4.getRowsInBlock(), hop4.getColsInBlock(), -1L);
                    indexingOp3.refreshSizeInformation();
                    Iterator it4 = arrayList2.iterator();
                    while (it4.hasNext()) {
                        Hop hop5 = (Hop) it4.next();
                        HopRewriteUtils.removeChildReference(hop5, hop4);
                        HopRewriteUtils.addChildReference(hop5, indexingOp3, 0);
                        HopRewriteUtils.removeChildReferenceByPos(hop5, hop5.getInput().get(3), 3);
                        HopRewriteUtils.addChildReference(hop5, new LiteralOp(1L), 3);
                        HopRewriteUtils.removeChildReferenceByPos(hop5, hop5.getInput().get(4), 4);
                        HopRewriteUtils.addChildReference(hop5, new LiteralOp(1L), 4);
                        hop5.refreshSizeInformation();
                    }
                    LOG.debug("Applied vectorizeRightIndexingCol");
                }
            }
        }
    }

    private void vectorizeLeftIndexing(Hop hop) throws HopsException {
        LeftIndexingOp leftIndexingOp;
        LeftIndexingOp leftIndexingOp2;
        if (hop instanceof LeftIndexingOp) {
            LeftIndexingOp leftIndexingOp3 = (LeftIndexingOp) hop;
            boolean rowLowerEqualsUpper = leftIndexingOp3.getRowLowerEqualsUpper();
            boolean colLowerEqualsUpper = leftIndexingOp3.getColLowerEqualsUpper();
            boolean z = false;
            if (rowLowerEqualsUpper && colLowerEqualsUpper) {
                ArrayList arrayList = new ArrayList();
                arrayList.add(leftIndexingOp3);
                LeftIndexingOp leftIndexingOp4 = leftIndexingOp3;
                while (true) {
                    leftIndexingOp2 = leftIndexingOp4;
                    if (!(leftIndexingOp2.getInput().get(0) instanceof LeftIndexingOp)) {
                        break;
                    }
                    LeftIndexingOp leftIndexingOp5 = (LeftIndexingOp) leftIndexingOp2.getInput().get(0);
                    if (leftIndexingOp5.getParent().size() > 1 || !leftIndexingOp5.getRowLowerEqualsUpper() || leftIndexingOp5.getInput().get(2) != leftIndexingOp3.getInput().get(2) || leftIndexingOp5.getInput().get(0).getDim2() <= 1) {
                        break;
                    }
                    arrayList.add(leftIndexingOp5);
                    leftIndexingOp4 = leftIndexingOp5;
                }
                if (arrayList.size() > 1) {
                    Hop hop2 = leftIndexingOp2.getInput().get(0);
                    Hop hop3 = leftIndexingOp3.getInput().get(2);
                    IndexingOp indexingOp = new IndexingOp("tmp1", Expression.DataType.MATRIX, Expression.ValueType.DOUBLE, hop2, hop3, hop3, new LiteralOp(1L), HopRewriteUtils.createValueHop(hop2, false), true, false);
                    HopRewriteUtils.setOutputParameters(indexingOp, -1L, -1L, hop2.getRowsInBlock(), hop2.getColsInBlock(), -1L);
                    indexingOp.refreshSizeInformation();
                    HopRewriteUtils.removeChildReference(leftIndexingOp2, hop2);
                    HopRewriteUtils.addChildReference(leftIndexingOp2, indexingOp, 0);
                    for (int size = arrayList.size() - 1; size >= 0; size--) {
                        Hop hop4 = (Hop) arrayList.get(size);
                        HopRewriteUtils.removeChildReferenceByPos(hop4, hop4.getInput().get(2), 2);
                        HopRewriteUtils.addChildReference(hop4, new LiteralOp(1L), 2);
                        HopRewriteUtils.removeChildReferenceByPos(hop4, hop4.getInput().get(3), 3);
                        HopRewriteUtils.addChildReference(hop4, new LiteralOp(1L), 3);
                        ((LeftIndexingOp) hop4).setRowLowerEqualsUpper(true);
                        hop4.refreshSizeInformation();
                    }
                    ArrayList arrayList2 = (ArrayList) leftIndexingOp3.getParent().clone();
                    ArrayList arrayList3 = new ArrayList();
                    Iterator it = arrayList2.iterator();
                    while (it.hasNext()) {
                        Hop hop5 = (Hop) it.next();
                        int childReferencePos = HopRewriteUtils.getChildReferencePos(hop5, leftIndexingOp3);
                        HopRewriteUtils.removeChildReferenceByPos(hop5, leftIndexingOp3, childReferencePos);
                        arrayList3.add(Integer.valueOf(childReferencePos));
                    }
                    LeftIndexingOp leftIndexingOp6 = new LeftIndexingOp("tmp2", Expression.DataType.MATRIX, Expression.ValueType.DOUBLE, hop2, leftIndexingOp3, hop3, hop3, new LiteralOp(1L), HopRewriteUtils.createValueHop(hop2, false), true, false);
                    HopRewriteUtils.setOutputParameters(leftIndexingOp6, -1L, -1L, hop2.getRowsInBlock(), hop2.getColsInBlock(), -1L);
                    leftIndexingOp6.refreshSizeInformation();
                    for (int i = 0; i < arrayList3.size(); i++) {
                        HopRewriteUtils.addChildReference((Hop) arrayList2.get(i), leftIndexingOp6, ((Integer) arrayList3.get(i)).intValue());
                    }
                    z = true;
                    LOG.debug("Applied vectorizeLeftIndexingRow");
                }
            }
            if (rowLowerEqualsUpper && colLowerEqualsUpper && !z) {
                ArrayList arrayList4 = new ArrayList();
                arrayList4.add(leftIndexingOp3);
                LeftIndexingOp leftIndexingOp7 = leftIndexingOp3;
                while (true) {
                    leftIndexingOp = leftIndexingOp7;
                    if (!(leftIndexingOp.getInput().get(0) instanceof LeftIndexingOp)) {
                        break;
                    }
                    LeftIndexingOp leftIndexingOp8 = (LeftIndexingOp) leftIndexingOp.getInput().get(0);
                    if (leftIndexingOp8.getParent().size() > 1 || !leftIndexingOp8.getColLowerEqualsUpper() || leftIndexingOp8.getInput().get(4) != leftIndexingOp3.getInput().get(4) || leftIndexingOp8.getInput().get(0).getDim1() <= 1) {
                        break;
                    }
                    arrayList4.add(leftIndexingOp8);
                    leftIndexingOp7 = leftIndexingOp8;
                }
                if (arrayList4.size() > 1) {
                    Hop hop6 = leftIndexingOp.getInput().get(0);
                    Hop hop7 = leftIndexingOp3.getInput().get(4);
                    IndexingOp indexingOp2 = new IndexingOp("tmp1", Expression.DataType.MATRIX, Expression.ValueType.DOUBLE, hop6, new LiteralOp(1L), HopRewriteUtils.createValueHop(hop6, true), hop7, hop7, false, true);
                    HopRewriteUtils.setOutputParameters(indexingOp2, -1L, -1L, hop6.getRowsInBlock(), hop6.getColsInBlock(), -1L);
                    indexingOp2.refreshSizeInformation();
                    HopRewriteUtils.removeChildReference(leftIndexingOp, hop6);
                    HopRewriteUtils.addChildReference(leftIndexingOp, indexingOp2, 0);
                    for (int size2 = arrayList4.size() - 1; size2 >= 0; size2--) {
                        Hop hop8 = (Hop) arrayList4.get(size2);
                        HopRewriteUtils.removeChildReferenceByPos(hop8, hop8.getInput().get(4), 4);
                        HopRewriteUtils.addChildReference(hop8, new LiteralOp(1L), 4);
                        HopRewriteUtils.removeChildReferenceByPos(hop8, hop8.getInput().get(5), 5);
                        HopRewriteUtils.addChildReference(hop8, new LiteralOp(1L), 5);
                        ((LeftIndexingOp) hop8).setColLowerEqualsUpper(true);
                        hop8.refreshSizeInformation();
                    }
                    ArrayList arrayList5 = (ArrayList) leftIndexingOp3.getParent().clone();
                    ArrayList arrayList6 = new ArrayList();
                    Iterator it2 = arrayList5.iterator();
                    while (it2.hasNext()) {
                        Hop hop9 = (Hop) it2.next();
                        int childReferencePos2 = HopRewriteUtils.getChildReferencePos(hop9, leftIndexingOp3);
                        HopRewriteUtils.removeChildReferenceByPos(hop9, leftIndexingOp3, childReferencePos2);
                        arrayList6.add(Integer.valueOf(childReferencePos2));
                    }
                    LeftIndexingOp leftIndexingOp9 = new LeftIndexingOp("tmp2", Expression.DataType.MATRIX, Expression.ValueType.DOUBLE, hop6, leftIndexingOp3, new LiteralOp(1L), HopRewriteUtils.createValueHop(hop6, true), hop7, hop7, false, true);
                    HopRewriteUtils.setOutputParameters(leftIndexingOp9, -1L, -1L, hop6.getRowsInBlock(), hop6.getColsInBlock(), -1L);
                    leftIndexingOp9.refreshSizeInformation();
                    for (int i2 = 0; i2 < arrayList6.size(); i2++) {
                        HopRewriteUtils.addChildReference((Hop) arrayList5.get(i2), leftIndexingOp9, ((Integer) arrayList6.get(i2)).intValue());
                    }
                    LOG.debug("Applied vectorizeLeftIndexingCol");
                }
            }
        }
    }
}
