package org.apache.sysml.api;

import java.io.IOException;
import java.util.List;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.types.StructType;
import org.apache.sysml.parser.DMLTranslator;
import org.apache.sysml.parser.ParseException;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.spark.functions.GetMIMBFromRow;
import org.apache.sysml.runtime.instructions.spark.functions.GetMLBlock;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.slf4j.Marker;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysml/api/MLMatrix.class */
public class MLMatrix extends DataFrame {
    private static final long serialVersionUID = -7005940673916671165L;
    protected MatrixCharacteristics mc;
    protected MLContext ml;
    protected static final Log LOG = LogFactory.getLog(DMLScript.class.getName());
    static String writeStmt = "write(output, \"tmp\", format=\"binary\", rows_in_block=" + DMLTranslator.DMLBlockSize + ", cols_in_block=" + DMLTranslator.DMLBlockSize + ");";

    protected MLMatrix(SQLContext sQLContext, LogicalPlan logicalPlan, MLContext mLContext) {
        super(sQLContext, logicalPlan);
        this.mc = null;
        this.ml = null;
        this.ml = mLContext;
    }

    protected MLMatrix(SQLContext sQLContext, SQLContext.QueryExecution queryExecution, MLContext mLContext) {
        super(sQLContext, queryExecution);
        this.mc = null;
        this.ml = null;
        this.ml = mLContext;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public MLMatrix(DataFrame dataFrame, MatrixCharacteristics matrixCharacteristics, MLContext mLContext) throws DMLRuntimeException {
        super(dataFrame.sqlContext(), dataFrame.logicalPlan());
        this.mc = null;
        this.ml = null;
        this.mc = matrixCharacteristics;
        this.ml = mLContext;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static MLMatrix createMLMatrix(MLContext mLContext, SQLContext sQLContext, JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, MatrixCharacteristics matrixCharacteristics) throws DMLRuntimeException {
        RDD rdd = javaPairRDD.map(new GetMLBlock()).rdd();
        return new MLMatrix(sQLContext.createDataFrame(rdd.toJavaRDD(), MLBlock.getDefaultSchemaForBinaryBlock()), matrixCharacteristics, mLContext);
    }

    public void write(String str, String str2) throws IOException, DMLException, ParseException {
        this.ml.reset();
        this.ml.registerInput("left", this);
        this.ml.executeScript("left = read(\"\"); output=left; write(output, \"" + str + "\", format=\"" + str2 + "\");");
    }

    private double getScalarBuiltinFunctionResult(String str) throws IOException, DMLException, ParseException {
        if (str.compareTo("nrow") != 0 && str.compareTo("ncol") != 0) {
            throw new DMLRuntimeException("The function " + str + " is not yet supported in MLMatrix");
        }
        this.ml.reset();
        this.ml.registerInput("left", getRDDLazily(this), this.mc.getRows(), this.mc.getCols(), this.mc.getRowsPerBlock(), this.mc.getColsPerBlock(), this.mc.getNonZeros());
        this.ml.registerOutput("output");
        List collect = this.ml.executeScript("left = read(\"\");val = " + str + "(left); output = matrix(val, rows=1, cols=1); " + writeStmt).getBinaryBlockedRDD("output").collect();
        if (collect == null || collect.size() != 1) {
            throw new DMLRuntimeException("Error while computing the function: " + str);
        }
        return ((MatrixBlock) ((Tuple2) collect.get(0))._2).getValue(0, 0);
    }

    public long numRows() throws IOException, DMLException, ParseException {
        return this.mc.rowsKnown() ? this.mc.getRows() : (long) getScalarBuiltinFunctionResult("nrow");
    }

    public long numCols() throws IOException, DMLException, ParseException {
        return this.mc.colsKnown() ? this.mc.getCols() : (long) getScalarBuiltinFunctionResult("ncol");
    }

    public int rowsPerBlock() {
        return this.mc.getRowsPerBlock();
    }

    public int colsPerBlock() {
        return this.mc.getColsPerBlock();
    }

    private String getScript(String str) {
        return "left = read(\"\");right = read(\"\");output = left " + str + " right; " + writeStmt;
    }

    private String getScalarBinaryScript(String str, double d, boolean z) {
        return z ? "left = read(\"\");output = " + d + " " + str + " left ;" + writeStmt : "left = read(\"\");output = left " + str + " " + d + ";" + writeStmt;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static JavaPairRDD<MatrixIndexes, MatrixBlock> getRDDLazily(MLMatrix mLMatrix) {
        return mLMatrix.rdd().toJavaRDD().mapToPair(new GetMIMBFromRow());
    }

    private MLMatrix matrixBinaryOp(MLMatrix mLMatrix, String str) throws IOException, DMLException, ParseException {
        if (this.mc.getRowsPerBlock() != mLMatrix.mc.getRowsPerBlock() || this.mc.getColsPerBlock() != mLMatrix.mc.getColsPerBlock()) {
            throw new DMLRuntimeException("Incompatible block sizes: brlen:" + this.mc.getRowsPerBlock() + "!=" + mLMatrix.mc.getRowsPerBlock() + " || bclen:" + this.mc.getColsPerBlock() + "!=" + mLMatrix.mc.getColsPerBlock());
        }
        if (str.compareTo("%*%") == 0) {
            if (this.mc.getCols() != mLMatrix.mc.getRows()) {
                throw new DMLRuntimeException("Dimensions mismatch:" + this.mc.getCols() + "!=" + mLMatrix.mc.getRows());
            }
        } else if (this.mc.getRows() != mLMatrix.mc.getRows() || this.mc.getCols() != mLMatrix.mc.getCols()) {
            throw new DMLRuntimeException("Dimensions mismatch:" + this.mc.getRows() + "!=" + mLMatrix.mc.getRows() + " || " + this.mc.getCols() + "!=" + mLMatrix.mc.getCols());
        }
        this.ml.reset();
        this.ml.registerInput("left", this);
        this.ml.registerInput("right", mLMatrix);
        this.ml.registerOutput("output");
        MLOutput executeScript = this.ml.executeScript(getScript(str));
        RDD rdd = executeScript.getBinaryBlockedRDD("output").map(new GetMLBlock()).rdd();
        return new MLMatrix(sqlContext().createDataFrame(rdd.toJavaRDD(), MLBlock.getDefaultSchemaForBinaryBlock()), executeScript.getMatrixCharacteristics("output"), this.ml);
    }

    private MLMatrix scalarBinaryOp(Double d, String str, boolean z) throws IOException, DMLException, ParseException {
        this.ml.reset();
        this.ml.registerInput("left", this);
        this.ml.registerOutput("output");
        MLOutput executeScript = this.ml.executeScript(getScalarBinaryScript(str, d.doubleValue(), z));
        RDD rdd = executeScript.getBinaryBlockedRDD("output").map(new GetMLBlock()).rdd();
        StructType defaultSchemaForBinaryBlock = MLBlock.getDefaultSchemaForBinaryBlock();
        return new MLMatrix(sqlContext().createDataFrame(rdd.toJavaRDD(), defaultSchemaForBinaryBlock), executeScript.getMatrixCharacteristics("output"), this.ml);
    }

    public MLMatrix $greater(MLMatrix mLMatrix) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(mLMatrix, ">");
    }

    public MLMatrix $less(MLMatrix mLMatrix) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(mLMatrix, "<");
    }

    public MLMatrix $greater$eq(MLMatrix mLMatrix) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(mLMatrix, ">=");
    }

