Is Arrays.stream(array_name).sum() slower than ite

2020-05-20 01:55发布

I was coding a leetcode problem : https://oj.leetcode.com/problems/gas-station/ using Java 8.

My solution got TLE when I used Arrays.stream(integer_array).sum() to compute sum while the same solution got accepted using iteration to calculate the sum of elements in array. The best possible time complexity for this problem is O(n) and I am surprised to get TLE when using streaming API's from Java 8. I have implemented the solution in O(n) only.

import java.util.Arrays;

public class GasStation {
    public int canCompleteCircuit(int[] gas, int[] cost) {
        int start = 0, i = 0, runningCost = 0, totalGas = 0, totalCost = 0; 
        totalGas = Arrays.stream(gas).sum();
        totalCost = Arrays.stream(cost).sum();

        // for (int item : gas) totalGas += item;
        // for (int item : cost) totalCost += item;

        if (totalGas < totalCost)
            return -1;

        while (start > i || (start == 0 && i < gas.length)) {
            runningCost += gas[i];
            if (runningCost >= cost[i]) {
                runningCost -= cost[i++];
            } else {
                runningCost -= gas[i];
                if (--start < 0)
                    start = gas.length - 1;
                runningCost += (gas[start] - cost[start]);
            }
        }
        return start;
    }

    public static void main(String[] args) {
        GasStation sol = new GasStation();
        int[] gas = new int[] { 10, 5, 7, 14, 9 };
        int[] cost = new int[] { 8, 5, 14, 3, 1 };
        System.out.println(sol.canCompleteCircuit(gas, cost));

        gas = new int[] { 10 };
        cost = new int[] { 8 };
        System.out.println(sol.canCompleteCircuit(gas, cost));
    }
}

The solution gets accepted when, I comment the following two lines: (calculating sum using streaming)

totalGas = Arrays.stream(gas).sum();
totalCost = Arrays.stream(cost).sum();

and uncomment the following two lines (calculating sum using iteration):

//for (int item : gas) totalGas += item;
//for (int item : cost) totalCost += item;

Now the solution gets accepted. Why streaming API in Java 8 is slower for large input than iteration for primitives?

4条回答
聊天终结者
2楼-- · 2020-05-20 02:30

The first step in dealing with problems like this is to bring the code into a controlled environment. That means running it in the JVM you control (and can invoke) and running tests inside a good benchmark harness like JMH. Analyze, don't speculate.

Here's a benchmark I whipped up using JMH to do some analysis on this:

@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
@State(Scope.Benchmark)
public class ArraySum {
    static final long SEED = -897234L;

    @Param({"1000000"})
    int sz;

    int[] array;

    @Setup
    public void setup() {
        Random random = new Random(SEED);
        array = new int[sz];
        Arrays.setAll(array, i -> random.nextInt());
    }

    @Benchmark
    public int sumForLoop() {
        int sum = 0;
        for (int a : array)
            sum += a;
        return sum;
    }

    @Benchmark
    public int sumStream() {
        return Arrays.stream(array).sum();
    }
}

Basically this creates an array of a million ints and sums them twice: once using a for-loop and once using streams. Running the benchmark produces a bunch of output (elided for brevity and for dramatic effect) but the summary results are below:

Benchmark                 (sz)  Mode  Samples     Score  Score error  Units
ArraySum.sumForLoop    1000000  avgt        3   514.473      398.512  us/op
ArraySum.sumStream     1000000  avgt        3  7355.971     3170.697  us/op

Whoa! That Java 8 streams stuff is teh SUXX0R! It's 14 times slower than a for-loop, don't use it!!!1!

Well, no. First let's go over these results, and then look more closely to see if we can figure out what's going on.

The summary shows the two benchmark methods, with the "sz" parameter of a million. It's possible to vary this parameter but it doesn't turn out to make a difference in this case. I also only ran the benchmark methods 3 times, as you can see from the "samples" column. (There were also only 3 warmup iterations, not visible here.) The score is in microseconds per operation, and clearly the stream code is much, much slower than the for-loop code. But note also the score error: that's the amount of variability in the different runs. JMH helpfully prints out the standard deviation of the results (not shown here) but you can easily see that the score error is a significant fraction of reported score. This reduces our confidence in the score.

