Ver código fonte

Improve PlatformSocket:
- Close channels on close
- Interrupitable `connect`
- Cancel `raed` on continuation cancellation

Him188 5 anos atrás
pai
commit
1963b3768d

+ 7 - 6
mirai-core/src/commonMain/kotlin/utils/PlatformSocket.kt

@@ -12,7 +12,6 @@ package net.mamoe.mirai.internal.utils
 import kotlinx.coroutines.Dispatchers
 import kotlinx.coroutines.Dispatchers
 import kotlinx.coroutines.runInterruptible
 import kotlinx.coroutines.runInterruptible
 import kotlinx.coroutines.suspendCancellableCoroutine
 import kotlinx.coroutines.suspendCancellableCoroutine
-import kotlinx.coroutines.withContext
 import kotlinx.io.core.ByteReadPacket
 import kotlinx.io.core.ByteReadPacket
 import kotlinx.io.core.Closeable
 import kotlinx.io.core.Closeable
 import kotlinx.io.streams.readPacketAtMost
 import kotlinx.io.streams.readPacketAtMost
@@ -42,6 +41,8 @@ internal class PlatformSocket : Closeable {
             socket.close()
             socket.close()
         }
         }
         thread.shutdownNow()
         thread.shutdownNow()
+        kotlin.runCatching { writeChannel.close() }
+        kotlin.runCatching { readChannel.close() }
     }
     }
 
 
     @PublishedApi
     @PublishedApi
@@ -51,7 +52,6 @@ internal class PlatformSocket : Closeable {
     internal lateinit var readChannel: BufferedInputStream
     internal lateinit var readChannel: BufferedInputStream
 
 
     suspend fun send(packet: ByteArray, offset: Int, length: Int) {
     suspend fun send(packet: ByteArray, offset: Int, length: Int) {
-        @Suppress("BlockingMethodInNonBlockingContext")
         runInterruptible(Dispatchers.IO) {
         runInterruptible(Dispatchers.IO) {
             writeChannel.write(packet, offset, length)
             writeChannel.write(packet, offset, length)
             writeChannel.flush()
             writeChannel.flush()
@@ -62,7 +62,6 @@ internal class PlatformSocket : Closeable {
      * @throws SendPacketInternalException
      * @throws SendPacketInternalException
      */
      */
     suspend fun send(packet: ByteReadPacket) {
     suspend fun send(packet: ByteReadPacket) {
-        @Suppress("BlockingMethodInNonBlockingContext")
         runInterruptible(Dispatchers.IO) {
         runInterruptible(Dispatchers.IO) {
             try {
             try {
                 writeChannel.writePacket(packet)
                 writeChannel.writePacket(packet)
@@ -79,18 +78,20 @@ internal class PlatformSocket : Closeable {
      * @throws ReadPacketInternalException
      * @throws ReadPacketInternalException
      */
      */
     suspend fun read(): ByteReadPacket = suspendCancellableCoroutine { cont ->
     suspend fun read(): ByteReadPacket = suspendCancellableCoroutine { cont ->
-        thread.submit {
+        val task = thread.submit {
             kotlin.runCatching {
             kotlin.runCatching {
                 readChannel.readPacketAtMost(Long.MAX_VALUE)
                 readChannel.readPacketAtMost(Long.MAX_VALUE)
             }.let {
             }.let {
                 cont.resumeWith(it)
                 cont.resumeWith(it)
             }
             }
         }
         }
+        cont.invokeOnCancellation {
+            kotlin.runCatching { task.cancel(true) }
+        }
     }
     }
 
 
     suspend fun connect(serverHost: String, serverPort: Int) {
     suspend fun connect(serverHost: String, serverPort: Int) {
-        @Suppress("BlockingMethodInNonBlockingContext")
-        withContext(Dispatchers.IO) {
+        runInterruptible(Dispatchers.IO) {
             socket = Socket(serverHost, serverPort)
             socket = Socket(serverHost, serverPort)
             readChannel = socket.getInputStream().buffered()
             readChannel = socket.getInputStream().buffered()
             writeChannel = socket.getOutputStream().buffered()
             writeChannel = socket.getOutputStream().buffered()