Select random element from a set, faster than line

2019-02-21 10:12发布

问题:

I'd like to create this function, which selects a random element from a Set:

randElem :: (RandomGen g) => Set a -> g -> (a, g)

Simple listy implementations can be written. For example (code updated, verified working):

import Data.Set as Set
import System.Random (getStdGen, randomR, RandomGen)

randElem :: (RandomGen g) => Set a -> g -> (a, g)
randElem s g = (Set.toList s !! n, g')
    where (n, g') = randomR (0, Set.size s - 1) g

-- simple test drive
main = do g <- getStdGen
          print . fst $ randElem s g
    where s = Set.fromList [1,3,5,7,9]

But using !! incurs a linear lookup cost for large (randomly selected) n. Is there a faster way to select a random element in a Set? Ideally, repeated random selections should produce a uniform distribution over all options, meaning it does not prefer some elements over others.


Edit: some great ideas are popping up in the answers, so I just wanted to throw a couple more clarifications on what exactly I'm looking for. I asked this question with Sets as the solution to this situation in mind. I'll prefer answers that both

  1. avoid using any outside-the-function bookkeeping beyond the Set's internals, and
  2. maintain good performance (better than O(n) on average) even though the function is only used once per unique set.

I also have this love of working code, so expect (at minimum) a +1 from me if your answer includes a working solution.

回答1:

Here's an idea: You could do interval bisection.

  1. size s is constant time. Use randomR to get how far into the set you are selecting.
  2. Do split with various values between the original findMin and findMax until you get the element at the position you want. If you really fear that the set is made up say of reals and is extremely tightly clustered, you can recompute findMin and findMax each time to guarantee knocking off some elements each time.

The performance would be O(n log n), basically no worse than your current solution, but with only rather weak conditions to the effect that the set not be entirely clustered round some accumulation point, the average performance should be ~((logn)^2), which is fairly constant. If it's a set of integers, you get O(log n * log m), where m is the initial range of the set; it's only reals that could cause really nasty performance in an interval bisection (or other data types whose order-type has accumulation points).

PS. This produces a perfectly even distribution, as long as watching for off-by-ones to make sure it's possible to get the elements at the top and bottom.

Edit: added 'code'

Some inelegant, unchecked (pseudo?) code. No compiler on my current machine to smoke test, possibility of off-by-ones, and could probably be done with fewer ifs. One thing: check out how mid is generated; it'll need some tweaking depending on whether you are looking for something that works with sets of ints or reals (interval bisection is inherently topological, and oughtn't to work quite the same for sets with different topologies).

import Data.Set as Set
import System.Random (getStdGen, randomR, RandomGen)

getNth (s, n) = if n = 0 then (Set.findMin s) else if n + 1 = Set.size s then Set.findMax s
    else if n < Set.size bott then getNth (bott, n) else if pres and Set.size bott = n then n
    else if pres then getNth (top, n - Set.size bott - 1) else getNth (top, n - Set.size)
    where mid = ((Set.findMax s) - (Set.findMin s)) /2 + (Set.findMin s)
          (bott, pres, top) = (splitMember mid s)

randElem s g = (getNth(s, n), g')
    where (n, g') = randomR (0, Set.size s - 1) g


回答2:

Data.Map has an indexing function (elemAt), so use this:

import qualified Data.Map as M
import Data.Map(member, size, empty)
import System.Random

type Set a = M.Map a ()

insert :: (Ord a) => a -> Set a -> Set a
insert a = M.insert a ()

fromList :: Ord a => [a] -> Set a
fromList = M.fromList . flip zip (repeat ())

elemAt i = fst . M.elemAt i

randElem :: (RandomGen g) => Set a -> g -> (a, g)
randElem s g = (elemAt n s, g')
    where (n, g') = randomR (0, size s - 1) g

And you have something quite compatible with Data.Set (with respect to interface and performance) that also has a log(n) indexing function and the randElem function you requested.

Note that randElem is log(n) (and it's probably the fastest implementation you can get with this complexity), and all the other functions have the same complexity as in Data.Set. Let me know if you need any other specific functions from the Set API and I will add them.



回答3:

As far as I know, the proper solution would be to use an indexed set -- i.e. an IntMap. You just need to store the total number of elements added along with the map. Every time you add an element, you add it with a key one higher than previously. Deleting an element is fine -- just don't alter the total elements counter. If, on looking up a keyed element, that element no longer exists, then generate a new random number and try again. This works until the total number of deletions dominates the number of active elements in the set. If that's a problem, you can keep a separate set of deleted keys to draw from when inserting new elements.



回答4:

If you had access to the internals of Data.Set (see here for the definition), which is just a binary tree, you could recurse over the tree, at each node selecting one of the branches with probability according to their respective sizes. This is quite straight forward and gives you very good performance in terms of memory management and allocations, as you have no extra book-keeping to do. OTOH, you have to invoke the RNG O(log n) times.

A variant is using Jonas’ suggestion to first take the size and select the index of the random element based on that, and then use a (yet to be added elemAt) function to Data.Set.



回答5:

If you don't need to modify your set or need to modify it infrequently you can use arrays as lookup table with O(1) access time.

import qualified Data.Vector 
import qualified Data.Set

newtype RandSet a = RandSet (V.Vector a)

randElem :: RandSet a -> RandomGen -> (a, RandomGen)
randElem (RandSet v) g
  | V.empty v = error "Cannot select from empty set" 
  | otherwise = 
    let (i,g') = randomR (0, V.length v - 1) g
    in (v ! i, g')

-- Of course you have to rebuild array on insertion/deletion which is O(n)
insert :: a -> RandSet a -> RandSet a
insert x = V.fromList . Set.toList . Set.insert x . Set.fromList . V.toList`


回答6:

This problem can be finessed a bit if you don't mind completely consuming your RandomGen. With splittable generators, this is an A-OK thing to do. The basic idea is to make a lookup table for the set:

randomElems :: Set a -> RandomGen -> [a]
randomElems set = map (table !) . randomRs bounds where
    bounds = (1, size set)
    table  = listArray bounds (toList set)

This will have very good performance: it will cost you O(n+m) time, where n is the size of the set and m is the number of elements of the resulting list you evaluate. (Plus the time it takes to randomly choose m numbers in bounds, of course.)



回答7:

Another way to achieve this might be to use Data.Sequence instead of Data.Set. This would allow you to add elements to the end in O(1) time and index elements in O(log n) time. If you also need to be able to do membership tests or deletions, you would have to use the more general fingertree package and use something like FingerTree (Sum 1, Max a) a. To insert an element, use the Max a annotation to find the right place to insert; this basically takes O(log n) time (for some usage patterns it might be a bit less). To do a membership test, do basically the same thing, so it's O(log n) time (again, for some usage patterns this might be a bit less). To pick a random element, use the Sum 1 annotation to do your indexing, taking O(log n) time (this will be the average case for uniformly random indices).



回答8:

As of containers-0.5.2.0 the Data.Set module has an elemAt function, which retrieves values by their zero-based index in the sorted sequence of elements. So it is now trivial to write this function

import           Control.Monad.Random
import           Data.Set (Set)
import qualified Data.Set as Set

randElem :: (MonadRandom m, Ord a) -> Set a -> m (a, Set a)
randElem xs = do
  n <- getRandomR (0, Set.size xs - 1)
  return (Set.elemAt n xs, Set.deleteAt n xs)

Since both Set.elemAt and Set.deleteAt are O(log n) where n is the number of elements in the set, the entire operation is O(log n)