package org.apache.sysml.udf.lib;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.codegen.CodegenUtils;
import org.apache.sysml.runtime.codegen.SpoofOperator;
import org.apache.sysml.runtime.compress.utils.IntArrayList;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.runtime.util.UtilFunctions;
import org.apache.sysml.udf.FunctionParameter;
import org.apache.sysml.udf.Matrix;
import org.apache.sysml.udf.PackageFunction;

/* loaded from: input_file:org/apache/sysml/udf/lib/RowClassMeet.class */
public class RowClassMeet extends PackageFunction {
    private static final long serialVersionUID = 1;
    private Matrix CMat;
    private Matrix NMat;

    /* loaded from: input_file:org/apache/sysml/udf/lib/RowClassMeet$ClassLabel.class */
    private static class ClassLabel {
        public int aVal;
        public int bVal;

        public ClassLabel(int i, int i2) {
            this.aVal = i;
            this.bVal = i2;
        }

        public int hashCode() {
            return UtilFunctions.intHashCode(this.aVal, this.bVal);
        }

        public boolean equals(Object obj) {
            if (!(obj instanceof ClassLabel)) {
                return false;
            }
            ClassLabel classLabel = (ClassLabel) obj;
            return this.aVal == classLabel.aVal && this.bVal == classLabel.bVal;
        }
    }

    @Override // org.apache.sysml.udf.PackageFunction
    public int getNumFunctionOutputs() {
        return 2;
    }

    @Override // org.apache.sysml.udf.PackageFunction
    public FunctionParameter getFunctionOutput(int i) {
        switch (i) {
            case 0:
                return this.CMat;
            case 1:
                return this.NMat;
            default:
                throw new RuntimeException("RowClassMeet produces only one output");
        }
    }

    @Override // org.apache.sysml.udf.PackageFunction
    public void execute() {
        try {
            MatrixBlock acquireRead = ((Matrix) getFunctionInput(0)).getMatrixObject().acquireRead();
            MatrixBlock acquireRead2 = ((Matrix) getFunctionInput(1)).getMatrixObject().acquireRead();
            int max = Math.max(acquireRead.getNumRows(), acquireRead2.getNumRows());
            int max2 = Math.max(acquireRead.getNumColumns(), acquireRead2.getNumColumns());
            MatrixBlock allocateBlock = new MatrixBlock(max, max2, false).allocateBlock();
            MatrixBlock allocateBlock2 = new MatrixBlock(max, max2, false).allocateBlock();
            double[] denseBlockValues = allocateBlock.getDenseBlockValues();
            double[] denseBlockValues2 = allocateBlock2.getDenseBlockValues();
            SpoofOperator.SideInput createSideInput = CodegenUtils.createSideInput(acquireRead2);
            boolean z = acquireRead2.getNumRows() == 1;
            int min = Math.min(acquireRead.getNumColumns(), acquireRead2.getNumColumns());
            HashMap hashMap = new HashMap();
            int i = 0;
            int i2 = 0;
            while (i < acquireRead.getNumRows()) {
                hashMap.clear();
                createSideInput.reset();
                if (acquireRead.isInSparseFormat()) {
                    if (acquireRead.getSparseBlock() != null && !acquireRead.getSparseBlock().isEmpty(i)) {
                        int size = acquireRead.getSparseBlock().size(i);
                        int pos = acquireRead.getSparseBlock().pos(i);
                        int[] indexes = acquireRead.getSparseBlock().indexes(i);
                        double[] values = acquireRead.getSparseBlock().values(i);
                        for (int i3 = pos; i3 < pos + size && indexes[i3] < min; i3++) {
                            int value = (int) createSideInput.getValue(z ? 0 : i, indexes[i3]);
                            if (value != 0) {
                                ClassLabel classLabel = new ClassLabel((int) values[i3], value);
                                if (!hashMap.containsKey(classLabel)) {
                                    hashMap.put(classLabel, new IntArrayList());
                                }
                                ((IntArrayList) hashMap.get(classLabel)).appendValue(indexes[i3]);
                            }
                        }
                    }
                    i++;
                    i2 += acquireRead.getNumColumns();
                } else {
                    double[] denseBlockValues3 = acquireRead.getDenseBlockValues();
                    if (denseBlockValues3 == null) {
                        break;
                    }
                    for (int i4 = 0; i4 < min; i4++) {
                        int i5 = (int) denseBlockValues3[i2 + i4];
                        int value2 = (int) createSideInput.getValue(z ? 0 : i, i4);
                        if (i5 != 0 && value2 != 0) {
                            ClassLabel classLabel2 = new ClassLabel(i5, value2);
                            if (!hashMap.containsKey(classLabel2)) {
                                hashMap.put(classLabel2, new IntArrayList());
                            }
                            ((IntArrayList) hashMap.get(classLabel2)).appendValue(i4);
                        }
                    }
                }
                int i6 = 1;
                for (Map.Entry entry : hashMap.entrySet()) {
                    int size2 = ((IntArrayList) entry.getValue()).size();
                    int[] extractValues = ((IntArrayList) entry.getValue()).extractValues();
                    int i7 = i * max2;
                    for (int i8 = 0; i8 < size2; i8++) {
                        denseBlockValues2[i7 + extractValues[i8]] = size2;
                        denseBlockValues[i7 + extractValues[i8]] = i6;
                    }
                    i6++;
                }
                i++;
                i2 += acquireRead.getNumColumns();
            }
            ((Matrix) getFunctionInput(0)).getMatrixObject().release();
            ((Matrix) getFunctionInput(1)).getMatrixObject().release();
            allocateBlock.recomputeNonZeros();
            allocateBlock.examSparsity();
            this.CMat = new Matrix(createOutputFilePathAndName("TMP"), max, max2, Matrix.ValueType.Double);
            this.CMat.setMatrixDoubleArray(allocateBlock, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
            allocateBlock2.recomputeNonZeros();
            allocateBlock2.examSparsity();
            this.NMat = new Matrix(createOutputFilePathAndName("TMP"), max, max2, Matrix.ValueType.Double);
            this.NMat.setMatrixDoubleArray(allocateBlock2, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
        } catch (IOException | DMLRuntimeException e) {
            throw new RuntimeException("Error while executing RowClassMeet", e);
        }
    }
}
