浏览代码

[core] Add PacketReplier DSL for core tests

Him188 2 年之前
父节点
当前提交
a988d442e2

+ 8 - 13
mirai-core/src/commonTest/kotlin/network/framework/AbstractCommonNHTest.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2019-2022 Mamoe Technologies and contributors.
+ * Copyright 2019-2023 Mamoe Technologies and contributors.
  *
  * 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
  * Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
@@ -9,6 +9,7 @@
 
 package net.mamoe.mirai.internal.network.framework
 
+import kotlinx.coroutines.CoroutineScope
 import net.mamoe.mirai.internal.QQAndroidBot
 import net.mamoe.mirai.internal.network.Packet
 import net.mamoe.mirai.internal.network.handler.*
@@ -32,6 +33,9 @@ internal abstract class TestCommonNetworkHandler(
         for (packetReplier in packetRepliers) {
             packetReplier.run {
                 object : PacketReplierContext {
+                    override val coroutineScope: CoroutineScope
+                        get() = CoroutineScope(coroutineContext)
+
                     override fun reply(incoming: IncomingPacket) {
                         collectReceived(incoming)
                     }
@@ -74,19 +78,10 @@ internal abstract class TestCommonNetworkHandler(
     fun addPacketReplier(packetReplier: PacketReplier) {
         packetRepliers.add(packetReplier)
     }
-}
 
-/**
- * 应答器, 模拟服务器返回.
- */
-internal fun interface PacketReplier {
-    fun PacketReplierContext.onSend(packet: OutgoingPacket)
-}
-
-internal interface PacketReplierContext {
-    fun reply(incoming: IncomingPacket)
-    fun reply(incoming: Packet)
-    fun reply(incoming: Throwable)
+    inline fun addPacketReplierDsl(crossinline action: PacketReplierDslBuilder.() -> Unit) {
+        packetRepliers.add(buildPacketReplier(action))
+    }
 }
 
 /**

+ 13 - 1
mirai-core/src/commonTest/kotlin/network/framework/AbstractCommonNHTestWithSelector.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2019-2022 Mamoe Technologies and contributors.
+ * Copyright 2019-2023 Mamoe Technologies and contributors.
  *
  * 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
  * Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
@@ -38,12 +38,19 @@ internal abstract class AbstractCommonNHTestWithSelector :
             override suspend fun createConnection(): PlatformConn {
                 return conn
             }
+        }.apply {
+            applyToInstances.forEach { it.invoke(this) }
         }
     }
 
     override val factory: NetworkHandlerFactory<TestSelectorNetworkHandler> =
         NetworkHandlerFactory { _, _ -> TestSelectorNetworkHandler(selector, bot) }
 
+
+    private val applyToInstances = mutableListOf<TestCommonNetworkHandler.() -> Unit>()
+    fun onEachNetworkInstance(action: TestCommonNetworkHandler.() -> Unit) {
+        applyToInstances.add(action)
+    }
 }
 
 internal class TestSelectorNetworkHandler(
@@ -54,6 +61,11 @@ internal class TestSelectorNetworkHandler(
     fun currentInstance() = selector.getCurrentInstanceOrCreate()
     fun currentInstanceOrNull() = selector.getCurrentInstanceOrNull()
 
+    private val applyToInstances = mutableListOf<TestCommonNetworkHandler.() -> Unit>()
+    fun onEachInstance(action: TestCommonNetworkHandler.() -> Unit) {
+        applyToInstances.add(action)
+    }
+
     override fun setStateClosed(exception: Throwable?): NetworkHandlerSupport.BaseStateImpl? {
         return selector.getCurrentInstanceOrCreate().setStateClosed(exception)
     }

+ 4 - 1
mirai-core/src/commonTest/kotlin/network/framework/AbstractMockNetworkHandlerTest.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2019-2022 Mamoe Technologies and contributors.
+ * Copyright 2019-2023 Mamoe Technologies and contributors.
  *
  * 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
  * Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
@@ -34,6 +34,9 @@ import kotlin.test.AfterTest
 import kotlin.test.assertEquals
 
 
+/**
+ * Test with mock [NetworkHandler], and without selector.
+ */
 internal abstract class AbstractMockNetworkHandlerTest : AbstractNetworkHandlerTest() {
     protected open fun createNetworkHandlerContext() = TestNetworkHandlerContext(bot, logger, components)
     protected open fun createNetworkHandler() = TestNetworkHandler(bot, createNetworkHandlerContext())

+ 6 - 1
mirai-core/src/commonTest/kotlin/network/framework/AbstractNetworkHandlerTest.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2019-2022 Mamoe Technologies and contributors.
+ * Copyright 2019-2023 Mamoe Technologies and contributors.
  *
  * 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
  * Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
@@ -15,6 +15,11 @@ import net.mamoe.mirai.utils.setSystemProp
 import kotlin.test.AfterTest
 import kotlin.test.BeforeTest
 
+/**
+ * @see AbstractCommonNHTest
+ * @see AbstractCommonNHTestWithSelector
+ * @see AbstractMockNetworkHandlerTest
+ */
 internal sealed class AbstractNetworkHandlerTest : AbstractTest() {
     ///////////////////////////////////////////////////////////////////////////
     // Defaults

+ 11 - 29
mirai-core/src/commonTest/kotlin/network/framework/AbstractRealNetworkHandlerTest.kt

@@ -13,15 +13,12 @@ package net.mamoe.mirai.internal.network.framework
 
 import kotlinx.coroutines.CoroutineScope
 import kotlinx.coroutines.SupervisorJob
-import net.mamoe.mirai.auth.BotAuthInfo
-import net.mamoe.mirai.auth.BotAuthResult
 import net.mamoe.mirai.internal.*
 import net.mamoe.mirai.internal.contact.uin
 import net.mamoe.mirai.internal.network.KeyWithCreationTime
 import net.mamoe.mirai.internal.network.KeyWithExpiry
 import net.mamoe.mirai.internal.network.WLoginSigInfo
 import net.mamoe.mirai.internal.network.WLoginSimpleInfo
-import net.mamoe.mirai.internal.network.auth.BotAuthSessionInternal
 import net.mamoe.mirai.internal.network.component.ComponentKey
 import net.mamoe.mirai.internal.network.component.ConcurrentComponentStorage
 import net.mamoe.mirai.internal.network.component.setAll
@@ -32,6 +29,7 @@ import net.mamoe.mirai.internal.network.handler.NetworkHandler.State
 import net.mamoe.mirai.internal.network.protocol.data.jce.SvcRespRegister
 import net.mamoe.mirai.internal.network.protocol.packet.login.StatSvc
 import net.mamoe.mirai.internal.test.runBlockingUnit
+import net.mamoe.mirai.internal.utils.crypto.QQEcdh
 import net.mamoe.mirai.internal.utils.subLogger
 import net.mamoe.mirai.utils.*
 import network.framework.components.TestEventDispatcherImpl
@@ -116,31 +114,6 @@ internal abstract class AbstractRealNetworkHandlerTest<H : NetworkHandler> : Abs
         set(SsoProcessorContext, SsoProcessorContextImpl(bot))
         set(SsoProcessor, object : TestSsoProcessor(bot) {
             override suspend fun login(handler: NetworkHandler) {
-                val botAuthInfo = object : BotAuthInfo {
-                    override val id: Long get() = bot.id
-                    override val deviceInfo: DeviceInfo
-                        get() = get(SsoProcessorContext).device
-                    override val configuration: BotConfiguration
-                        get() = bot.configuration
-                }
-                val rsp = object : BotAuthResult {}
-
-                val session = object : BotAuthSessionInternal() {
-                    override suspend fun authByPassword(passwordMd5: SecretsProtection.EscapedByteBuffer): BotAuthResult {
-                        return rsp
-                    }
-
-                    override suspend fun authByQRCode(): BotAuthResult {
-                        return rsp
-                    }
-
-                }
-
-                bot.account.authorization.authorize(session, botAuthInfo)
-                bot.account.accountSecretsKeyBuffer = SecretsProtection.EscapedByteBuffer(
-                    bot.account.authorization.calculateSecretsKey(botAuthInfo)
-                )
-
                 nhEvents.add(NHEvent.Login)
                 super.login(handler)
             }
@@ -202,6 +175,15 @@ internal abstract class AbstractRealNetworkHandlerTest<H : NetworkHandler> : Abs
             override fun attachJob(bot: AbstractBot, scope: CoroutineScope) {
             }
         })
+
+        set(EcdhInitialPublicKeyUpdater, object : EcdhInitialPublicKeyUpdater {
+            override suspend fun refreshInitialPublicKeyAndApplyEcdh() {
+            }
+
+            override fun getQQEcdh(): QQEcdh = QQEcdh()
+        })
+
+        set(AccountSecretsManager, MemoryAccountSecretsManager())
         // set(StateObserver, bot.run { stateObserverChain() })
     }
 
@@ -251,7 +233,7 @@ internal fun AbstractRealNetworkHandlerTest<*>.setSsoProcessor(action: suspend S
     }
 }
 
-private fun createWLoginSigInfo(
+internal fun createWLoginSigInfo(
     uin: Long,
     creationTime: Long = currentTimeSeconds(),
     random: Random = Random(1)

+ 135 - 0
mirai-core/src/commonTest/kotlin/network/framework/PacketReplier.kt

@@ -0,0 +1,135 @@
+/*
+ * Copyright 2019-2023 Mamoe Technologies and contributors.
+ *
+ * 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
+ * Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
+ *
+ * https://github.com/mamoe/mirai/blob/dev/LICENSE
+ */
+
+package net.mamoe.mirai.internal.network.framework
+
+import kotlinx.coroutines.CoroutineScope
+import kotlinx.coroutines.launch
+import net.mamoe.mirai.internal.network.Packet
+import net.mamoe.mirai.internal.network.protocol.packet.IncomingPacket
+import net.mamoe.mirai.internal.network.protocol.packet.OutgoingPacket
+import net.mamoe.mirai.internal.network.protocol.packet.OutgoingPacketFactory
+import kotlin.jvm.JvmName
+
+/**
+ * 应答器, 模拟服务器返回.
+ */
+internal fun interface PacketReplier {
+    fun PacketReplierContext.onSend(packet: OutgoingPacket)
+}
+
+internal inline fun buildPacketReplier(crossinline builderAction: PacketReplierDslBuilder.() -> Unit): PacketReplier {
+    return PacketReplierDslBuilder().apply { builderAction() }.build()
+}
+
+internal interface PacketReplierContext {
+    val context: PacketReplierContext get() = this
+    val coroutineScope: CoroutineScope
+
+    fun reply(incoming: IncomingPacket)
+    fun reply(incoming: Packet)
+    fun reply(incoming: Throwable)
+    fun ignore() {}
+}
+
+
+internal sealed class PacketReplierDecision {
+    data class Reply(val action: PacketReplierContext.(outgoingPacket: OutgoingPacket) -> Unit) :
+        PacketReplierDecision()
+
+    data object Ignore : PacketReplierDecision()
+}
+
+internal class PacketReplierDslBuilder {
+    val decisions: MutableList<PacketReplierDecision> = mutableListOf()
+
+    class On<T : Packet?>(
+        val fromFactories: List<OutgoingPacketFactory<T>>,
+    )
+
+    /**
+     * Expects the next packet to be exactly
+     */
+    fun <T : Packet?> expect(
+        vararg from: OutgoingPacketFactory<T>,
+    ): On<T> = On(from.toList())
+
+    fun <T : Packet?> expect(
+        vararg from: OutgoingPacketFactory<T>,
+        action: PacketReplierContext.(outgoingPacket: OutgoingPacket) -> Unit
+    ): Unit = On(from.toList()).invoke(action)
+
+    operator fun <T : Packet?> On<T>.invoke(
+        action: PacketReplierContext.(outgoingPacket: OutgoingPacket) -> Unit
+    ) {
+        decisions.add(PacketReplierDecision.Reply { outgoing ->
+            fromFactories
+                .find { it.commandName == outgoing.commandName }
+                ?.let {
+                    return@Reply action(this, outgoing)
+                }
+                ?: run {
+                    val factories = fromFactories.joinToString(prefix = "[", postfix = "]") { it.commandName }
+                    throw AssertionError(
+                        "Expected client to send a packet from factories $factories, but client sent ${outgoing.commandName}"
+                    )
+                }
+        })
+    }
+
+    @JvmName("replyPacket")
+    @OverloadResolutionByLambdaReturnType
+    inline infix fun <T : Packet?> On<T>.reply(crossinline lazyIncoming: () -> Packet) {
+        invoke { context.reply(lazyIncoming()) }
+    }
+
+    @JvmName("replyIncomingPacket")
+    @OverloadResolutionByLambdaReturnType
+    inline infix fun <T : Packet?> On<T>.reply(crossinline lazyIncoming: () -> IncomingPacket) {
+        invoke { context.reply(lazyIncoming()) }
+    }
+
+    @JvmName("replyThrowable")
+    @OverloadResolutionByLambdaReturnType
+    inline infix fun <T : Packet?> On<T>.reply(crossinline lazyIncoming: () -> Throwable) {
+        invoke { context.reply(lazyIncoming()) }
+    }
+
+    inline infix fun <T : Packet?> On<T>.ignore(crossinline lazy: () -> Unit) {
+        invoke {
+            lazy()
+            context.ignore()
+        }
+    }
+
+
+    /**
+     * Ignore the next packet.
+     */
+    fun ignore() {
+        decisions.add(PacketReplierDecision.Ignore)
+    }
+
+    internal fun build(): PacketReplier {
+        return PacketReplier { outgoing ->
+            val context = this
+            coroutineScope.launch {
+                when (val decision =
+                    decisions.removeFirstOrNull()
+                        ?: throw AssertionError("Client sent a packet ${outgoing.commandName} while not expected")
+                ) {
+                    is PacketReplierDecision.Ignore -> return@launch
+                    is PacketReplierDecision.Reply -> {
+                        decision.action.invoke(context, outgoing)
+                    }
+                }
+            }
+        }
+    }
+}

+ 6 - 0
mirai-core/src/commonTest/kotlin/network/impl/common/AccountSecretsTest.kt

@@ -17,12 +17,18 @@ import net.mamoe.mirai.internal.network.handler.NetworkHandler
 import net.mamoe.mirai.internal.test.runBlockingUnit
 import net.mamoe.mirai.internal.utils.accountSecretsFile
 import net.mamoe.mirai.utils.DeviceInfo
+import net.mamoe.mirai.utils.SecretsProtection
 import net.mamoe.mirai.utils.getRandomByteArray
 import net.mamoe.mirai.utils.writeBytes
 import kotlin.test.Test
 import kotlin.test.assertEquals
 
 internal class AccountSecretsTest : AbstractCommonNHTest() {
+    init {
+        overrideComponents.remove(AccountSecretsManager)
+        bot.account.accountSecretsKeyBuffer = SecretsProtection.EscapedByteBuffer(ByteArray(16))
+    }
+
     @Test
     fun `can login with no secrets`() = runBlockingUnit {
         val file = bot.configuration.accountSecretsFile()