Procházet zdrojové kódy

Fix `ByteArray.loadAs`

Him188 před 4 roky
rodič
revize
be4423c993

+ 20 - 18
mirai-core/src/commonMain/kotlin/utils/io/serialization/utils.kt

@@ -71,8 +71,8 @@ internal fun <T : JceStruct> BytePacketBuilder.writeJceRequestPacket(
         version = version.toShort(),
         servantName = servantName,
         funcName = funcName,
-        sBuffer = jceRequestSBuffer(name, serializer, body)
-    )
+        sBuffer = jceRequestSBuffer(name, serializer, body),
+    ),
 )
 
 /**
@@ -110,16 +110,18 @@ private fun <K, V> Map<K, V>.singleValue(): V = this.entries.single().value
 internal fun <R> ByteReadPacket.decodeUniRequestPacketAndDeserialize(name: String? = null, block: (ByteArray) -> R): R {
     val request = this.readJceStruct(RequestPacket.serializer())
 
-    return block(if (name == null) when (request.version?.toInt() ?: 3) {
-        2 -> request.sBuffer.loadAs(RequestDataVersion2.serializer()).map.singleValue().singleValue()
-        3 -> request.sBuffer.loadAs(RequestDataVersion3.serializer()).map.singleValue()
-        else -> error("unsupported version ${request.version}")
-    } else when (request.version?.toInt() ?: 3) {
-        2 -> request.sBuffer.loadAs(RequestDataVersion2.serializer()).map.getOrElse(name) { error("cannot find $name") }
-            .singleValue()
-        3 -> request.sBuffer.loadAs(RequestDataVersion3.serializer()).map.getOrElse(name) { error("cannot find $name") }
-        else -> error("unsupported version ${request.version}")
-    })
+    return block(
+        if (name == null) when (request.version?.toInt() ?: 3) {
+            2 -> request.sBuffer.loadAs(RequestDataVersion2.serializer()).map.singleValue().singleValue()
+            3 -> request.sBuffer.loadAs(RequestDataVersion3.serializer()).map.singleValue()
+            else -> error("unsupported version ${request.version}")
+        } else when (request.version?.toInt() ?: 3) {
+            2 -> request.sBuffer.loadAs(RequestDataVersion2.serializer()).map.getOrElse(name) { error("cannot find $name") }
+                .singleValue()
+            3 -> request.sBuffer.loadAs(RequestDataVersion3.serializer()).map.getOrElse(name) { error("cannot find $name") }
+            else -> error("unsupported version ${request.version}")
+        },
+    )
 }
 
 internal fun <T : JceStruct> T.toByteArray(
@@ -143,8 +145,8 @@ internal fun <T : ProtoBuf> BytePacketBuilder.writeOidb(
             command = command,
             serviceType = serviceType,
             clientVersion = clientVersion,
-            bodybuffer = v.toByteArray(serializer)
-        )
+            bodybuffer = v.toByteArray(serializer),
+        ),
     )
 }
 
@@ -160,8 +162,8 @@ internal fun <T : ProtoBuf> T.toByteArray(serializer: SerializationStrategy<T>):
  */
 internal fun <T : ProtoBuf> ByteArray.loadAs(deserializer: DeserializationStrategy<T>, offset: Int = 0): T {
     if (offset != 0) {
-        require(this.size >= offset) { "size < offset" }
-        return this.copyOfRange(offset, this.lastIndex).loadAs(deserializer)
+        require(offset in offset..this.lastIndex) { "invalid offset: $offset" }
+        return this.copyOfRange(offset, this.size).loadAs(deserializer)
     }
     return KtProtoBuf.decodeFromByteArray(deserializer, this)
 }
@@ -239,8 +241,8 @@ internal fun <T : JceStruct> jceRequestSBuffer(
 ): ByteArray {
     return RequestDataVersion3(
         mapOf(
-            name to JCE_STRUCT_HEAD_OF_TAG_0 + jceStruct.toByteArray(serializer) + JCE_STRUCT_TAIL_OF_TAG_0
-        )
+            name to JCE_STRUCT_HEAD_OF_TAG_0 + jceStruct.toByteArray(serializer) + JCE_STRUCT_TAIL_OF_TAG_0,
+        ),
     ).toByteArray(RequestDataVersion3.serializer())
 }