Why is the code written in scala 6 times slower th

2020-06-22 04:55发布

问题:

I am not sure if I make some mistake writing the scala code.
the question is:

    The four adjacent digits in the 1000-digit number that have the greatest product are 9 × 9 × 8 × 9 = 5832.

    73167176531330624919225119674426574742355349194934 96983520312774506326239578318016984801869478851843 85861560789112949495459501737958331952853208805511 12540698747158523863050715693290963295227443043557 66896648950445244523161731856403098711121722383113 62229893423380308135336276614282806444486645238749 30358907296290491560440772390713810515859307960866 70172427121883998797908792274921901699720888093776 65727333001053367881220235421809751254540594752243 52584907711670556013604839586446706324415722155397 53697817977846174064955149290862569321978468622482 83972241375657056057490261407972968652414535100474 82166370484403199890008895243450658541227588666881 16427171479924442928230863465674813919123162824586 17866458359124566529476545682848912883142607690042 24219022671055626321111109370544217506941658960408 07198403850962455444362981230987879927244284909188 84580156166097919133875499200524063689912560717606 05886116467109405077541002256983155200055935729725 71636269561882670428252483600823257530420752963450

    Find the thirteen adjacent digits in the 1000-digit number that have the greatest product. What is the value of this product?  

the bloger(http://www.ituring.com.cn/article/111574) says the code he written by haskell only takes 6ms:

    import Data.List
    import Data.Char

    main = do
      a <- readFile "008.txt"
      print . maximum . map (product . take 13) . tails $ map digitToInt $ filter isDigit a  

so I try to use scala:

    object Main {

      def main(args: Array[String]): Unit = {
        val begin: Long = System.currentTimeMillis()
            val content = Source.fromFile("file/text").filter(_.isDigit).map(_.toInt - '0').toList
            val lists =
              for (i <- 0 to content.size - 13)
                yield content.drop(i).take(13)
            println(lists.maxBy(_.reduce(_ * _)))
        val end: Long = System.currentTimeMillis()
        println(end - begin)
      }
    }  

But it takes 120ms averagely.I thought the problem is I/O,but I found it just took 10ms(I tried to use FileChannel instead of Source,but it doesn't save much time).It is map and flatmap(for) operations which takes the most of time.

Then I try to use java to see if the reason is JVM. Unsurprising, java version runs much faster.Just took about 20ms:

    public static void main(String[] args) throws IOException {
        long begin = System.currentTimeMillis();

        byte[] bytes = Files.readAllBytes(Paths.get("file/text"));
        List<Integer> list=new ArrayList<>();
        for(int i=0;i<bytes.length;i++){
            if(bytes[i]-'0'>=0&&bytes[i]-'0'<=9) list.add(bytes[i]-'0');
        }

        int max=-1;
        List<Integer> maxList=new ArrayList<>();
        List<Integer> temp=new ArrayList<>();

        for(int i=0;i<=list.size()-13;i++){
            int value=1;
            for(int j=i;j<i+13;j++){
                temp.add(list.get(j));
                value*=list.get(j);
            }
            if(value > max) {
                max = value;
                maxList.clear();
                maxList.addAll(temp);
            }
            temp.clear();
        }
        System.out.println(maxList);
        long end = System.currentTimeMillis();
        System.out.println(end - begin);
    }  

My question is why the code of scala version runs so slowly?

回答1:

As @etherous mentioned: You use mutable state in the Java-version, whereas your Scala-version is completely immutable and also written more inefficiently. They are just different.

You can try to avoid maxBy and also try to save already computed results in one iteration. This one should be closer to your Java-version.

val content = Source.fromFile("file/text").filter(_.isDigit).map(_.toLong - '0').toList

val result = (0 to content.size - 13).foldLeft((List.empty[Long], -1l)){case (current @(_, curMax), next) => {
    val temp = content.drop(next).take(13)
    val tempVal = temp.reduce(_*_)
    if(tempVal > curMax) (temp, tempVal) else current
  }
}

result is a tuple here, containing the list of the thirteen numbers as _1 and it's product as _2, as it seems you wanted both.

Bonus

Now that I think about it. There is a method called sliding that exactly deals with this problem. But I guess it runs as slow as your scala-code. At least this would be short :).

content.sliding(13).maxBy(_.reduce(_*_))


回答2:

The scala version runs slowly because you're passing around a bunch of functions and creating a lot of intermediate objects. The haskell version is fast because it's built around these idioms instead of scala which is hacked into the JVM. You can get equivilant performance in scala if you write it as you would java (20ms for me, same as your java):

  import scala.collection.mutable.ArrayBuffer

  val begin = System.currentTimeMillis()

  val buf = new ArrayBuffer[Int]()

  val test = "73167176531330624919225119674426574742355349194934 96983520312774506326239578318016984801869478851843 85861560789112949495459501737958331952853208805511 12540698747158523863050715693290963295227443043557 66896648950445244523161731856403098711121722383113 62229893423380308135336276614282806444486645238749 30358907296290491560440772390713810515859307960866 70172427121883998797908792274921901699720888093776 65727333001053367881220235421809751254540594752243 52584907711670556013604839586446706324415722155397 53697817977846174064955149290862569321978468622482 83972241375657056057490261407972968652414535100474 82166370484403199890008895243450658541227588666881 16427171479924442928230863465674813919123162824586 17866458359124566529476545682848912883142607690042 24219022671055626321111109370544217506941658960408 07198403850962455444362981230987879927244284909188 84580156166097919133875499200524063689912560717606 05886116467109405077541002256983155200055935729725 71636269561882670428252483600823257530420752963450"

  val d = test.toCharArray

  var i = 0
  while (i != d.length) {
    val c = d(i) - '0'
    if (c >= 0 && c <= 9) {
      buf += c
    }
    i += 1
  }


  i = 0
  var max = 0
  var maxStart = 0
  while (i != buf.length - 13) {
   val test = buf(i) * buf(i + 1) * buf(i + 2) * buf(i + 3) * buf(i +4) * buf(i + 5) * buf(i + 6) * buf(i + 7) * buf(i + 8)* buf(i + 9)* buf(i + 10)* buf(i + 11)* buf(i + 12)
   if (test > max){
      max = test
      maxStart = i
    }
    i += 1
  }


  System.out.println(buf.slice(maxStart, maxStart + 13))
  val end = System.currentTimeMillis()
  println(end - begin)

Which prints out:

ArrayBuffer(9, 7, 8, 1, 7, 9, 7, 7, 8, 4, 6, 1, 7)
20

UPDATE

It looks like the JVM can optimize this quite a bit and this goes down to 1ms for me after 50 iterations:

object Main extends App {

  (1 until 50).foreach(i => {
    val begin = System.currentTimeMillis()

    val max = Source.fromFile("file/text")
      .toList
      .filter(_.isDigit)
      .map(_.toInt - '0')
      .sliding(13)
      .maxBy(_.reduce(_ * _))

    println(max)
    val end = System.currentTimeMillis()
    println(end - begin)
  })

}

Prints:

List(9, 7, 8, 1, 7, 9, 7, 7, 8, 4, 6, 1, 7)
144
...
List(9, 7, 8, 1, 7, 9, 7, 7, 8, 4, 6, 1, 7)
1


回答3:

One more variant, just for the fun of it. It does not transform String to the List of Longs in one go. 18 ms in console.

def findMax(startTime:Long, maxStretch:List[Long], currentStretch:List[Long], longLine:String, currentProd:Long):(List[Long], Long, Long) = 
     longLine match {
     case "" => (maxStretch, maxStretch.product, System.currentTimeMillis - startTime)
     case _ => {
     val nextElement = longLine(0)
     val lastElement = currentStretch(0)
     val nextStretch = currentStretch.tail ++ List(nextElement.toLong-'0'.toLong)
     val nextProd = nextStretch.product
     val (nextMaxStretch, nextMaxProd) = if(nextProd>currentProd) {(nextStretch, nextProd)} else (maxStretch, currentProd)
     findMax(startTime,nextMaxStretch, nextStretch, longLine.substring(1),nextMaxProd)
     }
}

val strContent = <long string>.replaceAll(" ","")
val start = strContent.take(13).map(_.toLong-'0'.toLong).toList

scala> findMax(System.currentTimeMillis, start, start, strContent.drop(13), start.product)
res47: (List[Long], Long, Long) = (List(5, 5, 7, 6, 6, 8, 9, 6, 6, 4, 8, 9, 5),23514624000,18)