Running more iterations should help. More warmup iterations will let the JIT do more work and settle down before running the benchmarks, and running more benchmark iterations will smooth out any errors from transient activity elsewhere on my system. So let's try 10 warmup iterations and 10 benchmark iterations:

Benchmark                 (sz)  Mode  Samples     Score  Score error  Units
ArraySum.sumForLoop    1000000  avgt       10   504.803       34.010  us/op
ArraySum.sumStream     1000000  avgt       10  7128.942      178.688  us/op

Performance is overall a little faster, and the measurement error is also quite a bit smaller, so running more iterations has had the desired effect. But the streams code is still considerably slower than the for-loop code. What's going on?

A large clue can be obtained by looking at the individual timings of the streams method:

# Warmup Iteration   1: 570.490 us/op
# Warmup Iteration   2: 491.765 us/op
# Warmup Iteration   3: 756.951 us/op
# Warmup Iteration   4: 7033.500 us/op
# Warmup Iteration   5: 7350.080 us/op
# Warmup Iteration   6: 7425.829 us/op
# Warmup Iteration   7: 7029.441 us/op
# Warmup Iteration   8: 7208.584 us/op
# Warmup Iteration   9: 7104.160 us/op
# Warmup Iteration  10: 7372.298 us/op

What happened? The first few iterations were reasonably fast, but then the 4th and subsequent iterations (and all the benchmark iterations that follow) were suddenly much slower.

I've seen this before. It was in this question and this answer elsewhere on SO. I recommend reading that answer; it explains how the JVM's inlining decisions in this case result in poorer performance.

A bit of background here: a for-loop compiles to a very simple increment-and-test loop, and can easily be handled by usual optimization techniques like loop peeling and unrolling. The streams code, while not very complex in this case, is actually quite complex compared to the for-loop code; there's a fair bit of setup, and each loop requires at least one method call. Thus, the JIT optimizations, particularly its inlining decisions, are critical to making the streams code go fast. And it's possible for it to go wrong.

Another background point is that integer summation is about the simplest possible operation you can think of to do in a loop or stream. This will tend to make the fixed overhead of stream setup look relatively more expensive. It is also so simple that it can trigger pathologies in the inlining policy.

The suggestion from the other answer was to add the JVM option -XX:MaxInlineLevel=12 to increase the amount of code that can be inlined. Rerunning the benchmark with that option gives:

Benchmark                 (sz)  Mode  Samples    Score  Score error  Units
ArraySum.sumForLoop    1000000  avgt       10  502.379       27.859  us/op
ArraySum.sumStream     1000000  avgt       10  498.572       24.195  us/op

Ah, much nicer. Disabling tiered compilation using -XX:-TieredCompilation also had the effect of avoiding the pathological behavior. I also found that making the loop computation even a bit more expensive, e.g. summing squares of integers -- that is, adding a single multiply -- also avoids the pathological behavior.

Now, your question is about running in the context of the leetcode environment, which seems to run the code in a JVM that you don't have any control over, so you can't change the inlining or compilation options. And you probably don't want to make your computation more complex to avoid the pathology either. So for this case, you might as well just stick to the good old for-loop. But don't be afraid to use streams, even for dealing with primitive arrays. It can perform quite well, aside from some narrow edge cases.

查看更多
等我变得足够好
3楼-- · 2020-05-20 02:40

The normal iteration approach is going to be pretty much as fast as anything can be, but streams have a variety of overheads: even though it's coming directly from a stream, there's probably going to be a primitive Spliterator involved and lots of other objects being generated.

In general, you should expect the "normal approach" to usually be faster than streams unless you're both using parallelization and your data is very large.

查看更多
Ridiculous、
4楼-- · 2020-05-20 02:46

My benchmark (see code below) shows that streaming approach is about 10-15% slower than iterative. Interestingly enough, parallel stream results vary greatly on my 4 core (i7) macbook pro, but, while I have seen a them a few times being about 30% faster than iterative, the most common result is almost three times slower than sequential.

Here is the benchmark code:

