1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.hadoop.hbase.mapreduce.hadoopbackport;
19
20 import java.io.IOException;
21 import java.util.ArrayList;
22 import java.util.Arrays;
23 import java.util.List;
24
25 import org.junit.Test;
26 import org.junit.experimental.categories.Category;
27
28 import static org.junit.Assert.*;
29
30 import org.apache.hadoop.hbase.SmallTests;
31 import org.apache.hadoop.io.IntWritable;
32 import org.apache.hadoop.io.NullWritable;
33 import org.apache.hadoop.mapreduce.InputFormat;
34 import org.apache.hadoop.mapreduce.InputSplit;
35 import org.apache.hadoop.mapreduce.Job;
36 import org.apache.hadoop.mapreduce.JobContext;
37 import org.apache.hadoop.mapreduce.RecordReader;
38 import org.apache.hadoop.mapreduce.TaskAttemptContext;
39
40
41
42
43 @Category(SmallTests.class)
44 public class TestInputSampler {
45
46 static class SequentialSplit extends InputSplit {
47 private int i;
48 SequentialSplit(int i) {
49 this.i = i;
50 }
51 @Override
52 public long getLength() { return 0; }
53 @Override
54 public String[] getLocations() { return new String[0]; }
55 public int getInit() { return i; }
56 }
57
58 static class TestInputSamplerIF
59 extends InputFormat<IntWritable,NullWritable> {
60
61 final int maxDepth;
62 final ArrayList<InputSplit> splits = new ArrayList<InputSplit>();
63
64 TestInputSamplerIF(int maxDepth, int numSplits, int... splitInit) {
65 this.maxDepth = maxDepth;
66 assert splitInit.length == numSplits;
67 for (int i = 0; i < numSplits; ++i) {
68 splits.add(new SequentialSplit(splitInit[i]));
69 }
70 }
71
72 @Override
73 public List<InputSplit> getSplits(JobContext context)
74 throws IOException, InterruptedException {
75 return splits;
76 }
77
78 @Override
79 public RecordReader<IntWritable,NullWritable> createRecordReader(
80 final InputSplit split, TaskAttemptContext context)
81 throws IOException, InterruptedException {
82 return new RecordReader<IntWritable,NullWritable>() {
83 private int maxVal;
84 private final IntWritable i = new IntWritable();
85 @Override
86 public void initialize(InputSplit split, TaskAttemptContext context)
87 throws IOException, InterruptedException {
88 i.set(((SequentialSplit)split).getInit() - 1);
89 maxVal = i.get() + maxDepth + 1;
90 }
91 @Override
92 public boolean nextKeyValue() {
93 i.set(i.get() + 1);
94 return i.get() < maxVal;
95 }
96 @Override
97 public IntWritable getCurrentKey() { return i; }
98 @Override
99 public NullWritable getCurrentValue() { return NullWritable.get(); }
100 @Override
101 public float getProgress() { return 1.0f; }
102 @Override
103 public void close() { }
104 };
105 }
106
107 }
108
109
110
111
112
113 @Test
114 @SuppressWarnings("unchecked")
115 public void testSplitSampler() throws Exception {
116 final int TOT_SPLITS = 15;
117 final int NUM_SPLITS = 5;
118 final int STEP_SAMPLE = 5;
119 final int NUM_SAMPLES = NUM_SPLITS * STEP_SAMPLE;
120 InputSampler.Sampler<IntWritable,NullWritable> sampler =
121 new InputSampler.SplitSampler<IntWritable,NullWritable>(
122 NUM_SAMPLES, NUM_SPLITS);
123 int inits[] = new int[TOT_SPLITS];
124 for (int i = 0; i < TOT_SPLITS; ++i) {
125 inits[i] = i * STEP_SAMPLE;
126 }
127 Job ignored = new Job();
128 Object[] samples = sampler.getSample(
129 new TestInputSamplerIF(100000, TOT_SPLITS, inits), ignored);
130 assertEquals(NUM_SAMPLES, samples.length);
131 Arrays.sort(samples, new IntWritable.Comparator());
132 for (int i = 0; i < NUM_SAMPLES; ++i) {
133 assertEquals(i, ((IntWritable)samples[i]).get());
134 }
135 }
136
137
138
139
140
141 @Test
142 @SuppressWarnings("unchecked")
143 public void testIntervalSampler() throws Exception {
144 final int TOT_SPLITS = 16;
145 final int PER_SPLIT_SAMPLE = 4;
146 final int NUM_SAMPLES = TOT_SPLITS * PER_SPLIT_SAMPLE;
147 final double FREQ = 1.0 / TOT_SPLITS;
148 InputSampler.Sampler<IntWritable,NullWritable> sampler =
149 new InputSampler.IntervalSampler<IntWritable,NullWritable>(
150 FREQ, NUM_SAMPLES);
151 int inits[] = new int[TOT_SPLITS];
152 for (int i = 0; i < TOT_SPLITS; ++i) {
153 inits[i] = i;
154 }
155 Job ignored = new Job();
156 Object[] samples = sampler.getSample(new TestInputSamplerIF(
157 NUM_SAMPLES, TOT_SPLITS, inits), ignored);
158 assertEquals(NUM_SAMPLES, samples.length);
159 Arrays.sort(samples, new IntWritable.Comparator());
160 for (int i = 0; i < NUM_SAMPLES; ++i) {
161 assertEquals(i, ((IntWritable)samples[i]).get());
162 }
163 }
164
165 }