Filtering Scala's Parallel Collections with ea

2019-03-19 03:14发布

Given a very large instance of collection.parallel.mutable.ParHashMap (or any other parallel collection), how can one abort a filtering parallel scan once a given, say 50, number of matches has been found ?

Attempting to accumulate intermediate matches in a thread-safe "external" data structure or keeping an external AtomicInteger with result count seems to be 2 to 3 times slower on 4 cores than using a regular collection.mutable.HashMap and pegging a single core at 100%.

I am aware that find or exists on Par* collections do abort "on the inside". Is there a way to generalize this to find more than one result ?

Here's the code which still seems to be 2 to 3 times slower on the ParHashMap with ~ 79,000 entries and also has a problem of stuffing more than maxResults results into the results CHM (Which is probably due to thread being preempted after incrementAndGet but before break which allows other threads to add more elements in). Update: it seems the slow down is due to worker threads contending on the counter.incrementAndGet() which of course defeats the purpose of the whole parallel scan :-(

def find(filter: Node => Boolean, maxResults: Int): Iterable[Node] =
{
  val counter = new AtomicInteger(0)
  val results = new ConcurrentHashMap[Key,  Node](maxResults)

  import util.control.Breaks._

  breakable
  {
    for ((key, node) <- parHashMap if filter(node))
    {
      results.put(key, node)
      val total = counter.incrementAndGet()
      if (total > maxResults) break
    }
  }

  results.values.toArray(new Array[Node](results.size))
}

3条回答
家丑人穷心不美
2楼-- · 2019-03-19 03:47

I would first do parallel scan in which variable maxResults would be threadlocal. This would find up to (maxResults * numberOfThreads) results.

Then I would do single threaded scan to reduce it to maxResults.

查看更多
孤傲高冷的网名
3楼-- · 2019-03-19 03:57

I had performed an interesting investigation about your case.

Investigation reasoning

I suspected the problem is with the mutability of the input Map and I will try to explain you why: HashMap implementation organizes the data in different buckets, as one can see on Wikipedia.

Wikipedia HashMap

The first thread-safe collections in Java, the synchronized collections were based on synchronizing all the methods around the underlying implementation and resulted in poor performance. Further research and thinking brought to the more performant Concurrent Collection, such as the ConcurrentHashMap which approach was smarter : why don't we protect each bucket with a specific lock?

According to my feeling the performance problem occurs because:

  • when you run in parallel your filter, some threads will conflict on accessing the same bucket at once and will hit the same lock, because your map is mutable.
  • You hold a counter to see how many results you have while you can actually check the size of your result. If you have a thread-safe way to build a collection, you don't need a thread-safe counter too.

Investigation result

I have developed a test case and I find out I was wrong. The problem is with the concurrent nature of the output map. In fact, that is where the collision occurs, when you are putting elements in the map, rather then when you are iterating on it. Additionally, since you want only the result on values, you don't need the keys and the hashing and all the map features. It might be interesting to test if you remove the AtomicCounter and you use only the result map to check if you collected enough elements how your version performs.

Please be careful with the following code in Scala 2.9.2. I am explaining in another post why I need two different functions for the parallel and the non parallel version: Calling map on a parallel collection via a reference to an ancestor type

object MapPerformance {

  val size = 100000
  val items = Seq.tabulate(size)( x => (x,x*2))


  val concurrentParallelMap = ImmutableParHashMap(items:_*)

  val concurrentMutableParallelMap = MutableParHashMap(items:_*)

  val unparallelMap = Map(items:_*)


  class ThreadSafeIndexedSeqBuilder[T](maxSize:Int) {
    val underlyingBuilder = new VectorBuilder[T]()
    var counter = 0
    def sizeHint(hint:Int) { underlyingBuilder.sizeHint(hint) }
    def +=(item:T):Boolean ={
      synchronized{
        if(counter>=maxSize)
          false
        else{
          underlyingBuilder+=item
          counter+=1
          true
        }
      }
    }
    def result():Vector[T] = underlyingBuilder.result()

  }

  def find(map:ParMap[Int,Int],filter: Int => Boolean, maxResults: Int): Iterable[Int] =
  {

    // we already know the maximum size
    val resultsBuilder = new ThreadSafeIndexedSeqBuilder[Int](maxResults)
    resultsBuilder.sizeHint(maxResults)

    import util.control.Breaks._
    breakable
    {
      for ((key, node) <- map if filter(node))
      {
        val newItemAdded = resultsBuilder+=node
        if (!newItemAdded)
          break()

      }
    }
    resultsBuilder.result().seq

  }

