1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.apache.hadoop.hbase.mapreduce.hadoopbackport;
20
21 import java.io.IOException;
22 import java.lang.reflect.Constructor;
23 import java.util.ArrayList;
24 import java.util.Arrays;
25 import java.util.List;
26 import java.util.Random;
27
28 import org.apache.commons.logging.Log;
29 import org.apache.commons.logging.LogFactory;
30
31 import org.apache.hadoop.conf.Configuration;
32 import org.apache.hadoop.conf.Configured;
33 import org.apache.hadoop.fs.FileSystem;
34 import org.apache.hadoop.fs.Path;
35 import org.apache.hadoop.io.NullWritable;
36 import org.apache.hadoop.io.RawComparator;
37 import org.apache.hadoop.io.SequenceFile;
38 import org.apache.hadoop.io.WritableComparable;
39 import org.apache.hadoop.mapreduce.InputFormat;
40 import org.apache.hadoop.mapreduce.InputSplit;
41 import org.apache.hadoop.mapreduce.Job;
42 import org.apache.hadoop.mapreduce.RecordReader;
43 import org.apache.hadoop.mapreduce.TaskAttemptContext;
44 import org.apache.hadoop.mapreduce.TaskAttemptID;
45 import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
46 import org.apache.hadoop.util.ReflectionUtils;
47 import org.apache.hadoop.util.Tool;
48 import org.apache.hadoop.util.ToolRunner;
49
50
51
52
53
54
55
56
57
58 public class InputSampler<K,V> extends Configured implements Tool {
59
60 private static final Log LOG = LogFactory.getLog(InputSampler.class);
61
62 static int printUsage() {
63 System.out.println("sampler -r <reduces>\n" +
64 " [-inFormat <input format class>]\n" +
65 " [-keyClass <map input & output key class>]\n" +
66 " [-splitRandom <double pcnt> <numSamples> <maxsplits> | " +
67 " // Sample from random splits at random (general)\n" +
68 " -splitSample <numSamples> <maxsplits> | " +
69 " // Sample from first records in splits (random data)\n"+
70 " -splitInterval <double pcnt> <maxsplits>]" +
71 " // Sample from splits at intervals (sorted data)");
72 System.out.println("Default sampler: -splitRandom 0.1 10000 10");
73 ToolRunner.printGenericCommandUsage(System.out);
74 return -1;
75 }
76
77 public InputSampler(Configuration conf) {
78 setConf(conf);
79 }
80
81
82
83
84
85 public interface Sampler<K,V> {
86
87
88
89
90 K[] getSample(InputFormat<K,V> inf, Job job)
91 throws IOException, InterruptedException;
92 }
93
94
95
96
97
98 public static class SplitSampler<K,V> implements Sampler<K,V> {
99
100 private final int numSamples;
101 private final int maxSplitsSampled;
102
103
104
105
106
107
108
109 public SplitSampler(int numSamples) {
110 this(numSamples, Integer.MAX_VALUE);
111 }
112
113
114
115
116
117
118
119 public SplitSampler(int numSamples, int maxSplitsSampled) {
120 this.numSamples = numSamples;
121 this.maxSplitsSampled = maxSplitsSampled;
122 }
123
124
125
126
127 @SuppressWarnings("unchecked")
128 @Override
129 public K[] getSample(InputFormat<K,V> inf, Job job)
130 throws IOException, InterruptedException {
131 List<InputSplit> splits = inf.getSplits(job);
132 ArrayList<K> samples = new ArrayList<K>(numSamples);
133 int splitsToSample = Math.min(maxSplitsSampled, splits.size());
134 int samplesPerSplit = numSamples / splitsToSample;
135 long records = 0;
136 for (int i = 0; i < splitsToSample; ++i) {
137 TaskAttemptContext samplingContext = getTaskAttemptContext(job);
138 RecordReader<K,V> reader = inf.createRecordReader(
139 splits.get(i), samplingContext);
140 reader.initialize(splits.get(i), samplingContext);
141 while (reader.nextKeyValue()) {
142 samples.add(ReflectionUtils.copy(job.getConfiguration(),
143 reader.getCurrentKey(), null));
144 ++records;
145 if ((i+1) * samplesPerSplit <= records) {
146 break;
147 }
148 }
149 reader.close();
150 }
151 return (K[])samples.toArray();
152 }
153 }
154
155
156
157
158
159
160
161
162
163
164
165
166 public static TaskAttemptContext getTaskAttemptContext(final Job job)
167 throws IOException {
168 Constructor<?> c;
169 try {
170 if (TaskAttemptContext.class.isInterface()) {
171
172 Class<?> clazz = Class.forName("org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl");
173 c = clazz.getConstructor(Configuration.class, TaskAttemptID.class);
174 } else {
175
176 c = TaskAttemptContext.class.getConstructor(Configuration.class, TaskAttemptID.class);
177 }
178 } catch (Exception e) {
179 throw new IOException("Failed getting constructor", e);
180 }
181 try {
182 return (TaskAttemptContext)c.newInstance(job.getConfiguration(), new TaskAttemptID());
183 } catch (Exception e) {
184 throw new IOException("Failed creating instance", e);
185 }
186 }
187
188
189
190
191
192
193 public static class RandomSampler<K,V> implements Sampler<K,V> {
194 private double freq;
195 private final int numSamples;
196 private final int maxSplitsSampled;
197
198
199
200
201
202
203
204
205 public RandomSampler(double freq, int numSamples) {
206 this(freq, numSamples, Integer.MAX_VALUE);
207 }
208
209
210
211
212
213
214
215
216 public RandomSampler(double freq, int numSamples, int maxSplitsSampled) {
217 this.freq = freq;
218 this.numSamples = numSamples;
219 this.maxSplitsSampled = maxSplitsSampled;
220 }
221
222
223
224
225
226
227
228 @SuppressWarnings("unchecked")
229 @Override
230 public K[] getSample(InputFormat<K,V> inf, Job job)
231 throws IOException, InterruptedException {
232 List<InputSplit> splits = inf.getSplits(job);
233 ArrayList<K> samples = new ArrayList<K>(numSamples);
234 int splitsToSample = Math.min(maxSplitsSampled, splits.size());
235
236 Random r = new Random();
237 long seed = r.nextLong();
238 r.setSeed(seed);
239 LOG.debug("seed: " + seed);
240
241 for (int i = 0; i < splits.size(); ++i) {
242 InputSplit tmp = splits.get(i);
243 int j = r.nextInt(splits.size());
244 splits.set(i, splits.get(j));
245 splits.set(j, tmp);
246 }
247
248
249
250 for (int i = 0; i < splitsToSample ||
251 (i < splits.size() && samples.size() < numSamples); ++i) {
252 TaskAttemptContext samplingContext = getTaskAttemptContext(job);
253 RecordReader<K,V> reader = inf.createRecordReader(
254 splits.get(i), samplingContext);
255 reader.initialize(splits.get(i), samplingContext);
256 while (reader.nextKeyValue()) {
257 if (r.nextDouble() <= freq) {
258 if (samples.size() < numSamples) {
259 samples.add(ReflectionUtils.copy(job.getConfiguration(),
260 reader.getCurrentKey(), null));
261 } else {
262
263
264
265
266 int ind = r.nextInt(numSamples);
267 if (ind != numSamples) {
268 samples.set(ind, ReflectionUtils.copy(job.getConfiguration(),
269 reader.getCurrentKey(), null));
270 }
271 freq *= (numSamples - 1) / (double) numSamples;
272 }
273 }
274 }
275 reader.close();
276 }
277 return (K[])samples.toArray();
278 }
279 }
280
281
282
283
284
285 public static class IntervalSampler<K,V> implements Sampler<K,V> {
286 private final double freq;
287 private final int maxSplitsSampled;
288
289
290
291
292
293 public IntervalSampler(double freq) {
294 this(freq, Integer.MAX_VALUE);
295 }
296
297
298
299
300
301
302
303 public IntervalSampler(double freq, int maxSplitsSampled) {
304 this.freq = freq;
305 this.maxSplitsSampled = maxSplitsSampled;
306 }
307
308
309
310
311
312
313 @SuppressWarnings("unchecked")
314 @Override
315 public K[] getSample(InputFormat<K,V> inf, Job job)
316 throws IOException, InterruptedException {
317 List<InputSplit> splits = inf.getSplits(job);
318 ArrayList<K> samples = new ArrayList<K>();
319 int splitsToSample = Math.min(maxSplitsSampled, splits.size());
320 long records = 0;
321 long kept = 0;
322 for (int i = 0; i < splitsToSample; ++i) {
323 TaskAttemptContext samplingContext = getTaskAttemptContext(job);
324 RecordReader<K,V> reader = inf.createRecordReader(
325 splits.get(i), samplingContext);
326 reader.initialize(splits.get(i), samplingContext);
327 while (reader.nextKeyValue()) {
328 ++records;
329 if ((double) kept / records < freq) {
330 samples.add(ReflectionUtils.copy(job.getConfiguration(),
331 reader.getCurrentKey(), null));
332 ++kept;
333 }
334 }
335 reader.close();
336 }
337 return (K[])samples.toArray();
338 }
339 }
340
341
342
343
344
345
346
347 @SuppressWarnings("unchecked")
348 public static <K,V> void writePartitionFile(Job job, Sampler<K,V> sampler)
349 throws IOException, ClassNotFoundException, InterruptedException {
350 Configuration conf = job.getConfiguration();
351 final InputFormat inf =
352 ReflectionUtils.newInstance(job.getInputFormatClass(), conf);
353 int numPartitions = job.getNumReduceTasks();
354 K[] samples = sampler.getSample(inf, job);
355 LOG.info("Using " + samples.length + " samples");
356 RawComparator<K> comparator =
357 (RawComparator<K>) job.getSortComparator();
358 Arrays.sort(samples, comparator);
359 Path dst = new Path(TotalOrderPartitioner.getPartitionFile(conf));
360 FileSystem fs = dst.getFileSystem(conf);
361 if (fs.exists(dst)) {
362 fs.delete(dst, false);
363 }
364 SequenceFile.Writer writer = SequenceFile.createWriter(fs,
365 conf, dst, job.getMapOutputKeyClass(), NullWritable.class);
366 NullWritable nullValue = NullWritable.get();
367 float stepSize = samples.length / (float) numPartitions;
368 int last = -1;
369 for(int i = 1; i < numPartitions; ++i) {
370 int k = Math.round(stepSize * i);
371 while (last >= k && comparator.compare(samples[last], samples[k]) == 0) {
372 ++k;
373 }
374 writer.append(samples[k], nullValue);
375 last = k;
376 }
377 writer.close();
378 }
379
380
381
382
383
384 @Override
385 public int run(String[] args) throws Exception {
386 Job job = new Job(getConf());
387 ArrayList<String> otherArgs = new ArrayList<String>();
388 Sampler<K,V> sampler = null;
389 for(int i=0; i < args.length; ++i) {
390 try {
391 if ("-r".equals(args[i])) {
392 job.setNumReduceTasks(Integer.parseInt(args[++i]));
393 } else if ("-inFormat".equals(args[i])) {
394 job.setInputFormatClass(
395 Class.forName(args[++i]).asSubclass(InputFormat.class));
396 } else if ("-keyClass".equals(args[i])) {
397 job.setMapOutputKeyClass(
398 Class.forName(args[++i]).asSubclass(WritableComparable.class));
399 } else if ("-splitSample".equals(args[i])) {
400 int numSamples = Integer.parseInt(args[++i]);
401 int maxSplits = Integer.parseInt(args[++i]);
402 if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE;
403 sampler = new SplitSampler<K,V>(numSamples, maxSplits);
404 } else if ("-splitRandom".equals(args[i])) {
405 double pcnt = Double.parseDouble(args[++i]);
406 int numSamples = Integer.parseInt(args[++i]);
407 int maxSplits = Integer.parseInt(args[++i]);
408 if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE;
409 sampler = new RandomSampler<K,V>(pcnt, numSamples, maxSplits);
410 } else if ("-splitInterval".equals(args[i])) {
411 double pcnt = Double.parseDouble(args[++i]);
412 int maxSplits = Integer.parseInt(args[++i]);
413 if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE;
414 sampler = new IntervalSampler<K,V>(pcnt, maxSplits);
415 } else {
416 otherArgs.add(args[i]);
417 }
418 } catch (NumberFormatException except) {
419 System.out.println("ERROR: Integer expected instead of " + args[i]);
420 return printUsage();
421 } catch (ArrayIndexOutOfBoundsException except) {
422 System.out.println("ERROR: Required parameter missing from " +
423 args[i-1]);
424 return printUsage();
425 }
426 }
427 if (job.getNumReduceTasks() <= 1) {
428 System.err.println("Sampler requires more than one reducer");
429 return printUsage();
430 }
431 if (otherArgs.size() < 2) {
432 System.out.println("ERROR: Wrong number of parameters: ");
433 return printUsage();
434 }
435 if (null == sampler) {
436 sampler = new RandomSampler<K,V>(0.1, 10000, 10);
437 }
438
439 Path outf = new Path(otherArgs.remove(otherArgs.size() - 1));
440 TotalOrderPartitioner.setPartitionFile(job.getConfiguration(), outf);
441 for (String s: otherArgs) {
442 FileInputFormat.addInputPath(job, new Path(s));
443 }
444 InputSampler.<K,V>writePartitionFile(job, sampler);
445
446 return 0;
447 }
448
449 public static void main(String[] args) throws Exception {
450 InputSampler<?,?> sampler = new InputSampler(new Configuration());
451 int res = ToolRunner.run(sampler, args);
452 System.exit(res);
453 }
454 }