1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.apache.hadoop.hbase.mapreduce.hadoopbackport;
20
21 import java.io.IOException;
22 import java.lang.reflect.Constructor;
23 import java.util.ArrayList;
24 import java.util.Arrays;
25 import java.util.List;
26 import java.util.Random;
27
28 import org.apache.commons.logging.Log;
29 import org.apache.commons.logging.LogFactory;
30
31 import org.apache.hadoop.conf.Configuration;
32 import org.apache.hadoop.conf.Configured;
33 import org.apache.hadoop.fs.FileSystem;
34 import org.apache.hadoop.fs.Path;
35 import org.apache.hadoop.io.NullWritable;
36 import org.apache.hadoop.io.RawComparator;
37 import org.apache.hadoop.io.SequenceFile;
38 import org.apache.hadoop.io.WritableComparable;
39 import org.apache.hadoop.mapreduce.InputFormat;
40 import org.apache.hadoop.mapreduce.InputSplit;
41 import org.apache.hadoop.mapreduce.Job;
42 import org.apache.hadoop.mapreduce.RecordReader;
43 import org.apache.hadoop.mapreduce.TaskAttemptContext;
44 import org.apache.hadoop.mapreduce.TaskAttemptID;
45 import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
46 import org.apache.hadoop.util.ReflectionUtils;
47 import org.apache.hadoop.util.Tool;
48 import org.apache.hadoop.util.ToolRunner;
49
50
51
52
53
54
55
56
57
58 public class InputSampler<K,V> extends Configured implements Tool {
59
60 private static final Log LOG = LogFactory.getLog(InputSampler.class);
61
62 static int printUsage() {
63 System.out.println("sampler -r <reduces>\n" +
64 " [-inFormat <input format class>]\n" +
65 " [-keyClass <map input & output key class>]\n" +
66 " [-splitRandom <double pcnt> <numSamples> <maxsplits> | " +
67 " // Sample from random splits at random (general)\n" +
68 " -splitSample <numSamples> <maxsplits> | " +
69 " // Sample from first records in splits (random data)\n"+
70 " -splitInterval <double pcnt> <maxsplits>]" +
71 " // Sample from splits at intervals (sorted data)");
72 System.out.println("Default sampler: -splitRandom 0.1 10000 10");
73 ToolRunner.printGenericCommandUsage(System.out);
74 return -1;
75 }
76
77 public InputSampler(Configuration conf) {
78 setConf(conf);
79 }
80
81
82
83
84
85 public interface Sampler<K,V> {
86
87
88
89
90 K[] getSample(InputFormat<K,V> inf, Job job)
91 throws IOException, InterruptedException;
92 }
93
94
95
96
97
98 public static class SplitSampler<K,V> implements Sampler<K,V> {
99
100 private final int numSamples;
101 private final int maxSplitsSampled;
102
103
104
105
106
107
108
109 public SplitSampler(int numSamples) {
110 this(numSamples, Integer.MAX_VALUE);
111 }
112
113
114
115
116
117
118
119 public SplitSampler(int numSamples, int maxSplitsSampled) {
120 this.numSamples = numSamples;
121 this.maxSplitsSampled = maxSplitsSampled;
122 }
123
124
125
126
127 @SuppressWarnings("unchecked")
128 public K[] getSample(InputFormat<K,V> inf, Job job)
129 throws IOException, InterruptedException {
130 List<InputSplit> splits = inf.getSplits(job);
131 ArrayList<K> samples = new ArrayList<K>(numSamples);
132 int splitsToSample = Math.min(maxSplitsSampled, splits.size());
133 int samplesPerSplit = numSamples / splitsToSample;
134 long records = 0;
135 for (int i = 0; i < splitsToSample; ++i) {
136 TaskAttemptContext samplingContext = getTaskAttemptContext(job);
137 RecordReader<K,V> reader = inf.createRecordReader(
138 splits.get(i), samplingContext);
139 reader.initialize(splits.get(i), samplingContext);
140 while (reader.nextKeyValue()) {
141 samples.add(ReflectionUtils.copy(job.getConfiguration(),
142 reader.getCurrentKey(), null));
143 ++records;
144 if ((i+1) * samplesPerSplit <= records) {
145 break;
146 }
147 }
148 reader.close();
149 }
150 return (K[])samples.toArray();
151 }
152 }
153
154
155
156
157
158
159
160
161
162
163
164
165 public static TaskAttemptContext getTaskAttemptContext(final Job job)
166 throws IOException {
167 Constructor<TaskAttemptContext> c;
168 try {
169 c = TaskAttemptContext.class.getConstructor(Configuration.class, TaskAttemptID.class);
170 } catch (Exception e) {
171 throw new IOException("Failed getting constructor", e);
172 }
173 try {
174 return c.newInstance(job.getConfiguration(), new TaskAttemptID());
175 } catch (Exception e) {
176 throw new IOException("Failed creating instance", e);
177 }
178 }
179
180
181
182
183
184
185 public static class RandomSampler<K,V> implements Sampler<K,V> {
186 private double freq;
187 private final int numSamples;
188 private final int maxSplitsSampled;
189
190
191
192
193
194
195
196
197 public RandomSampler(double freq, int numSamples) {
198 this(freq, numSamples, Integer.MAX_VALUE);
199 }
200
201
202
203
204
205
206
207
208 public RandomSampler(double freq, int numSamples, int maxSplitsSampled) {
209 this.freq = freq;
210 this.numSamples = numSamples;
211 this.maxSplitsSampled = maxSplitsSampled;
212 }
213
214
215
216
217
218
219
220 @SuppressWarnings("unchecked")
221 public K[] getSample(InputFormat<K,V> inf, Job job)
222 throws IOException, InterruptedException {
223 List<InputSplit> splits = inf.getSplits(job);
224 ArrayList<K> samples = new ArrayList<K>(numSamples);
225 int splitsToSample = Math.min(maxSplitsSampled, splits.size());
226
227 Random r = new Random();
228 long seed = r.nextLong();
229 r.setSeed(seed);
230 LOG.debug("seed: " + seed);
231
232 for (int i = 0; i < splits.size(); ++i) {
233 InputSplit tmp = splits.get(i);
234 int j = r.nextInt(splits.size());
235 splits.set(i, splits.get(j));
236 splits.set(j, tmp);
237 }
238
239
240
241 for (int i = 0; i < splitsToSample ||
242 (i < splits.size() && samples.size() < numSamples); ++i) {
243 TaskAttemptContext samplingContext = getTaskAttemptContext(job);
244 RecordReader<K,V> reader = inf.createRecordReader(
245 splits.get(i), samplingContext);
246 reader.initialize(splits.get(i), samplingContext);
247 while (reader.nextKeyValue()) {
248 if (r.nextDouble() <= freq) {
249 if (samples.size() < numSamples) {
250 samples.add(ReflectionUtils.copy(job.getConfiguration(),
251 reader.getCurrentKey(), null));
252 } else {
253
254
255
256
257 int ind = r.nextInt(numSamples);
258 if (ind != numSamples) {
259 samples.set(ind, ReflectionUtils.copy(job.getConfiguration(),
260 reader.getCurrentKey(), null));
261 }
262 freq *= (numSamples - 1) / (double) numSamples;
263 }
264 }
265 }
266 reader.close();
267 }
268 return (K[])samples.toArray();
269 }
270 }
271
272
273
274
275
276 public static class IntervalSampler<K,V> implements Sampler<K,V> {
277 private final double freq;
278 private final int maxSplitsSampled;
279
280
281
282
283
284 public IntervalSampler(double freq) {
285 this(freq, Integer.MAX_VALUE);
286 }
287
288
289
290
291
292
293
294 public IntervalSampler(double freq, int maxSplitsSampled) {
295 this.freq = freq;
296 this.maxSplitsSampled = maxSplitsSampled;
297 }
298
299
300
301
302
303
304 @SuppressWarnings("unchecked")
305 public K[] getSample(InputFormat<K,V> inf, Job job)
306 throws IOException, InterruptedException {
307 List<InputSplit> splits = inf.getSplits(job);
308 ArrayList<K> samples = new ArrayList<K>();
309 int splitsToSample = Math.min(maxSplitsSampled, splits.size());
310 long records = 0;
311 long kept = 0;
312 for (int i = 0; i < splitsToSample; ++i) {
313 TaskAttemptContext samplingContext = getTaskAttemptContext(job);
314 RecordReader<K,V> reader = inf.createRecordReader(
315 splits.get(i), samplingContext);
316 reader.initialize(splits.get(i), samplingContext);
317 while (reader.nextKeyValue()) {
318 ++records;
319 if ((double) kept / records < freq) {
320 samples.add(ReflectionUtils.copy(job.getConfiguration(),
321 reader.getCurrentKey(), null));
322 ++kept;
323 }
324 }
325 reader.close();
326 }
327 return (K[])samples.toArray();
328 }
329 }
330
331
332
333
334
335
336
337 @SuppressWarnings("unchecked")
338 public static <K,V> void writePartitionFile(Job job, Sampler<K,V> sampler)
339 throws IOException, ClassNotFoundException, InterruptedException {
340 Configuration conf = job.getConfiguration();
341 final InputFormat inf =
342 ReflectionUtils.newInstance(job.getInputFormatClass(), conf);
343 int numPartitions = job.getNumReduceTasks();
344 K[] samples = sampler.getSample(inf, job);
345 LOG.info("Using " + samples.length + " samples");
346 RawComparator<K> comparator =
347 (RawComparator<K>) job.getSortComparator();
348 Arrays.sort(samples, comparator);
349 Path dst = new Path(TotalOrderPartitioner.getPartitionFile(conf));
350 FileSystem fs = dst.getFileSystem(conf);
351 if (fs.exists(dst)) {
352 fs.delete(dst, false);
353 }
354 SequenceFile.Writer writer = SequenceFile.createWriter(fs,
355 conf, dst, job.getMapOutputKeyClass(), NullWritable.class);
356 NullWritable nullValue = NullWritable.get();
357 float stepSize = samples.length / (float) numPartitions;
358 int last = -1;
359 for(int i = 1; i < numPartitions; ++i) {
360 int k = Math.round(stepSize * i);
361 while (last >= k && comparator.compare(samples[last], samples[k]) == 0) {
362 ++k;
363 }
364 writer.append(samples[k], nullValue);
365 last = k;
366 }
367 writer.close();
368 }
369
370
371
372
373
374 public int run(String[] args) throws Exception {
375 Job job = new Job(getConf());
376 ArrayList<String> otherArgs = new ArrayList<String>();
377 Sampler<K,V> sampler = null;
378 for(int i=0; i < args.length; ++i) {
379 try {
380 if ("-r".equals(args[i])) {
381 job.setNumReduceTasks(Integer.parseInt(args[++i]));
382 } else if ("-inFormat".equals(args[i])) {
383 job.setInputFormatClass(
384 Class.forName(args[++i]).asSubclass(InputFormat.class));
385 } else if ("-keyClass".equals(args[i])) {
386 job.setMapOutputKeyClass(
387 Class.forName(args[++i]).asSubclass(WritableComparable.class));
388 } else if ("-splitSample".equals(args[i])) {
389 int numSamples = Integer.parseInt(args[++i]);
390 int maxSplits = Integer.parseInt(args[++i]);
391 if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE;
392 sampler = new SplitSampler<K,V>(numSamples, maxSplits);
393 } else if ("-splitRandom".equals(args[i])) {
394 double pcnt = Double.parseDouble(args[++i]);
395 int numSamples = Integer.parseInt(args[++i]);
396 int maxSplits = Integer.parseInt(args[++i]);
397 if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE;
398 sampler = new RandomSampler<K,V>(pcnt, numSamples, maxSplits);
399 } else if ("-splitInterval".equals(args[i])) {
400 double pcnt = Double.parseDouble(args[++i]);
401 int maxSplits = Integer.parseInt(args[++i]);
402 if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE;
403 sampler = new IntervalSampler<K,V>(pcnt, maxSplits);
404 } else {
405 otherArgs.add(args[i]);
406 }
407 } catch (NumberFormatException except) {
408 System.out.println("ERROR: Integer expected instead of " + args[i]);
409 return printUsage();
410 } catch (ArrayIndexOutOfBoundsException except) {
411 System.out.println("ERROR: Required parameter missing from " +
412 args[i-1]);
413 return printUsage();
414 }
415 }
416 if (job.getNumReduceTasks() <= 1) {
417 System.err.println("Sampler requires more than one reducer");
418 return printUsage();
419 }
420 if (otherArgs.size() < 2) {
421 System.out.println("ERROR: Wrong number of parameters: ");
422 return printUsage();
423 }
424 if (null == sampler) {
425 sampler = new RandomSampler<K,V>(0.1, 10000, 10);
426 }
427
428 Path outf = new Path(otherArgs.remove(otherArgs.size() - 1));
429 TotalOrderPartitioner.setPartitionFile(getConf(), outf);
430 for (String s : otherArgs) {
431 FileInputFormat.addInputPath(job, new Path(s));
432 }
433 InputSampler.<K,V>writePartitionFile(job, sampler);
434
435 return 0;
436 }
437
438 public static void main(String[] args) throws Exception {
439 InputSampler<?,?> sampler = new InputSampler(new Configuration());
440 int res = ToolRunner.run(sampler, args);
441 System.exit(res);
442 }
443 }