Browse Source

[core] fix SequenceBasedRoamingMessageImpl

StageGuard 2 years ago
parent
commit
db420c5b83

+ 24 - 79
mirai-core/src/commonMain/kotlin/contact/roaming/RoamingMessagesImplGroup.kt

@@ -9,109 +9,54 @@
 
 package net.mamoe.mirai.internal.contact.roaming
 
-import kotlinx.coroutines.flow.*
+import kotlinx.coroutines.flow.Flow
 import net.mamoe.mirai.contact.roaming.RoamingMessageFilter
 import net.mamoe.mirai.internal.contact.CommonGroupImpl
-import net.mamoe.mirai.internal.message.toMessageChainOnline
 import net.mamoe.mirai.internal.network.protocol.data.proto.MsgComm
+import net.mamoe.mirai.internal.network.protocol.packet.chat.TroopManagement
 import net.mamoe.mirai.internal.network.protocol.packet.chat.receive.MessageSvcPbGetGroupMsg
 import net.mamoe.mirai.message.data.MessageChain
-import net.mamoe.mirai.message.data.MessageSourceKind
 
 internal class RoamingMessagesImplGroup(
     override val contact: CommonGroupImpl
 ) : SequenceBasedRoamingMessagesImpl() {
-
-    /**
-     * time-based roaming without extending [TimeBasedRoamingMessagesImpl]
-     * because protocol MessageSvc.PbGetGroupMsg doesn't support querying via time.
-     * so this is actually sequence-based roaming.
-     */
     override suspend fun getMessagesIn(
         timeStart: Long,
         timeEnd: Long,
         filter: RoamingMessageFilter?
-    ): Flow<MessageChain> {
-        var currentSeq: Int = getLastMsgSeq() ?: return emptyFlow()
-        var lastOfferedSeq = -1
-
-        return flow {
-            while (true) {
-                val resp = getGroupMsg(currentSeq.toLong()) ?: break
-
-                // the message may be sorted increasing by message time,
-                // if so, additional sortBy will not take cost.
-
-                val messageTimeSequence = resp.msgElem.asSequence().map { it.time }
-
-                val maxTime = messageTimeSequence.max()
-
-
-                // we have fetched all messages
-                // note: maxTime = 0 means all fetched messages were recalled
-                if (maxTime < timeStart && maxTime != 0) break
-
-                emitAll(
-                    resp.msgElem.asSequence()
-                        .filter { lastOfferedSeq == -1 || it.msgHead.msgSeq < lastOfferedSeq }
-                        .filter { it.time in timeStart..timeEnd }
-                        .sortedByDescending { it.msgHead.msgSeq } // Ensure caller receives newer messages first
-                        .filter { filter.apply(it) } // Call filter after sort
-                        .asFlow()
-                        .map { listOf(it).toMessageChainOnline(bot, contact.id, MessageSourceKind.GROUP) }
-                )
-
-
-                currentSeq = resp.msgElem.first().msgHead.msgSeq
-                lastOfferedSeq = currentSeq
-            }
-        }
-    }
-
-    override suspend fun getMessagesBeforeFlow(
-        messageId: Int?,
-        filter: RoamingMessageFilter?
-    ): Flow<MessageChain> {
-        var currentSeq = messageId ?: (getLastMsgSeq() ?: return emptyFlow())
-
-        return flow {
-            while (true) {
-                val resp = getGroupMsg(currentSeq.toLong()) ?: break
-
-                emitAll(
-                    resp.msgElem.asSequence()
-                        .filter { getMessageSourceKindFromC2cCmdOrNull(it.msgHead.c2cCmd) != null } // ignore unsupported messages
-                        .sortedByDescending { it.time } // Ensure caller receiver newer messages first
-                        .filter { filter.apply(it) } // Call filter after sort
-                        .asFlow()
-                        .map { it.toMessageChainOnline(bot) }
-                )
+    ): Flow<MessageChain> = getMessagesImpl(
+        preFilter = { maxTime -> maxTime >= timeStart || maxTime == 0 },
+        preSortFilter = { msg -> msg.msgHead.msgTime in timeStart..timeEnd },
+        filter = filter
+    )
+
+    override suspend fun getLastMsgSeq(): Int? {
+        // Iterate from the newest message to find messages within [timeStart] and [timeEnd]
+        val lastMsgSeqResp = bot.network.sendAndExpect(
+            TroopManagement.GetGroupLastMsgSeq(
+                client = bot.client,
+                groupUin = contact.uin
+            )
+        )
 
-                currentSeq = resp.msgElem.minBy { it.time }.msgHead.msgSeq - 1
-            }
+        return when (lastMsgSeqResp) {
+            TroopManagement.GetGroupLastMsgSeq.Response.Failed -> null
+            is TroopManagement.GetGroupLastMsgSeq.Response.Success -> lastMsgSeqResp.seq
         }
     }
 
-    private suspend fun getGroupMsg(seq: Long): MessageSvcPbGetGroupMsg.Success? {
+    override suspend fun getMsg(seq: Int): List<MsgComm.Msg> {
         val resp = contact.bot.network.sendAndExpect(
             MessageSvcPbGetGroupMsg(
                 client = contact.bot.client,
                 groupUin = contact.uin,
-                messageSequence = seq,
+                messageSequence = seq.toLong(),
                 count = 20 // maximum 20
             )
         )
 
-        if (resp is MessageSvcPbGetGroupMsg.Failed) return null
-        resp as MessageSvcPbGetGroupMsg.Success
-        if (resp.msgElem.isEmpty()) return null
-
-        return resp
+        if (resp is MessageSvcPbGetGroupMsg.Failed) return listOf()
+        resp as MessageSvcPbGetGroupMsg.Success // stupid smart cast
+        return resp.msgElem
     }
-
-    private val MsgComm.Msg.time get() = msgHead.msgTime
-
-    private fun RoamingMessageFilter?.apply(
-        it: MsgComm.Msg
-    ) = this?.invoke(createRoamingMessage(it, listOf())) != false
 }

+ 49 - 18
mirai-core/src/commonMain/kotlin/contact/roaming/SequenceBasedRoamingMessagesImpl.kt

@@ -9,11 +9,12 @@
 
 package net.mamoe.mirai.internal.contact.roaming
 
-import kotlinx.coroutines.flow.Flow
+import kotlinx.coroutines.flow.*
 import net.mamoe.mirai.contact.roaming.RoamingMessageFilter
-import net.mamoe.mirai.internal.contact.uin
-import net.mamoe.mirai.internal.network.protocol.packet.chat.TroopManagement
+import net.mamoe.mirai.internal.message.toMessageChainOnline
+import net.mamoe.mirai.internal.network.protocol.data.proto.MsgComm
 import net.mamoe.mirai.message.data.MessageChain
+import net.mamoe.mirai.message.data.MessageSourceKind
 import net.mamoe.mirai.utils.Streamable
 
 internal sealed class SequenceBasedRoamingMessagesImpl : AbstractRoamingMessages() {
@@ -30,7 +31,7 @@ internal sealed class SequenceBasedRoamingMessagesImpl : AbstractRoamingMessages
         messageId: Int?,
         filter: RoamingMessageFilter?
     ): Streamable<MessageChain> {
-        val flow = getMessagesBeforeFlow(messageId, filter)
+        val flow = getMessagesImpl(messageId, preSortFilter = { true }, filter = filter)
         return object : Streamable<MessageChain> {
             override fun asFlow(): Flow<MessageChain> {
                 return flow
@@ -42,23 +43,53 @@ internal sealed class SequenceBasedRoamingMessagesImpl : AbstractRoamingMessages
         filter: RoamingMessageFilter?
     ): Flow<MessageChain> = getMessagesBefore().asFlow()
 
-    abstract suspend fun getMessagesBeforeFlow(
-        messageId: Int?,
+
+    /**
+     * get message sequences
+     * @param preFilter: filter before emitting message elements, break loop if false.
+     *  use it to predict if we fetched all messages. param1 is time of newest message.
+     * @param preSortFilter: message element filter, param is msgElem
+     * @param filter: user-defined roaming message filter
+     */
+    internal suspend fun getMessagesImpl(
+        initialSeq: Int? = null,
+        preFilter: (maxTime: Int) -> Boolean = { true },
+        preSortFilter: (msg: MsgComm.Msg) -> Boolean,
         filter: RoamingMessageFilter?
-    ): Flow<MessageChain>
+    ): Flow<MessageChain> {
+        var currentSeq: Int = initialSeq ?: getLastMsgSeq() ?: return emptyFlow()
+        var lastOfferedSeq = -1
+
+        return flow {
+            while (true) {
+                val msgElem = getMsg(currentSeq)
+                if (msgElem.isEmpty()) break
+
+                // the message may be sorted increasing by message time,
+                // if so, additional sortBy will not take cost.
+                val maxTime = msgElem.asSequence().map { it.msgHead.msgTime }.max()
+                if (!preFilter(maxTime)) break
 
-    internal suspend fun getLastMsgSeq(): Int? {
-        // Iterate from the newest message to find messages within [timeStart] and [timeEnd]
-        val lastMsgSeqResp = bot.network.sendAndExpect(
-            TroopManagement.GetGroupLastMsgSeq(
-                client = bot.client,
-                groupUin = contact.uin
-            )
-        )
+                emitAll(
+                    msgElem.asSequence()
+                        .filter { lastOfferedSeq == -1 || it.msgHead.msgSeq < lastOfferedSeq }
+                        .filter(preSortFilter)
+                        .sortedByDescending { it.msgHead.msgSeq } // Ensure caller receives newer messages first
+                        .filter { filter.apply(it) } // Call filter after sort
+                        .asFlow()
+                        .map { listOf(it).toMessageChainOnline(bot, contact.id, MessageSourceKind.GROUP) }
+                )
 
-        return when (lastMsgSeqResp) {
-            TroopManagement.GetGroupLastMsgSeq.Response.Failed -> null
-            is TroopManagement.GetGroupLastMsgSeq.Response.Success -> lastMsgSeqResp.seq
+                currentSeq = msgElem.first().msgHead.msgSeq
+                lastOfferedSeq = currentSeq
+            }
         }
     }
+
+    private fun RoamingMessageFilter?.apply(it: MsgComm.Msg) =
+        this?.invoke(createRoamingMessage(it, listOf())) != false
+
+    internal abstract suspend fun getLastMsgSeq(): Int?
+
+    internal abstract suspend fun getMsg(seq: Int): List<MsgComm.Msg>
 }