This is a note regarding the implementation of Vose's Method to construct the alias tables for non-uniform discrete probability sampling presented by Keith Schwartz (as of 2011-12-29). Mr. Schwartz otherwise very informative write-up contains an error in the presentation of the Method. Step 5 of the algorithm fails if the Small
list is nonempty but the Large
list is. This can happen either:
- if the initial probabilities sum to less than 1, or
- if updating in step 5.5 the
Large
probability results in cancellation so that pg results in an approximation less than its true value
The first case can be mitigated by scaling by n / ∑ pj in step 3. The second case can be mitigated by replacing 5.5 by "Set pg := (pg + pl) - 1." In any case, there is no proof that the algorithm as shown is correct, which is not the case for Vose's presentation. In the interest of correctness and completeness, here is an OCaml implementation of Vose's Method:
type alias = { n : int; prob : float array; alias : int array; } let alias pa = let rec split pa n j (small, large as part) = if j == n then part else if pa.(j) > 1. then split pa n (succ j) ( small, j :: large) else split pa n (succ j) (j :: small, large) in let rec init r pa part = match part with | j :: small, k :: large -> r.prob.(j) <- pa.(j); r.alias.(j) <- k; pa.(k) <- (pa.(k) +. pa.(j)) -. 1.; if pa.(k) > 1. then init r pa (small, k :: large) else init r pa (k :: small, large) | j :: small, [] -> r.prob.(j) <- 1.; init r pa (small, []) | [] , k :: large -> r.prob.(k) <- 1.; init r pa ([], large) | [] , [] -> r in let n = Array.length pa in if n == 0 then invalid_arg "alias" else let sc = float n /. Array.fold_left (fun s p -> if p < 0. then invalid_arg "alias" else s +. p) 0. pa in let sa = Array.map (( *. ) sc) pa in let r = { n; prob = Array.make n 0.; alias = Array.make n (-1); } in init r sa (split sa n 0 ([], [])) let choose r = let p, e = modf (Random.float (float r.n)) in let j = truncate e in if p <= r.prob.(j) then j else r.alias.(j)
Since this algorithm is recursive, it might not be immediately obvious how to translate it to more mainstream languages. A perhaps more faithful rendition in rather low-level Java is the following:
import java.util.Random; public class Vose { private final Random random; private final int limit; private final double[] prob; private final int[] alias; public Vose(final double[] pa) { this(pa, new Random()); } public int getLimit() { return limit; } public Vose(final double[] pa, final Random random) { final int limit = pa.length; if (limit == 0) throw new IllegalArgumentException("Vose"); double sum = 0; for (int j = 0; j != limit; j++) { if (pa[j] < 0) throw new IllegalArgumentException("Vose"); sum += pa[j]; } final double scale = (double) limit / sum; final double[] sa = new double[limit]; for (int j = 0; j != limit; j++) sa[j] = pa[j] * scale; this.random = random; this.limit = limit; this.prob = new double[limit]; this.alias = new int[limit]; init(sa); } private void init(final double[] sa) { final int[] small = new int[limit]; final int[] large = new int[limit]; int ns = 0; int nl = 0; for (int j = 0; j != limit; j++) if (sa[j] > 1) large[nl++] = j; else small[ns++] = j; while (ns != 0 && nl != 0) { final int j = small[--ns]; final int k = large[--nl]; prob[j] = sa[j]; alias[j] = k; sa[k] = (sa[k] + sa[j]) - 1; // sic if (sa[k] > 1) large[nl++] = k; else small[ns++] = k; } while (ns != 0) prob[small[--ns]] = 1; while (nl != 0) prob[large[--nl]] = 1; } public int choose() { final double u = limit * random.nextDouble(); final int j = (int) Math.floor(u); final double p = u - (double) j; return p <= prob[j] ? j : alias[j]; } }
In all generality, it is sufficient that the "probabilities" be non-negative, since dividing the values by the mean normalizes them so that the conditions of the Method are satisfied.
References
- Michael D. Vose. A Linear Algorithm For Generating Random Numbers With a Given Distribution, IEEE TRANSACTIONS ON SOFTWARE ENGINEERING VOL. 17, NO. 9 SEPTEMBER 1991. Online.
3 comments:
Hello Matias - this is Keith Schwarz. Thanks for spotting these errors! In my writeup I had been ignoring issues of floating-point errors, but these are extremely valid points. I'll be sure to update the article with additional information about implementing the algorithm in practice.
Out of curiosity, can you elaborate on the reason why it's preferable to use (p_g + p_l) - 1 rather than p_g - (1 - p_l)?
Thanks!
@Keith,
in your language of choice, follow the execution of your algorithm with [0.7, 0.3]. As I wrote on reddit, even if neither number can be represented exactly in floating point, the sum 0.7 + 0.3 is exactly 1.0, provided the language uses correct decimal to binary conversion. When computing p_g - (1 - p_l) the errors combine, making the result less than 1. Computing (p_g + p_l) - 1 the errors cancel, preserving precision.
Thanks to both of you for your work on this. I was really inspired. I wrote an object-oriented version of the algorithm in Smalltalk here: http://on.fb.me/zkzq0I. I'd appreciate any comments you have on the style.
Post a Comment