package org.apache.sysml.runtime.controlprogram.parfor;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.hadoop.io.Writable;
import org.apache.spark.Accumulator;
import org.apache.spark.TaskContext;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.controlprogram.ParForProgramBlock;
import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.parfor.Task;
import org.apache.sysml.runtime.controlprogram.parfor.util.IDHandler;
import org.apache.sysml.runtime.controlprogram.parfor.util.PairWritableBlock;
import org.apache.sysml.runtime.controlprogram.parfor.util.PairWritableCell;
import org.apache.sysml.runtime.instructions.cp.IntObject;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.runtime.util.LocalFileUtils;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysml/runtime/controlprogram/parfor/RemoteDPParForSparkWorker.class */
public class RemoteDPParForSparkWorker extends ParWorker implements PairFlatMapFunction<Iterator<Tuple2<Long, Iterable<Writable>>>, Long, String> {
    private static final long serialVersionUID = 30223759283155139L;
    private String _prog;
    private boolean _caching;
    private String _inputVar;
    private String _iterVar;
    private OutputInfo _oinfo;
    private int _rlen;
    private int _clen;
    private int _brlen;
    private int _bclen;
    private boolean _tSparseCol;
    private ParForProgramBlock.PDataPartitionFormat _dpf;
    private Accumulator<Integer> _aTasks;
    private Accumulator<Integer> _aIters;

    public RemoteDPParForSparkWorker(String str, String str2, String str3, boolean z, MatrixCharacteristics matrixCharacteristics, boolean z2, ParForProgramBlock.PDataPartitionFormat pDataPartitionFormat, OutputInfo outputInfo, Accumulator<Integer> accumulator, Accumulator<Integer> accumulator2) throws DMLRuntimeException, DMLUnsupportedOperationException {
        this._prog = null;
        this._caching = true;
        this._inputVar = null;
        this._iterVar = null;
        this._oinfo = null;
        this._rlen = -1;
        this._clen = -1;
        this._brlen = -1;
        this._bclen = -1;
        this._tSparseCol = false;
        this._dpf = null;
        this._aTasks = null;
        this._aIters = null;
        this._prog = str;
        this._caching = z;
        this._inputVar = str2;
        this._iterVar = str3;
        this._oinfo = outputInfo;
        this._aTasks = accumulator;
        this._aIters = accumulator2;
        this._rlen = (int) matrixCharacteristics.getRows();
        this._clen = (int) matrixCharacteristics.getCols();
        this._brlen = matrixCharacteristics.getRowsPerBlock();
        this._bclen = matrixCharacteristics.getColsPerBlock();
        this._tSparseCol = z2;
        this._dpf = pDataPartitionFormat;
        switch (this._dpf) {
            case ROW_WISE:
                this._rlen = 1;
                return;
            case COLUMN_WISE:
                this._clen = 1;
                return;
            default:
                throw new RuntimeException("Partition format not yet supported in fused partition-execute: " + pDataPartitionFormat);
        }
    }

    public Iterable<Tuple2<Long, String>> call(Iterator<Tuple2<Long, Iterable<Writable>>> it) throws Exception {
        ArrayList arrayList = new ArrayList();
        configureWorker(TaskContext.get().taskAttemptId());
        while (it.hasNext()) {
            Tuple2<Long, Iterable<Writable>> next = it.next();
            ((MatrixObject) this._ec.getVariable(this._inputVar)).setInMemoryPartition(this._oinfo.equals(OutputInfo.BinaryBlockOutputInfo) ? collectBinaryBlock((Iterable) next._2()) : collectBinaryCellInput((Iterable) next._2()));
            Task task = new Task(Task.TaskType.SET);
            task.addIteration(new IntObject(this._iterVar, ((Long) next._1()).longValue()));
            long executedIterations = getExecutedIterations();
            super.executeTask(task);
            this._aTasks.add(1);
            this._aIters.add(Integer.valueOf((int) (getExecutedIterations() - executedIterations)));
            Iterator<String> it2 = RemoteParForUtils.exportResultVariables(this._workerID, this._ec.getVariables(), this._resultVars).iterator();
            while (it2.hasNext()) {
                arrayList.add(new Tuple2(Long.valueOf(this._workerID), it2.next()));
            }
        }
        return arrayList;
    }

