Răsfoiți Sursa

Concurrent event processing

Him188 6 ani în urmă
părinte
comite
3859ce6daf

+ 20 - 9
mirai-core/src/commonMain/kotlin/net.mamoe.mirai/event/internal/InternalEventListeners.kt

@@ -10,6 +10,8 @@
 package net.mamoe.mirai.event.internal
 
 import kotlinx.coroutines.*
+import kotlinx.coroutines.sync.Mutex
+import kotlinx.coroutines.sync.withLock
 import net.mamoe.mirai.event.Event
 import net.mamoe.mirai.event.EventDisabled
 import net.mamoe.mirai.event.Listener
@@ -73,6 +75,9 @@ internal class Handler<in E : Event>
             ListeningStatus.LISTENING
         }
     }
+
+    @MiraiInternalAPI
+    override val lock: Mutex = Mutex()
 }
 
 /**
@@ -138,23 +143,29 @@ internal object EventListenerManager {
 
 // inline: NO extra Continuation
 @Suppress("UNCHECKED_CAST")
-internal suspend inline fun Event.broadcastInternal() {
-    if (EventDisabled) return
+internal suspend inline fun Event.broadcastInternal() = coroutineScope {
+    if (EventDisabled) return@coroutineScope
 
     EventLogger.info { "Event broadcast: $this" }
 
-    val listeners = this::class.listeners()
-    callAndRemoveIfRequired(listeners)
+    val listeners = this@broadcastInternal::class.listeners()
+    callAndRemoveIfRequired(this@broadcastInternal, listeners)
     listeners.supertypes.forEach {
-        callAndRemoveIfRequired(it.listeners())
+        callAndRemoveIfRequired(this@broadcastInternal, it.listeners())
     }
 }
 
-private suspend inline fun <E : Event> E.callAndRemoveIfRequired(listeners: EventListeners<E>) {
+@UseExperimental(MiraiInternalAPI::class)
+private fun <E : Event> CoroutineScope.callAndRemoveIfRequired(event: E, listeners: EventListeners<E>) {
     // atomic foreach
-    listeners.forEach {
-        if (it.onEvent(this) == ListeningStatus.STOPPED) {
-            listeners.remove(it) // atomic remove
+    listeners.forEachNode { node ->
+        launch {
+            val listener = node.nodeValue
+            listener.lock.withLock {
+                if (!node.isRemoved() && listener.onEvent(event) == ListeningStatus.STOPPED) {
+                    listeners.remove(listener) // atomic remove
+                }
+            }
         }
     }
 }

+ 7 - 0
mirai-core/src/commonMain/kotlin/net.mamoe.mirai/event/subscriber.kt

@@ -13,6 +13,7 @@ import kotlinx.coroutines.CompletableJob
 import kotlinx.coroutines.CoroutineExceptionHandler
 import kotlinx.coroutines.CoroutineScope
 import kotlinx.coroutines.GlobalScope
+import kotlinx.coroutines.sync.Mutex
 import net.mamoe.mirai.Bot
 import net.mamoe.mirai.event.events.BotEvent
 import net.mamoe.mirai.event.internal.Handler
@@ -51,6 +52,12 @@ enum class ListeningStatus {
  * 取消监听: [complete]
  */
 interface Listener<in E : Event> : CompletableJob {
+    /**
+     * [onEvent] 的锁
+     */
+    @MiraiInternalAPI
+    val lock: Mutex
+
     suspend fun onEvent(event: E): ListeningStatus
 }
 

+ 44 - 32
mirai-core/src/commonMain/kotlin/net.mamoe.mirai/utils/LockFreeLinkedList.kt

@@ -132,7 +132,7 @@ open class LockFreeLinkedList<E> {
         addLastNode(element.asNode(tail))
     }
 
