Him188 6 лет назад
Родитель
Сommit
6590d8ade7

+ 180 - 58
mirai-core/src/commonMain/kotlin/net.mamoe.mirai/event/select.kt

@@ -7,6 +7,8 @@
  * https://github.com/mamoe/mirai/blob/master/LICENSE
  */
 
+@file:Suppress("DuplicatedCode")
+
 package net.mamoe.mirai.event
 
 import kotlinx.coroutines.*
@@ -134,10 +136,46 @@ abstract class MessageSelectBuilder<M : ContactMessage, R> @PublishedApi interna
     @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
     override infix fun MessageSelectionTimeoutChecker.reply(block: suspend () -> Any?): Nothing = error("prohibited")
 
+    @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
+    override infix fun MessageSelectionTimeoutChecker.reply(message: String): Nothing = error("prohibited")
+
+    @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
+    override infix fun MessageSelectionTimeoutChecker.reply(message: Message): Nothing = error("prohibited")
+
+    @JvmName("reply3")
+    @Suppress(
+        "INAPPLICABLE_JVM_NAME",
+        "unused",
+        "UNCHECKED_CAST",
+        "INVALID_CHARACTERS",
+        "NAME_CONTAINS_ILLEGAL_CHARS",
+        "FunctionName"
+    )
+    @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
+    override infix fun MessageSelectionTimeoutChecker.`->`(message: String): Nothing = error("prohibited")
+
+    @JvmName("reply3")
+    @Suppress(
+        "INAPPLICABLE_JVM_NAME",
+        "unused",
+        "UNCHECKED_CAST",
+        "INVALID_CHARACTERS",
+        "NAME_CONTAINS_ILLEGAL_CHARS",
+        "FunctionName"
+    )
+    @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
+    override infix fun MessageSelectionTimeoutChecker.`->`(message: Message): Nothing = error("prohibited")
+
     @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
     override infix fun MessageSelectionTimeoutChecker.quoteReply(block: suspend () -> Any?): Nothing =
         error("prohibited")
 
+    @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
+    override infix fun MessageSelectionTimeoutChecker.quoteReply(message: String): Nothing = error("prohibited")
+
+    @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
+    override infix fun MessageSelectionTimeoutChecker.quoteReply(message: Message): Nothing = error("prohibited")
+
     @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
     override fun String.containsReply(reply: String): Nothing = error("prohibited")
 
@@ -172,6 +210,16 @@ abstract class MessageSelectBuilder<M : ContactMessage, R> @PublishedApi interna
     @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
     override fun ListeningFilter.reply(message: Message) = error("prohibited")
 
+    @JvmName("reply3")
+    @Suppress("INAPPLICABLE_JVM_NAME", "INVALID_CHARACTERS", "NAME_CONTAINS_ILLEGAL_CHARS", "FunctionName")
+    @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
+    override fun ListeningFilter.`->`(toReply: String) = error("prohibited")
+
+    @JvmName("reply3")
+    @Suppress("INAPPLICABLE_JVM_NAME", "INVALID_CHARACTERS", "NAME_CONTAINS_ILLEGAL_CHARS", "FunctionName")
+    @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
+    override fun ListeningFilter.`->`(message: Message) = error("prohibited")
+
     @Deprecated("Using `reply` DSL in message selection is prohibited", level = DeprecationLevel.HIDDEN)
     override fun ListeningFilter.reply(replier: suspend M.(String) -> Any?) =
         error("prohibited")
@@ -221,7 +269,7 @@ abstract class MessageSelectBuilderUnit<M : ContactMessage, R> @PublishedApi int
         obtainCurrentCoroutineScope().launch {
             delay(timeoutMillis)
             val deferred = obtainCurrentDeferred() ?: return@launch
-            if (deferred.isActive) {
+            if (deferred.isActive && !deferred.isCompleted) {
                 deferred.completeExceptionally(exception())
             }
         }
@@ -236,7 +284,7 @@ abstract class MessageSelectBuilderUnit<M : ContactMessage, R> @PublishedApi int
         obtainCurrentCoroutineScope().launch {
             delay(timeoutMillis)
             val deferred = obtainCurrentDeferred() ?: return@launch
-            if (deferred.isActive) {
+            if (deferred.isActive && !deferred.isCompleted) {
                 deferred.complete(block())
             }
         }
@@ -281,6 +329,48 @@ abstract class MessageSelectBuilderUnit<M : ContactMessage, R> @PublishedApi int
         }
     }
 