import java.util.*;
import java.util.function.*;

public class StreamingBenchmark {

    private static void benchmark(String name, LongSupplier f) {
       long start = System.currentTimeMillis(), sum = 0;
       for(int count = 0; count < 1000; count ++) sum += f.getAsLong();
       System.out.println(String.format(
           "%10s in  %d millis. Sum = %d", 
            name, System.currentTimeMillis() - start, sum
       ));
    }

    public static void main(String argv[]) {
        int data[] = new int[1000000];
        Random randy = new Random();
        for(int i = 0; i < data.length; i++) data[i] = randy.nextInt();

        benchmark("iterative", () -> { int s = 0; for(int n: data) s+=n; return s; });
        benchmark("stream", () -> Arrays.stream(data).sum());
        benchmark("parallel", () -> Arrays.stream(data).parallel().sum());

    }
}

Here is the output from a few runs:

 iterative in  350 millis. Sum = 564821058000
 stream in  394 millis. Sum = 564821058000
 parallel in  883 millis. Sum = 564821058000

 iterative in  340 millis. Sum = -295411382000
 stream in  376 millis. Sum = -295411382000
 parallel in  1031 millis. Sum = -295411382000

 iterative in  365 millis. Sum = 1205763898000
 stream in  379 millis. Sum = 1205763898000
 parallel in  1053 millis. Sum = 1205763898000

etc.

This got me curious, and I also tried running equivalent logic in scala:

object Scarr {
    def main(argv: Array[String]) = {
        val randy = new java.util.Random
        val data = (1 to 1000000).map { _ => randy.nextInt }.toArray
        val start = System.currentTimeMillis
        var sum = 0l;
        for ( _ <- 1 to 1000 ) sum += data.sum
        println(sum + " in " + (System.currentTimeMillis - start) + " millis.")

    }
}

This took 14 seconds! About 40 times(!) longer than streaming in java. Ouch!

查看更多
够拽才男人
5楼-- · 2020-05-20 02:46