-    private fun addLastNode(node: Node<E>) {
+    private fun addLastNode(node: LockFreeLinkedListNode<E>) {
         while (true) {
             val tail = head.iterateBeforeFirst { it === tail } // find the last node.
             if (tail.nextNodeRef.compareAndSet(this.tail, node)) { // ensure the last node is the last node
@@ -146,9 +146,9 @@ open class LockFreeLinkedList<E> {
      */
     @Suppress("DuplicatedCode")
     open fun addAll(iterable: Iterable<E>) {
-        var firstNode: Node<E>? = null
+        var firstNode: LockFreeLinkedListNode<E>? = null
 
-        var currentNode: Node<E>? = null
+        var currentNode: LockFreeLinkedListNode<E>? = null
         iterable.forEach {
             val nextNode = it.asNode(tail)
             if (firstNode == null) {
@@ -166,9 +166,9 @@ open class LockFreeLinkedList<E> {
      */
     @Suppress("DuplicatedCode")
     open fun addAll(iterable: Sequence<E>) {
-        var firstNode: Node<E>? = null
+        var firstNode: LockFreeLinkedListNode<E>? = null
 
-        var currentNode: Node<E>? = null
+        var currentNode: LockFreeLinkedListNode<E>? = null
         iterable.forEach {
             val nextNode = it.asNode(tail)
             if (firstNode == null) {
@@ -190,7 +190,7 @@ open class LockFreeLinkedList<E> {
         val node = LazyNode(tail, supplier)
 
         while (true) {
-            var current: Node<E> = head
+            var current: LockFreeLinkedListNode<E> = head
 
             findLastNode@ while (true) {
                 if (current.isValidElementNode() && filter(current.nodeValue))
@@ -208,13 +208,14 @@ open class LockFreeLinkedList<E> {
     }
 
     @PublishedApi // limitation by atomicfu
-    internal fun <E> Node<E>.compareAndSetNextNodeRef(expect: Node<E>, update: Node<E>) = this.nextNodeRef.compareAndSet(expect, update)
+    internal fun <E> LockFreeLinkedListNode<E>.compareAndSetNextNodeRef(expect: LockFreeLinkedListNode<E>, update: LockFreeLinkedListNode<E>) =
+        this.nextNodeRef.compareAndSet(expect, update)
 
     override fun toString(): String = joinToString()
 
     @Suppress("unused")
     internal fun getLinkStructure(): String = buildString {
-        head.childIterateReturnsLastSatisfying<Node<*>>({
+        head.childIterateReturnsLastSatisfying<LockFreeLinkedListNode<*>>({
             append(it.toString())
             append(" <- ")
             it.nextNode
@@ -240,7 +241,7 @@ open class LockFreeLinkedList<E> {
 
 
             // physically remove: try to fix the link
-            var next: Node<E> = toRemove.nextNode
+            var next: LockFreeLinkedListNode<E> = toRemove.nextNode
             while (next !== tail && next.isRemoved()) {
                 next = next.nextNode
             }
@@ -269,7 +270,7 @@ open class LockFreeLinkedList<E> {
 
 
             // physically remove: try to fix the link
-            var next: Node<E> = toRemove.nextNode
+            var next: LockFreeLinkedListNode<E> = toRemove.nextNode
             while (next !== tail && next.isRemoved()) {
                 next = next.nextNode
             }
@@ -282,7 +283,7 @@ open class LockFreeLinkedList<E> {
     /**
      * 动态计算的大小
      */
-    val size: Int get() = head.countChildIterate<Node<E>>({ it.nextNode }, { it !is Tail }) - 1 // empty head is always included
+    val size: Int get() = head.countChildIterate<LockFreeLinkedListNode<E>>({ it.nextNode }, { it !is Tail }) - 1 // empty head is always included
 
     open operator fun contains(element: E): Boolean {
         forEach { if (it == element) return true }
@@ -295,7 +296,7 @@ open class LockFreeLinkedList<E> {
     open fun isEmpty(): Boolean = head.allMatching { it.isValidElementNode().not() }
 
     inline fun forEach(block: (E) -> Unit) {
-        var node: Node<E> = head
+        var node: LockFreeLinkedListNode<E> = head
         while (true) {
             if (node === tail) return
             node.letValueIfValid(block)
@@ -303,6 +304,15 @@ open class LockFreeLinkedList<E> {
         }
     }
 
+    inline fun forEachNode(block: (LockFreeLinkedListNode<E>) -> Unit) {
+        var node: LockFreeLinkedListNode<E> = head
+        while (true) {
+            if (node === tail) return
+            node.letValueIfValid { block(node) }
+            node = node.nextNode
+        }
+    }
+
     @Suppress("unused")
     open fun clear() {
         val first = head.nextNode
@@ -638,14 +648,14 @@ open class LockFreeLinkedList<E> {
 // region internal
 
 @Suppress("NOTHING_TO_INLINE")
-private inline fun <E> E.asNode(nextNode: Node<E>): Node<E> = Node(nextNode, this)
+private inline fun <E> E.asNode(nextNode: LockFreeLinkedListNode<E>): LockFreeLinkedListNode<E> = LockFreeLinkedListNode(nextNode, this)
 
 /**
  * Self-iterate using the [iterator], until [mustBeTrue] returns `false`.
  * Returns the element at the last time when the [mustBeTrue] returns `true`
  */
 @PublishedApi
-internal inline fun <N : Node<*>> N.childIterateReturnsLastSatisfying(iterator: (N) -> N, mustBeTrue: (N) -> Boolean): N {
+internal inline fun <N : LockFreeLinkedListNode<*>> N.childIterateReturnsLastSatisfying(iterator: (N) -> N, mustBeTrue: (N) -> Boolean): N {
     if (!mustBeTrue(this)) return this
     var value: N = this
 
@@ -703,9 +713,9 @@ private inline fun <E> E.countChildIterate(iterator: (E) -> E, mustBeTrue: (E) -
 
 @PublishedApi
 internal class LazyNode<E> @PublishedApi internal constructor(
-    nextNode: Node<E>,
+    nextNode: LockFreeLinkedListNode<E>,
     private val valueComputer: () -> E
-) : Node<E>(nextNode, null) {
+) : LockFreeLinkedListNode<E>(nextNode, null) {
     private val initialized = atomic(false)
 
     private val value: AtomicRef<E?> = atomic(null)
@@ -727,20 +737,19 @@ internal class LazyNode<E> @PublishedApi internal constructor(
 }
 
 @PublishedApi
-internal class Head<E>(nextNode: Node<E>) : Node<E>(nextNode, null) {
+internal class Head<E>(nextNode: LockFreeLinkedListNode<E>) : LockFreeLinkedListNode<E>(nextNode, null) {
     override fun toString(): String = "Head"
     override val nodeValue: Nothing get() = error("Internal error: trying to get the value of a Head")
 }
 
 @PublishedApi
-internal open class Tail<E> : Node<E>(null, null) {
+internal open class Tail<E> : LockFreeLinkedListNode<E>(null, null) {
     override fun toString(): String = "Tail"
     override val nodeValue: Nothing get() = error("Internal error: trying to get the value of a Tail")
 }
 
-@PublishedApi
-internal open class Node<E>(
-    nextNode: Node<E>?,
+open class LockFreeLinkedListNode<E>(
+    nextNode: LockFreeLinkedListNode<E>?,
     private var initialNodeValue: E?
 ) {
     /*
@@ -754,10 +763,11 @@ internal open class Node<E>(
 
     open val nodeValue: E get() = initialNodeValue ?: error("Internal error: nodeValue is not initialized")
 
-    val removed = atomic(false)
+    @PublishedApi
+    internal val removed = atomic(false)
 
     @Suppress("LeakingThis")
-    val nextNodeRef: AtomicRef<Node<E>> = atomic(nextNode ?: this)
+    internal val nextNodeRef: AtomicRef<LockFreeLinkedListNode<E>> = atomic(nextNode ?: this)
 
     inline fun <R> letValueIfValid(block: (E) -> R): R? {
         if (!this.isValidElementNode()) {
@@ -770,7 +780,8 @@ internal open class Node<E>(
     /**
      * Short cut for accessing [nextNodeRef]
      */
-    var nextNode: Node<E>
+    @PublishedApi
+    internal var nextNode: LockFreeLinkedListNode<E>
         get() = nextNodeRef.value
         set(value) {
             nextNodeRef.value = value
@@ -779,7 +790,7 @@ internal open class Node<E>(
     /**
      * Returns the former node of the last node whence [filter] returns true
      */
-    inline fun iterateBeforeFirst(filter: (Node<E>) -> Boolean): Node<E> =
+    inline fun iterateBeforeFirst(filter: (LockFreeLinkedListNode<E>) -> Boolean): LockFreeLinkedListNode<E> =
         this.childIterateReturnsLastSatisfying({ it.nextNode }, { !filter(it) })
 
     /**
@@ -788,7 +799,8 @@ internal open class Node<E>(
      * Head, which is this, is also being tested.
      * [Tail], is not being tested.
      */
-    inline fun allMatching(condition: (Node<E>) -> Boolean): Boolean = this.childIterateReturnsLastSatisfying({ it.nextNode }, condition) !is Tail
+    inline fun allMatching(condition: (LockFreeLinkedListNode<E>) -> Boolean): Boolean =
+        this.childIterateReturnsLastSatisfying({ it.nextNode }, condition) !is Tail
 
     /**
      * Stop on and returns the former element of the element that is [equals] to the [element]
@@ -796,23 +808,23 @@ internal open class Node<E>(
      * E.g.: for `head <- 1 <- 2 <- 3 <- tail`, `iterateStopOnNodeValue(2)` returns the node whose value is 1
      */
     @Suppress("NOTHING_TO_INLINE")
-    internal inline fun iterateBeforeNodeValue(element: E): Node<E> = this.iterateBeforeFirst { it.isValidElementNode() && it.nodeValue == element }
+    internal inline fun iterateBeforeNodeValue(element: E): LockFreeLinkedListNode<E> =
+        this.iterateBeforeFirst { it.isValidElementNode() && it.nodeValue == element }
 
 }
 
-@PublishedApi // DO NOT INLINE: ATOMIC OPERATION
-internal fun <E> Node<E>.isRemoved() = this.removed.value
+fun <E> LockFreeLinkedListNode<E>.isRemoved() = this.removed.value
 
 @PublishedApi
 @Suppress("NOTHING_TO_INLINE")
-internal inline fun Node<*>.isValidElementNode(): Boolean = !isHead() && !isTail() && !isRemoved()
+internal inline fun LockFreeLinkedListNode<*>.isValidElementNode(): Boolean = !isHead() && !isTail() && !isRemoved()
 
 @PublishedApi
 @Suppress("NOTHING_TO_INLINE")
-internal inline fun Node<*>.isHead(): Boolean = this is Head
+internal inline fun LockFreeLinkedListNode<*>.isHead(): Boolean = this is Head
 
 @PublishedApi
 @Suppress("NOTHING_TO_INLINE")
-internal inline fun Node<*>.isTail(): Boolean = this is Tail
+internal inline fun LockFreeLinkedListNode<*>.isTail(): Boolean = this is Tail
 
 // end region