Просмотр исходного кода

Ensure for all MessageChain subclasses, `equals`, `hashCode` give consistent results.

Him188 3 лет назад
Родитель
Сommit
732e61e37d

+ 2 - 21
mirai-core-api/src/commonMain/kotlin/message/data/CombinedMessage.kt

@@ -13,18 +13,18 @@ import net.mamoe.mirai.message.data.visitor.MessageVisitor
 import net.mamoe.mirai.message.data.visitor.RecursiveMessageVisitor
 import net.mamoe.mirai.message.data.visitor.accept
 import net.mamoe.mirai.utils.MiraiInternalApi
-import net.mamoe.mirai.utils.isSameType
 
 /**
  * One after one, hierarchically.
  * @since 2.12
  */
 @MiraiInternalApi
+@Suppress("EXPOSED_SUPER_CLASS")
 public class CombinedMessage @MessageChainConstructor constructor(
     @MiraiInternalApi public val element: Message,
     @MiraiInternalApi public val tail: Message,
     @MiraiInternalApi public override val hasConstrainSingle: Boolean
-) : MessageChainImpl, List<SingleMessage> {
+) : AbstractMessageChain(), List<SingleMessage> {
     override fun <D, R> accept(visitor: MessageVisitor<D, R>, data: D): R {
         return visitor.visitCombinedMessage(this, data)
     }
@@ -165,25 +165,6 @@ public class CombinedMessage @MessageChainConstructor constructor(
     }
 
 
-    override fun equals(other: Any?): Boolean {
-        if (this === other) return true
-        if (!isSameType(this, other)) return false
-
-        if (element != other.element) return false
-        if (tail != other.tail) return false
-        if (hasConstrainSingle != other.hasConstrainSingle) return false
-
-        return true
-    }
-
-    override fun hashCode(): Int {
-        var result = element.hashCode()
-        result = 31 * result + tail.hashCode()
-        result = 31 * result + hasConstrainSingle.hashCode()
-        return result
-    }
-
-
     ///////////////////////////////////////////////////////////////////////////
     // slow operations
     ///////////////////////////////////////////////////////////////////////////

+ 3 - 3
mirai-core-api/src/commonMain/kotlin/message/data/Message.kt

@@ -164,7 +164,7 @@ public interface Message {
      * - ...
      *
      * @see toString 得到包含 mirai 消息元素代码的, 易读的字符串
-     * @see contentEquals
+     * @see chainEquals
      * @see Message.content Kotlin 扩展
      */
     public fun contentToString(): String
@@ -175,8 +175,8 @@ public interface Message {
      * [strict] 为 `true` 时, 还会额外判断每个消息元素的类型, 顺序和属性. 如 [Image] 会判断 [Image.imageId]
      *
      * **有关 [strict]:** 每个 [Image] 的 [contentToString] 都是 `"[图片]"`,
-     * 在 [strict] 为 `false` 时 [contentEquals] 会得到 `true`,
-     * 而为 `true` 时由于 [Image.imageId] 会被比较, 两张不同的图片的 [contentEquals] 会是 `false`.
+     * 在 [strict] 为 `false` 时 [chainEquals] 会得到 `true`,
+     * 而为 `true` 时由于 [Image.imageId] 会被比较, 两张不同的图片的 [chainEquals] 会是 `false`.
      *
      * @param ignoreCase 为 `true` 时忽略大小写
      */

+ 4 - 4
mirai-core-api/src/commonMain/kotlin/message/data/MessageChain.kt

@@ -363,7 +363,10 @@ public fun emptyMessageChain(): MessageChain = EmptyMessageChain
     replaceWith = ReplaceWith("emptyMessageChain()", "net.mamoe.mirai.message.data.emptyMessageChain")
 )
 @DeprecatedSinceMirai(warningSince = "2.12")
-public object EmptyMessageChain : MessageChain, List<SingleMessage> by emptyList(), MessageChainImpl {
+@Suppress("EXPOSED_SUPER_CLASS")
+public object EmptyMessageChain : MessageChain, List<SingleMessage> by emptyList(),
+    AbstractMessageChain(), DirectSizeAccess, DirectToStringAccess {
+
     override val size: Int get() = 0
 
     override fun toString(): String = ""
@@ -378,9 +381,6 @@ public object EmptyMessageChain : MessageChain, List<SingleMessage> by emptyList
     override fun appendMiraiCodeTo(builder: StringBuilder) {
     }
 
-    override fun equals(other: Any?): Boolean = other === this
-    override fun hashCode(): Int = 1
-
     override fun iterator(): Iterator<SingleMessage> = EmptyMessageChainIterator
 
     @Suppress("DeprecatedCallableAddReplaceWith")

+ 45 - 7
mirai-core-api/src/commonMain/kotlin/message/data/impl.kt

@@ -18,6 +18,8 @@ import net.mamoe.mirai.message.data.Image.Key.IMAGE_ID_REGEX
 import net.mamoe.mirai.message.data.Image.Key.IMAGE_RESOURCE_ID_REGEX_1
 import net.mamoe.mirai.message.data.Image.Key.IMAGE_RESOURCE_ID_REGEX_2
 import net.mamoe.mirai.message.data.visitor.MessageVisitor
+import net.mamoe.mirai.message.data.visitor.RecursiveMessageVisitor
+import net.mamoe.mirai.message.data.visitor.acceptChildren
 import net.mamoe.mirai.utils.MiraiInternalApi
 import net.mamoe.mirai.utils.asImmutable
 import net.mamoe.mirai.utils.castOrNull
@@ -60,20 +62,59 @@ internal fun Message.contentEqualsStrictImpl(another: Message, ignoreCase: Boole
 }
 
 
-internal sealed interface MessageChainImpl : MessageChain {
+internal sealed class AbstractMessageChain : MessageChain {
     /**
      * 去重算法 v1 - 2.12:
      * 在连接时若只有 0-1 方包含 [ConstrainSingle], 则使用 [CombinedMessage] 优化性能. 否则使用旧版复杂去重算法构造 [LinearMessageChainImpl].
      */
     @MiraiInternalApi
-    val hasConstrainSingle: Boolean
+    abstract val hasConstrainSingle: Boolean
+
+    override fun hashCode(): Int {
+        var result = 1
+        acceptChildren(object : RecursiveMessageVisitor<Unit>() {
+//            override fun visitMessageChain(messageChain: MessageChain, data: Unit) {
+//                result = 31 * result + messageChain.hashCode()
+//                // do not call children
+//            }
+
+            // ensure `messageChainOf(messageChainOf(AtAll))` and `messageChainOf(AtAll)` get same hash code.
+            override fun visitSingleMessage(message: SingleMessage, data: Unit) {
+                result = 31 * result + message.hashCode()
+                super.visitSingleMessage(message, data)
+            }
+        })
+
+        return result
+    }
+
+    override fun equals(other: Any?): Boolean {
+        if (other === null) return false
+        if (other !is MessageChain) return false
+        return chainEquals(this, other)
+    }
+
+    private companion object {
+        private fun chainEquals(a: MessageChain, b: MessageChain): Boolean {
+            if (a.size != b.size) return false // Averagely faster even if we may end up counting size.
+
+            val itr1 = a.iterator()
+            val itr2 = b.iterator()
+            for (singleMessage in itr1) {
+                if (!itr2.hasNext()) return false
+                val n = itr2.next()
+                if (singleMessage != n) return false
+            }
+            return true
+        }
+    }
 }
 
 internal val Message.hasConstrainSingle: Boolean
     get() {
         if (this is SingleMessage) return this is ConstrainSingle
         // now `this` is MessageChain
-        return this.castOrNull<MessageChainImpl>()?.hasConstrainSingle ?: true // for external type, assume they do
+        return this.castOrNull<AbstractMessageChain>()?.hasConstrainSingle ?: true // for external type, assume they do
     }
 
 /**
@@ -137,7 +178,7 @@ internal class LinearMessageChainImpl @MessageChainConstructor private construct
     @JvmField
     internal val delegate: List<SingleMessage>,
     override val hasConstrainSingle: Boolean
-) : Message, MessageChain, List<SingleMessage> by delegate, MessageChainImpl,
+) : Message, MessageChain, List<SingleMessage> by delegate, AbstractMessageChain(),
     DirectSizeAccess, DirectToStringAccess {
     override val size: Int get() = delegate.size
     override fun iterator(): Iterator<SingleMessage> = delegate.iterator()
@@ -148,9 +189,6 @@ internal class LinearMessageChainImpl @MessageChainConstructor private construct
     private val contentToStringTemp: String by lazy { this.delegate.joinToString("") { it.contentToString() } }
     override fun contentToString(): String = contentToStringTemp
 
-    override fun hashCode(): Int = delegate.hashCode()
-    override fun equals(other: Any?): Boolean = other is LinearMessageChainImpl && other.delegate == this.delegate
-
     override fun <D> acceptChildren(visitor: MessageVisitor<D, *>, data: D) {
         for (singleMessage in delegate) {
             singleMessage.accept(visitor, data)

+ 3 - 3
mirai-core-api/src/commonTest/kotlin/message.data/MessageChainImplTest.kt

@@ -16,10 +16,10 @@ internal class MessageChainImplTest {
     @OptIn(MessageChainConstructor::class)
     @Test
     fun allInternalImplementationsOfMessageChainAreMessageChainImpl() {
-        assertIs<MessageChainImpl>(CombinedMessage(AtAll, AtAll, false))
-        assertIs<MessageChainImpl>(emptyMessageChain())
+        assertIs<AbstractMessageChain>(CombinedMessage(AtAll, AtAll, false))
+        assertIs<AbstractMessageChain>(emptyMessageChain())
         val linear = LinearMessageChainImpl.create(listOf(AtAll), true)
         assertIs<LinearMessageChainImpl>(linear)
-        assertIs<MessageChainImpl>(linear)
+        assertIs<AbstractMessageChain>(linear)
     }
 }