Parcourir la source

Avoid user injection

Karlatemp il y a 4 ans
Parent
commit
e1ca6dd6c9

+ 10 - 4
mirai-core/src/commonMain/kotlin/contact/AbstractUser.kt

@@ -243,19 +243,25 @@ internal suspend fun <C : User> SendMessageHandler<out C>.sendMessageImpl(
     preSendEventConstructor: (C, Message) -> MessagePreSendEvent,
     postSendEventConstructor: (C, MessageChain, Throwable?, MessageReceipt<C>?) -> MessagePostSendEvent<C>,
 ): MessageReceipt<C> {
-    require(!message.isContentEmpty()) { "message is empty" }
+    val isMiraiInternal = if (message is MessageChain) {
+        message.anyIsInstance<MiraiInternalMessageFlag>()
+    } else false
 
-    val chain = contact.broadcastMessagePreSendEvent(message, preSendEventConstructor)
+    require(isMiraiInternal || !message.isContentEmpty()) { "message is empty" }
+
+    val chain = contact.broadcastMessagePreSendEvent(message, isMiraiInternal, preSendEventConstructor)
 
     val result = this
-        .runCatching { sendMessage(message, chain, SendMessageStep.FIRST) }
+        .runCatching { sendMessage(message, chain, isMiraiInternal, SendMessageStep.FIRST) }
 
     if (result.isSuccess) {
         // logMessageSent(result.getOrNull()?.source?.plus(chain) ?: chain) // log with source
         contact.logMessageSent(chain)
     }
 
-    postSendEventConstructor(contact, chain, result.exceptionOrNull(), result.getOrNull()).broadcast()
+    if (!isMiraiInternal) {
+        postSendEventConstructor(contact, chain, result.exceptionOrNull(), result.getOrNull()).broadcast()
+    }
 
     return result.getOrThrow()
 }

+ 10 - 6
mirai-core/src/commonMain/kotlin/contact/GroupImpl.kt

@@ -154,21 +154,25 @@ internal class GroupImpl constructor(
     }
 
     override suspend fun sendMessage(message: Message): MessageReceipt<Group> {
-        require(!message.isContentEmpty()) { "message is empty" }
+        val isMiraiInternal = if (message is MessageChain) {
+            message.anyIsInstance<MiraiInternalMessageFlag>()
+        } else false
+
+        require(isMiraiInternal || !message.isContentEmpty()) { "message is empty" }
         check(!isBotMuted) { throw BotIsBeingMutedException(this) }
 
-        val chain = broadcastMessagePreSendEvent(message, ::GroupMessagePreSendEvent)
+        val chain = broadcastMessagePreSendEvent(message, isMiraiInternal, ::GroupMessagePreSendEvent)
 
         val result = GroupSendMessageHandler(this)
-            .runCatching { sendMessage(message, chain, SendMessageStep.FIRST) }
+            .runCatching { sendMessage(message, chain, isMiraiInternal, SendMessageStep.FIRST) }
 
         if (result.isSuccess) {
             // logMessageSent(result.getOrNull()?.source?.plus(chain) ?: chain) // log with source
             logMessageSent(chain)
         }
-
-        GroupMessagePostSendEvent(this, chain, result.exceptionOrNull(), result.getOrNull()).broadcast()
-
+        if (!isMiraiInternal) {
+            GroupMessagePostSendEvent(this, chain, result.exceptionOrNull(), result.getOrNull()).broadcast()
+        }
         return result.getOrThrow()
     }
 

+ 2 - 0
mirai-core/src/commonMain/kotlin/contact/GroupSendMessageImpl.kt

