package org.apache.sysml.runtime.instructions.spark;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.spark.data.LazyIterableIterator;
import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.matrix.operators.Operator;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/MatrixAppendMSPInstruction.class */
public class MatrixAppendMSPInstruction extends AppendMSPInstruction {

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/MatrixAppendMSPInstruction$MapSideAppendFunction.class */
    private static class MapSideAppendFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 2738541014432173450L;
        private final PartitionedBroadcast<MatrixBlock> _pm;
        private final boolean _cbind;
        private final int _brlen;
        private final int _bclen;
        private final long _lastBlockColIndex;

        public MapSideAppendFunction(PartitionedBroadcast<MatrixBlock> partitionedBroadcast, boolean z, long j, int i, int i2) {
            this._pm = partitionedBroadcast;
            this._cbind = z;
            this._brlen = i;
            this._bclen = i2;
            this._lastBlockColIndex = Math.max((long) Math.ceil(j / (z ? i2 : i)), 1L);
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            MatrixBlock block;
            ArrayList arrayList = new ArrayList();
            IndexedMatrixValue indexedMatrixBlock = SparkUtils.toIndexedMatrixBlock(tuple2);
            MatrixIndexes indexes = indexedMatrixBlock.getIndexes();
            if ((this._cbind ? indexes.getColumnIndex() : indexes.getRowIndex()) != this._lastBlockColIndex) {
                arrayList.add(tuple2);
            } else if (!(this._cbind && indexedMatrixBlock.getValue().getNumColumns() == this._bclen) && (this._cbind || indexedMatrixBlock.getValue().getNumRows() != this._brlen)) {
                ArrayList arrayList2 = new ArrayList(2);
                arrayList2.add(new IndexedMatrixValue(new MatrixIndexes(indexes), new MatrixBlock()));
                if (this._cbind) {
                    block = this._pm.getBlock((int) indexes.getRowIndex(), 1);
                    if (indexedMatrixBlock.getValue().getNumColumns() + block.getNumColumns() > this._bclen) {
                        IndexedMatrixValue indexedMatrixValue = new IndexedMatrixValue(new MatrixIndexes(), new MatrixBlock());
                        indexedMatrixValue.getIndexes().setIndexes(indexes.getRowIndex(), indexes.getColumnIndex() + 1);
                        arrayList2.add(indexedMatrixValue);
                    }
                } else {
                    block = this._pm.getBlock(1, (int) indexes.getColumnIndex());
                    if (indexedMatrixBlock.getValue().getNumRows() + block.getNumRows() > this._brlen) {
                        IndexedMatrixValue indexedMatrixValue2 = new IndexedMatrixValue(new MatrixIndexes(), new MatrixBlock());
                        indexedMatrixValue2.getIndexes().setIndexes(indexes.getRowIndex() + 1, indexes.getColumnIndex());
                        arrayList2.add(indexedMatrixValue2);
                    }
                }
                OperationsOnMatrixValues.performAppend(indexedMatrixBlock.getValue(), block, arrayList2, this._brlen, this._bclen, this._cbind, true, 0);
                arrayList.addAll(SparkUtils.fromIndexedMatrixBlock((ArrayList<IndexedMatrixValue>) arrayList2));
            } else {
                arrayList.add(tuple2);
                if (this._cbind) {
                    arrayList.add(new Tuple2(new MatrixIndexes(indexes.getRowIndex(), indexes.getColumnIndex() + 1), this._pm.getBlock((int) indexes.getRowIndex(), 1)));
                } else {
                    arrayList.add(new Tuple2(new MatrixIndexes(indexes.getRowIndex() + 1, indexes.getColumnIndex()), this._pm.getBlock(1, (int) indexes.getColumnIndex())));
                }
            }
            return arrayList.iterator();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/MatrixAppendMSPInstruction$MapSideAppendPartitionFunction.class */
    public static class MapSideAppendPartitionFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 5767240739761027220L;
        private PartitionedBroadcast<MatrixBlock> _pm;
        private boolean _cbind;
        private long _lastBlockColIndex;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/MatrixAppendMSPInstruction$MapSideAppendPartitionFunction$MapAppendPartitionIterator.class */
        public class MapAppendPartitionIterator extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> {
            public MapAppendPartitionIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> it) {
                super(it);
            }

            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.apache.sysml.runtime.instructions.spark.data.LazyIterableIterator
            public Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
                MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
                MatrixBlock matrixBlock = (MatrixBlock) tuple2._2();
                if ((MapSideAppendPartitionFunction.this._cbind ? matrixIndexes.getColumnIndex() : matrixIndexes.getRowIndex()) != MapSideAppendPartitionFunction.this._lastBlockColIndex) {
                    return tuple2;
                }
                return new Tuple2<>(matrixIndexes, matrixBlock.append((MatrixBlock) MapSideAppendPartitionFunction.this._pm.getBlock(MapSideAppendPartitionFunction.this._cbind ? (int) matrixIndexes.getRowIndex() : 1, MapSideAppendPartitionFunction.this._cbind ? 1 : (int) matrixIndexes.getColumnIndex()), new MatrixBlock(), MapSideAppendPartitionFunction.this._cbind));
            }
        }

        public MapSideAppendPartitionFunction(PartitionedBroadcast<MatrixBlock> partitionedBroadcast, boolean z, long j, int i, int i2) {
            this._pm = null;
            this._cbind = true;
            this._lastBlockColIndex = -1L;
            this._pm = partitionedBroadcast;
            this._cbind = z;
            this._lastBlockColIndex = Math.max((long) Math.ceil(j / (z ? i2 : i)), 1L);
        }

        public LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> it) throws Exception {
            return new MapAppendPartitionIterator(it);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public MatrixAppendMSPInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, boolean z, String str, String str2) {
        super(operator, cPOperand, cPOperand2, cPOperand3, cPOperand4, z, str, str2);
    }

    @Override // org.apache.sysml.runtime.instructions.spark.SPInstruction, org.apache.sysml.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) throws DMLRuntimeException {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        checkBinaryAppendInputCharacteristics(sparkExecutionContext, this._cbind, false, false);
        MatrixCharacteristics matrixCharacteristics = sparkExecutionContext.getMatrixCharacteristics(this.input1.getName());
        MatrixCharacteristics matrixCharacteristics2 = sparkExecutionContext.getMatrixCharacteristics(this.input2.getName());
        int rowsPerBlock = matrixCharacteristics.getRowsPerBlock();
        int colsPerBlock = matrixCharacteristics.getColsPerBlock();
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlockRDDHandleForVariable = sparkExecutionContext.getBinaryBlockRDDHandleForVariable(this.input1.getName());
        PartitionedBroadcast<MatrixBlock> broadcastForVariable = sparkExecutionContext.getBroadcastForVariable(this.input2.getName());
        long longValue = sparkExecutionContext.getScalarInput(this._offset.getName(), this._offset.getValueType(), this._offset.isLiteral()).getLongValue();
        JavaPairRDD<?, ?> mapPartitionsToPair = preservesPartitioning(matrixCharacteristics, matrixCharacteristics2, this._cbind) ? binaryBlockRDDHandleForVariable.mapPartitionsToPair(new MapSideAppendPartitionFunction(broadcastForVariable, this._cbind, longValue, rowsPerBlock, colsPerBlock), true) : binaryBlockRDDHandleForVariable.flatMapToPair(new MapSideAppendFunction(broadcastForVariable, this._cbind, longValue, rowsPerBlock, colsPerBlock));
        updateBinaryAppendOutputMatrixCharacteristics(sparkExecutionContext, this._cbind);
        sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), mapPartitionsToPair);
        sparkExecutionContext.addLineageRDD(this.output.getName(), this.input1.getName());
        sparkExecutionContext.addLineageBroadcast(this.output.getName(), this.input2.getName());
    }

    private static boolean preservesPartitioning(MatrixCharacteristics matrixCharacteristics, MatrixCharacteristics matrixCharacteristics2, boolean z) {
        return (z ? matrixCharacteristics.getNumColBlocks() : matrixCharacteristics.getNumRowBlocks()) == (z ? Math.max((long) Math.ceil((((double) matrixCharacteristics.getCols()) + ((double) matrixCharacteristics2.getCols())) / ((double) matrixCharacteristics.getColsPerBlock())), 1L) : Math.max((long) Math.ceil((((double) matrixCharacteristics.getRows()) + ((double) matrixCharacteristics2.getRows())) / ((double) matrixCharacteristics.getRowsPerBlock())), 1L));
    }
}
