View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.rng.core;
19  
20  import org.apache.commons.rng.RestorableUniformRandomProvider;
21  import org.apache.commons.rng.RandomProviderState;
22  
23  /**
24   * Base class with default implementation for common methods.
25   */
26  public abstract class BaseProvider
27      implements RestorableUniformRandomProvider {
28      /** {@inheritDoc} */
29      @Override
30      public int nextInt(int n) {
31          checkStrictlyPositive(n);
32  
33          if ((n & -n) == n) {
34              return (int) ((n * (long) (nextInt() >>> 1)) >> 31);
35          }
36          int bits;
37          int val;
38          do {
39              bits = nextInt() >>> 1;
40              val = bits % n;
41          } while (bits - val + (n - 1) < 0);
42  
43          return val;
44      }
45  
46      /** {@inheritDoc} */
47      @Override
48      public long nextLong(long n) {
49          checkStrictlyPositive(n);
50  
51          long bits;
52          long val;
53          do {
54              bits = nextLong() >>> 1;
55              val  = bits % n;
56          } while (bits - val + (n - 1) < 0);
57  
58          return val;
59      }
60  
61      /** {@inheritDoc} */
62      @Override
63      public RandomProviderState saveState() {
64          return new RandomProviderDefaultState(getStateInternal());
65      }
66  
67      /** {@inheritDoc} */
68      @Override
69      public void restoreState(RandomProviderState state) {
70          if (state instanceof RandomProviderDefaultState) {
71              setStateInternal(((RandomProviderDefaultState) state).getState());
72          } else {
73              throw new IllegalArgumentException("Foreign instance");
74          }
75      }
76  
77      /** {@inheritDoc} */
78      @Override
79      public String toString() {
80          return getClass().getName();
81      }
82  
83      /**
84       * Combine parent and subclass states.
85       * This method must be called by all subclasses in order to ensure
86       * that state can be restored in case some of it is stored higher
87       * up in the class hierarchy.
88       *
89       * I.e. the body of the overridden {@link #getStateInternal()},
90       * will end with a statement like the following:
91       * <pre>
92       *  <code>
93       *    return composeStateInternal(state,
94       *                                super.getStateInternal());
95       *  </code>
96       * </pre>
97       * where {@code state} is the state needed and defined by the class
98       * where the method is overridden.
99       *
100      * @param state State of the calling class.
101      * @param parentState State of the calling class' parent.
102      * @return the combined state.
103      * Bytes that belong to the local state will be stored at the
104      * beginning of the resulting array.
105      */
106     protected byte[] composeStateInternal(byte[] state,
107                                           byte[] parentState) {
108         final int len = parentState.length + state.length;
109         final byte[] c = new byte[len];
110         System.arraycopy(state, 0, c, 0, state.length);
111         System.arraycopy(parentState, 0, c, state.length, parentState.length);
112         return c;
113     }
114 
115     /**
116      * Splits the given {@code state} into a part to be consumed by the caller
117      * in order to restore its local state, while the reminder is passed to
118      * the parent class.
119      *
120      * I.e. the body of the overridden {@link #setStateInternal(byte[])},
121      * will contain statements like the following:
122      * <pre>
123      *  <code>
124      *    final byte[][] s = splitState(state, localStateLength);
125      *    // Use "s[0]" to recover the local state.
126      *    super.setStateInternal(s[1]);
127      *  </code>
128      * </pre>
129      * where {@code state} is the combined state of the calling class and of
130      * all its parents.
131      *
132      * @param state State.
133      * The local state must be stored at the beginning of the array.
134      * @param localStateLength Number of elements that will be consumed by the
135      * locally defined state.
136      * @return the local state (in slot 0) and the parent state (in slot 1).
137      * @throws IllegalStateException if {@code state.length < localStateLength}.
138      */
139     protected byte[][] splitStateInternal(byte[] state,
140                                           int localStateLength) {
141         checkStateSize(state, localStateLength);
142 
143         final byte[] local = new byte[localStateLength];
144         System.arraycopy(state, 0, local, 0, localStateLength);
145         final int parentLength = state.length - localStateLength;
146         final byte[] parent = new byte[parentLength];
147         System.arraycopy(state, localStateLength, parent, 0, parentLength);
148 
149         return new byte[][] { local, parent };
150     }
151 
152     /**
153      * Creates a snapshot of the RNG state.
154      *
155      * @return the internal state.
156      */
157     protected byte[] getStateInternal() {
158         // This class has no state (and is the top-level class that
159         // declares this method).
160         return new byte[0];
161     }
162 
163     /**
164      * Resets the RNG to the given {@code state}.
165      *
166      * @param state State (previously obtained by a call to
167      * {@link #getStateInternal()}).
168      * @throws IllegalStateException if the size of the given array is
169      * not consistent with the state defined by this class.
170      *
171      * @see #checkStateSize(byte[],int)
172      */
173     protected void setStateInternal(byte[] state) {
174         if (state.length != 0) {
175             // This class has no state.
176             throw new IllegalStateException("State not fully recovered by subclasses");
177         }
178     }
179 
180     /**
181      * Simple filling procedure.
182      * It will
183      * <ol>
184      *  <li>
185      *   fill the beginning of {@code state} by copying
186      *   {@code min(seed.length, state.length)} elements from
187      *   {@code seed},
188      *  </li>
189      *  <li>
190      *   set all remaining elements of {@code state} with non-zero
191      *   values (even if {@code seed.length < state.length}).
192      *  </li>
193      * </ol>
194      *
195      * @param state State. Must be allocated.
196      * @param seed Seed. Cannot be null.
197      */
198     protected void fillState(int[] state,
199                              int[] seed) {
200         final int stateSize = state.length;
201         final int seedSize = seed.length;
202         System.arraycopy(seed, 0, state, 0, Math.min(seedSize, stateSize));
203 
204         if (seedSize < stateSize) {
205             for (int i = seedSize; i < stateSize; i++) {
206                 state[i] = (int) (scrambleWell(state[i - seed.length], i) & 0xffffffffL);
207             }
208         }
209     }
210 
211     /**
212      * Simple filling procedure.
213      * It will
214      * <ol>
215      *  <li>
216      *   fill the beginning of {@code state} by copying
217      *   {@code min(seed.length, state.length)} elements from
218      *   {@code seed},
219      *  </li>
220      *  <li>
221      *   set all remaining elements of {@code state} with non-zero
222      *   values (even if {@code seed.length < state.length}).
223      *  </li>
224      * </ol>
225      *
226      * @param state State. Must be allocated.
227      * @param seed Seed. Cannot be null.
228      */
229     protected void fillState(long[] state,
230                              long[] seed) {
231         final int stateSize = state.length;
232         final int seedSize = seed.length;
233         System.arraycopy(seed, 0, state, 0, Math.min(seedSize, stateSize));
234 
235         if (seedSize < stateSize) {
236             for (int i = seedSize; i < stateSize; i++) {
237                 state[i] = scrambleWell(state[i - seed.length], i);
238             }
239         }
240     }
241 
242     /**
243      * Checks that the {@code state} has the {@code expected} size.
244      *
245      * @param state State.
246      * @param expected Expected length of {@code state} array.
247      * @throws IllegalStateException if {@code state.length < expected}.
248      * @deprecated Method is used internally and should be made private in
249      * some future release.
250      */
251     @Deprecated
252     protected void checkStateSize(byte[] state,
253                                   int expected) {
254         if (state.length < expected) {
255             throw new IllegalStateException("State size must be larger than " +
256                                             expected + " but was " + state.length);
257         }
258     }
259 
260     /**
261      * Checks whether {@code index} is in the range {@code [min, max]}.
262      *
263      * @param min Lower bound.
264      * @param max Upper bound.
265      * @param index Value that must lie within the {@code [min, max]} interval.
266      * @throws IndexOutOfBoundsException if {@code index} is not within the
267      * {@code [min, max]} interval.
268      */
269     protected void checkIndex(int min,
270                               int max,
271                               int index) {
272         if (index < min ||
273             index > max) {
274             throw new IndexOutOfBoundsException(index + " is out of interval [" +
275                                                 min + ", " +
276                                                 max + "]");
277         }
278     }
279 
280     /**
281      * Checks that the argument is strictly positive.
282      *
283      * @param n Number to check.
284      * @throws IllegalArgumentException if {@code n <= 0}.
285      */
286     private void checkStrictlyPositive(long n) {
287         if (n <= 0) {
288             throw new IllegalArgumentException("Must be strictly positive: " + n);
289         }
290     }
291 
292     /**
293      * Transformation used to scramble the initial state of
294      * a generator.
295      *
296      * @param n Seed element.
297      * @param mult Multiplier.
298      * @param shift Shift.
299      * @param add Offset.
300      * @return the transformed seed element.
301      */
302     private static long scramble(long n,
303                                  long mult,
304                                  int shift,
305                                  int add) {
306         // Code inspired from "AbstractWell" class.
307         return mult * (n ^ (n >> shift)) + add;
308     }
309 
310     /**
311      * Transformation used to scramble the initial state of
312      * a generator.
313      *
314      * @param n Seed element.
315      * @param add Offset.
316      * @return the transformed seed element.
317      * @see #scramble(long,long,int,int)
318      */
319     private static long scrambleWell(long n,
320                                      int add) {
321         // Code inspired from "AbstractWell" class.
322         return scramble(n, 1812433253L, 30, add);
323     }
324 }