View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.rng.core;
19  
20  import org.apache.commons.rng.RestorableUniformRandomProvider;
21  import org.apache.commons.rng.RandomProviderState;
22  
23  /**
24   * Base class with default implementation for common methods.
25   */
26  public abstract class BaseProvider
27      implements RestorableUniformRandomProvider {
28      /** Error message when an integer is not positive. */
29      private static final String NOT_POSITIVE = "Must be strictly positive: ";
30      /** 2^32. */
31      private static final long POW_32 = 1L << 32;
32  
33      /** {@inheritDoc} */
34      @Override
35      public int nextInt(int n) {
36          if (n <= 0) {
37              throw new IllegalArgumentException(NOT_POSITIVE + n);
38          }
39  
40          // Lemire (2019): Fast Random Integer Generation in an Interval
41          // https://arxiv.org/abs/1805.10941
42          long m = (nextInt() & 0xffffffffL) * n;
43          long l = m & 0xffffffffL;
44          if (l < n) {
45              // 2^32 % n
46              final long t = POW_32 % n;
47              while (l < t) {
48                  m = (nextInt() & 0xffffffffL) * n;
49                  l = m & 0xffffffffL;
50              }
51          }
52          return (int) (m >>> 32);
53      }
54  
55      /** {@inheritDoc} */
56      @Override
57      public long nextLong(long n) {
58          if (n <= 0) {
59              throw new IllegalArgumentException(NOT_POSITIVE + n);
60          }
61  
62          long bits;
63          long val;
64          do {
65              bits = nextLong() >>> 1;
66              val  = bits % n;
67          } while (bits - val + (n - 1) < 0);
68  
69          return val;
70      }
71  
72      /** {@inheritDoc} */
73      @Override
74      public RandomProviderState saveState() {
75          return new RandomProviderDefaultState(getStateInternal());
76      }
77  
78      /** {@inheritDoc} */
79      @Override
80      public void restoreState(RandomProviderState state) {
81          if (state instanceof RandomProviderDefaultState) {
82              setStateInternal(((RandomProviderDefaultState) state).getState());
83          } else {
84              throw new IllegalArgumentException("Foreign instance");
85          }
86      }
87  
88      /** {@inheritDoc} */
89      @Override
90      public String toString() {
91          return getClass().getName();
92      }
93  
94      /**
95       * Combine parent and subclass states.
96       * This method must be called by all subclasses in order to ensure
97       * that state can be restored in case some of it is stored higher
98       * up in the class hierarchy.
99       *
100      * I.e. the body of the overridden {@link #getStateInternal()},
101      * will end with a statement like the following:
102      * <pre>
103      *  <code>
104      *    return composeStateInternal(state,
105      *                                super.getStateInternal());
106      *  </code>
107      * </pre>
108      * where {@code state} is the state needed and defined by the class
109      * where the method is overridden.
110      *
111      * @param state State of the calling class.
112      * @param parentState State of the calling class' parent.
113      * @return the combined state.
114      * Bytes that belong to the local state will be stored at the
115      * beginning of the resulting array.
116      */
117     protected byte[] composeStateInternal(byte[] state,
118                                           byte[] parentState) {
119         final int len = parentState.length + state.length;
120         final byte[] c = new byte[len];
121         System.arraycopy(state, 0, c, 0, state.length);
122         System.arraycopy(parentState, 0, c, state.length, parentState.length);
123         return c;
124     }
125 
126     /**
127      * Splits the given {@code state} into a part to be consumed by the caller
128      * in order to restore its local state, while the reminder is passed to
129      * the parent class.
130      *
131      * I.e. the body of the overridden {@link #setStateInternal(byte[])},
132      * will contain statements like the following:
133      * <pre>
134      *  <code>
135      *    final byte[][] s = splitState(state, localStateLength);
136      *    // Use "s[0]" to recover the local state.
137      *    super.setStateInternal(s[1]);
138      *  </code>
139      * </pre>
140      * where {@code state} is the combined state of the calling class and of
141      * all its parents.
142      *
143      * @param state State.
144      * The local state must be stored at the beginning of the array.
145      * @param localStateLength Number of elements that will be consumed by the
146      * locally defined state.
147      * @return the local state (in slot 0) and the parent state (in slot 1).
148      * @throws IllegalStateException if {@code state.length < localStateLength}.
149      */
150     protected byte[][] splitStateInternal(byte[] state,
151                                           int localStateLength) {
152         checkStateSize(state, localStateLength);
153 
154         final byte[] local = new byte[localStateLength];
155         System.arraycopy(state, 0, local, 0, localStateLength);
156         final int parentLength = state.length - localStateLength;
157         final byte[] parent = new byte[parentLength];
158         System.arraycopy(state, localStateLength, parent, 0, parentLength);
159 
160         return new byte[][] {local, parent};
161     }
162 
163     /**
164      * Creates a snapshot of the RNG state.
165      *
166      * @return the internal state.
167      */
168     protected byte[] getStateInternal() {
169         // This class has no state (and is the top-level class that
170         // declares this method).
171         return new byte[0];
172     }
173 
174     /**
175      * Resets the RNG to the given {@code state}.
176      *
177      * @param state State (previously obtained by a call to
178      * {@link #getStateInternal()}).
179      * @throws IllegalStateException if the size of the given array is
180      * not consistent with the state defined by this class.
181      *
182      * @see #checkStateSize(byte[],int)
183      */
184     protected void setStateInternal(byte[] state) {
185         if (state.length != 0) {
186             // This class has no state.
187             throw new IllegalStateException("State not fully recovered by subclasses");
188         }
189     }
190 
191     /**
192      * Simple filling procedure.
193      * It will
194      * <ol>
195      *  <li>
196      *   fill the beginning of {@code state} by copying
197      *   {@code min(seed.length, state.length)} elements from
198      *   {@code seed},
199      *  </li>
200      *  <li>
201      *   set all remaining elements of {@code state} with non-zero
202      *   values (even if {@code seed.length < state.length}).
203      *  </li>
204      * </ol>
205      *
206      * @param state State. Must be allocated.
207      * @param seed Seed. Cannot be null.
208      */
209     protected void fillState(int[] state,
210                              int[] seed) {
211         final int stateSize = state.length;
212         final int seedSize = seed.length;
213         System.arraycopy(seed, 0, state, 0, Math.min(seedSize, stateSize));
214 
215         if (seedSize < stateSize) {
216             for (int i = seedSize; i < stateSize; i++) {
217                 state[i] = (int) (scrambleWell(state[i - seed.length], i) & 0xffffffffL);
218             }
219         }
220     }
221 
222     /**
223      * Simple filling procedure.
224      * It will
225      * <ol>
226      *  <li>
227      *   fill the beginning of {@code state} by copying
228      *   {@code min(seed.length, state.length)} elements from
229      *   {@code seed},
230      *  </li>
231      *  <li>
232      *   set all remaining elements of {@code state} with non-zero
233      *   values (even if {@code seed.length < state.length}).
234      *  </li>
235      * </ol>
236      *
237      * @param state State. Must be allocated.
238      * @param seed Seed. Cannot be null.
239      */
240     protected void fillState(long[] state,
241                              long[] seed) {
242         final int stateSize = state.length;
243         final int seedSize = seed.length;
244         System.arraycopy(seed, 0, state, 0, Math.min(seedSize, stateSize));
245 
246         if (seedSize < stateSize) {
247             for (int i = seedSize; i < stateSize; i++) {
248                 state[i] = scrambleWell(state[i - seed.length], i);
249             }
250         }
251     }
252 
253     /**
254      * Checks that the {@code state} has the {@code expected} size.
255      *
256      * @param state State.
257      * @param expected Expected length of {@code state} array.
258      * @throws IllegalStateException if {@code state.length < expected}.
259      * @deprecated Method is used internally and should be made private in
260      * some future release.
261      */
262     @Deprecated
263     protected void checkStateSize(byte[] state,
264                                   int expected) {
265         if (state.length < expected) {
266             throw new IllegalStateException("State size must be larger than " +
267                                             expected + " but was " + state.length);
268         }
269     }
270 
271     /**
272      * Checks whether {@code index} is in the range {@code [min, max]}.
273      *
274      * @param min Lower bound.
275      * @param max Upper bound.
276      * @param index Value that must lie within the {@code [min, max]} interval.
277      * @throws IndexOutOfBoundsException if {@code index} is not within the
278      * {@code [min, max]} interval.
279      */
280     protected void checkIndex(int min,
281                               int max,
282                               int index) {
283         if (index < min ||
284             index > max) {
285             throw new IndexOutOfBoundsException(index + " is out of interval [" +
286                                                 min + ", " +
287                                                 max + "]");
288         }
289     }
290 
291     /**
292      * Transformation used to scramble the initial state of
293      * a generator.
294      *
295      * @param n Seed element.
296      * @param mult Multiplier.
297      * @param shift Shift.
298      * @param add Offset.
299      * @return the transformed seed element.
300      */
301     private static long scramble(long n,
302                                  long mult,
303                                  int shift,
304                                  int add) {
305         // Code inspired from "AbstractWell" class.
306         return mult * (n ^ (n >> shift)) + add;
307     }
308 
309     /**
310      * Transformation used to scramble the initial state of
311      * a generator.
312      *
313      * @param n Seed element.
314      * @param add Offset.
315      * @return the transformed seed element.
316      * @see #scramble(long,long,int,int)
317      */
318     private static long scrambleWell(long n,
319                                      int add) {
320         // Code inspired from "AbstractWell" class.
321         return scramble(n, 1812433253L, 30, add);
322     }
323 }