package org.apache.sysml.runtime.transform;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.ByteWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.JobConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function2;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.matrix.CSVReblockMR;
import org.apache.sysml.runtime.matrix.data.CSVFileFormatProperties;
import org.apache.sysml.runtime.matrix.data.Pair;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysml/runtime/transform/GenTfMtdSPARK.class */
public class GenTfMtdSPARK {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/transform/GenTfMtdSPARK$GenTfMtdMap.class */
    public static class GenTfMtdMap implements Function2<Integer, Iterator<Tuple2<LongWritable, Text>>, Iterator<Tuple2<Integer, DistinctValue>>> {
        private static final long serialVersionUID = -5622745445470598215L;
        private TfUtils _agents;

        public GenTfMtdMap(boolean z, String str, String str2, String str3, long j, String str4) throws IllegalArgumentException, IOException, JSONException {
            this._agents = null;
            this._agents = new TfUtils(str4, z, str, TfUtils.parseNAStrings(str2), new JSONObject(str3), j, null, null, null);
        }

        public Iterator<Tuple2<Integer, DistinctValue>> call(Integer num, Iterator<Tuple2<LongWritable, Text>> it) throws Exception {
            boolean z = true;
            long j = -1;
            while (it.hasNext()) {
                Tuple2<LongWritable, Text> next = it.next();
                if (z) {
                    z = false;
                    j = ((LongWritable) next._1()).get();
                    if (num.intValue() == 0 && this._agents.hasHeader() && j == 0) {
                    }
                }
                this._agents.prepareTfMtd(((Text) next._2()).toString());
            }
            ArrayList<Pair<Integer, DistinctValue>> arrayList = new ArrayList<>();
            this._agents.getMVImputeAgent().mapOutputTransformationMetadata(num.intValue(), arrayList, this._agents);
            this._agents.getRecodeAgent().mapOutputTransformationMetadata(num.intValue(), arrayList, this._agents);
            this._agents.getBinAgent().mapOutputTransformationMetadata(num.intValue(), arrayList, this._agents);
            arrayList.add(new Pair<>(Integer.valueOf((int) (this._agents.getNumCols() + 1)), new DistinctValue(new CSVReblockMR.OffsetCount("Partition" + num, j, this._agents.getTotal()))));
            return GenTfMtdSPARK.toTuple2List(arrayList).iterator();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/transform/GenTfMtdSPARK$GenTfMtdReduce.class */
    public static class GenTfMtdReduce implements FlatMapFunction<Tuple2<Integer, Iterable<DistinctValue>>, Long> {
        private static final long serialVersionUID = -2733233671193035242L;
        private TfUtils _agents;

        public GenTfMtdReduce(boolean z, String str, String str2, String str3, String str4, String str5, String str6, long j) throws IOException, JSONException {
            this._agents = null;
            this._agents = new TfUtils(str3, z, str, TfUtils.parseNAStrings(str2), new JSONObject(str6), j, str4, str5, null);
        }

        public Iterable<Long> call(Tuple2<Integer, Iterable<DistinctValue>> tuple2) throws Exception {
            int intValue = ((Integer) tuple2._1()).intValue();
            Iterator<DistinctValue> it = ((Iterable) tuple2._2()).iterator();
            JobConf jobConf = new JobConf();
            FileSystem fileSystem = FileSystem.get(jobConf);
            ArrayList arrayList = new ArrayList();
            if (intValue < 0) {
                this._agents.getMVImputeAgent().mergeAndOutputTransformationMetadata(it, this._agents.getTfMtdDir(), intValue * (-1), fileSystem, this._agents);
                arrayList.add(0L);
            } else if (intValue == this._agents.getNumCols() + 1) {
                ArrayList arrayList2 = new ArrayList();
                while (it.hasNext()) {
                    arrayList2.add(new CSVReblockMR.OffsetCount(it.next().getOffsetCount()));
                }
                Collections.sort(arrayList2);
                SequenceFile.Writer writer = new SequenceFile.Writer(fileSystem, jobConf, new Path(this._agents.getOffsetFile() + "/part-00000"), ByteWritable.class, CSVReblockMR.OffsetCount.class);
                long j = 0;
                Iterator it2 = arrayList2.iterator();
                while (it2.hasNext()) {
                    CSVReblockMR.OffsetCount offsetCount = (CSVReblockMR.OffsetCount) it2.next();
                    long j2 = offsetCount.count;
                    offsetCount.count = j;
                    writer.append(new ByteWritable((byte) 0), offsetCount);
                    j += j2;
                }
                writer.close();
                arrayList2.clear();
                arrayList.add(Long.valueOf(j));
            } else {
                this._agents.getRecodeAgent().mergeAndOutputTransformationMetadata(it, this._agents.getTfMtdDir(), intValue, fileSystem, this._agents);
                arrayList.add(0L);
            }
            return arrayList;
        }
    }

    public static long runSparkJob(SparkExecutionContext sparkExecutionContext, JavaRDD<Tuple2<LongWritable, Text>> javaRDD, String str, String str2, String str3, CSVFileFormatProperties cSVFileFormatProperties, long j, String str4) throws IOException, ClassNotFoundException, InterruptedException, IllegalArgumentException, JSONException {
        return ((Long) JavaPairRDD.fromJavaRDD(javaRDD.mapPartitionsWithIndex(new GenTfMtdMap(cSVFileFormatProperties.hasHeader(), cSVFileFormatProperties.getDelim(), cSVFileFormatProperties.getNAStrings(), str2, j, str4), true)).groupByKey().flatMap(new GenTfMtdReduce(cSVFileFormatProperties.hasHeader(), cSVFileFormatProperties.getDelim(), cSVFileFormatProperties.getNAStrings(), str4, str, str3, str2, j)).reduce(new Function2<Long, Long, Long>() { // from class: org.apache.sysml.runtime.transform.GenTfMtdSPARK.1
            private static final long serialVersionUID = 1263336168859959795L;

            public Long call(Long l, Long l2) throws Exception {
                return Long.valueOf(l.longValue() + l2.longValue());
            }
        })).longValue();
    }

    public static List<Tuple2<Integer, DistinctValue>> toTuple2List(List<Pair<Integer, DistinctValue>> list) {
        ArrayList arrayList = new ArrayList();
        for (Pair<Integer, DistinctValue> pair : list) {
            arrayList.add(new Tuple2(pair.getKey(), pair.getValue()));
        }
        return arrayList;
    }
}