    private void configureWorker(long j) throws DMLRuntimeException, DMLUnsupportedOperationException, IOException {
        this._workerID = j;
        ParForBody parseParForBody = ProgramConverter.parseParForBody(this._prog, (int) this._workerID);
        this._childBlocks = parseParForBody.getChildBlocks();
        this._ec = parseParForBody.getEc();
        this._resultVars = parseParForBody.getResultVarNames();
        this._numTasks = 0L;
        this._numIters = 0L;
        if (!CacheableData.isCachingActive()) {
            String createDistributedUniqueID = IDHandler.createDistributedUniqueID();
            LocalFileUtils.createWorkingDirectoryWithUUID(createDistributedUniqueID);
            CacheableData.initCaching(createDistributedUniqueID);
        }
        if (!CacheableData.cacheEvictionLocalFilePrefix.contains("_")) {
            CacheableData.cacheEvictionLocalFilePrefix += "_" + this._workerID;
        }
        super.pinResultVariables();
        if (this._caching) {
            return;
        }
        CacheableData.disableCaching();
    }

    private MatrixBlock collectBinaryBlock(Iterable<Writable> iterable) throws IOException {
        try {
            MatrixBlock matrixBlock = this._tSparseCol ? new MatrixBlock(this._clen, this._rlen, true) : new MatrixBlock(this._rlen, this._clen, false);
            Iterator<Writable> it = iterable.iterator();
            while (it.hasNext()) {
                PairWritableBlock pairWritableBlock = (PairWritableBlock) it.next();
                int rowIndex = ((int) (pairWritableBlock.indexes.getRowIndex() - 1)) * this._brlen;
                int columnIndex = ((int) (pairWritableBlock.indexes.getColumnIndex() - 1)) * this._bclen;
                MatrixBlock matrixBlock2 = pairWritableBlock.block;
                if (matrixBlock.isInSparseFormat()) {
                    matrixBlock.appendToSparse(pairWritableBlock.block, rowIndex, columnIndex);
                } else {
                    matrixBlock.copy(rowIndex, (rowIndex + matrixBlock2.getNumRows()) - 1, columnIndex, (columnIndex + matrixBlock2.getNumColumns()) - 1, pairWritableBlock.block, false);
                }
            }
            cleanupCollectedMatrixPartition(matrixBlock, matrixBlock.isInSparseFormat());
            return matrixBlock;
        } catch (DMLRuntimeException e) {
            throw new IOException(e);
        }
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    /* JADX WARN: Failed to find 'out' block for switch in B:5:0x0039. Please report as an issue. */
    private MatrixBlock collectBinaryCellInput(Iterable<Writable> iterable) throws IOException {
        MatrixBlock matrixBlock = this._tSparseCol ? new MatrixBlock(this._clen, this._rlen, true) : new MatrixBlock(this._rlen, this._clen, false);
        switch (this._dpf) {
            case ROW_WISE:
                while (iterable.iterator().hasNext()) {
                    PairWritableCell pairWritableCell = (PairWritableCell) iterable.iterator().next();
                    if (pairWritableCell.indexes.getColumnIndex() >= 0) {
                        matrixBlock.quickSetValue(0, ((int) pairWritableCell.indexes.getColumnIndex()) - 1, pairWritableCell.cell.getValue());
                    }
                }
                cleanupCollectedMatrixPartition(matrixBlock, this._tSparseCol);
                return matrixBlock;
            case COLUMN_WISE:
                while (iterable.iterator().hasNext()) {
                    PairWritableCell pairWritableCell2 = (PairWritableCell) iterable.iterator().next();
                    if (pairWritableCell2.indexes.getRowIndex() >= 0) {
                        if (this._tSparseCol) {
                            matrixBlock.appendValue(0, ((int) pairWritableCell2.indexes.getRowIndex()) - 1, pairWritableCell2.cell.getValue());
                        } else {
                            matrixBlock.quickSetValue(((int) pairWritableCell2.indexes.getRowIndex()) - 1, 0, pairWritableCell2.cell.getValue());
                        }
                    }
                }
                cleanupCollectedMatrixPartition(matrixBlock, this._tSparseCol);
                return matrixBlock;
            default:
                throw new IOException("Partition format not yet supported in fused partition-execute: " + this._dpf);
        }
    }

    private void cleanupCollectedMatrixPartition(MatrixBlock matrixBlock, boolean z) throws IOException {
        if (matrixBlock.isInSparseFormat() && z) {
            matrixBlock.sortSparseRows();
        }
        if (!matrixBlock.isInSparseFormat()) {
            matrixBlock.recomputeNonZeros();
        }
        try {
            matrixBlock.examSparsity();
        } catch (Exception e) {
            throw new IOException(e);
        }
    }
}
