Faster implementation for reduceByKey on Seq of pa

2019-07-13 02:58发布

问题:

The code below contains various single-threaded implementations of reduceByKeyXXX methods and a few helper methods to create input sets and measure execution times. (Feel free to run the main-method)

The main purpose of reduceByKey (as in Spark) is to reduce key-value-pairs with the same key. Example:

scala> val xs = Seq( "a" -> 2, "b" -> 3, "a" -> 5)
xs: Seq[(String, Int)] = List((a,2), (b,3), (a,5))

scala> ReduceByKeyComparison.reduceByKey(xs, (x:Int, y:Int) ⇒ x+y )
res8: Seq[(String, Int)] = ArrayBuffer((b,3), (a,7))

Code

import java.util.HashMap

object Util {
  def measure( body : => Unit ) : Long = {
    val now = System.currentTimeMillis
    body
    val nowAfter = System.currentTimeMillis
    nowAfter - now
  }

  def measureMultiple( body: => Unit, n: Int) : String = {
    val executionTimes = (1 to n).toList.map( x => {
      print(".")
      measure(body)
    } )

    val avg = executionTimes.sum / executionTimes.size
    executionTimes.mkString("", "ms, ", "ms") + s" Average: ${avg}ms."
  }
}

object RandomUtil {
  val AB = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
  val r = new java.util.Random();

  def randomString( len: Int ) : String = {
    val sb = new StringBuilder( len );
    for( i <- 0 to len-1 ) {
      sb.append(AB.charAt(r.nextInt(AB.length())));
    }
    sb.toString();
  }

  def generateSeq(n: Int) : Seq[(String, Int)] = {
    Seq.fill(n)( (randomString(2), r.nextInt(100)) )
  }
}

object ReduceByKeyComparison {

  def main(args: Array[String]) : Unit = {
    implicit def iterableToPairedIterable[K, V](x: Iterable[(K, V)]) = { new PairedIterable(x) }

    val runs = 10
    val problemSize = 2000000

    val ss = RandomUtil.generateSeq(problemSize)


    println("ReduceByKey :       " + Util.measureMultiple( reduceByKey(ss, (x:Int, y:Int) ⇒ x+y ), runs ))
    println("ReduceByKey2:       " + Util.measureMultiple( reduceByKey2(ss, (x:Int, y:Int) ⇒ x+y ), runs ))
    println("ReduceByKey3:       " + Util.measureMultiple( reduceByKey3(ss, (x:Int, y:Int) ⇒ x+y ), runs ))

    println("ReduceByKeyPaired:  " + Util.measureMultiple( ss.reduceByKey( (x:Int, y:Int) ⇒ x+y ), runs ))
    println("ReduceByKeyA:       " + Util.measureMultiple( reduceByKeyA( ss, (x:Int, y:Int) ⇒ x+y ), runs ))
  }

  // =============================================================================
  // Different implementations
  // =============================================================================

  def reduceByKey[A,B]( s: Seq[(A,B)], fnc: (B, B) ⇒ B) : Seq[(A,B)] = {
    val t = s.groupBy(x => x._1)
    val u = t.map { case (k,v) => (k,  v.map(_._2).reduce(fnc))}
    u.toSeq
  }

  def reduceByKey2[A,B]( s: Seq[(A,B)], fnc: (B, B) ⇒ B) : Seq[(A,B)] = {
    val r = s.foldLeft( Map[A,B]() ){ (m,a) ⇒
      val k = a._1
      val v = a._2
      m.get(k) match {
        case Some(pv) ⇒ m + ((k, fnc(pv, v)))
        case None ⇒ m + ((k, v))
      }
    }
    r.toSeq
  }

  def reduceByKey3[A,B]( s: Seq[(A,B)], fnc: (B, B) ⇒ B) : Seq[(A,B)] = {
    var m = scala.collection.mutable.Map[A,B]()
    s.foreach{ e ⇒
      val k = e._1
      val v = e._2
      m.get(k) match {
        case Some(pv) ⇒ m(k) = fnc(pv, v)
        case None ⇒ m(k) = v
      }
    }
    m.toSeq
  }

  /**
    * Method code from [[http://ideone.com/dyrkYM]]
    * All rights to Muhammad-Ali A'rabi according to [[https://issues.scala-lang.org/browse/SI-9064]]
    */
  def reduceByKeyA[A,B]( s: Seq[(A,B)], fnc: (B, B) ⇒ B): Map[A, B] = {
    s.groupBy(_._1).map(l => (l._1, l._2.map(_._2).reduce( fnc )))
  }

  /**
    * Method code from [[http://ideone.com/dyrkYM]]
    * All rights to Muhammad-Ali A'rabi according to [[https://issues.scala-lang.org/browse/SI-9064]]
    */
  class PairedIterable[K, V](x: Iterable[(K, V)]) {
    def reduceByKey(func: (V,V) => V) = {
      val map = new HashMap[K, V]
      x.foreach { pair =>
        val old = map.get(pair._1)
        map.put(pair._1, if (old == null) pair._2 else func(old, pair._2))
      }
      map
    }
  }
}

yielding the following results on my machine

..........ReduceByKey :       723ms, 782ms, 761ms, 617ms, 640ms, 707ms, 634ms, 611ms, 380ms, 458ms Average: 631ms.
..........ReduceByKey2:       580ms, 458ms, 452ms, 463ms, 462ms, 470ms, 463ms, 465ms, 458ms, 462ms Average: 473ms.
..........ReduceByKey3:       489ms, 466ms, 461ms, 468ms, 555ms, 474ms, 469ms, 457ms, 461ms, 468ms Average: 476ms.
..........ReduceByKeyPaired:  140ms, 124ms, 124ms, 120ms, 122ms, 124ms, 118ms, 126ms, 121ms, 119ms Average: 123ms.
..........ReduceByKeyA:       628ms, 694ms, 666ms, 656ms, 616ms, 660ms, 594ms, 659ms, 445ms, 399ms Average: 601ms.

and ReduceByKeyPaired currently being the fastest.

Question / Task

Is there a faster single-threaded (Scala) implementation?

回答1:

Rewritting reduceByKey method of PairedIterable to recursion gives around 5-10% performance improvement. That all i was able to get. I've also tryed to increase initial capacity allocation for HashMap - but it does not show any significant changes.

  class PairedIterable[K, V](x: Iterable[(K, V)]) {
    def reduceByKey(func: (V,V) => V) = {
      val map = new HashMap[K, V]()

      @tailrec
      def reduce(it: Iterable[(K, V)]): HashMap[K, V] = {
        it match {
          case Nil => map
          case (k, v) :: tail =>
            val old = map.get(k)
            map.put(k, if (old == null) v else func(old, v))
            reduce(tail)
        }
      }

      val r = reduce(x)
      r

    }
  }

In general, making some comparison analysis of provided methods - they can be splitted onto two categories.

  • First set of reduces are with sorting (grouping) - as we can see those methods add extra O(n*log[n]) complexity and are not effective for this scenario.

  • Seconds are with linear looping across all enries of Iterable. Those set of methods has extra get/put operations to temp map. But those gets/puts are not so time consuming - O(n)*O(c). Moreover necessity to work with Options in scala collections makes it less effective.