View Javadoc

1   /**
2    * Licensed to the Apache Software Foundation (ASF) under one
3    * or more contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  The ASF licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   *     http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing, software
13   * distributed under the License is distributed on an "AS IS" BASIS,
14   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15   * See the License for the specific language governing permissions and
16   * limitations under the License.
17   */
18  
19  package org.apache.hadoop.hbase.mapreduce.hadoopbackport;
20  
21  import java.io.IOException;
22  import java.util.ArrayList;
23  import java.util.Arrays;
24  import java.util.List;
25  import java.util.Random;
26  
27  import org.apache.commons.logging.Log;
28  import org.apache.commons.logging.LogFactory;
29  
30  import org.apache.hadoop.conf.Configuration;
31  import org.apache.hadoop.conf.Configured;
32  import org.apache.hadoop.fs.FileSystem;
33  import org.apache.hadoop.fs.Path;
34  import org.apache.hadoop.io.NullWritable;
35  import org.apache.hadoop.io.RawComparator;
36  import org.apache.hadoop.io.SequenceFile;
37  import org.apache.hadoop.io.WritableComparable;
38  import org.apache.hadoop.mapreduce.InputFormat;
39  import org.apache.hadoop.mapreduce.InputSplit;
40  import org.apache.hadoop.mapreduce.Job;
41  import org.apache.hadoop.mapreduce.RecordReader;
42  import org.apache.hadoop.mapreduce.TaskAttemptContext;
43  import org.apache.hadoop.mapreduce.TaskAttemptID;
44  import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
45  import org.apache.hadoop.util.ReflectionUtils;
46  import org.apache.hadoop.util.Tool;
47  import org.apache.hadoop.util.ToolRunner;
48  
49  /**
50   * Utility for collecting samples and writing a partition file for
51   * {@link TotalOrderPartitioner}.
52   *
53   * This is an identical copy of o.a.h.mapreduce.lib.partition.TotalOrderPartitioner
54   * from Hadoop trunk at r961542, with the exception of replacing
55   * TaskAttemptContextImpl with TaskAttemptContext.
56   */
57  public class InputSampler<K,V> extends Configured implements Tool  {
58  
59    private static final Log LOG = LogFactory.getLog(InputSampler.class);
60  
61    static int printUsage() {
62      System.out.println("sampler -r <reduces>\n" +
63        "      [-inFormat <input format class>]\n" +
64        "      [-keyClass <map input & output key class>]\n" +
65        "      [-splitRandom <double pcnt> <numSamples> <maxsplits> | " +
66        "             // Sample from random splits at random (general)\n" +
67        "       -splitSample <numSamples> <maxsplits> | " +
68        "             // Sample from first records in splits (random data)\n"+
69        "       -splitInterval <double pcnt> <maxsplits>]" +
70        "             // Sample from splits at intervals (sorted data)");
71      System.out.println("Default sampler: -splitRandom 0.1 10000 10");
72      ToolRunner.printGenericCommandUsage(System.out);
73      return -1;
74    }
75  
76    public InputSampler(Configuration conf) {
77      setConf(conf);
78    }
79  
80    /**
81     * Interface to sample using an 
82     * {@link org.apache.hadoop.mapreduce.InputFormat}.
83     */
84    public interface Sampler<K,V> {
85      /**
86       * For a given job, collect and return a subset of the keys from the
87       * input data.
88       */
89      K[] getSample(InputFormat<K,V> inf, Job job) 
90      throws IOException, InterruptedException;
91    }
92  
93    /**
94     * Samples the first n records from s splits.
95     * Inexpensive way to sample random data.
96     */
97    public static class SplitSampler<K,V> implements Sampler<K,V> {
98  
99      private final int numSamples;
100     private final int maxSplitsSampled;
101 
102     /**
103      * Create a SplitSampler sampling <em>all</em> splits.
104      * Takes the first numSamples / numSplits records from each split.
105      * @param numSamples Total number of samples to obtain from all selected
106      *                   splits.
107      */
108     public SplitSampler(int numSamples) {
109       this(numSamples, Integer.MAX_VALUE);
110     }
111 
112     /**
113      * Create a new SplitSampler.
114      * @param numSamples Total number of samples to obtain from all selected
115      *                   splits.
116      * @param maxSplitsSampled The maximum number of splits to examine.
117      */
118     public SplitSampler(int numSamples, int maxSplitsSampled) {
119       this.numSamples = numSamples;
120       this.maxSplitsSampled = maxSplitsSampled;
121     }
122 
123     /**
124      * From each split sampled, take the first numSamples / numSplits records.
125      */
126     @SuppressWarnings("unchecked") // ArrayList::toArray doesn't preserve type
127     public K[] getSample(InputFormat<K,V> inf, Job job) 
128         throws IOException, InterruptedException {
129       List<InputSplit> splits = inf.getSplits(job);
130       ArrayList<K> samples = new ArrayList<K>(numSamples);
131       int splitsToSample = Math.min(maxSplitsSampled, splits.size());
132       int samplesPerSplit = numSamples / splitsToSample;
133       long records = 0;
134       for (int i = 0; i < splitsToSample; ++i) {
135         TaskAttemptContext samplingContext = new TaskAttemptContext(
136             job.getConfiguration(), new TaskAttemptID());
137         RecordReader<K,V> reader = inf.createRecordReader(
138             splits.get(i), samplingContext);
139         reader.initialize(splits.get(i), samplingContext);
140         while (reader.nextKeyValue()) {
141           samples.add(ReflectionUtils.copy(job.getConfiguration(),
142                                            reader.getCurrentKey(), null));
143           ++records;
144           if ((i+1) * samplesPerSplit <= records) {
145             break;
146           }
147         }
148         reader.close();
149       }
150       return (K[])samples.toArray();
151     }
152   }
153 
154   /**
155    * Sample from random points in the input.
156    * General-purpose sampler. Takes numSamples / maxSplitsSampled inputs from
157    * each split.
158    */
159   public static class RandomSampler<K,V> implements Sampler<K,V> {
160     private double freq;
161     private final int numSamples;
162     private final int maxSplitsSampled;
163 
164     /**
165      * Create a new RandomSampler sampling <em>all</em> splits.
166      * This will read every split at the client, which is very expensive.
167      * @param freq Probability with which a key will be chosen.
168      * @param numSamples Total number of samples to obtain from all selected
169      *                   splits.
170      */
171     public RandomSampler(double freq, int numSamples) {
172       this(freq, numSamples, Integer.MAX_VALUE);
173     }
174 
175     /**
176      * Create a new RandomSampler.
177      * @param freq Probability with which a key will be chosen.
178      * @param numSamples Total number of samples to obtain from all selected
179      *                   splits.
180      * @param maxSplitsSampled The maximum number of splits to examine.
181      */
182     public RandomSampler(double freq, int numSamples, int maxSplitsSampled) {
183       this.freq = freq;
184       this.numSamples = numSamples;
185       this.maxSplitsSampled = maxSplitsSampled;
186     }
187 
188     /**
189      * Randomize the split order, then take the specified number of keys from
190      * each split sampled, where each key is selected with the specified
191      * probability and possibly replaced by a subsequently selected key when
192      * the quota of keys from that split is satisfied.
193      */
194     @SuppressWarnings("unchecked") // ArrayList::toArray doesn't preserve type
195     public K[] getSample(InputFormat<K,V> inf, Job job) 
196         throws IOException, InterruptedException {
197       List<InputSplit> splits = inf.getSplits(job);
198       ArrayList<K> samples = new ArrayList<K>(numSamples);
199       int splitsToSample = Math.min(maxSplitsSampled, splits.size());
200 
201       Random r = new Random();
202       long seed = r.nextLong();
203       r.setSeed(seed);
204       LOG.debug("seed: " + seed);
205       // shuffle splits
206       for (int i = 0; i < splits.size(); ++i) {
207         InputSplit tmp = splits.get(i);
208         int j = r.nextInt(splits.size());
209         splits.set(i, splits.get(j));
210         splits.set(j, tmp);
211       }
212       // our target rate is in terms of the maximum number of sample splits,
213       // but we accept the possibility of sampling additional splits to hit
214       // the target sample keyset
215       for (int i = 0; i < splitsToSample ||
216                      (i < splits.size() && samples.size() < numSamples); ++i) {
217         TaskAttemptContext samplingContext = new TaskAttemptContext(
218             job.getConfiguration(), new TaskAttemptID());
219         RecordReader<K,V> reader = inf.createRecordReader(
220             splits.get(i), samplingContext);
221         reader.initialize(splits.get(i), samplingContext);
222         while (reader.nextKeyValue()) {
223           if (r.nextDouble() <= freq) {
224             if (samples.size() < numSamples) {
225               samples.add(ReflectionUtils.copy(job.getConfiguration(),
226                                                reader.getCurrentKey(), null));
227             } else {
228               // When exceeding the maximum number of samples, replace a
229               // random element with this one, then adjust the frequency
230               // to reflect the possibility of existing elements being
231               // pushed out
232               int ind = r.nextInt(numSamples);
233               if (ind != numSamples) {
234                 samples.set(ind, ReflectionUtils.copy(job.getConfiguration(),
235                                  reader.getCurrentKey(), null));
236               }
237               freq *= (numSamples - 1) / (double) numSamples;
238             }
239           }
240         }
241         reader.close();
242       }
243       return (K[])samples.toArray();
244     }
245   }
246 
247   /**
248    * Sample from s splits at regular intervals.
249    * Useful for sorted data.
250    */
251   public static class IntervalSampler<K,V> implements Sampler<K,V> {
252     private final double freq;
253     private final int maxSplitsSampled;
254 
255     /**
256      * Create a new IntervalSampler sampling <em>all</em> splits.
257      * @param freq The frequency with which records will be emitted.
258      */
259     public IntervalSampler(double freq) {
260       this(freq, Integer.MAX_VALUE);
261     }
262 
263     /**
264      * Create a new IntervalSampler.
265      * @param freq The frequency with which records will be emitted.
266      * @param maxSplitsSampled The maximum number of splits to examine.
267      * @see #getSample
268      */
269     public IntervalSampler(double freq, int maxSplitsSampled) {
270       this.freq = freq;
271       this.maxSplitsSampled = maxSplitsSampled;
272     }
273 
274     /**
275      * For each split sampled, emit when the ratio of the number of records
276      * retained to the total record count is less than the specified
277      * frequency.
278      */
279     @SuppressWarnings("unchecked") // ArrayList::toArray doesn't preserve type
280     public K[] getSample(InputFormat<K,V> inf, Job job) 
281         throws IOException, InterruptedException {
282       List<InputSplit> splits = inf.getSplits(job);
283       ArrayList<K> samples = new ArrayList<K>();
284       int splitsToSample = Math.min(maxSplitsSampled, splits.size());
285       long records = 0;
286       long kept = 0;
287       for (int i = 0; i < splitsToSample; ++i) {
288         TaskAttemptContext samplingContext = new TaskAttemptContext(
289             job.getConfiguration(), new TaskAttemptID());
290         RecordReader<K,V> reader = inf.createRecordReader(
291             splits.get(i), samplingContext);
292         reader.initialize(splits.get(i), samplingContext);
293         while (reader.nextKeyValue()) {
294           ++records;
295           if ((double) kept / records < freq) {
296             samples.add(ReflectionUtils.copy(job.getConfiguration(),
297                                  reader.getCurrentKey(), null));
298             ++kept;
299           }
300         }
301         reader.close();
302       }
303       return (K[])samples.toArray();
304     }
305   }
306 
307   /**
308    * Write a partition file for the given job, using the Sampler provided.
309    * Queries the sampler for a sample keyset, sorts by the output key
310    * comparator, selects the keys for each rank, and writes to the destination
311    * returned from {@link TotalOrderPartitioner#getPartitionFile}.
312    */
313   @SuppressWarnings("unchecked") // getInputFormat, getOutputKeyComparator
314   public static <K,V> void writePartitionFile(Job job, Sampler<K,V> sampler) 
315       throws IOException, ClassNotFoundException, InterruptedException {
316     Configuration conf = job.getConfiguration();
317     final InputFormat inf = 
318         ReflectionUtils.newInstance(job.getInputFormatClass(), conf);
319     int numPartitions = job.getNumReduceTasks();
320     K[] samples = sampler.getSample(inf, job);
321     LOG.info("Using " + samples.length + " samples");
322     RawComparator<K> comparator =
323       (RawComparator<K>) job.getSortComparator();
324     Arrays.sort(samples, comparator);
325     Path dst = new Path(TotalOrderPartitioner.getPartitionFile(conf));
326     FileSystem fs = dst.getFileSystem(conf);
327     if (fs.exists(dst)) {
328       fs.delete(dst, false);
329     }
330     SequenceFile.Writer writer = SequenceFile.createWriter(fs, 
331       conf, dst, job.getMapOutputKeyClass(), NullWritable.class);
332     NullWritable nullValue = NullWritable.get();
333     float stepSize = samples.length / (float) numPartitions;
334     int last = -1;
335     for(int i = 1; i < numPartitions; ++i) {
336       int k = Math.round(stepSize * i);
337       while (last >= k && comparator.compare(samples[last], samples[k]) == 0) {
338         ++k;
339       }
340       writer.append(samples[k], nullValue);
341       last = k;
342     }
343     writer.close();
344   }
345 
346   /**
347    * Driver for InputSampler from the command line.
348    * Configures a JobConf instance and calls {@link #writePartitionFile}.
349    */
350   public int run(String[] args) throws Exception {
351     Job job = new Job(getConf());
352     ArrayList<String> otherArgs = new ArrayList<String>();
353     Sampler<K,V> sampler = null;
354     for(int i=0; i < args.length; ++i) {
355       try {
356         if ("-r".equals(args[i])) {
357           job.setNumReduceTasks(Integer.parseInt(args[++i]));
358         } else if ("-inFormat".equals(args[i])) {
359           job.setInputFormatClass(
360               Class.forName(args[++i]).asSubclass(InputFormat.class));
361         } else if ("-keyClass".equals(args[i])) {
362           job.setMapOutputKeyClass(
363               Class.forName(args[++i]).asSubclass(WritableComparable.class));
364         } else if ("-splitSample".equals(args[i])) {
365           int numSamples = Integer.parseInt(args[++i]);
366           int maxSplits = Integer.parseInt(args[++i]);
367           if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE;
368           sampler = new SplitSampler<K,V>(numSamples, maxSplits);
369         } else if ("-splitRandom".equals(args[i])) {
370           double pcnt = Double.parseDouble(args[++i]);
371           int numSamples = Integer.parseInt(args[++i]);
372           int maxSplits = Integer.parseInt(args[++i]);
373           if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE;
374           sampler = new RandomSampler<K,V>(pcnt, numSamples, maxSplits);
375         } else if ("-splitInterval".equals(args[i])) {
376           double pcnt = Double.parseDouble(args[++i]);
377           int maxSplits = Integer.parseInt(args[++i]);
378           if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE;
379           sampler = new IntervalSampler<K,V>(pcnt, maxSplits);
380         } else {
381           otherArgs.add(args[i]);
382         }
383       } catch (NumberFormatException except) {
384         System.out.println("ERROR: Integer expected instead of " + args[i]);
385         return printUsage();
386       } catch (ArrayIndexOutOfBoundsException except) {
387         System.out.println("ERROR: Required parameter missing from " +
388             args[i-1]);
389         return printUsage();
390       }
391     }
392     if (job.getNumReduceTasks() <= 1) {
393       System.err.println("Sampler requires more than one reducer");
394       return printUsage();
395     }
396     if (otherArgs.size() < 2) {
397       System.out.println("ERROR: Wrong number of parameters: ");
398       return printUsage();
399     }
400     if (null == sampler) {
401       sampler = new RandomSampler<K,V>(0.1, 10000, 10);
402     }
403 
404     Path outf = new Path(otherArgs.remove(otherArgs.size() - 1));
405     TotalOrderPartitioner.setPartitionFile(getConf(), outf);
406     for (String s : otherArgs) {
407       FileInputFormat.addInputPath(job, new Path(s));
408     }
409     InputSampler.<K,V>writePartitionFile(job, sampler);
410 
411     return 0;
412   }
413 
414   public static void main(String[] args) throws Exception {
415     InputSampler<?,?> sampler = new InputSampler(new Configuration());
416     int res = ToolRunner.run(sampler, args);
417     System.exit(res);
418   }
419 }