1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling;
18
19 import org.junit.Assert;
20 import org.junit.Test;
21 import org.apache.commons.rng.simple.RandomSource;
22 import org.apache.commons.rng.UniformRandomProvider;
23 import org.apache.commons.rng.core.source64.SplitMix64;
24
25
26
27
28 public class UnitSphereSamplerTest {
29 @Test(expected = IllegalArgumentException.class)
30 public void testPrecondition() {
31 new UnitSphereSampler(0, null);
32 }
33
34
35
36
37 @Test
38 public void testDistribution2D() {
39 UniformRandomProvider rng = RandomSource.create(RandomSource.XOR_SHIFT_1024_S, 17399225432L);
40 UnitSphereSampler generator = new UnitSphereSampler(2, rng);
41
42
43 final int[] angleBuckets = new int[100];
44 final int steps = 1000000;
45 for (int i = 0; i < steps; ++i) {
46 final double[] v = generator.nextVector();
47 Assert.assertEquals(2, v.length);
48 Assert.assertEquals(1, length(v), 1e-10);
49
50
51
52 final double angle = Math.acos(v[0]);
53 final int bucket = (int) (angle / Math.PI * angleBuckets.length);
54 ++angleBuckets[bucket];
55 }
56
57
58 final int expectedBucketSize = steps / angleBuckets.length;
59 for (int bucket : angleBuckets) {
60 Assert.assertTrue("Bucket count " + bucket + " vs expected " + expectedBucketSize,
61 Math.abs(expectedBucketSize - bucket) < 350);
62 }
63 }
64
65
66 @Test(expected = StackOverflowError.class)
67 public void testBadProvider1() {
68 final UniformRandomProvider bad = new UniformRandomProvider() {
69
70 public long nextLong(long n) { return 0; }
71 public long nextLong() { return 0; }
72 public int nextInt(int n) { return 0; }
73 public int nextInt() { return 0; }
74 public float nextFloat() { return 0; }
75 public double nextDouble() { return 0;}
76 public void nextBytes(byte[] bytes, int start, int len) {}
77 public void nextBytes(byte[] bytes) {}
78 public boolean nextBoolean() { return false; }
79
80 };
81
82 new UnitSphereSampler(1, bad).nextVector();
83 }
84
85
86 @Test
87 public void testBadProvider1ThenGoodProvider() {
88
89
90 final UniformRandomProvider bad = new SplitMix64(0L) {
91 private int count;
92
93 public long nextLong() { return (count++ == 0) ? 0 : super.nextLong(); }
94 public double nextDouble() { return (count++ == 0) ? 0 : super.nextDouble(); }
95
96 };
97
98 final double[] vector = new UnitSphereSampler(1, bad).nextVector();
99 Assert.assertEquals(1, vector.length);
100 }
101
102
103
104
105
106 @Test
107 public void testNextNormSquaredAfterZeroIsValid() {
108
109
110 final double normSq = Math.nextAfter(0, 1);
111
112 final double f = 1 / Math.sqrt(normSq);
113
114 Assert.assertTrue(f > 0 && f <= Double.MAX_VALUE);
115 }
116
117
118
119
120 @Test
121 public void testSharedStateSampler() {
122 final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
123 final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
124 final int n = 3;
125 final UnitSphereSampler sampler1 =
126 new UnitSphereSampler(n, rng1);
127 final UnitSphereSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
128 RandomAssert.assertProduceSameSequence(
129 new RandomAssert.Sampler<double[]>() {
130 @Override
131 public double[] sample() {
132 return sampler1.nextVector();
133 }
134 },
135 new RandomAssert.Sampler<double[]>() {
136 @Override
137 public double[] sample() {
138 return sampler2.nextVector();
139 }
140 });
141 }
142
143
144
145
146 private static double length(double[] vector) {
147 double total = 0;
148 for (double d : vector) {
149 total += d * d;
150 }
151 return Math.sqrt(total);
152 }
153 }