Selaa lähdekoodia

Handle http request missing fields

ryoii 2 vuotta sitten
vanhempi
sitoutus
5a81385384

+ 5 - 7
mirai-api-http/src/main/kotlin/net/mamoe/mirai/api/http/adapter/http/router/dsl.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020 Mamoe Technologies and contributors.
+ * Copyright 2023 Mamoe Technologies and contributors.
  *
  * 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
  * Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
@@ -29,7 +29,6 @@ import net.mamoe.mirai.api.http.adapter.internal.dto.AuthedDTO
 import net.mamoe.mirai.api.http.adapter.internal.dto.BindDTO
 import net.mamoe.mirai.api.http.adapter.internal.dto.DTO
 import net.mamoe.mirai.api.http.adapter.internal.dto.VerifyDTO
-import net.mamoe.mirai.api.http.adapter.internal.serializer.jsonParseOrNull
 import net.mamoe.mirai.api.http.adapter.internal.serializer.toJson
 import net.mamoe.mirai.api.http.context.MahContext
 import net.mamoe.mirai.api.http.context.MahContextHolder
@@ -74,7 +73,7 @@ internal inline fun Route.routeWithHandle(path: String, method: HttpMethod, cros
 @KtorDsl
 internal inline fun Route.httpVerify(path: String, crossinline body: Strategy<VerifyDTO>) =
     routeWithHandle(path, HttpMethod.Post) {
-        val dto = context.receiveDTO<VerifyDTO>() ?: throw IllegalParamException()
+        val dto = context.receiveDTO<VerifyDTO>()
         this.body(dto)
     }
 
@@ -82,7 +81,7 @@ internal inline fun Route.httpVerify(path: String, crossinline body: Strategy<Ve
 @KtorDsl
 internal inline fun Route.httpBind(path: String, crossinline body: Strategy<BindDTO>) =
     routeWithHandle(path, HttpMethod.Post) {
-        val dto = context.receiveDTO<BindDTO>() ?: throw IllegalParamException()
+        val dto = context.receiveDTO<BindDTO>()
         body(dto)
     }
 
@@ -99,7 +98,7 @@ internal inline fun <reified T : AuthedDTO> Route.httpAuthedPost(
     path: String,
     crossinline body: Strategy<T>
 ) = routeWithHandle(path, HttpMethod.Post) {
-    val dto = context.receiveDTO<T>() ?: throw IllegalParamException()
+    val dto = context.receiveDTO<T>()
 
     getAuthedSession(dto.sessionKey).also { dto.session = it }
     this.body(dto)
@@ -168,8 +167,7 @@ internal suspend fun ApplicationCall.respondJson(json: String, status: HttpStatu
 /**
  * 接收 http body 指定类型 [T] 的 [DTO]
  */
-internal suspend inline fun <reified T : DTO> ApplicationCall.receiveDTO(): T? =
-    receive<String>().jsonParseOrNull()
+internal suspend inline fun <reified T : DTO> ApplicationCall.receiveDTO(): T = receive<T>()
 
 /**
  * 接收 http multi part 值类型

+ 24 - 1
mirai-api-http/src/main/kotlin/net/mamoe/mirai/api/http/adapter/internal/handler/exceptionHandle.kt

@@ -9,6 +9,9 @@
 
 package net.mamoe.mirai.api.http.adapter.internal.handler
 
+import io.ktor.server.plugins.*
+import kotlinx.serialization.ExperimentalSerializationApi
+import kotlinx.serialization.MissingFieldException
 import net.mamoe.mirai.api.http.adapter.common.*
 import net.mamoe.mirai.contact.BotIsBeingMutedException
 import net.mamoe.mirai.contact.MessageTooLargeException
@@ -35,6 +38,26 @@ internal fun Throwable.toStateCode(): StateCode = when (this) {
     is PermissionDeniedException -> StateCode.PermissionDenied
     is BotIsBeingMutedException -> StateCode.BotMuted
     is MessageTooLargeException -> StateCode.MessageTooLarge
+    is BadRequestException -> StateCode.IllegalAccess(findMissingFiled() ?: this.localizedMessage ?: "")
     is IllegalAccessException -> StateCode.IllegalAccess(this.message)
     else -> StateCode.InternalError(this.localizedMessage ?: "", this)
-}
+}
+
+@OptIn(ExperimentalSerializationApi::class)
+internal fun Throwable.findMissingFiled(): String? {
+    if (rootCause is MissingFieldException) {
+        return (rootCause as MissingFieldException)
+            .missingFields
+            .joinToString(prefix = "参数错误,缺少字段: ", separator = ", ")
+    }
+    return null
+}
+
+private val Throwable.rootCause: Throwable?
+    get() {
+        var rootCause: Throwable? = this
+        while (rootCause?.cause != null) {
+            rootCause = rootCause.cause
+        }
+        return rootCause
+    }

+ 153 - 0
mirai-api-http/src/test/kotlin/net/mamoe/mirai/api/http/adapter/http/plugin/ContentNegotiationJsonTest.kt

@@ -0,0 +1,153 @@
+/*
+ * Copyright 2023 Mamoe Technologies and contributors.
+ *
+ * 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
+ * Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
+ *
+ * https://github.com/mamoe/mirai/blob/master/LICENSE
+ */
+package net.mamoe.mirai.api.http.adapter.http.plugin
+
+import io.ktor.client.request.*
+import io.ktor.client.statement.*
+import io.ktor.http.*
+import io.ktor.serialization.kotlinx.json.*
+import io.ktor.server.application.*
+import io.ktor.server.plugins.contentnegotiation.*
+import io.ktor.server.request.*
+import io.ktor.server.response.*
+import io.ktor.server.routing.*
+import io.ktor.server.testing.*
+import kotlinx.serialization.Serializable
+import net.mamoe.mirai.api.http.adapter.internal.dto.parameter.LongTargetDTO
+import net.mamoe.mirai.api.http.adapter.internal.serializer.BuiltinJsonSerializer
+import kotlin.test.Test
+import kotlin.test.assertEquals
+
+class ContentNegotiationJsonTest {
+
+    @Test
+    fun testNormalConvert() = testApplication {
+        install(ContentNegotiation) { json(json=BuiltinJsonSerializer.buildJson()) }
+
+        val content = """{"target": 123}"""
+
+        routing {
+            post("/test") {
+                val dto = call.receive<LongTargetDTO>()
+                assertEquals(123, dto.target)
+
+                call.respond(HttpStatusCode.OK)
+            }
+        }
+
+        client.post("/test") {
+            contentType(ContentType.Application.Json)
+            setBody(content)
+        }.also {
+            assertEquals(HttpStatusCode.OK, it.status)
+        }
+    }
+
+    @Test
+    fun testMissingFieldConvert() = testApplication {
+        install(GlobalExceptionHandler)
+        install(ContentNegotiation) { json(json=BuiltinJsonSerializer.buildJson()) }
+
+        routing {
+            post("/test") {
+                call.receive<LongTargetDTO>()
+            }
+        }
+
+        client.post("/test") {
+            contentType(ContentType.Application.Json)
+            setBody("{}")
+        }.also {
+            assertEquals(HttpStatusCode.OK, it.status)
+            assertEquals("""{"code":400,"msg":"参数错误,缺少字段: target"}""", it.bodyAsText())
+        }
+    }
+
+    @Serializable
+    internal data class NestDTO(val nest: LongTargetDTO)
+
+    @Test
+    fun testNestMissingFieldConvert() = testApplication {
+        install(GlobalExceptionHandler)
+        install(ContentNegotiation) { json(json=BuiltinJsonSerializer.buildJson()) }
+
+        routing {
+            post("/test") {
+                call.receive<NestDTO>()
+            }
+        }
+
+        client.post("/test") {
+            contentType(ContentType.Application.Json)
+            setBody("{}")
+        }.also {
+            assertEquals(HttpStatusCode.OK, it.status)
+            assertEquals("""{"code":400,"msg":"参数错误,缺少字段: nest"}""", it.bodyAsText())
+        }
+    }
+
+    @Test
+    fun testErrorConvert() = testApplication {
+        install(GlobalExceptionHandler)
+        install(ContentNegotiation) { json(json=BuiltinJsonSerializer.buildJson()) }
+
+        routing {
+            post("/test") {
+                call.receive<LongTargetDTO>()
+            }
+        }
+
+        client.post("/test") {
+            contentType(ContentType.Application.Json)
+            setBody("---")
+        }.also {
+            assertEquals(HttpStatusCode.OK, it.status)
+            assertEquals("""{"code":400,"msg":"Illegal input"}""", it.bodyAsText())
+        }
+    }
+
+    @Test
+    fun testNullValueContentConvert() = testApplication {
+        install(GlobalExceptionHandler)
+        install(ContentNegotiation) { json(json=BuiltinJsonSerializer.buildJson()) }
+
+        routing {
+            post("/test") {
+                call.receive<NestDTO>()
+            }
+        }
+
+        client.post("/test") {
+            contentType(ContentType.Application.Json)
+            setBody("""{"nest": null}""")
+        }.also {
+            assertEquals(HttpStatusCode.OK, it.status)
+            assertEquals("""{"code":400,"msg":"Illegal input"}""", it.bodyAsText())
+        }
+    }
+
+    @Test
+    fun testEmptyContentConvert() = testApplication {
+        install(GlobalExceptionHandler)
+        install(ContentNegotiation) { json(json=BuiltinJsonSerializer.buildJson()) }
+
+        routing {
+            post("/test") {
+                call.receive<LongTargetDTO>()
+            }
+        }
+
+        client.post("/test") {
+            contentType(ContentType.Application.Json)
+        }.also {
+            assertEquals(HttpStatusCode.OK, it.status)
+            assertEquals("""{"code":400,"msg":"Illegal input"}""", it.bodyAsText())
+        }
+    }
+}