瀏覽代碼

[core] Revise OnDemandChannel, improve state abstraction:

[core] Start producer coroutine immediately on `expectMore` and yield

Improve docs for OnDemandChannel

Rename factory function `OnDemandReceiveChannel` to  `OnDemandChannel` to better cover its meaning

Create deferred lazily in Producing state

Rename ProducerState to ChannelState

fix atomicfu bug receiveOrNull

Add docs (WIP, to be rebased)

[core] OnDemandChannel: Catch Throwable in `receiveOrNull` to prevent possible failures
Him188 2 年之前
父節點
當前提交
87fbaa4fb2

+ 11 - 9
mirai-core-utils/src/commonMain/kotlin/channels/ProducerState.kt → mirai-core-utils/src/commonMain/kotlin/channels/ChannelState.kt

@@ -12,13 +12,13 @@ package net.mamoe.mirai.utils.channels
 import kotlinx.coroutines.CompletableDeferred
 import kotlinx.coroutines.Deferred
 import kotlinx.coroutines.ExperimentalCoroutinesApi
-import net.mamoe.mirai.utils.sync.Latch
+import kotlinx.coroutines.Job
 import kotlin.coroutines.CoroutineContext
 
 /**
  * Producer states.
  */
-internal sealed interface ProducerState<T, V> {
+internal sealed interface ChannelState<T, V> {
     /*
      * 可变更状态的函数: [emit], [receiveOrNull], [expectMore], [finish], [finishExceptionally]
      * 
@@ -93,11 +93,11 @@ internal sealed interface ProducerState<T, V> {
      */
     abstract override fun toString(): String
 
-    class JustInitialized<T, V> : ProducerState<T, V> {
+    class JustInitialized<T, V> : ChannelState<T, V> {
         override fun toString(): String = "JustInitialized"
     }
 
-    sealed interface HasProducer<T, V> : ProducerState<T, V> {
+    sealed interface HasProducer<T, V> : ChannelState<T, V> {
         val producer: OnDemandSendChannel<T, V>
     }
 
@@ -116,8 +116,10 @@ internal sealed interface ProducerState<T, V> {
 
     class Producing<T, V>(
         override val producer: OnDemandSendChannel<T, V>,
-        val deferred: CompletableDeferred<V>,
+        parentJob: Job,
     ) : HasProducer<T, V> {
+        val deferred: CompletableDeferred<V> by lazy { CompletableDeferred<V>(parentJob) }
+        
         override fun toString(): String = "Producing(deferred.completed=${deferred.isCompleted})"
     }
 
@@ -126,7 +128,7 @@ internal sealed interface ProducerState<T, V> {
         val value: Deferred<V>,
         parentCoroutineContext: CoroutineContext,
     ) : HasProducer<T, V> {
-        val producerLatch: Latch<T> = Latch(parentCoroutineContext)
+        val producerLatch: CompletableDeferred<T> = CompletableDeferred(parentCoroutineContext[Job])
 
         override fun toString(): String {
             @OptIn(ExperimentalCoroutinesApi::class)
@@ -138,15 +140,15 @@ internal sealed interface ProducerState<T, V> {
 
     class Consumed<T, V>(
         override val producer: OnDemandSendChannel<T, V>,
-        val producerLatch: Latch<T>
+        val producerLatch: CompletableDeferred<T>
     ) : HasProducer<T, V> {
         override fun toString(): String = "Consumed($producerLatch)"
     }
 
     class Finished<T, V>(
-        private val previousState: ProducerState<T, V>,
+        private val previousState: ChannelState<T, V>,
         val exception: Throwable?,
-    ) : ProducerState<T, V> {
+    ) : ChannelState<T, V> {
         val isSuccess: Boolean get() = exception == null
 
         fun createAlreadyFinishedException(cause: Throwable?): IllegalProducerStateException {

+ 2 - 2
mirai-core-utils/src/commonMain/kotlin/channels/IllegalProducerStateException.kt

@@ -10,9 +10,9 @@
 package net.mamoe.mirai.utils.channels
 
 public class IllegalProducerStateException internal constructor(
-    private val state: ProducerState<*, *>,
+    private val state: ChannelState<*, *>,
     message: String? = state.toString(),
     cause: Throwable? = null,
 ) : IllegalStateException(message, cause) {
-    public val lastStateWasSucceed: Boolean get() = (state is ProducerState.Finished) && state.isSuccess
+    public val lastStateWasSucceed: Boolean get() = (state is ChannelState.Finished) && state.isSuccess
 }

+ 44 - 40
mirai-core-utils/src/commonMain/kotlin/channels/OnDemandChannelImpl.kt

@@ -12,10 +12,7 @@ package net.mamoe.mirai.utils.channels
 import kotlinx.atomicfu.AtomicRef
 import kotlinx.atomicfu.atomic
 import kotlinx.atomicfu.loop
-import kotlinx.coroutines.CompletableDeferred
-import kotlinx.coroutines.cancel
-import kotlinx.coroutines.job
-import kotlinx.coroutines.launch
+import kotlinx.coroutines.*
 import net.mamoe.mirai.utils.UtilsLogger
 import net.mamoe.mirai.utils.childScope
 import net.mamoe.mirai.utils.debug
@@ -30,14 +27,17 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
 ) : OnDemandReceiveChannel<T, V> {
     private val coroutineScope = parentCoroutineContext.childScope("CoroutineOnDemandReceiveChannel")
 
-    private val state: AtomicRef<ProducerState<T, V>> = atomic(ProducerState.JustInitialized())
+    private val state: AtomicRef<ChannelState<T, V>> = atomic(ChannelState.JustInitialized())
 
 
     inner class Producer(
         private val initialTicket: T,
     ) : OnDemandSendChannel<T, V> {
         init {
-            coroutineScope.launch {
+            // `UNDISPATCHED` with `yield()`: start the coroutine immediately in current thread,
+            // attaching Job to the coroutineScope, then `yield` the thread back, to complete `launch`.
+            coroutineScope.launch(start = CoroutineStart.UNDISPATCHED) {
+                yield()
                 try {
                     producerCoroutine(initialTicket)
                 } catch (_: CancellationException) {
@@ -51,21 +51,21 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
         override suspend fun emit(value: V): T {
             state.loop { state ->
                 when (state) {
-                    is ProducerState.Finished -> throw state.createAlreadyFinishedException(null)
-                    is ProducerState.Producing -> {
+                    is ChannelState.Finished -> throw state.createAlreadyFinishedException(null)
+                    is ChannelState.Producing -> {
                         val deferred = state.deferred
-                        val consumingState = ProducerState.Consuming(
+                        val consumingState = ChannelState.Consuming(
                             state.producer,
                             state.deferred,
                             coroutineScope.coroutineContext
                         )
                         if (compareAndSetState(state, consumingState)) {
                             deferred.complete(value) // produce a value
-                            return consumingState.producerLatch.acquire() // wait for producer to consume the previous value.
+                            return consumingState.producerLatch.await() // wait for producer to consume the previous value.
                         }
                     }
 
-                    is ProducerState.ProducerReady -> {
+                    is ChannelState.ProducerReady -> {
                         setStateProducing(state)
                     }
 
@@ -81,9 +81,9 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
         override fun finish() {
             state.loop { state ->
                 when (state) {
-                    is ProducerState.Finished -> throw state.createAlreadyFinishedException(null)
+                    is ChannelState.Finished -> throw state.createAlreadyFinishedException(null)
                     else -> {
-                        if (compareAndSetState(state, ProducerState.Finished(state, null))) {
+                        if (compareAndSetState(state, ChannelState.Finished(state, null))) {
                             return
                         }
                     }
@@ -92,20 +92,16 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
         }
     }
 
-    private fun setStateProducing(state: ProducerState.ProducerReady<T, V>) {
-        val deferred = CompletableDeferred<V>(coroutineScope.coroutineContext.job)
-        if (!compareAndSetState(state, ProducerState.Producing(state.producer, deferred))) {
-            deferred.cancel() // avoid leak
-        }
-        // loop again
+    private fun setStateProducing(state: ChannelState.ProducerReady<T, V>) {
+        compareAndSetState(state, ChannelState.Producing(state.producer, coroutineScope.coroutineContext.job))
     }
 
     private fun finishImpl(exception: Throwable?) {
         state.loop { state ->
             when (state) {
-                is ProducerState.Finished -> {} // ignore 
+                is ChannelState.Finished -> {} // ignore 
                 else -> {
-                    if (compareAndSetState(state, ProducerState.Finished(state, exception))) {
+                    if (compareAndSetState(state, ChannelState.Finished(state, exception))) {
                         val cancellationException = kotlinx.coroutines.CancellationException("Finished", exception)
                         coroutineScope.cancel(cancellationException)
                         return
@@ -115,24 +111,31 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
         }
     }
 
-    private fun compareAndSetState(state: ProducerState<T, V>, newState: ProducerState<T, V>): Boolean {
+    private fun compareAndSetState(state: ChannelState<T, V>, newState: ChannelState<T, V>): Boolean {
         return this.state.compareAndSet(state, newState).also {
             logger.debug { "CAS: $state -> $newState: $it" }
         }
     }
 
     override suspend fun receiveOrNull(): V? {
-        state.loop { state ->
-            when (state) {
-                is ProducerState.Consuming -> {
-                    // value is ready, switch state to Consumed
+        // don't use `.loop`:
+        // java.lang.VerifyError: Bad type on operand stack	
+        //     net/mamoe/mirai/utils/channels/CoroutineOnDemandReceiveChannel.receiveOrNull(Lkotlin/coroutines/Continuation;)Ljava/lang/Object; @103: getfield	
+
+        while (true) {
+            when (val state = state.value) {
+                is ChannelState.Consuming -> {
+                    // value is ready, now we consume the value
+
+                    if (compareAndSetState(state, ChannelState.Consumed(state.producer, state.producerLatch))) {
+                        // value is consumed, no contention, safe to retrieve 
 
-                    if (compareAndSetState(state, ProducerState.Consumed(state.producer, state.producerLatch))) {
                         return try {
                             // This actually won't suspend, since the value is already completed
-                            // Just to be error-tolerating
+                            // Just to be error-tolerating and re-throwing exceptions.
                             state.value.await()
-                        } catch (e: Exception) {
+                        } catch (e: Throwable) {
+                            // Producer failed to produce the previous value with exception
                             throw ProducerFailureException(cause = e)
                         }
                     }
@@ -140,7 +143,7 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
 
                 // note: actually, this case should be the first case (for code consistency) in `when`, 
                 // but atomicfu 1.8.10 fails on this.
-                is ProducerState.Producing<T, V> -> {
+                is ChannelState.Producing<T, V> -> {
                     // still producing value
 
                     state.deferred.await() // just wait for value, but does not return it.
@@ -151,7 +154,7 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
                     // Here we will loop again, to atomically switch to Consumed state.
                 }
 
-                is ProducerState.Finished -> {
+                is ChannelState.Finished -> {
                     state.exception?.let { err ->
                         throw ProducerFailureException(cause = err)
                     }
@@ -166,32 +169,33 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
     override fun expectMore(ticket: T): Boolean {
         state.loop { state ->
             when (state) {
-                is ProducerState.JustInitialized -> {
-                    val ready = ProducerState.ProducerReady { Producer(ticket) }
+                is ChannelState.JustInitialized -> {
+                    // start producer atomically
+                    val ready = ChannelState.ProducerReady { Producer(ticket) }
                     if (compareAndSetState(state, ready)) {
                         ready.startProducerIfNotYet()
                     }
                     // loop again
                 }
 
-                is ProducerState.ProducerReady -> {
+                is ChannelState.ProducerReady -> {
                     setStateProducing(state)
                 }
 
-                is ProducerState.Producing -> return true // ok
+                is ChannelState.Producing -> return true // ok
 
-                is ProducerState.Consuming -> throw IllegalProducerStateException(state) // a value is already ready
+                is ChannelState.Consuming -> throw IllegalProducerStateException(state) // a value is already ready
 
-                is ProducerState.Consumed -> {
-                    if (compareAndSetState(state, ProducerState.ProducerReady { state.producer })) {
+                is ChannelState.Consumed -> {
+                    if (compareAndSetState(state, ChannelState.ProducerReady { state.producer })) {
                         // wake up producer async.
-                        state.producerLatch.resumeWith(Result.success(ticket))
+                        state.producerLatch.complete(ticket)
                         // loop again to switch state atomically to Producing. 
                         // Do not do switch state directly here — async producer may race with you! 
                     }
                 }
 
-                is ProducerState.Finished -> return false
+                is ChannelState.Finished -> return false
             }
         }
     }

+ 38 - 15
mirai-core-utils/src/commonMain/kotlin/channels/OnDemandSendChannel.kt

@@ -9,8 +9,10 @@
 
 package net.mamoe.mirai.utils.channels
 
-import kotlinx.coroutines.channels.Channel
+import kotlinx.coroutines.Deferred
+import kotlinx.coroutines.Job
 import kotlinx.coroutines.channels.ReceiveChannel
+import kotlinx.coroutines.channels.SendChannel
 import net.mamoe.mirai.utils.UtilsLogger
 import kotlin.coroutines.Continuation
 import kotlin.coroutines.CoroutineContext
@@ -18,43 +20,60 @@ import kotlin.coroutines.EmptyCoroutineContext
 import kotlin.coroutines.cancellation.CancellationException
 
 /**
- * 按需供给的 [Channel].
+ * 按需供给的 [SendChannel].
  */
 public interface OnDemandSendChannel<T, V> {
     /**
-     * 挂起协程, 直到 [OnDemandReceiveChannel] 期望接收一个 [V], 届时将 [value] 传递给 [OnDemandReceiveChannel.receiveOrNull], 成为其返回值.
+     * 挂起协程, 直到 [OnDemandReceiveChannel] [期望接收][OnDemandReceiveChannel.receiveOrNull]一个 [V], 届时将 [value] 传递给 [OnDemandReceiveChannel.receiveOrNull], 成为其返回值.
      *
-     * 若在调用 [emit] 时已经有 [OnDemandReceiveChannel] 正在等待, 则该 [OnDemandReceiveChannel] 协程会立即[恢复][Continuation.resumeWith].
+     * 若在调用 [emit] 时已经有 [OnDemandReceiveChannel.receiveOrNull] 正在等待, 则该协程会立即[恢复][Continuation.resumeWith], [emit] 不会挂起.
      *
      * 若 [OnDemandReceiveChannel] 已经[完结][OnDemandReceiveChannel.finish], [OnDemandSendChannel.emit] 会抛出 [IllegalProducerStateException].
+     *
+     * @see OnDemandReceiveChannel.receiveOrNull
      */
     public suspend fun emit(value: V): T
 
     /**
-     * 标记此 [OnDemandSendChannel] 在生产 [V] 的过程中出现错误.
+     * 标记此 [OnDemandSendChannel] 在生产 [V] 的过程中出现异常.
+     *
+     * 这也会终止此 [OnDemandSendChannel], 若 [OnDemandReceiveChannel] 正在期待一个值, 则当它调用 [OnDemandReceiveChannel.receiveOrNull] 时, 它将得到一个 [ProducerFailureException].
      *
-     * 这也会终止此 [OnDemandSendChannel], 随后 [OnDemandReceiveChannel.receiveOrNull] 将会抛出 [ProducerFailureException].
+     * 在 [finishExceptionally] 之后若尝试调用 [OnDemandSendChannel.emit], [OnDemandReceiveChannel.receiveOrNull] 或 [OnDemandReceiveChannel.expectMore] 都会导致 [IllegalStateException].
      */
     public fun finishExceptionally(exception: Throwable)
 
     /**
      * 标记此 [OnDemandSendChannel] 已经没有更多 [V] 可生产.
      *
-     * 随后 [OnDemandReceiveChannel.receiveOrNull] 将会抛出 [IllegalStateException].
+     * 这会终止此 [OnDemandSendChannel], 若 [OnDemandReceiveChannel] 正在期待一个值, 则当它调用 [OnDemandReceiveChannel.receiveOrNull] 时, 它将得到一个 [ProducerFailureException].
+     *
+     * 在 [finish] 之后若尝试调用 [OnDemandSendChannel.emit], [OnDemandReceiveChannel.receiveOrNull] 或 [OnDemandReceiveChannel.expectMore] 都会导致 [IllegalStateException].
      */
     public fun finish()
 }
 
+
 /**
- * 按需消费者.
+ * 线程安全的按需接收通道.
  *
  * 与 [ReceiveChannel] 不同, [OnDemandReceiveChannel] 只有在调用 [expectMore] 后才会让[生产者][OnDemandSendChannel] 开始生产下一个 [V].
  */
 public interface OnDemandReceiveChannel<T, V> {
     /**
-     * 挂起协程并等待从 [OnDemandSendChannel] [接收][OnDemandSendChannel.emit]一个 [V].
+     * 尝试从 [OnDemandSendChannel] [接收][OnDemandSendChannel.emit]一个 [V].
+     * 当且仅当在 [OnDemandSendChannel] 已经[正常结束][OnDemandSendChannel.finish] 时返回 `null`.
      *
-     * 当此函数被多个线程 (协程) 同时调用时, 只有一个线程挂起并获得 [V], 其他线程将会
+     * 若目前已有 [V], 此函数立即返回该 [V], 不会挂起.
+     * 否则, 此函数将会挂起直到 [OnDemandSendChannel.emit].
+     *
+     * 当此函数被多个协程 (线程) 同时调用时, 只有一个协程会获得 [V], 其他协程将会挂起.
+     *
+     * 若在等待过程中 [OnDemandSendChannel] [异常结束][OnDemandSendChannel.finishExceptionally],
+     * 本函数会立即恢复并抛出 [ProducerFailureException], 其 `cause` 为令 [OnDemandSendChannel] 的异常.
+     *
+     * 此挂起函数可被取消.
+     * 如果在此函数挂起时当前协程的 [Job] 被取消或完结, 此函数会立即恢复并抛出 [CancellationException]. 此行为与 [Deferred.await] 相同.
      *
      * @throws ProducerFailureException 当 [OnDemandSendChannel.finishExceptionally] 时抛出.
      * @throws CancellationException 当协程被取消时抛出
@@ -64,7 +83,9 @@ public interface OnDemandReceiveChannel<T, V> {
     public suspend fun receiveOrNull(): V?
 
     /**
-     * 期待 [OnDemandSendChannel] 再生产一个 [V]. 期望生产后必须在之后调用 [receiveOrNull] 或 [finish] 来消耗生产的 [V].
+     * 期待 [OnDemandSendChannel] 再生产一个 [V].
+     * 期望生产后必须在之后调用 [receiveOrNull] 或 [finish] 来消耗生产的 [V].
+     * 不可连续重复调用 [expectMore].
      *
      * 在成功发起期待后返回 `true`; 在 [OnDemandSendChannel] 已经[完结][OnDemandSendChannel.finish] 时返回 `false`.
      *
@@ -73,16 +94,18 @@ public interface OnDemandReceiveChannel<T, V> {
     public fun expectMore(ticket: T): Boolean
 
     /**
-     * 标记此 [OnDemandReceiveChannel] 已经完结.
+     * 标记此 [OnDemandSendChannel] 已经不再需要更多的值.
+     *
+     * 如果 [OnDemandSendChannel] 仍在运行 (无论是挂起中还是正在计算下一个值), 都会正常地[取消][Job.cancel] [OnDemandSendChannel].
      *
-     * 如果 [OnDemandSendChannel] 仍在运行, 将会 (正常地) 取消 [OnDemandSendChannel].
      *
-     * 随后 [OnDemandSendChannel.emit] 将会抛出 [IllegalStateException].
+     * 在 [finish] 之后若尝试调用 [OnDemandSendChannel.emit], [OnDemandReceiveChannel.receiveOrNull] 或 [OnDemandReceiveChannel.expectMore] 都会导致 [IllegalStateException].
      */
     public fun finish()
 }
 
-public fun <T, V> OnDemandReceiveChannel(
+@Suppress("FunctionName")
+public fun <T, V> OnDemandChannel(
     parentCoroutineContext: CoroutineContext = EmptyCoroutineContext,
     logger: UtilsLogger = UtilsLogger.noop(),
     producerCoroutine: suspend OnDemandSendChannel<T, V>.(initialTicket: T) -> Unit,

+ 76 - 0
mirai-core-utils/src/commonTest/kotlin/net/mamoe/mirai/utils/channels/OnDemandChannelTest.kt

@@ -0,0 +1,76 @@
+/*
+ * Copyright 2019-2023 Mamoe Technologies and contributors.
+ *
+ * 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
+ * Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
+ *
+ * https://github.com/mamoe/mirai/blob/dev/LICENSE
+ */
+
+package net.mamoe.mirai.utils.channels
+
+import kotlinx.coroutines.*
+import kotlin.test.*
+
+
+class OnDemandChannelTest {
+    ///////////////////////////////////////////////////////////////////////////
+    // CoroutineScope lifecycle
+    ///////////////////////////////////////////////////////////////////////////
+
+    @Test
+    fun attachScopeJob() {
+        val job = SupervisorJob()
+        val channel = OnDemandChannel<Int, Int>(job) {
+            fail()
+        }
+        assertEquals(1, job.children.toList().size)
+        channel.finish()
+    }
+
+    @Test
+    fun finishAfterInstantiation() {
+        val supervisor = SupervisorJob()
+        val channel = OnDemandChannel<Int, Int>(supervisor) {
+            fail("ran")
+        }
+        assertEquals(1, supervisor.children.toList().size)
+        val job = supervisor.children.single()
+        assertEquals(true, job.isActive)
+
+        channel.finish()
+
+        assertEquals(0, supervisor.children.toList().size)
+        assertEquals(false, job.isActive)
+    }
+
+    ///////////////////////////////////////////////////////////////////////////
+    // Producer Coroutine — Lazy Initialization
+    ///////////////////////////////////////////////////////////////////////////
+
+    @Test
+    fun `producer coroutine won't start until expectMore`() {
+        val channel = OnDemandChannel<Int, Int> {
+            fail()
+        }
+        channel.finish()
+    }
+
+    @Test
+    fun `producer coroutine starts iff expectMore`() = runBlocking(Dispatchers.Default.limitedParallelism(1)) {
+        var started = false
+        val channel = OnDemandChannel<Int, Int>(currentCoroutineContext()) {
+            // (1)
+            assertEquals(false, started)
+            started = true
+            yield() // goto (2)
+            fail()
+        }
+        assertFalse { started }
+        channel.expectMore(1) // launches the job, but it won't execute due to single parallelism
+        yield() // goto (1)
+        // (2)
+        assertTrue { started }
+        channel.finish()
+    }
+}

+ 2 - 1
mirai-core/src/commonMain/kotlin/network/auth/AuthControl.kt

@@ -16,6 +16,7 @@ import net.mamoe.mirai.internal.utils.asUtilsLogger
 import net.mamoe.mirai.internal.utils.subLogger
 import net.mamoe.mirai.utils.ExceptionCollector
 import net.mamoe.mirai.utils.MiraiLogger
+import net.mamoe.mirai.utils.channels.OnDemandChannel
 import net.mamoe.mirai.utils.channels.OnDemandReceiveChannel
 import net.mamoe.mirai.utils.channels.ProducerFailureException
 import net.mamoe.mirai.utils.debug
@@ -39,7 +40,7 @@ internal class AuthControl(
     internal val exceptionCollector = ExceptionCollector()
 
     private val userDecisions: OnDemandReceiveChannel<Throwable?, SsoProcessorImpl.AuthMethod> =
-        OnDemandReceiveChannel(
+        OnDemandChannel(
             parentCoroutineContext,
             logger.subLogger("AuthControl/UserDecisions").asUtilsLogger()
         ) { _ ->