Is this implementation of takeWhileInclusive safe?

2020-05-08 07:08发布

问题:

I found the following implementation of an inclusive takeWhile (found here)

fun <T> Sequence<T>.takeWhileInclusive(pred: (T) -> Boolean): Sequence<T> {
    var shouldContinue = true
    return takeWhile {
        val result = shouldContinue
        shouldContinue = pred(it)
        result
    }
}

The problem is I'm not 100% convinced this is safe if used on a parallel sequence.

My concern is that we'd be relying on the shouldContinue variable to know when to stop, but we're not synchronizing it's access.

Any insights?

回答1:

Here's what I've figured out so far.

Question clarification

The question is unclear. There's no such thing as a parallel sequence I probably got them mixed up with Java's parallel streams. What I meant was a sequence that was consumed concurrently.

Sequences are synchronous

As @LouisWasserman pointed out in the comments sequences are not designed for parallel execution. In particular the SequenceBuilder is annotated with @RestrictSuspension. Citing from Kotlin Coroutine repo:

It means that no SequenceBuilder extension of lambda in its scope can invoke suspendContinuation or other general suspending function

Having said that as @MarkoTopolnik commented they can still be used in a parallel program just like any other Object.

Sequences used in parallel

As an example here's a first attempt of using Sequences in parallel

fun launchProcessor(id: Int, iterator: Iterator<Int>) = launch {
    println("[${Thread.currentThread().name}] Processor #$id received ${iterator.next()}")
}

fun main(args: Array<String>) {
    val s = sequenceOf(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
    runBlocking {
        val iterator = s.iterator()
        repeat(10) { launchProcessor(it, iterator) }
    }
}

This code prints:

[ForkJoinPool.commonPool-worker-2] Processor #1 received 1

[ForkJoinPool.commonPool-worker-1] Processor #0 received 0

[ForkJoinPool.commonPool-worker-3] Processor #2 received 2

[ForkJoinPool.commonPool-worker-2] Processor #3 received 3

[ForkJoinPool.commonPool-worker-1] Processor #4 received 3

[ForkJoinPool.commonPool-worker-3] Processor #5 received 3

[ForkJoinPool.commonPool-worker-1] Processor #7 received 5

[ForkJoinPool.commonPool-worker-2] Processor #6 received 4

[ForkJoinPool.commonPool-worker-1] Processor #9 received 7

[ForkJoinPool.commonPool-worker-3] Processor #8 received 6

Which of course is not what we want. As some numbers are consumed twice.

Enter channels

On the other hand if we were to use channels we could write something like this:

fun produceNumbers() = produce {
    var x = 1 // start from 1
    while (true) {
        send(x++) // produce next
        delay(100) // wait 0.1s
    }
}

fun launchProcessor(id: Int, channel: ReceiveChannel<Int>) = launch {
    channel.consumeEach {
        println("[${Thread.currentThread().name}] Processor #$id received $it")
    }
}

fun main(args: Array<String>) = runBlocking<Unit> {
    val producer = produceNumbers()
    repeat(5) { launchProcessor(it, producer) }
    delay(1000)
    producer.cancel() // cancel producer coroutine and thus kill them all
}

Then the output is:

[ForkJoinPool.commonPool-worker-2] Processor #0 received 1

[ForkJoinPool.commonPool-worker-2] Processor #0 received 2

[ForkJoinPool.commonPool-worker-1] Processor #1 received 3

[ForkJoinPool.commonPool-worker-2] Processor #2 received 4

[ForkJoinPool.commonPool-worker-1] Processor #3 received 5

[ForkJoinPool.commonPool-worker-2] Processor #4 received 6

[ForkJoinPool.commonPool-worker-2] Processor #0 received 7

[ForkJoinPool.commonPool-worker-1] Processor #1 received 8

[ForkJoinPool.commonPool-worker-1] Processor #2 received 9

[ForkJoinPool.commonPool-worker-2] Processor #3 received 10

Furthermore we could implement the takeWhileInclusive method for channels like this:

fun <E> ReceiveChannel<E>.takeWhileInclusive(
        context: CoroutineContext = Unconfined,
        predicate: suspend (E) -> Boolean
): ReceiveChannel<E> = produce(context) {
    var shouldContinue = true
    consumeEach {
        val currentShouldContinue = shouldContinue
        shouldContinue = predicate(it)
        if (!currentShouldContinue) return@produce
        send(it)
    }
}

And it works as expected.