Prechádzať zdrojové kódy

Migrate ktor websocket extension

ryoii 2 rokov pred
rodič
commit
42e5281075

+ 9 - 1
mirai-api-http/src/main/kotlin/net/mamoe/mirai/api/http/adapter/reverse/client/WsClient.kt

@@ -21,6 +21,8 @@ import kotlinx.coroutines.launch
 import net.mamoe.mirai.api.http.adapter.reverse.Destination
 import net.mamoe.mirai.api.http.adapter.reverse.ReverseWebsocketAdapterSetting
 import net.mamoe.mirai.api.http.adapter.reverse.handleReverseWs
+import net.mamoe.mirai.api.http.adapter.ws.extension.FrameLogExtension
+import net.mamoe.mirai.api.http.context.MahContextHolder
 import net.mamoe.mirai.utils.MiraiLogger
 import net.mamoe.mirai.utils.warning
 import kotlin.coroutines.CoroutineContext
@@ -35,7 +37,13 @@ class WsClient(private var log: MiraiLogger) : CoroutineScope {
     var bindingSessionKey: String? = null
 
     private val client = HttpClient {
-        install(WebSockets)
+        install(WebSockets) {
+            extensions {
+                if (MahContextHolder.debug) {
+                    install(FrameLogExtension)
+                }
+            }
+        }
     }
 
     private var webSocketSession: DefaultClientWebSocketSession? = null

+ 7 - 20
mirai-api-http/src/main/kotlin/net/mamoe/mirai/api/http/adapter/ws/extension/FrameLogExtension.kt

@@ -2,33 +2,25 @@ package net.mamoe.mirai.api.http.adapter.ws.extension
 
 import io.ktor.util.*
 import io.ktor.websocket.*
-import net.mamoe.mirai.api.http.adapter.internal.serializer.jsonParseOrNull
-import net.mamoe.mirai.api.http.adapter.ws.dto.WsIncoming
 import net.mamoe.mirai.utils.MiraiLogger
 
-class FrameLogExtension(configuration: Configuration) :
-    WebSocketExtension<FrameLogExtension.Configuration> {
+class FrameLogExtension: WebSocketExtension<Unit> {
 
-    private val logger = configuration.logger.value
-    private val enable = configuration.enableAccessLog
+    private val logger = MiraiLogger.Factory.create(FrameLogExtension::class, "MAH Access")
     
     override val factory = FrameLogExtension
     override val protocols = emptyList<WebSocketExtensionHeader>()
 
     override fun clientNegotiation(negotiatedProtocols: List<WebSocketExtensionHeader>): Boolean {
-        
         return true
     }
 
     override fun serverNegotiation(requestedProtocols: List<WebSocketExtensionHeader>): List<WebSocketExtensionHeader> {
-        return emptyList()
+        return listOf(WebSocketExtensionHeader("frame-log", emptyList()))
     }
 
     override fun processIncomingFrame(frame: Frame): Frame {
-        if (enable) {
-            val commandWrapper = String(frame.data).jsonParseOrNull<WsIncoming>() ?: return frame
-            logger.debug("[incoming] $commandWrapper")
-        }
+        logger.debug("[incoming] ${(frame as Frame.Text).readText()})")
         return frame
     }
 
@@ -36,20 +28,15 @@ class FrameLogExtension(configuration: Configuration) :
         return frame
     }
 
-    class Configuration {
-        var logger = lazy { MiraiLogger.Factory.create(FrameLogExtension::class, "MAH Access") }
-        var enableAccessLog = false
-    }
-
-    companion object : WebSocketExtensionFactory<Configuration, FrameLogExtension> {
+    companion object : WebSocketExtensionFactory<Unit, FrameLogExtension> {
         override val key = AttributeKey<FrameLogExtension>("FRAME LOG")
         
         override val rsv1: Boolean = false
         override val rsv2: Boolean = false
         override val rsv3: Boolean = false
         
-        override fun install(config: Configuration.() -> Unit): FrameLogExtension {
-            return FrameLogExtension(Configuration().apply(config))
+        override fun install(config: Unit.() -> Unit): FrameLogExtension {
+            return FrameLogExtension()
         }
     }
 

+ 4 - 2
mirai-api-http/src/main/kotlin/net/mamoe/mirai/api/http/adapter/ws/router/base.kt

@@ -27,8 +27,10 @@ import net.mamoe.mirai.api.http.context.MahContextHolder
  */
 fun Application.websocketRouteModule(wsAdapter: WebsocketAdapter) {
     install(WebSockets) {
-        extensions { 
-            install(FrameLogExtension) { enableAccessLog = MahContextHolder.debug }
+        extensions {
+            if (MahContextHolder.debug) {
+                install(FrameLogExtension)
+            }
         }
     }
     wsRouter(wsAdapter)

+ 0 - 12
mirai-api-http/src/main/kotlin/net/mamoe/mirai/api/http/adapter/ws/router/utils.kt

@@ -9,7 +9,6 @@
 
 package net.mamoe.mirai.api.http.adapter.ws.router
 
-import io.ktor.server.application.*
 import io.ktor.server.routing.*
 import io.ktor.server.websocket.*
 import io.ktor.util.*
@@ -19,7 +18,6 @@ import net.mamoe.mirai.api.http.adapter.common.StateCode
 import net.mamoe.mirai.api.http.adapter.internal.serializer.toJson
 import net.mamoe.mirai.api.http.adapter.internal.serializer.toJsonElement
 import net.mamoe.mirai.api.http.adapter.ws.dto.WsOutgoing
-import net.mamoe.mirai.api.http.adapter.ws.extension.FrameLogExtension
 import net.mamoe.mirai.api.http.context.MahContextHolder
 
 
@@ -33,9 +31,6 @@ internal inline fun Route.miraiWebsocket(
         val sessionKey = call.request.headers["sessionKey"] ?: call.parameters["sessionKey"]
         val qq = (call.request.headers["qq"] ?: call.parameters["qq"])?.toLongOrNull()
 
-        // 注入无协商的扩展
-        installExtension(FrameLogExtension)
-
         // 校验
         if (MahContextHolder.enableVerify && MahContextHolder.sessionManager.verifyKey != verifyKey) {
             closeWithCode(StateCode.AuthKeyFail)
@@ -93,10 +88,3 @@ internal suspend fun DefaultWebSocketServerSession.closeWithCode(code: StateCode
     ))
     close(CloseReason(CloseReason.Codes.NORMAL, code.msg))
 }
-
-
-internal fun <T: WebSocketExtension<*>> WebSocketServerSession.installExtension(factory: WebSocketExtensionFactory<*, T>) {
-    application.plugin(WebSockets).extensionsConfig.build().find { it.factory.key == factory.key }?.let {
-        (extensions as MutableList<WebSocketExtension<*>>).add(it)
-    }
-}

+ 49 - 0
mirai-api-http/src/test/kotlin/net/mamoe/mirai/api/http/adapter/ws/extension/FrameLogExtensionTest.kt

@@ -0,0 +1,49 @@
+/*
+ * Copyright 2023 Mamoe Technologies and contributors.
+ *
+ * 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
+ * Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
+ *
+ * https://github.com/mamoe/mirai/blob/master/LICENSE
+ */
+
+package net.mamoe.mirai.api.http.adapter.ws.extension
+
+import io.ktor.client.plugins.websocket.*
+import io.ktor.server.testing.*
+import io.ktor.server.websocket.*
+import io.ktor.server.websocket.WebSockets
+import io.ktor.websocket.*
+import kotlin.test.Test
+import kotlin.test.assertEquals
+import kotlin.test.assertNotNull
+
+class FrameLogExtensionTest {
+
+    @Test
+    fun testFrameLogExtension() = testApplication {
+        install(WebSockets) {
+            extensions {
+                install(FrameLogExtension)
+            }
+        }
+
+        routing {
+            webSocket("/echo") {
+                assertNotNull(extensionOrNull(FrameLogExtension))
+                for (frame in incoming) {
+                    send(frame)
+                }
+            }
+        }
+
+        val wsClient = createClient { install(io.ktor.client.plugins.websocket.WebSockets) }
+        wsClient.ws("/echo") {
+            outgoing.send(Frame.Text("Hello"))
+
+            val receive = incoming.receive()
+            assertEquals(FrameType.TEXT, receive.frameType)
+            assertEquals("Hello", (receive as Frame.Text).readText())
+        }
+    }
+}