    public MLMatrix $less$eq(MLMatrix mLMatrix) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(mLMatrix, "<=");
    }

    public MLMatrix $eq$eq(MLMatrix mLMatrix) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(mLMatrix, "==");
    }

    public MLMatrix $bang$eq(MLMatrix mLMatrix) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(mLMatrix, "!=");
    }

    public MLMatrix $up(MLMatrix mLMatrix) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(mLMatrix, "^");
    }

    public MLMatrix exp(MLMatrix mLMatrix) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(mLMatrix, "^");
    }

    public MLMatrix $plus(MLMatrix mLMatrix) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(mLMatrix, Marker.ANY_NON_NULL_MARKER);
    }

    public MLMatrix add(MLMatrix mLMatrix) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(mLMatrix, Marker.ANY_NON_NULL_MARKER);
    }

    public MLMatrix $minus(MLMatrix mLMatrix) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(mLMatrix, HelpFormatter.DEFAULT_OPT_PREFIX);
    }

    public MLMatrix minus(MLMatrix mLMatrix) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(mLMatrix, HelpFormatter.DEFAULT_OPT_PREFIX);
    }

    public MLMatrix $times(MLMatrix mLMatrix) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(mLMatrix, "*");
    }

    public MLMatrix elementWiseMultiply(MLMatrix mLMatrix) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(mLMatrix, "*");
    }

    public MLMatrix $div(MLMatrix mLMatrix) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(mLMatrix, "/");
    }

    public MLMatrix divide(MLMatrix mLMatrix) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(mLMatrix, "/");
    }

    public MLMatrix $percent$div$percent(MLMatrix mLMatrix) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(mLMatrix, "%/%");
    }

    public MLMatrix integerDivision(MLMatrix mLMatrix) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(mLMatrix, "%/%");
    }

    public MLMatrix $percent$percent(MLMatrix mLMatrix) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(mLMatrix, "%%");
    }

    public MLMatrix modulus(MLMatrix mLMatrix) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(mLMatrix, "%%");
    }

    public MLMatrix $percent$times$percent(MLMatrix mLMatrix) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(mLMatrix, "%*%");
    }

    public MLMatrix multiply(MLMatrix mLMatrix) throws IOException, DMLException, ParseException {
        return matrixBinaryOp(mLMatrix, "%*%");
    }

    public MLMatrix transpose() throws IOException, DMLException, ParseException {
        this.ml.reset();
        this.ml.registerInput("left", this);
        this.ml.registerOutput("output");
        MLOutput executeScript = this.ml.executeScript("left = read(\"\");output = t(left); " + writeStmt);
        RDD rdd = executeScript.getBinaryBlockedRDD("output").map(new GetMLBlock()).rdd();
        return new MLMatrix(sqlContext().createDataFrame(rdd.toJavaRDD(), MLBlock.getDefaultSchemaForBinaryBlock()), executeScript.getMatrixCharacteristics("output"), this.ml);
    }

    public MLMatrix $plus(Double d) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(d, Marker.ANY_NON_NULL_MARKER, false);
    }

    public MLMatrix add(Double d) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(d, Marker.ANY_NON_NULL_MARKER, false);
    }

    public MLMatrix $minus(Double d) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(d, HelpFormatter.DEFAULT_OPT_PREFIX, false);
    }

    public MLMatrix minus(Double d) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(d, HelpFormatter.DEFAULT_OPT_PREFIX, false);
    }

    public MLMatrix $times(Double d) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(d, "*", false);
    }

    public MLMatrix elementWiseMultiply(Double d) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(d, "*", false);
    }

    public MLMatrix $div(Double d) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(d, "/", false);
    }

    public MLMatrix divide(Double d) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(d, "/", false);
    }

    public MLMatrix $greater(Double d) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(d, ">", false);
    }

    public MLMatrix $less(Double d) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(d, "<", false);
    }

    public MLMatrix $greater$eq(Double d) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(d, ">=", false);
    }

    public MLMatrix $less$eq(Double d) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(d, "<=", false);
    }

    public MLMatrix $eq$eq(Double d) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(d, "==", false);
    }

    public MLMatrix $bang$eq(Double d) throws IOException, DMLException, ParseException {
        return scalarBinaryOp(d, "!=", false);
    }
}
