|
|
@@ -12,10 +12,7 @@ package net.mamoe.mirai.utils.channels
|
|
|
import kotlinx.atomicfu.AtomicRef
|
|
|
import kotlinx.atomicfu.atomic
|
|
|
import kotlinx.atomicfu.loop
|
|
|
-import kotlinx.coroutines.CompletableDeferred
|
|
|
-import kotlinx.coroutines.cancel
|
|
|
-import kotlinx.coroutines.job
|
|
|
-import kotlinx.coroutines.launch
|
|
|
+import kotlinx.coroutines.*
|
|
|
import net.mamoe.mirai.utils.UtilsLogger
|
|
|
import net.mamoe.mirai.utils.childScope
|
|
|
import net.mamoe.mirai.utils.debug
|
|
|
@@ -30,14 +27,17 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
|
|
|
) : OnDemandReceiveChannel<T, V> {
|
|
|
private val coroutineScope = parentCoroutineContext.childScope("CoroutineOnDemandReceiveChannel")
|
|
|
|
|
|
- private val state: AtomicRef<ProducerState<T, V>> = atomic(ProducerState.JustInitialized())
|
|
|
+ private val state: AtomicRef<ChannelState<T, V>> = atomic(ChannelState.JustInitialized())
|
|
|
|
|
|
|
|
|
inner class Producer(
|
|
|
private val initialTicket: T,
|
|
|
) : OnDemandSendChannel<T, V> {
|
|
|
init {
|
|
|
- coroutineScope.launch {
|
|
|
+ // `UNDISPATCHED` with `yield()`: start the coroutine immediately in current thread,
|
|
|
+ // attaching Job to the coroutineScope, then `yield` the thread back, to complete `launch`.
|
|
|
+ coroutineScope.launch(start = CoroutineStart.UNDISPATCHED) {
|
|
|
+ yield()
|
|
|
try {
|
|
|
producerCoroutine(initialTicket)
|
|
|
} catch (_: CancellationException) {
|
|
|
@@ -51,21 +51,21 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
|
|
|
override suspend fun emit(value: V): T {
|
|
|
state.loop { state ->
|
|
|
when (state) {
|
|
|
- is ProducerState.Finished -> throw state.createAlreadyFinishedException(null)
|
|
|
- is ProducerState.Producing -> {
|
|
|
+ is ChannelState.Finished -> throw state.createAlreadyFinishedException(null)
|
|
|
+ is ChannelState.Producing -> {
|
|
|
val deferred = state.deferred
|
|
|
- val consumingState = ProducerState.Consuming(
|
|
|
+ val consumingState = ChannelState.Consuming(
|
|
|
state.producer,
|
|
|
state.deferred,
|
|
|
coroutineScope.coroutineContext
|
|
|
)
|
|
|
if (compareAndSetState(state, consumingState)) {
|
|
|
deferred.complete(value) // produce a value
|
|
|
- return consumingState.producerLatch.acquire() // wait for producer to consume the previous value.
|
|
|
+ return consumingState.producerLatch.await() // wait for producer to consume the previous value.
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- is ProducerState.ProducerReady -> {
|
|
|
+ is ChannelState.ProducerReady -> {
|
|
|
setStateProducing(state)
|
|
|
}
|
|
|
|
|
|
@@ -81,9 +81,9 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
|
|
|
override fun finish() {
|
|
|
state.loop { state ->
|
|
|
when (state) {
|
|
|
- is ProducerState.Finished -> throw state.createAlreadyFinishedException(null)
|
|
|
+ is ChannelState.Finished -> throw state.createAlreadyFinishedException(null)
|
|
|
else -> {
|
|
|
- if (compareAndSetState(state, ProducerState.Finished(state, null))) {
|
|
|
+ if (compareAndSetState(state, ChannelState.Finished(state, null))) {
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
@@ -92,20 +92,16 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private fun setStateProducing(state: ProducerState.ProducerReady<T, V>) {
|
|
|
- val deferred = CompletableDeferred<V>(coroutineScope.coroutineContext.job)
|
|
|
- if (!compareAndSetState(state, ProducerState.Producing(state.producer, deferred))) {
|
|
|
- deferred.cancel() // avoid leak
|
|
|
- }
|
|
|
- // loop again
|
|
|
+ private fun setStateProducing(state: ChannelState.ProducerReady<T, V>) {
|
|
|
+ compareAndSetState(state, ChannelState.Producing(state.producer, coroutineScope.coroutineContext.job))
|
|
|
}
|
|
|
|
|
|
private fun finishImpl(exception: Throwable?) {
|
|
|
state.loop { state ->
|
|
|
when (state) {
|
|
|
- is ProducerState.Finished -> {} // ignore
|
|
|
+ is ChannelState.Finished -> {} // ignore
|
|
|
else -> {
|
|
|
- if (compareAndSetState(state, ProducerState.Finished(state, exception))) {
|
|
|
+ if (compareAndSetState(state, ChannelState.Finished(state, exception))) {
|
|
|
val cancellationException = kotlinx.coroutines.CancellationException("Finished", exception)
|
|
|
coroutineScope.cancel(cancellationException)
|
|
|
return
|
|
|
@@ -115,24 +111,31 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private fun compareAndSetState(state: ProducerState<T, V>, newState: ProducerState<T, V>): Boolean {
|
|
|
+ private fun compareAndSetState(state: ChannelState<T, V>, newState: ChannelState<T, V>): Boolean {
|
|
|
return this.state.compareAndSet(state, newState).also {
|
|
|
logger.debug { "CAS: $state -> $newState: $it" }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
override suspend fun receiveOrNull(): V? {
|
|
|
- state.loop { state ->
|
|
|
- when (state) {
|
|
|
- is ProducerState.Consuming -> {
|
|
|
- // value is ready, switch state to Consumed
|
|
|
+ // don't use `.loop`:
|
|
|
+ // java.lang.VerifyError: Bad type on operand stack
|
|
|
+ // net/mamoe/mirai/utils/channels/CoroutineOnDemandReceiveChannel.receiveOrNull(Lkotlin/coroutines/Continuation;)Ljava/lang/Object; @103: getfield
|
|
|
+
|
|
|
+ while (true) {
|
|
|
+ when (val state = state.value) {
|
|
|
+ is ChannelState.Consuming -> {
|
|
|
+ // value is ready, now we consume the value
|
|
|
+
|
|
|
+ if (compareAndSetState(state, ChannelState.Consumed(state.producer, state.producerLatch))) {
|
|
|
+ // value is consumed, no contention, safe to retrieve
|
|
|
|
|
|
- if (compareAndSetState(state, ProducerState.Consumed(state.producer, state.producerLatch))) {
|
|
|
return try {
|
|
|
// This actually won't suspend, since the value is already completed
|
|
|
- // Just to be error-tolerating
|
|
|
+ // Just to be error-tolerating and re-throwing exceptions.
|
|
|
state.value.await()
|
|
|
- } catch (e: Exception) {
|
|
|
+ } catch (e: Throwable) {
|
|
|
+ // Producer failed to produce the previous value with exception
|
|
|
throw ProducerFailureException(cause = e)
|
|
|
}
|
|
|
}
|
|
|
@@ -140,7 +143,7 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
|
|
|
|
|
|
// note: actually, this case should be the first case (for code consistency) in `when`,
|
|
|
// but atomicfu 1.8.10 fails on this.
|
|
|
- is ProducerState.Producing<T, V> -> {
|
|
|
+ is ChannelState.Producing<T, V> -> {
|
|
|
// still producing value
|
|
|
|
|
|
state.deferred.await() // just wait for value, but does not return it.
|
|
|
@@ -151,7 +154,7 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
|
|
|
// Here we will loop again, to atomically switch to Consumed state.
|
|
|
}
|
|
|
|
|
|
- is ProducerState.Finished -> {
|
|
|
+ is ChannelState.Finished -> {
|
|
|
state.exception?.let { err ->
|
|
|
throw ProducerFailureException(cause = err)
|
|
|
}
|
|
|
@@ -166,32 +169,33 @@ internal class CoroutineOnDemandReceiveChannel<T, V>(
|
|
|
override fun expectMore(ticket: T): Boolean {
|
|
|
state.loop { state ->
|
|
|
when (state) {
|
|
|
- is ProducerState.JustInitialized -> {
|
|
|
- val ready = ProducerState.ProducerReady { Producer(ticket) }
|
|
|
+ is ChannelState.JustInitialized -> {
|
|
|
+ // start producer atomically
|
|
|
+ val ready = ChannelState.ProducerReady { Producer(ticket) }
|
|
|
if (compareAndSetState(state, ready)) {
|
|
|
ready.startProducerIfNotYet()
|
|
|
}
|
|
|
// loop again
|
|
|
}
|
|
|
|
|
|
- is ProducerState.ProducerReady -> {
|
|
|
+ is ChannelState.ProducerReady -> {
|
|
|
setStateProducing(state)
|
|
|
}
|
|
|
|
|
|
- is ProducerState.Producing -> return true // ok
|
|
|
+ is ChannelState.Producing -> return true // ok
|
|
|
|
|
|
- is ProducerState.Consuming -> throw IllegalProducerStateException(state) // a value is already ready
|
|
|
+ is ChannelState.Consuming -> throw IllegalProducerStateException(state) // a value is already ready
|
|
|
|
|
|
- is ProducerState.Consumed -> {
|
|
|
- if (compareAndSetState(state, ProducerState.ProducerReady { state.producer })) {
|
|
|
+ is ChannelState.Consumed -> {
|
|
|
+ if (compareAndSetState(state, ChannelState.ProducerReady { state.producer })) {
|
|
|
// wake up producer async.
|
|
|
- state.producerLatch.resumeWith(Result.success(ticket))
|
|
|
+ state.producerLatch.complete(ticket)
|
|
|
// loop again to switch state atomically to Producing.
|
|
|
// Do not do switch state directly here — async producer may race with you!
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- is ProducerState.Finished -> return false
|
|
|
+ is ChannelState.Finished -> return false
|
|
|
}
|
|
|
}
|
|
|
}
|