package org.apache.sysml.runtime.instructions.spark.functions;

import java.io.Serializable;
import java.util.ArrayList;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcastMatrix;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.WeightedCell;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.util.UtilFunctions;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/functions/ExtractGroup.class */
public abstract class ExtractGroup implements Serializable {
    private static final long serialVersionUID = -7059358143841229966L;
    protected long _bclen;
    protected long _ngroups;
    protected Operator _op;

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/functions/ExtractGroup$ExtractGroupBroadcast.class */
    public static class ExtractGroupBroadcast extends ExtractGroup implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, WeightedCell> {
        private static final long serialVersionUID = 5709955602290131093L;
        private PartitionedBroadcastMatrix _pbm;

        public ExtractGroupBroadcast(PartitionedBroadcastMatrix partitionedBroadcastMatrix, long j, long j2, Operator operator) {
            super(j, j2, operator);
            this._pbm = null;
            this._pbm = partitionedBroadcastMatrix;
        }

        public Iterable<Tuple2<MatrixIndexes, WeightedCell>> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1;
            return execute(matrixIndexes, this._pbm.getMatrixBlock((int) matrixIndexes.getRowIndex(), 1), (MatrixBlock) tuple2._2);
        }
    }

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/functions/ExtractGroup$ExtractGroupJoin.class */
    public static class ExtractGroupJoin extends ExtractGroup implements PairFlatMapFunction<Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>>, MatrixIndexes, WeightedCell> {
        private static final long serialVersionUID = 8890978615936560266L;

        public ExtractGroupJoin(long j, long j2, Operator operator) {
            super(j, j2, operator);
        }

        public Iterable<Tuple2<MatrixIndexes, WeightedCell>> call(Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> tuple2) throws Exception {
            return execute((MatrixIndexes) tuple2._1, (MatrixBlock) ((Tuple2) tuple2._2)._1, (MatrixBlock) ((Tuple2) tuple2._2)._2);
        }
    }

    public ExtractGroup(long j, long j2, Operator operator) {
        this._bclen = -1L;
        this._ngroups = -1L;
        this._op = null;
        this._bclen = j;
        this._ngroups = j2;
        this._op = operator;
    }

    protected Iterable<Tuple2<MatrixIndexes, WeightedCell>> execute(MatrixIndexes matrixIndexes, MatrixBlock matrixBlock, MatrixBlock matrixBlock2) throws Exception {
        if (matrixBlock.getNumRows() != matrixBlock2.getNumRows()) {
            throw new Exception("The blocksize for group and target blocks are mismatched: " + matrixBlock.getNumRows() + " != " + matrixBlock2.getNumRows());
        }
        ArrayList arrayList = new ArrayList();
        long columnIndex = (matrixIndexes.getColumnIndex() - 1) * this._bclen;
        if ((this._op instanceof AggregateOperator) && this._ngroups > 0 && OptimizerUtils.isValidCPDimensions(this._ngroups, matrixBlock2.getNumColumns())) {
            MatrixBlock groupedAggOperations = matrixBlock.groupedAggOperations(matrixBlock2, null, new MatrixBlock(), (int) this._ngroups, this._op);
            for (int i = 0; i < groupedAggOperations.getNumRows(); i++) {
                for (int i2 = 0; i2 < groupedAggOperations.getNumColumns(); i2++) {
                    double quickGetValue = groupedAggOperations.quickGetValue(i, i2);
                    if (quickGetValue != 0.0d) {
                        WeightedCell weightedCell = new WeightedCell();
                        weightedCell.setValue(quickGetValue);
                        weightedCell.setWeight(1.0d);
                        arrayList.add(new Tuple2(new MatrixIndexes(i + 1, columnIndex + i2 + 1), weightedCell));
                    }
                }
            }
        } else {
            for (int i3 = 0; i3 < matrixBlock.getNumRows(); i3++) {
                long j = UtilFunctions.toLong(matrixBlock.quickGetValue(i3, 0));
                if (j < 1) {
                    throw new Exception("Expected group values to be greater than equal to 1 but found " + j);
                }
                for (int i4 = 0; i4 < matrixBlock2.getNumColumns(); i4++) {
                    WeightedCell weightedCell2 = new WeightedCell();
                    weightedCell2.setValue(matrixBlock2.quickGetValue(i3, i4));
                    weightedCell2.setWeight(1.0d);
                    arrayList.add(new Tuple2(new MatrixIndexes(j, columnIndex + i4 + 1), weightedCell2));
                }
            }
        }
        return arrayList;
    }
}