The sum() method is syntactically equivalent to return reduce(0, Integer::sum); In a large list, there will be more overhead in making all the method calls than the basic by-hand for-loop iteration. The byte code for the for(int i : numbers) iteration is only very slightly more complicated than that generated by the by-hand for-loop. The stream operation is possibly faster in parallel-friendly environments (though maybe not for primitive methods), but unless we know that the environment is parallel-friendly (and it may not be since leetcode itself is probably designed to favor low-level over abstract since it's measuring efficiency rather than legibility).

The sum operation done in any of the three ways (Arrays.stream(int[]).sum, for (int i : ints){total+=i;}, and for(int i = 0; i < ints.length; i++){total+=i;} should be relatively similar in efficiency. I used the following test class (which sums a hundred million integers between 0 and 4096 a hundred times each and records the average times). All of them returned in very similar timeframes. It even attempts to limit parallel processing by occupying all but one of the available cores in while(true) loops, but I still found no particular difference:

public class SumTester {
    private static final int ARRAY_SIZE = 100_000_000;
    private static final int ITERATION_LIMIT = 100;
    private static final int INT_VALUE_LIMIT = 4096;

    public static void main(String[] args) {
        Random random = new Random();
        int[] numbers = new int[ARRAY_SIZE];
        IntStream.range(0, ARRAY_SIZE).forEach(i->numbers[i] = random.nextInt(INT_VALUE_LIMIT));

        Map<String, ToLongFunction<int[]>> inputs = new HashMap<String, ToLongFunction<int[]>>();

        NanoTimer initializer = NanoTimer.start();
        System.out.println("initialized NanoTimer in " + initializer.microEnd() + " microseconds");

        inputs.put("sumByStream", SumTester::sumByStream);
        inputs.put("sumByIteration", SumTester::sumByIteration);
        inputs.put("sumByForLoop", SumTester::sumByForLoop);

        System.out.println("Parallelables: ");
        averageTimeFor(ITERATION_LIMIT, inputs, Arrays.copyOf(numbers, numbers.length));

        int cores = Runtime.getRuntime().availableProcessors();
        List<CancelableThreadEater> threadEaters = new ArrayList<CancelableThreadEater>();
        if (cores > 1){
            threadEaters = occupyThreads(cores - 1);
        }
        // Only one core should be left to our class
        System.out.println("\nSingleCore (" + threadEaters.size() + " of " + cores + " cores occupied)");
        averageTimeFor(ITERATION_LIMIT, inputs, Arrays.copyOf(numbers, numbers.length));
        for (CancelableThreadEater cte : threadEaters){
            cte.end();
        }
        System.out.println("Complete!");
    }

    public static long sumByStream(int[] numbers){
        return Arrays.stream(numbers).sum();
    }

    public static long sumByIteration(int[] numbers){
        int total = 0;
        for (int i : numbers){
            total += i;
        }
        return total;
    }

    public static long sumByForLoop(int[] numbers){
        int total = 0;
        for (int i = 0; i < numbers.length; i++){
            total += numbers[i];
        }
        return total;       
    }

    public static void averageTimeFor(int iterations, Map<String, ToLongFunction<int[]>> testMap, int[] numbers){
        Map<String, Long> durationMap = new HashMap<String, Long>();
        Map<String, Long> sumMap = new HashMap<String, Long>();
        for (String methodName : testMap.keySet()){
            durationMap.put(methodName, 0L);
            sumMap.put(methodName, 0L);
        }
        for (int i = 0; i < iterations; i++){
            for (String methodName : testMap.keySet()){
                int[] newNumbers = Arrays.copyOf(numbers, ARRAY_SIZE);
                ToLongFunction<int[]> function = testMap.get(methodName);
                NanoTimer nt = NanoTimer.start();
                long sum = function.applyAsLong(newNumbers);
                long duration = nt.microEnd();
                sumMap.put(methodName, sum);
                durationMap.put(methodName, durationMap.get(methodName) + duration);
            }
        }
        for (String methodName : testMap.keySet()){
            long duration = durationMap.get(methodName) / iterations;
            long sum = sumMap.get(methodName);
            System.out.println(methodName + ": result '" + sum + "', elapsed time: " + duration + " milliseconds on average over " + iterations + " iterations");
        }
    }

    private static List<CancelableThreadEater> occupyThreads(int numThreads){
        List<CancelableThreadEater> result = new ArrayList<CancelableThreadEater>();
        for (int i = 0; i < numThreads; i++){
            CancelableThreadEater cte = new CancelableThreadEater();
            result.add(cte);
            new Thread(cte).start();
        }
        return result;
    }

    private  static class CancelableThreadEater implements Runnable {
        private Boolean stop = false;
        public void run(){
            boolean canContinue = true;
            while (canContinue){
                synchronized(stop){
                    if (stop){
                        canContinue = false;
                    }
                }
            }           
        }

        public void end(){
            synchronized(stop){
                stop = true;
            }
        }
    }

}

which returned

initialized NanoTimer in 22 microseconds
Parallelables: 
sumByIteration: result '-1413860413', elapsed time: 35844 milliseconds on average over 100 iterations
sumByStream: result '-1413860413', elapsed time: 35414 milliseconds on average over 100 iterations
sumByForLoop: result '-1413860413', elapsed time: 35218 milliseconds on average over 100 iterations

SingleCore (3 of 4 cores occupied)
sumByIteration: result '-1413860413', elapsed time: 37010 milliseconds on average over 100 iterations
sumByStream: result '-1413860413', elapsed time: 38375 milliseconds on average over 100 iterations
sumByForLoop: result '-1413860413', elapsed time: 37990 milliseconds on average over 100 iterations
Complete!

That said, there's no real reason to do the sum() operation in this case. You are iterating through each array, for a total of three iterations and the last one may be a longer-than-normal iteration. It's possible to calculate correctly with one full simultaneous iteration of the arrays and one short-circuiting iteration. It may be possible to do it even more efficiently, but I couldn't figure out any better way to do it than I did. My solution ended up being one of the fastest java solutions on the chart - it ran in 223ms, which was in amongst the middle pack of python solutions.

I'll add my solution to the problem if you care to see it, but I hope the actual question is answered here.

查看更多
登录 后发表回答