From e2a99924a4f9ad1c14a7a7f2c939650f8f3f17bb Mon Sep 17 00:00:00 2001 From: Ylian Saint-Hilaire Date: Fri, 2 Jul 2021 12:40:36 -0700 Subject: [PATCH] Fixed TCP tunneling flow control. --- MeshMapper.cs | 41 +++++++++-- WebSocketClient.cs | 168 +++++++++++++++++++++++++-------------------- 2 files changed, 128 insertions(+), 81 deletions(-) diff --git a/MeshMapper.cs b/MeshMapper.cs index 9b71e7d..7f2417c 100644 --- a/MeshMapper.cs +++ b/MeshMapper.cs @@ -223,7 +223,9 @@ namespace MeshCentralRouter wc.onStateChanged += Wc_onStateChanged; wc.onBinaryData += Wc_onBinaryData; wc.onStringData += Wc_onStringData; + wc.onSendOk += Wc_onSendOk; } + private void ConnectWS(UdpClient client, int counter) { webSocketClient wc = new webSocketClient(); @@ -237,6 +239,27 @@ namespace MeshCentralRouter wc.onStateChanged += Wc_onStateChanged; wc.onBinaryData += Wc_onBinaryData; wc.onStringData += Wc_onStringData; + wc.onSendOk += Wc_onSendOk; + } + + private void Wc_onSendOk(webSocketClient sender) + { + if (sender.tag.GetType() == typeof(TcpClient)) + { + // This is a TCP client, if it's not reading now, start reading + if (sender.tag2 == null) return; + object[] args = sender.tag2; + sender.tag2 = null; + MeshMapper mm = (MeshMapper)args[0]; + webSocketClient wc = (webSocketClient)args[1]; + TcpClient client = (TcpClient)args[2]; + byte[] buf = (byte[])args[3]; + try { client.GetStream().BeginRead(buf, 0, buf.Length, new AsyncCallback(ClientEndReadWS), new object[] { mm, wc, client, buf }); } catch (Exception) { } + } + if ((sender.tag.GetType() == typeof(UdpClient)) && (sender.endpoint != null)) + { + // This is a UDP socket, do nothing since it's always reading + } } private void Wc_onStateChanged(webSocketClient sender, webSocketClient.ConnectionStates state) @@ -299,8 +322,7 @@ namespace MeshCentralRouter { case "ping": { - // Send pong back - try { sender.SendString("{\"ctrlChannel\":\"102938\",\"type\":\"ping\"}"); } catch (Exception) { } + // We can't respond to a ping with a pong in this case since it will be relayed and corrupt the data channel. break; } case "pong": @@ -323,7 +345,10 @@ namespace MeshCentralRouter { // Write: WS --> TCP TcpClient client = (TcpClient)sender.tag; - if (client != null) { try { client.GetStream().Write(data, offset, length); } catch (Exception) { } } + if (client != null) { + sender.Pause(); // Pause reading from the websocket until the data is sent on the TCP client + client.GetStream().BeginWrite(data, offset, length, new AsyncCallback(ClientEndWrite), sender); + } } if ((sender.tag.GetType() == typeof(UdpClient)) && (sender.endpoint != null)) { @@ -342,6 +367,11 @@ namespace MeshCentralRouter } } + private void ClientEndWrite(IAsyncResult ar) + { + // TCP Client finished sending data, read more from the websocket + ((webSocketClient)ar.AsyncState).Resume(); + } // Read from the local client private void ClientEndReadWS(IAsyncResult ar) @@ -369,8 +399,8 @@ namespace MeshCentralRouter try { mm.bytesToServer += len; - mm.bytesToServerCompressed += wc.SendBinary(buf, 0, len); // TODO: Do Async - try { client.GetStream().BeginRead(buf, 0, buf.Length, new AsyncCallback(ClientEndReadWS), new object[] { mm, wc, client, buf }); } catch (Exception) { } + wc.tag2 = args; // When the websocket SendOK is triggered, read more data from the TCP client. + mm.bytesToServerCompressed += wc.SendBinary(buf, 0, len); } catch (Exception) { @@ -380,6 +410,7 @@ namespace MeshCentralRouter } else { + Debug("#" + counter + ": ClientEndRead(" + len + ") - Disconnect"); ShutdownClients(client, null, wc, counter); return; } diff --git a/WebSocketClient.cs b/WebSocketClient.cs index 61fd8a9..b504478 100644 --- a/WebSocketClient.cs +++ b/WebSocketClient.cs @@ -63,9 +63,11 @@ namespace MeshCentralRouter private bool readPaused = false; private bool shouldRead = false; private RNGCryptoServiceProvider CryptoRandom = new RNGCryptoServiceProvider(); + private object mainLock = new object(); // Outside variables public object tag = null; + public object[] tag2 = null; public int id = 0; public bool tunneling = false; public IPEndPoint endpoint; @@ -351,17 +353,20 @@ namespace MeshCentralRouter readBuffer = readBuffer2; } - // Receive more data - if (readPaused == false) + lock (mainLock) { - if (wsstream != null) + // Receive more data + if (readPaused == false) { - try { wsstream.BeginRead(readBuffer, readBufferLen, readBuffer.Length - readBufferLen, new AsyncCallback(OnTlsDataSink), this); } catch (Exception) { } + if (wsstream != null) + { + try { wsstream.BeginRead(readBuffer, readBufferLen, readBuffer.Length - readBufferLen, new AsyncCallback(OnTlsDataSink), this); } catch (Exception) { } + } + } + else + { + shouldRead = true; } - } - else - { - shouldRead = true; } } private void WriteWebSocketAsyncDone(IAsyncResult ar) @@ -609,77 +614,82 @@ namespace MeshCentralRouter // Fragment op code (129 = text, 130 = binary) public int SendFragment(byte[] data, int offset, int len, byte op) { - if (state != ConnectionStates.Connected) return 0; - byte[] buf; - - // If deflate is active, attempt to compress the data here. - if ((deflateMemory != null) && (len > 32) && (AllowCompression)) + lock (mainLock) { - deflateMemory.SetLength(0); - deflateMemory.Write(inflateStart, 0, 14); - DeflateStream deflate = new DeflateStream(deflateMemory, CompressionMode.Compress, true); - deflate.Write(data, offset, len); - deflate.Dispose(); - deflate = null; - if (deflateMemory.Length < len) + if (state != ConnectionStates.Connected) return 0; + byte[] buf; + + // If deflate is active, attempt to compress the data here. + if ((deflateMemory != null) && (len > 32) && (AllowCompression)) + { + deflateMemory.SetLength(0); + deflateMemory.Write(inflateStart, 0, 14); + DeflateStream deflate = new DeflateStream(deflateMemory, CompressionMode.Compress, true); + deflate.Write(data, offset, len); + deflate.Dispose(); + deflate = null; + if (deflateMemory.Length < len) + { + // Use the compressed data + int newlen = (int)deflateMemory.Length; + buf = deflateMemory.GetBuffer(); + len = newlen - 14; + op |= 0x40; // Add compression op + } + else + { + // Don't use the compress data + // Convert the string into a buffer with 4 byte of header space. + buf = new byte[14 + len]; + Array.Copy(data, offset, buf, 14, len); + } + } + else { - // Use the compressed data - int newlen = (int)deflateMemory.Length; - buf = deflateMemory.GetBuffer(); - len = newlen - 14; - op |= 0x40; // Add compression op - } else { - // Don't use the compress data // Convert the string into a buffer with 4 byte of header space. buf = new byte[14 + len]; - Array.Copy(data, offset, buf, 14, len); + if (len > 0) { Array.Copy(data, offset, buf, 14, len); } } - } - else - { - // Convert the string into a buffer with 4 byte of header space. - buf = new byte[14 + len]; - if (len > 0) { Array.Copy(data, offset, buf, 14, len); } - } - // Check that everything is ok - if (len < 0) { Dispose(); return 0; } + // Check that everything is ok + if (len < 0) { Dispose(); return 0; } - // Set the mask to a cryptographic random value and XOR the data - byte[] rand = new byte[4]; - CryptoRandom.GetBytes(rand); - Array.Copy(rand, 0, buf, 10, 4); - for (int x = 0; x < len; x++) { buf[x + 14] ^= rand[x % 4]; } + // Set the mask to a cryptographic random value and XOR the data + byte[] rand = new byte[4]; + CryptoRandom.GetBytes(rand); + Array.Copy(rand, 0, buf, 10, 4); + for (int x = 0; x < len; x++) { buf[x + 14] ^= rand[x % 4]; } - if (len < 126) - { - // Small fragment - buf[8] = op; - buf[9] = (byte)((len & 0x7F) + 128); // Add 128 to indicate the mask is present - SendData(buf, 8, len + 6); - } - else if (len < 65535) - { - // Medium fragment - buf[6] = op; - buf[7] = 126 + 128; // Add 128 to indicate the mask is present - buf[8] = (byte)((len >> 8) & 0xFF); - buf[9] = (byte)(len & 0xFF); - SendData(buf, 6, len + 8); - } - else - { - // Large fragment - buf[0] = op; - buf[1] = 127 + 128; // Add 128 to indicate the mask is present - buf[6] = (byte)((len >> 24) & 0xFF); - buf[7] = (byte)((len >> 16) & 0xFF); - buf[8] = (byte)((len >> 8) & 0xFF); - buf[9] = (byte)(len & 0xFF); - SendData(buf, 0, len + 14); - } + if (len < 126) + { + // Small fragment + buf[8] = op; + buf[9] = (byte)((len & 0x7F) + 128); // Add 128 to indicate the mask is present + SendData(buf, 8, len + 6); + } + else if (len < 65535) + { + // Medium fragment + buf[6] = op; + buf[7] = 126 + 128; // Add 128 to indicate the mask is present + buf[8] = (byte)((len >> 8) & 0xFF); + buf[9] = (byte)(len & 0xFF); + SendData(buf, 6, len + 8); + } + else + { + // Large fragment + buf[0] = op; + buf[1] = 127 + 128; // Add 128 to indicate the mask is present + buf[6] = (byte)((len >> 24) & 0xFF); + buf[7] = (byte)((len >> 16) & 0xFF); + buf[8] = (byte)((len >> 8) & 0xFF); + buf[9] = (byte)(len & 0xFF); + SendData(buf, 0, len + 14); + } - return len; + return len; + } } private void SendData(byte[] buf) { SendData(buf, 0, buf.Length); } @@ -722,17 +732,23 @@ namespace MeshCentralRouter public void Pause() { - readPaused = true; + lock (mainLock) + { + readPaused = true; + } } public void Resume() { - if (readPaused == false) return; - readPaused = false; - if (shouldRead == true) + lock (mainLock) { - shouldRead = false; - try { wsstream.BeginRead(readBuffer, readBufferLen, readBuffer.Length - readBufferLen, new AsyncCallback(OnTlsDataSink), this); } catch (Exception) { } + if (readPaused == false) return; + readPaused = false; + if (shouldRead == true) + { + shouldRead = false; + try { wsstream.BeginRead(readBuffer, readBufferLen, readBuffer.Length - readBufferLen, new AsyncCallback(OnTlsDataSink), this); } catch (Exception) { } + } } }