@@ -24,8 +24,10 @@ import net.mamoe.mirai.message.data.toMessageChain
  */
 internal suspend fun <C : Contact> C.broadcastMessagePreSendEvent(
     message: Message,
+    isMiraiInternal: Boolean,
     eventConstructor: (C, Message) -> MessagePreSendEvent,
 ): MessageChain {
+    if (isMiraiInternal) return message.toMessageChain()
     return kotlin.runCatching {
         eventConstructor(this, message).broadcast()
     }.onSuccess {

+ 10 - 4
mirai-core/src/commonMain/kotlin/contact/SendMessageHandler.kt

@@ -121,6 +121,7 @@ internal abstract class SendMessageHandler<C : Contact> {
         originalMessage: Message,
         transformedMessage: MessageChain,
         finalMessage: MessageChain,
+        isMiraiInternal: Boolean,
         step: SendMessageStep,
     ): MessageReceipt<C> {
         bot.components[MessageSvcSyncer].joinSync()
@@ -140,10 +141,10 @@ internal abstract class SendMessageHandler<C : Contact> {
                         if (resp is MessageSvcPbSendMsg.Response.MessageTooLarge) {
                             return when (step) {
                                 SendMessageStep.FIRST -> {
-                                    sendMessageImpl(originalMessage, transformedMessage, SendMessageStep.LONG_MESSAGE)
+                                    sendMessageImpl(originalMessage, transformedMessage, isMiraiInternal, SendMessageStep.LONG_MESSAGE)
                                 }
                                 SendMessageStep.LONG_MESSAGE -> {
-                                    sendMessageImpl(originalMessage, transformedMessage, SendMessageStep.FRAGMENTED)
+                                    sendMessageImpl(originalMessage, transformedMessage, isMiraiInternal, SendMessageStep.FRAGMENTED)
 
                                 }
                                 else -> {
@@ -312,6 +313,7 @@ internal suspend fun <C : Contact> SendMessageHandler<C>.transformSpecialMessage
 internal suspend fun <C : Contact> SendMessageHandler<C>.sendMessage(
     originalMessage: Message,
     transformedMessage: Message,
+    isMiraiInternal: Boolean,
     step: SendMessageStep,
 ): MessageReceipt<C> = sendMessageImpl(
     originalMessage,
@@ -320,6 +322,7 @@ internal suspend fun <C : Contact> SendMessageHandler<C>.sendMessage(
             preConversionTransformedMessage(transformedMessage)
         )
     ),
+    isMiraiInternal,
     step
 )
 
@@ -329,16 +332,19 @@ internal suspend fun <C : Contact> SendMessageHandler<C>.sendMessage(
 private suspend fun <C : Contact> SendMessageHandler<C>.sendMessageImpl(
     originalMessage: Message,
     transformedMessage: MessageChain,
+    isMiraiInternal: Boolean,
     step: SendMessageStep,
 ): MessageReceipt<C> { // Result cannot be in interface.
-    transformedMessage.verifySendingValid()
+    if (!isMiraiInternal && step == SendMessageStep.FIRST) {
+        transformedMessage.verifySendingValid()
+    }
     val chain = transformedMessage.convertToLongMessageIfNeeded(step)
 
     chain.findIsInstance<QuoteReply>()?.source?.ensureSequenceIdAvailable()
 
     postTransformActions(chain)
 
-    return sendMessagePacket(originalMessage, transformedMessage, chain, step)
+    return sendMessagePacket(originalMessage, transformedMessage, chain, isMiraiInternal, step)
 }
 
 internal sealed class UserSendMessageHandler<C : AbstractUser>(

+ 0 - 3
mirai-core/src/commonMain/kotlin/contact/util.kt

@@ -38,9 +38,6 @@ internal fun Message.verifySendingValid() {
     fun fail(msg: String): Nothing = throw IllegalArgumentException(msg)
     when (this) {
         is MessageChain -> {
-            if (contains(MiraiInternalMessageFlag)) {
-                return
-            }
             this.forEach { it.verifySendingValid() }
         }
         is FileMessage -> fail("Sending FileMessage is not in support")

+ 1 - 1
mirai-core/src/jvmTest/kotlin/message/data/MessageReceiptTest.kt

@@ -73,7 +73,7 @@ internal class MessageReceiptTest : AbstractTestWithMiraiImpl() {
                     listOf()
                 }
         }
-        val result = handler.sendMessage(message, message, SendMessageStep.FIRST)
+        val result = handler.sendMessage(message, message, false, SendMessageStep.FIRST)
 
         assertIs<ForwardMessage>(result.source.originalMessage[ForwardMessage])
         assertEquals(message, result.source.originalMessage)