+    @Suppress("unused", "UNCHECKED_CAST")
+    open infix fun MessageSelectionTimeoutChecker.reply(message: Message) {
+        return timeout(this.timeoutMillis) {
+            ownerMessagePacket.reply(message)
+            Unit as R
+        }
+    }
+
+    @Suppress("unused", "UNCHECKED_CAST")
+    open infix fun MessageSelectionTimeoutChecker.reply(message: String) {
+        return timeout(this.timeoutMillis) {
+            ownerMessagePacket.reply(message)
+            Unit as R
+        }
+    }
+
+    @JvmName("reply3")
+    @Suppress(
+        "INAPPLICABLE_JVM_NAME",
+        "unused",
+        "UNCHECKED_CAST",
+        "INVALID_CHARACTERS",
+        "NAME_CONTAINS_ILLEGAL_CHARS",
+        "FunctionName"
+    )
+    open infix fun MessageSelectionTimeoutChecker.`->`(message: Message) {
+        return this.reply(message)
+    }
+
+    @JvmName("reply3")
+    @Suppress(
+        "INAPPLICABLE_JVM_NAME",
+        "unused",
+        "UNCHECKED_CAST",
+        "INVALID_CHARACTERS",
+        "NAME_CONTAINS_ILLEGAL_CHARS",
+        "FunctionName"
+    )
+    open infix fun MessageSelectionTimeoutChecker.`->`(message: String) {
+        return this.reply(message)
+    }
+
     /**
      * 在超时后引用回复原消息
      *
@@ -297,6 +387,22 @@ abstract class MessageSelectBuilderUnit<M : ContactMessage, R> @PublishedApi int
         }
     }
 
+    @Suppress("unused", "UNCHECKED_CAST")
+    open infix fun MessageSelectionTimeoutChecker.quoteReply(message: Message) {
+        return timeout(this.timeoutMillis) {
+            ownerMessagePacket.quoteReply(message)
+            Unit as R
+        }
+    }
+
+    @Suppress("unused", "UNCHECKED_CAST")
+    open infix fun MessageSelectionTimeoutChecker.quoteReply(message: String) {
+        return timeout(this.timeoutMillis) {
+            ownerMessagePacket.quoteReply(message)
+            Unit as R
+        }
+    }
+
     /**
      * 当其他条件都不满足时回复原消息.
      *
@@ -359,16 +465,24 @@ internal suspend inline fun <R> withTimeoutOrCoroutineScope(
 ): R {
     require(timeoutMillis == -1L || timeoutMillis > 0) { "timeoutMillis must be -1 or > 0 " }
 
-    return if (timeoutMillis == -1L) {
-        coroutineScope(block)
-    } else {
-        withTimeout(timeoutMillis, block)
+    return withContext(ExceptionHandlerIgnoringCancellationException) {
+        if (timeoutMillis == -1L) {
+            coroutineScope(block)
+        } else {
+            withTimeout(timeoutMillis, block)
+        }
     }
 }
 
 @PublishedApi
 internal val SELECT_MESSAGE_STUB = Any()
 
+@PublishedApi
+internal val ExceptionHandlerIgnoringCancellationException = CoroutineExceptionHandler { _, throwable ->
+    if (throwable !is CancellationException) {
+        throw throwable
+    }
+}
 
 @PublishedApi
 @BuilderInference
@@ -379,7 +493,10 @@ internal suspend inline fun <reified T : ContactMessage, R> T.selectMessagesImpl
     @BuilderInference
     crossinline selectBuilder: @MessageDsl MessageSelectBuilderUnit<T, R>.() -> Unit
 ): R = withTimeoutOrCoroutineScope(timeoutMillis) {
-    val deferred = CompletableDeferred<R>()
+    var deferred: CompletableDeferred<R>? = CompletableDeferred()
+    coroutineContext[Job]!!.invokeOnCompletion {
+        deferred?.cancel()
+    }
 
     // ensure sequential invoking
     val listeners: MutableList<Pair<T.(String) -> Boolean, MessageListener<T, Any?>>> = mutableListOf()
@@ -421,14 +538,13 @@ internal suspend inline fun <reified T : ContactMessage, R> T.selectMessagesImpl
 
     // we don't have any way to reduce duplication yet,
     // until local functions are supported in inline functions
-    @Suppress("DuplicatedCode")
-    subscribeAlways<T> { event ->
+    @Suppress("DuplicatedCode") val subscribeAlways = subscribeAlways<T> { event ->
         if (!this.isContextIdenticalWith(this@selectMessagesImpl))
             return@subscribeAlways
 
         val toString = event.message.toString()
         listeners.forEach { (filter, listener) ->
-            if (deferred.isCompleted || !isActive)
+            if (deferred?.isCompleted == true || !isActive)
                 return@subscribeAlways
 
             if (filter.invoke(event, toString)) {
@@ -436,12 +552,12 @@ internal suspend inline fun <reified T : ContactMessage, R> T.selectMessagesImpl
                 val value = listener.invoke(event, toString)
                 if (value !== SELECT_MESSAGE_STUB) {
                     @Suppress("UNCHECKED_CAST")
-                    deferred.complete(value as R)
+                    deferred?.complete(value as R)
                     return@subscribeAlways
                 } else if (isUnit) { // value === stub
                     // unit mode: we can directly complete this selection
                     @Suppress("UNCHECKED_CAST")
-                    deferred.complete(Unit as R)
+                    deferred?.complete(Unit as R)
                 }
             }
         }
@@ -450,17 +566,21 @@ internal suspend inline fun <reified T : ContactMessage, R> T.selectMessagesImpl
             val value = listener.invoke(event, toString)
             if (value !== SELECT_MESSAGE_STUB) {
                 @Suppress("UNCHECKED_CAST")
-                deferred.complete(value as R)
+                deferred?.complete(value as R)
                 return@subscribeAlways
             } else if (isUnit) { // value === stub
                 // unit mode: we can directly complete this selection
                 @Suppress("UNCHECKED_CAST")
-                deferred.complete(Unit as R)
+                deferred?.complete(Unit as R)
             }
         }
     }
 
-    deferred.await().also { coroutineContext[Job]!!.cancelChildren() }
+    deferred!!.await().also {
+        subscribeAlways.complete()
+        deferred = null
+        coroutineContext.cancelChildren()
+    }
 }
 
 @Suppress("unused")
@@ -468,50 +588,43 @@ internal suspend inline fun <reified T : ContactMessage, R> T.selectMessagesImpl
 internal suspend inline fun <reified T : ContactMessage> T.whileSelectMessagesImpl(
     timeoutMillis: Long = -1,
     crossinline selectBuilder: @MessageDsl MessageSelectBuilder<T, Boolean>.() -> Unit
-) {
-    withTimeoutOrCoroutineScope(timeoutMillis) {
-        var deferred: CompletableDeferred<Boolean>? = CompletableDeferred()
+) = withTimeoutOrCoroutineScope(timeoutMillis) {
+    var deferred: CompletableDeferred<Boolean>? = CompletableDeferred()
+    coroutineContext[Job]!!.invokeOnCompletion {
+        deferred?.cancel()
+    }
 
-        // ensure sequential invoking
-        val listeners: MutableList<Pair<T.(String) -> Boolean, MessageListener<T, Any?>>> = mutableListOf()
-        val defaltListeners: MutableList<MessageListener<T, Any?>> = mutableListOf()
+    // ensure sequential invoking
+    val listeners: MutableList<Pair<T.(String) -> Boolean, MessageListener<T, Any?>>> = mutableListOf()
+    val defaultListeners: MutableList<MessageListener<T, Any?>> = mutableListOf()
 
-        // https://youtrack.jetbrains.com/issue/KT-37716
-        val outside = { filter: T.(String) -> Boolean, listener: MessageListener<T, Any?> ->
-            listeners += filter to listener
+    // https://youtrack.jetbrains.com/issue/KT-37716
+    val outside = { filter: T.(String) -> Boolean, listener: MessageListener<T, Any?> ->
+        listeners += filter to listener
+    }
+    object : MessageSelectBuilder<T, Boolean>(
+        this@whileSelectMessagesImpl,
+        SELECT_MESSAGE_STUB,
+        outside
+    ) {
+        override fun obtainCurrentCoroutineScope(): CoroutineScope = this@withTimeoutOrCoroutineScope
+        override fun obtainCurrentDeferred(): CompletableDeferred<Boolean>? = deferred
+        override fun default(onEvent: MessageListener<T, Boolean>) {
+            defaultListeners += onEvent
         }
-        object : MessageSelectBuilder<T, Boolean>(
-            this@whileSelectMessagesImpl,
-            SELECT_MESSAGE_STUB,
-            outside
-        ) {
-            override fun obtainCurrentCoroutineScope(): CoroutineScope = this@withTimeoutOrCoroutineScope
-            override fun obtainCurrentDeferred(): CompletableDeferred<Boolean>? = deferred
-            override fun default(onEvent: MessageListener<T, Boolean>) {
-                defaltListeners += onEvent
-            }
-        }.apply(selectBuilder)
+    }.apply(selectBuilder)
 
-        // ensure atomic completing
-        subscribeAlways<T>(concurrency = Listener.ConcurrencyKind.LOCKED) { event ->
-            if (!this.isContextIdenticalWith(this@whileSelectMessagesImpl))
-                return@subscribeAlways
+    // ensure atomic completing
+    val subscribeAlways = subscribeAlways<T>(concurrency = Listener.ConcurrencyKind.LOCKED) { event ->
+        if (!this.isContextIdenticalWith(this@whileSelectMessagesImpl))
+            return@subscribeAlways
 
-            val toString = event.message.toString()
-            listeners.forEach { (filter, listener) ->
-                if (deferred?.isCompleted != false || !isActive)
-                    return@subscribeAlways
+        val toString = event.message.toString()
+        listeners.forEach { (filter, listener) ->
+            if (deferred?.isCompleted != false || !isActive)
+                return@subscribeAlways
 
-                if (filter.invoke(event, toString)) {
-                    listener.invoke(event, toString).let { value ->
-                        if (value !== SELECT_MESSAGE_STUB) {
-                            deferred?.complete(value as Boolean)
-                            return@subscribeAlways // accept the first value only
-                        }
-                    }
-                }
-            }
-            defaltListeners.forEach { listener ->
+            if (filter.invoke(event, toString)) {
                 listener.invoke(event, toString).let { value ->
                     if (value !== SELECT_MESSAGE_STUB) {
                         deferred?.complete(value as Boolean)
@@ -520,11 +633,20 @@ internal suspend inline fun <reified T : ContactMessage> T.whileSelectMessagesIm
                 }
             }
         }
-
-        while (deferred?.await() == true) {
-            deferred = CompletableDeferred()
+        defaultListeners.forEach { listener ->
+            listener.invoke(event, toString).let { value ->
+                if (value !== SELECT_MESSAGE_STUB) {
+                    deferred?.complete(value as Boolean)
+                    return@subscribeAlways // accept the first value only
+                }
+            }
         }
-        deferred = null
-        coroutineContext[Job]!!.cancelChildren()
     }
+
+    while (deferred?.await() == true) {
+        deferred = CompletableDeferred()
+    }
+    subscribeAlways.complete()
+    deferred = null
+    coroutineContext.cancelChildren()
 }

+ 14 - 0
mirai-core/src/commonMain/kotlin/net.mamoe.mirai/event/subscribeMessages.kt

@@ -331,6 +331,20 @@ open class MessageSubscribersBuilder<M : ContactMessage, out Ret, R : RR, RR>(
         return content(filter) { reply(message);[email protected] }
     }
 
+    @JvmName("reply3")
+    @Suppress("INAPPLICABLE_JVM_NAME", "INVALID_CHARACTERS", "NAME_CONTAINS_ILLEGAL_CHARS", "FunctionName")
+    @SinceMirai("0.33.0")
+    open infix fun ListeningFilter.`->`(toReply: String): Ret {
+        return this.reply(toReply)
+    }
+
+    @JvmName("reply3")
+    @Suppress("INAPPLICABLE_JVM_NAME", "INVALID_CHARACTERS", "NAME_CONTAINS_ILLEGAL_CHARS", "FunctionName")
+    @SinceMirai("0.33.0")
+    open infix fun ListeningFilter.`->`(message: Message): Ret {
+        return this.reply(message)
+    }
+
     @SinceMirai("0.29.0")
     open infix fun ListeningFilter.reply(replier: (@MessageDsl suspend M.(String) -> Any?)): Ret {
         return content(filter) {