  def findUnParallel(map:Map[Int,Int],filter: Int => Boolean, maxResults: Int): Iterable[Int] =
  {

    // we already know the maximum size
    val resultsBuilder = Array.newBuilder[Int]
    resultsBuilder.sizeHint(maxResults)

    var counter = 0
      for {
        (key, node) <- map if filter(node)
        if counter < maxResults
      }{
        resultsBuilder+=node
        counter+=1
      }

    resultsBuilder.result()

  }

  def measureTime[K](f: => K):(Long,K) = {
    val startMutable = System.currentTimeMillis()
    val result = f
    val endMutable = System.currentTimeMillis()
    (endMutable-startMutable,result)
  }

  def main(args:Array[String]) = {
    val maxResultSetting=10
    (1 to 10).foreach{
      tryNumber =>
        println("Try number " +tryNumber)
        val (mutableTime, mutableResult) = measureTime(find(concurrentMutableParallelMap,_%2==0,maxResultSetting))
        val (immutableTime, immutableResult) = measureTime(find(concurrentMutableParallelMap,_%2==0,maxResultSetting))
        val (unparallelTime, unparallelResult) = measureTime(findUnParallel(unparallelMap,_%2==0,maxResultSetting))
        assert(mutableResult.size==maxResultSetting)
        assert(immutableResult.size==maxResultSetting)
        assert(unparallelResult.size==maxResultSetting)
        println(" The mutable version has taken " + mutableTime + " milliseconds")
        println(" The immutable version has taken " + immutableTime + " milliseconds")
        println(" The unparallel version has taken " + unparallelTime + " milliseconds")
     }
  }

}

With this code, I have systematically the parallel (both mutable and immutable version of the input map) about 3,5 time faster then the unparallel on my machine.

查看更多
三岁会撩人
4楼-- · 2019-03-19 04:03

You could try to get an iterator and then create a lazy list (a Stream) where you filter (with your predicate) and take the number of elements you want. Because it is a non strict, this 'taking' of elements is not evaluated. Afterwards you can force the execution by adding ".par" to the whole thing and achieve parallelization.

Example code:

A parallelized map with random values (simulating your parallel hash map):

scala> myMap
res14: scala.collection.parallel.immutable.ParMap[Int,Int] = ParMap(66978401 -> -1331298976, 256964068 -> 126442706, 1698061835 -> 1622679396, -1556333580 -> -1737927220, 791194343 -> -591951714, -1907806173 -> 365922424, 1970481797 -> 162004380, -475841243 -> -445098544, -33856724 -> -1418863050, 1851826878 -> 64176692, 1797820893 -> 405915272, -1838192182 -> 1152824098, 1028423518 -> -2124589278, -670924872 -> 1056679706, 1530917115 -> 1265988738, -808655189 -> -1742792788, 873935965 -> 733748120, -1026980400 -> -163182914, 576661388 -> 900607992, -1950678599 -> -731236098)

Get an iterator and create a Stream from the iterator and filter it. In this case my predicate is only accepting pairs (of the value member of the map). I want to get 10 even elements, so I take 10 elements which will only get evaluated when I force it to:

scala> val mapIterator = myMap.toIterator
mapIterator: Iterator[(Int, Int)] = HashTrieIterator(20)


scala> val r = Stream.continually(mapIterator.next()).filter(_._2 % 2 == 0).take(10)
r: scala.collection.immutable.Stream[(Int, Int)] = Stream((66978401,-1331298976), ?)

Finally, I force the evaluation which only gets 10 elements as planned

scala> r.force
res16: scala.collection.immutable.Stream[(Int, Int)] = Stream((66978401,-1331298976), (256964068,126442706), (1698061835,1622679396), (-1556333580,-1737927220), (791194343,-591951714), (-1907806173,365922424), (1970481797,162004380), (-475841243,-445098544), (-33856724,-1418863050), (1851826878,64176692))

This way you only get the number of elements you want (without needing to process the remaining elements) and you parallelize the process without locks, atomics or breaks.

Please compare this to your solutions to see if it is any good.

查看更多
登录 后发表回答