diff --git a/FileViewer.cs b/FileViewer.cs index b872d11..b6fd7fd 100644 --- a/FileViewer.cs +++ b/FileViewer.cs @@ -327,7 +327,8 @@ namespace MeshCentralRouter wc.onStateChanged += Wc_onStateChanged; wc.onBinaryData += Wc_onBinaryData; wc.onStringData += Wc_onStringData; - wc.Start(u, server.wshash); + wc.TLSCertCheck = webSocketClient.TLSCertificateCheck.Fingerprint; + wc.Start(u, server.wshash, null); } private void Wc_onStateChanged(webSocketClient sender, webSocketClient.ConnectionStates wsstate) @@ -1250,7 +1251,7 @@ namespace MeshCentralRouter localFilePath = (string)uploadFileArray[uploadFileArrayPtr]; localFileName = Path.GetFileName(localFilePath); } - try { uploadFileStream = File.OpenRead(localFilePath); } catch (Exception ex) + try { uploadFileStream = File.OpenRead(localFilePath); } catch (Exception) { // Display the error if (transferStatusForm != null) { transferStatusForm.addErrorMessage(String.Format(Translate.T(Properties.Resources.UnableToOpenFileX), localFileName)); } diff --git a/KVMViewer.cs b/KVMViewer.cs index 0233f7d..120ce9c 100644 --- a/KVMViewer.cs +++ b/KVMViewer.cs @@ -184,7 +184,8 @@ namespace MeshCentralRouter wc.onStateChanged += Wc_onStateChanged; wc.onBinaryData += Wc_onBinaryData; wc.onStringData += Wc_onStringData; - wc.Start(u, server.wshash); + wc.TLSCertCheck = webSocketClient.TLSCertificateCheck.Fingerprint; + wc.Start(u, server.wshash, null); } private void Wc_onStateChanged(webSocketClient sender, webSocketClient.ConnectionStates wsstate) diff --git a/MainForm.cs b/MainForm.cs index 4f710f2..172920d 100644 --- a/MainForm.cs +++ b/MainForm.cs @@ -170,11 +170,32 @@ namespace MeshCentralRouter [DllImport("user32.dll", CharSet = CharSet.Auto)] private static extern Int32 SendMessage(IntPtr hWnd, int msg, int wParam, [MarshalAs(UnmanagedType.LPWStr)]string lParam); + private bool RemoteCertificateValidationCallback(object sender, X509Certificate certificate, X509Chain chain, System.Net.Security.SslPolicyErrors sslPolicyErrors) + { + if (meshcentral.ignoreCert) return true; + if (meshcentral.connectionState < 2) + { + // Normal certificate check + if (chain.Build(new X509Certificate2(certificate)) == true) { meshcentral.certHash = webSocketClient.GetMeshKeyHash(certificate); return true; } + if ((meshcentral.okCertHash != null) && ((meshcentral.okCertHash == certificate.GetCertHashString()) || (meshcentral.okCertHash == webSocketClient.GetMeshKeyHash(certificate)) || (meshcentral.okCertHash == webSocketClient.GetMeshCertHash(certificate)))) { meshcentral.certHash = webSocketClient.GetMeshKeyHash(certificate); return true; } + if ((meshcentral.okCertHash2 != null) && ((meshcentral.okCertHash2 == certificate.GetCertHashString()) || (meshcentral.okCertHash2 == webSocketClient.GetMeshKeyHash(certificate)) || (meshcentral.okCertHash2 == webSocketClient.GetMeshCertHash(certificate)))) { meshcentral.certHash = webSocketClient.GetMeshKeyHash(certificate); return true; } + meshcentral.certHash = null; + meshcentral.disconnectMsg = "cert"; + meshcentral.disconnectCert = new X509Certificate2(certificate); + } + else + { + if ((meshcentral.certHash != null) && ((meshcentral.certHash == certificate.GetCertHashString()) || (meshcentral.certHash == webSocketClient.GetMeshKeyHash(certificate)) || (meshcentral.certHash == webSocketClient.GetMeshCertHash(certificate)))) { return true; } + } + return false; + } + public MainForm(string[] args) { // Set TLS 1.2 ServicePointManager.SecurityProtocol = SecurityProtocolType.Tls12; + ServicePointManager.ServerCertificateValidationCallback = new System.Net.Security.RemoteCertificateValidationCallback(RemoteCertificateValidationCallback); this.args = args; InitializeComponent(); @@ -215,6 +236,7 @@ namespace MeshCentralRouter if (arg.ToLower() == "-all") { inaddrany = true; } if (arg.ToLower() == "-inaddrany") { inaddrany = true; } if (arg.ToLower() == "-tray") { notifyIcon.Visible = true; this.ShowInTaskbar = false; this.MinimizeBox = false; } + if (arg.ToLower() == "-nonative") { webSocketClient.nativeWebSocketFirst = false; } if (arg.Length > 6 && arg.Substring(0, 6).ToLower() == "-host:") { serverNameComboBox.Text = arg.Substring(6); argflags |= 1; } if (arg.Length > 6 && arg.Substring(0, 6).ToLower() == "-user:") { userNameTextBox.Text = arg.Substring(6); argflags |= 2; } if (arg.Length > 6 && arg.Substring(0, 6).ToLower() == "-pass:") { passwordTextBox.Text = arg.Substring(6); argflags |= 4; } diff --git a/MeshCentralRouter.csproj b/MeshCentralRouter.csproj index 7ba5a07..f327c63 100644 --- a/MeshCentralRouter.csproj +++ b/MeshCentralRouter.csproj @@ -10,7 +10,7 @@ Properties MeshCentralRouter MeshCentralRouter - v4.5 + v4.7.2 512 MeshCentralRouter.Program MeshServer.ico diff --git a/MeshCentralServer.cs b/MeshCentralServer.cs index 3e40085..ecd131d 100644 --- a/MeshCentralServer.cs +++ b/MeshCentralServer.cs @@ -17,11 +17,7 @@ limitations under the License. using System; using System.IO; using System.Web; -using System.Text; using System.Collections; -using System.Net.Sockets; -using System.Net.Security; -using System.Windows.Forms; using System.Collections.Generic; using System.Security.Cryptography; using System.Deployment.Application; @@ -39,7 +35,7 @@ namespace MeshCentralRouter private string user = null; private string pass = null; private string token = null; - private xwebclient wc = null; + private webSocketClient wc = null; //private System.Timers.Timer procTimer = new System.Timers.Timer(5000); private int constate = 0; public Dictionary nodes = null; @@ -70,6 +66,8 @@ namespace MeshCentralRouter public int features = 0; // Bit flags of server features public int features2 = 0; // Bit flags of server features + public int connectionState { get { return constate; } } + // Mesh Rights /* const MESHRIGHT_EDITMESH = 1; @@ -146,13 +144,25 @@ namespace MeshCentralRouter this.token = token; this.wsurl = wsurl; - wc = new xwebclient(); + // Setup extra headers if needed + Dictionary extraHeaders = new Dictionary(); + if (user != null && pass != null && token != null) { + extraHeaders.Add("x-meshauth", Base64Encode(user) + "," + Base64Encode(pass) + "," + Base64Encode(token)); + } else if (user != null && pass != null) { + extraHeaders.Add("x-meshauth", Base64Encode(user) + "," + Base64Encode(pass)); + } + + wc = new webSocketClient(); + wc.extraHeaders = extraHeaders; + wc.onStateChanged += new webSocketClient.onStateChangedHandler(changeStateEx); + wc.onStringData += new webSocketClient.onStringDataHandler(processServerData); //Debug("#" + counter + ": Connecting web socket to: " + wsurl.ToString()); - wc.Start(this, wsurl, user, pass, token, wshash); + wc.TLSCertCheck = webSocketClient.TLSCertificateCheck.Verify; + wc.Start(wsurl, okCertHash, okCertHash2); if (debug || tlsdump) { try { File.AppendAllText("debug.log", "Connect-" + wsurl + "\r\n"); } catch (Exception) { } } - wc.xdebug = debug; - wc.xtlsdump = tlsdump; - wc.xignoreCert = ignoreCert; + wc.debug = debug; + wc.tlsdump = tlsdump; + wc.TLSCertCheck = (ignoreCert) ? webSocketClient.TLSCertificateCheck.Ignore : webSocketClient.TLSCertificateCheck.Verify; } public void disconnect() @@ -170,7 +180,7 @@ namespace MeshCentralRouter if (wc != null) { if (debug) { try { File.AppendAllText("debug.log", "sendCommand: " + cmd + "\r\n"); } catch (Exception) { } } - wc.WriteStringWebSocket(cmd); + wc.SendString(cmd); } } @@ -178,8 +188,8 @@ namespace MeshCentralRouter { if (wc != null) { if (debug) { try { File.AppendAllText("debug.log", "RefreshCookies\r\n"); } catch (Exception) { } } - wc.WriteStringWebSocket("{\"action\":\"authcookie\"}"); - wc.WriteStringWebSocket("{\"action\":\"logincookie\"}"); + wc.SendString("{\"action\":\"authcookie\"}"); + wc.SendString("{\"action\":\"logincookie\"}"); } } @@ -188,11 +198,11 @@ namespace MeshCentralRouter if (wc != null) { if (debug) { try { File.AppendAllText("debug.log", "SetRdpPort\r\n"); } catch (Exception) { } } - wc.WriteStringWebSocket("{\"action\":\"changedevice\",\"nodeid\":\"" + node.nodeid + "\",\"rdpport\":" + port + "}"); + wc.SendString("{\"action\":\"changedevice\",\"nodeid\":\"" + node.nodeid + "\",\"rdpport\":" + port + "}"); } } - public void processServerData(string data) + public void processServerData(webSocketClient sender, string data, int orglen) { if (debug) { try { File.AppendAllText("debug.log", "ServerData-" + data + "\r\n"); } catch (Exception) { } } @@ -223,7 +233,7 @@ namespace MeshCentralRouter case "ping": { // Send pong back - if (wc != null) { wc.WriteStringWebSocket("{\"action\":\"pong\"}"); } + if (wc != null) { wc.SendString("{\"action\":\"pong\"}"); } break; } case "close": @@ -245,12 +255,12 @@ namespace MeshCentralRouter if (serverinfo.ContainsKey("features2") && (serverinfo["features2"].GetType() == typeof(int))) { features2 = (int)serverinfo["features2"]; } // Ask for a lot of things from the server - wc.WriteStringWebSocket("{\"action\":\"usergroups\"}"); - wc.WriteStringWebSocket("{\"action\":\"meshes\"}"); - wc.WriteStringWebSocket("{\"action\":\"nodes\"}"); - wc.WriteStringWebSocket("{\"action\":\"authcookie\"}"); - wc.WriteStringWebSocket("{\"action\":\"logincookie\"}"); - wc.WriteStringWebSocket("{\"action\":\"meshToolInfo\",\"name\":\"MeshCentralRouter\"}"); + wc.SendString("{\"action\":\"usergroups\"}"); + wc.SendString("{\"action\":\"meshes\"}"); + wc.SendString("{\"action\":\"nodes\"}"); + wc.SendString("{\"action\":\"authcookie\"}"); + wc.SendString("{\"action\":\"logincookie\"}"); + wc.SendString("{\"action\":\"meshToolInfo\",\"name\":\"MeshCentralRouter\"}"); break; } case "authcookie": @@ -258,6 +268,12 @@ namespace MeshCentralRouter authCookie = jsonAction["cookie"].ToString(); rauthCookie = jsonAction["rcookie"].ToString(); changeState(2); + + if (sender.RemoteCertificate != null) + { + certHash = webSocketClient.GetMeshCertHash(new X509Certificate2(sender.RemoteCertificate)); + } + break; } case "logincookie": @@ -362,7 +378,7 @@ namespace MeshCentralRouter mesh.links = newlinks; meshes[meshid] = mesh; } - wc.WriteStringWebSocket("{\"action\":\"nodes\"}"); + wc.SendString("{\"action\":\"nodes\"}"); if ((onNodesChanged != null) && (nodes != null)) onNodesChanged(false); break; } @@ -427,7 +443,7 @@ namespace MeshCentralRouter } else { - wc.WriteStringWebSocket("{\"action\":\"nodes\"}"); + wc.SendString("{\"action\":\"nodes\"}"); } break; } @@ -449,8 +465,8 @@ namespace MeshCentralRouter } case "usergroupchange": { - wc.WriteStringWebSocket("{\"action\":\"usergroups\"}"); - wc.WriteStringWebSocket("{\"action\":\"nodes\"}"); + wc.SendString("{\"action\":\"usergroups\"}"); + wc.SendString("{\"action\":\"nodes\"}"); break; } } @@ -697,6 +713,16 @@ namespace MeshCentralRouter public event onStateChangedHandler onStateChanged; public void changeState(int newState) { if (constate != newState) { constate = newState; if (onStateChanged != null) { onStateChanged(constate); } } } + private void changeStateEx(webSocketClient sender, webSocketClient.ConnectionStates newState) + { + if (newState == webSocketClient.ConnectionStates.Disconnected) { + if (sender.failedTlsCert != null) { certHash = null; disconnectMsg = "cert"; disconnectCert = sender.failedTlsCert; } + changeState(0); + } + if (newState == webSocketClient.ConnectionStates.Connecting) { changeState(1); } + if (newState == webSocketClient.ConnectionStates.Connected) { } + } + public delegate void onNodeListChangedHandler(bool fullRefresh); public event onNodeListChangedHandler onNodesChanged; public delegate void onLoginTokenChangedHandler(); @@ -708,431 +734,16 @@ namespace MeshCentralRouter public delegate void toolUpdateHandler(string url, string hash, int size, string serverhash); public event toolUpdateHandler onToolUpdate; - public class xwebclient : IDisposable + public string Base64Encode(string plainText) { - private MeshCentralServer parent = null; - private TcpClient wsclient = null; - private SslStream wsstream = null; - private NetworkStream wsrawstream = null; - private int state = 0; - private Uri url = null; - private byte[] readBuffer = new Byte[500]; - private int readBufferLen = 0; - private int accopcodes = 0; - private bool accmask = false; - private int acclen = 0; - private bool proxyInUse = false; - private string user = null; - private string pass = null; - private string token = null; - public bool xdebug = false; - public bool xtlsdump = false; - public bool xignoreCert = false; - - public void Dispose() { - try { wsstream.Close(); } catch (Exception) { } - try { wsstream.Dispose(); } catch (Exception) { } - wsstream = null; - wsclient = null; - state = -1; - parent.changeState(0); - parent.wshash = null; - } - - public void Debug(string msg) { if (xdebug) { try { File.AppendAllText("debug.log", "Debug-" + msg + "\r\n"); } catch (Exception) { } } } - public void TlsDump(string direction, byte[] data, int offset, int len) { if (xtlsdump) { try { File.AppendAllText("debug.log", direction + ": " + BitConverter.ToString(data, offset, len).Replace("-", string.Empty) + "\r\n"); } catch (Exception) { } } } - - public bool Start(MeshCentralServer parent, Uri url, string user, string pass, string token, string fingerprint) - { - if (state != 0) return false; - parent.changeState(1); - state = 1; - this.parent = parent; - this.url = url; - this.user = user; - this.pass = pass; - this.token = token; - Uri proxyUri = null; - - // Check if we need to use a HTTP proxy (Auto-proxy way) - try { - RegistryKey registryKey = Registry.CurrentUser.OpenSubKey("Software\\Microsoft\\Windows\\CurrentVersion\\Internet Settings", true); - Object x = registryKey.GetValue("AutoConfigURL", null); - if ((x != null) && (x.GetType() == typeof(string))) { - string proxyStr = GetProxyForUrlUsingPac("http" + ((url.Port == 80) ? "" : "s") + "://" + url.Host + ":" + url.Port, x.ToString()); - if (proxyStr != null) { proxyUri = new Uri("http://" + proxyStr); } - } - } catch (Exception) { proxyUri = null; } - - // Check if we need to use a HTTP proxy (Normal way) - if (proxyUri == null) { - var proxy = System.Net.HttpWebRequest.GetSystemWebProxy(); - proxyUri = proxy.GetProxy(url); - if ((url.Host.ToLower() == proxyUri.Host.ToLower()) && (url.Port == proxyUri.Port)) { proxyUri = null; } - } - - if (proxyUri != null) - { - // Proxy in use - proxyInUse = true; - wsclient = new TcpClient(); - Debug("Connecting with proxy in use: " + proxyUri.ToString()); - wsclient.BeginConnect(proxyUri.Host, proxyUri.Port, new AsyncCallback(OnConnectSink), this); - } - else - { - // No proxy in use - proxyInUse = false; - wsclient = new TcpClient(); - Debug("Connecting without proxy"); - wsclient.BeginConnect(url.Host, url.Port, new AsyncCallback(OnConnectSink), this); - } - return true; - } - - private void OnConnectSink(IAsyncResult ar) - { - if (wsclient == null) return; - - // Accept the connection - try - { - wsclient.EndConnect(ar); - } catch (Exception ex) { - Debug("Websocket TCP failed to connect: " + ex.ToString()); - Dispose(); - return; - } - - if (proxyInUse == true) - { - // Send proxy connection request - wsrawstream = wsclient.GetStream(); - byte[] proxyRequestBuf = UTF8Encoding.UTF8.GetBytes("CONNECT " + url.Host + ":" + url.Port + " HTTP/1.1\r\nHost: " + url.Host + ":" + url.Port + "\r\n\r\n"); - TlsDump("OutRaw", proxyRequestBuf, 0, proxyRequestBuf.Length); - try { wsrawstream.Write(proxyRequestBuf, 0, proxyRequestBuf.Length); } catch (Exception ex) { Debug(ex.ToString()); } - wsrawstream.BeginRead(readBuffer, readBufferLen, readBuffer.Length - readBufferLen, new AsyncCallback(OnProxyResponseSink), this); - } - else - { - // Start TLS connection - Debug("Websocket TCP connected, doing TLS..."); - wsstream = new SslStream(wsclient.GetStream(), false, VerifyServerCertificate, null); - wsstream.BeginAuthenticateAsClient(url.Host, null, System.Security.Authentication.SslProtocols.Tls12, false, new AsyncCallback(OnTlsSetupSink), this); - } - } - - private void OnProxyResponseSink(IAsyncResult ar) - { - if (wsrawstream == null) return; - - int len = 0; - try { len = wsrawstream.EndRead(ar); } catch (Exception) { } - if (len == 0) - { - // Disconnect - Debug("Websocket proxy disconnected, length = 0."); - Dispose(); - return; - } - - TlsDump("InRaw", readBuffer, 0, readBufferLen); - - readBufferLen += len; - string proxyResponse = UTF8Encoding.UTF8.GetString(readBuffer, 0, readBufferLen); - if (proxyResponse.IndexOf("\r\n\r\n") >= 0) - { - // We get a full proxy response, we should get something like "HTTP/1.1 200 Connection established\r\n\r\n" - if (proxyResponse.StartsWith("HTTP/1.1 200 ")) - { - // All good, start TLS setup. - readBufferLen = 0; - Debug("Websocket TCP connected, doing TLS..."); - wsstream = new SslStream(wsrawstream, false, VerifyServerCertificate, null); - wsstream.BeginAuthenticateAsClient(url.Host, null, System.Security.Authentication.SslProtocols.Tls12, false, new AsyncCallback(OnTlsSetupSink), this); - } - else - { - // Invalid response - Debug("Proxy connection failed: " + proxyResponse); - Dispose(); - } - } else { - if (readBufferLen == readBuffer.Length) - { - // Buffer overflow - Debug("Proxy connection failed"); - Dispose(); - } - else - { - // Read more proxy data - wsrawstream.BeginRead(readBuffer, readBufferLen, readBuffer.Length - readBufferLen, new AsyncCallback(OnProxyResponseSink), this); - } - } - } - - public string Base64Encode(string plainText) - { - var plainTextBytes = System.Text.Encoding.UTF8.GetBytes(plainText); - return System.Convert.ToBase64String(plainTextBytes); - } - - public string Base64Decode(string base64EncodedData) - { - var base64EncodedBytes = System.Convert.FromBase64String(base64EncodedData); - return System.Text.Encoding.UTF8.GetString(base64EncodedBytes); - } - - private void OnTlsSetupSink(IAsyncResult ar) - { - if (wsstream == null) return; - - // Accept the connection - try - { - wsstream.EndAuthenticateAsClient(ar); - } - catch (Exception ex) - { - // Disconnect - if (ex.InnerException != null) { - MessageBox.Show(ex.Message + ", Inner: " + ex.InnerException.ToString(), "MeshCentral Router"); - } else { - MessageBox.Show(ex.Message, "MeshCentral Router"); - } - Debug("Websocket TLS failed: " + ex.ToString()); - Dispose(); - return; - } - - // Fetch remote certificate - parent.wshash = wsstream.RemoteCertificate.GetCertHashString(); - - // Setup extra headers if needed - string extraHeaders = ""; - if (user != null && pass != null && token != null) { extraHeaders = "x-meshauth: " + Base64Encode(user) + "," + Base64Encode(pass) + "," + Base64Encode(token) + "\r\n"; } - else if (user != null && pass != null) { extraHeaders = "x-meshauth: " + Base64Encode(user) + "," + Base64Encode(pass) + "\r\n"; } - - // Send the HTTP headers - Debug("Websocket TLS setup, sending HTTP header..."); - string header = "GET " + url.PathAndQuery + " HTTP/1.1\r\nHost: " + url.Host + "\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n" + extraHeaders + "\r\n"; - try { wsstream.Write(UTF8Encoding.UTF8.GetBytes(header)); } catch (Exception ex) { Debug(ex.ToString()); } - - // Start receiving data - wsstream.BeginRead(readBuffer, readBufferLen, readBuffer.Length - readBufferLen, new AsyncCallback(OnTlsDataSink), this); - } - - private void OnTlsDataSink(IAsyncResult ar) - { - if (wsstream == null) return; - - int len = 0; - try { len = wsstream.EndRead(ar); } catch (Exception) { } - if (len == 0) - { - // Disconnect - Debug("Websocket disconnected, length = 0."); - Dispose(); - return; - } - //parent.Debug("#" + counter + ": Websocket got new data: " + len); - readBufferLen += len; - TlsDump("In", readBuffer, 0, len); - - // Consume all of the data - int consumed = 0; - int ptr = 0; - do - { - consumed = ProcessBuffer(readBuffer, ptr, readBufferLen - ptr); - if (consumed < 0) { Dispose(); return; } // Error, close the connection - ptr += consumed; - } while ((consumed > 0) && ((readBufferLen - consumed) > 0)); - - // Move the data forward - if ((ptr > 0) && (readBufferLen - ptr) > 0) { - //Console.Write("MOVE FORWARD\r\n"); - Array.Copy(readBuffer, ptr, readBuffer, 0, (readBufferLen - ptr)); - } - readBufferLen = (readBufferLen - ptr); - - // If the buffer is too small, double the size here. - if (readBuffer.Length - readBufferLen == 0) - { - Debug("Increasing the read buffer size from " + readBuffer.Length + " to " + (readBuffer.Length * 2) + "."); - byte[] readBuffer2 = new byte[readBuffer.Length * 2]; - Array.Copy(readBuffer, 0, readBuffer2, 0, readBuffer.Length); - readBuffer = readBuffer2; - } - - // Receive more data - try { wsstream.BeginRead(readBuffer, readBufferLen, readBuffer.Length - readBufferLen, new AsyncCallback(OnTlsDataSink), this); } catch (Exception) { } - } - - private int ProcessBuffer(byte[] buffer, int offset, int len) - { - string ss = UTF8Encoding.UTF8.GetString(buffer, offset, len); - - if (state == 1) - { - // Look for the end of the http header - string header = UTF8Encoding.UTF8.GetString(buffer, offset, len); - int i = header.IndexOf("\r\n\r\n"); - if (i == -1) return 0; - Dictionary parsedHeader = ParseHttpHeader(header.Substring(0, i)); - if ((parsedHeader == null) || (parsedHeader["_Path"] != "101")) { Debug("Websocket bad header."); return -1; } // Bad header, close the connection - Debug("Websocket got setup upgrade header."); - state = 2; - return len; // TODO: Technically we need to return the header length before UTF8 convert. - } else if (state == 2) { - // Parse a websocket fragment header - if (len < 2) return 0; - int headsize = 2; - accopcodes = buffer[offset]; - accmask = ((buffer[offset + 1] & 0x80) != 0); - acclen = (buffer[offset + 1] & 0x7F); - - if ((accopcodes & 0x0F) == 8) - { - // Close the websocket - Debug("Websocket got closed fragment."); - return -1; - } - - if (acclen == 126) - { - if (len < 4) return 0; - headsize = 4; - acclen = (buffer[offset + 2] << 8) + (buffer[offset + 3]); - } - else if (acclen == 127) - { - if (len < 10) return 0; - headsize = 10; - acclen = (buffer[offset + 6] << 24) + (buffer[offset + 7] << 16) + (buffer[offset + 8] << 8) + (buffer[offset + 9]); - Debug("Websocket receive large fragment: " + acclen); - } - if (accmask == true) - { - // TODO: Do unmasking here. - headsize += 4; - } - //parent.Debug("#" + counter + ": Websocket frag header - FIN: " + ((accopcodes & 0x80) != 0) + ", OP: " + (accopcodes & 0x0F) + ", LEN: " + acclen + ", MASK: " + accmask); - state = 3; - return headsize; - } - else if (state == 3) - { - // Parse a websocket fragment data - if (len < acclen) return 0; - //Console.Write("WSREAD: " + acclen + "\r\n"); - ProcessWsBuffer(buffer, offset, acclen, accopcodes); - state = 2; - return acclen; - } - return 0; - } - - private void ProcessWsBuffer(byte[] data, int offset, int len, int op) - { - Debug("Websocket got data."); - //try { parent.processServerData(UTF8Encoding.UTF8.GetString(data, offset, len)); } catch (Exception ex) { } - parent.processServerData(UTF8Encoding.UTF8.GetString(data, offset, len)); - } - - private Dictionary ParseHttpHeader(string header) - { - string[] lines = header.Replace("\r\n", "\r").Split('\r'); - if (lines.Length < 2) { return null; } - string[] directive = lines[0].Split(' '); - Dictionary values = new Dictionary(); - values["_Action"] = directive[0]; - values["_Path"] = directive[1]; - values["_Protocol"] = directive[2]; - for (int i = 1; i < lines.Length; i++) - { - var j = lines[i].IndexOf(":"); - values[lines[i].Substring(0, j).ToLower()] = lines[i].Substring(j + 1).Trim(); - } - return values; - } - - // Return a modified base64 SHA384 hash string of the certificate public key - public static string GetMeshKeyHash(X509Certificate cert) - { - return ByteArrayToHexString(new SHA384Managed().ComputeHash(cert.GetPublicKey())); - } - - // Return a modified base64 SHA384 hash string of the certificate - public static string GetMeshCertHash(X509Certificate cert) - { - return ByteArrayToHexString(new SHA384Managed().ComputeHash(cert.GetRawCertData())); - } - - public static string ByteArrayToHexString(byte[] Bytes) - { - StringBuilder Result = new StringBuilder(Bytes.Length * 2); - string HexAlphabet = "0123456789ABCDEF"; - foreach (byte B in Bytes) { Result.Append(HexAlphabet[(int)(B >> 4)]); Result.Append(HexAlphabet[(int)(B & 0xF)]); } - return Result.ToString(); - } - - private bool VerifyServerCertificate(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) - { - parent.certHash = GetMeshKeyHash(certificate); - Debug("Verify cert: " + parent.certHash); - if (xignoreCert) return true; - if (chain.Build(new X509Certificate2(certificate)) == true) return true; - - // Check that the remote certificate is the expected one - if ((parent.okCertHash != null) && (parent.okCertHash == certificate.GetCertHashString())) return true; - - // Check that the remote certificate is the expected one - if ((parent.okCertHash2 != null) && ((parent.okCertHash2 == GetMeshKeyHash(certificate)) || (parent.okCertHash2 == GetMeshCertHash(certificate)))) { return true; } - - parent.certHash = null; - parent.disconnectMsg = "cert"; - parent.disconnectCert = new X509Certificate2(certificate); - return false; - } - - public void WriteStringWebSocket(string data) - { - // Convert the string into a buffer with 4 byte of header space. - int len = UTF8Encoding.UTF8.GetByteCount(data); - byte[] buf = new byte[4 + len]; - UTF8Encoding.UTF8.GetBytes(data, 0, data.Length, buf, 4); - len = buf.Length - 4; - - // Check that everything is ok - if ((state < 2) || (len < 1) || (len > 65535)) { Dispose(); return; } - - //Console.Write("Length: " + len + "\r\n"); - //System.Threading.Thread.Sleep(0); - - if (len < 126) - { - // Small fragment - buf[2] = 130; // Fragment op code (129 = text, 130 = binary) - buf[3] = (byte)(len & 0x7F); - //try { wsstream.BeginWrite(buf, 2, len + 2, new AsyncCallback(WriteWebSocketAsyncDone), args); } catch (Exception) { Dispose(); return; } - TlsDump("Out", buf, 2, len + 2); - try { wsstream.Write(buf, 2, len + 2); } catch (Exception ex) { Debug(ex.ToString()); } - } - else - { - // Large fragment - buf[0] = 130; // Fragment op code (129 = text, 130 = binary) - buf[1] = 126; - buf[2] = (byte)((len >> 8) & 0xFF); - buf[3] = (byte)(len & 0xFF); - //try { wsstream.BeginWrite(buf, 0, len + 4, new AsyncCallback(WriteWebSocketAsyncDone), args); } catch (Exception) { Dispose(); return; } - TlsDump("Out", buf, 0, len + 4); - try { wsstream.Write(buf, 0, len + 4); } catch (Exception ex) { Debug(ex.ToString()); } - } - } + var plainTextBytes = System.Text.Encoding.UTF8.GetBytes(plainText); + return System.Convert.ToBase64String(plainTextBytes); + } + public string Base64Decode(string base64EncodedData) + { + var base64EncodedBytes = System.Convert.FromBase64String(base64EncodedData); + return System.Text.Encoding.UTF8.GetString(base64EncodedBytes); } } diff --git a/MeshMapper.cs b/MeshMapper.cs index 2a9476f..a4bf003 100644 --- a/MeshMapper.cs +++ b/MeshMapper.cs @@ -216,7 +216,6 @@ namespace MeshCentralRouter Uri wsurl = new Uri(url + "&auth=" + Uri.EscapeDataString(parent.authCookie)); Debug("#" + counter + ": Connecting web socket to: " + wsurl.ToString()); wc.debug = xdebug; - wc.Start(wsurl, certhash); wc.tag = client; wc.id = counter; wc.tunneling = false; @@ -224,6 +223,8 @@ namespace MeshCentralRouter wc.onBinaryData += Wc_onBinaryData; wc.onStringData += Wc_onStringData; wc.onSendOk += Wc_onSendOk; + wc.TLSCertCheck = webSocketClient.TLSCertificateCheck.Fingerprint; + wc.Start(wsurl, certhash, null); } private void ConnectWS(UdpClient client, int counter) @@ -232,7 +233,6 @@ namespace MeshCentralRouter Uri wsurl = new Uri(url + "&auth=" + Uri.EscapeDataString(parent.authCookie)); Debug("#" + counter + ": Connecting web socket to: " + wsurl.ToString()); wc.debug = xdebug; - wc.Start(wsurl, certhash); wc.tag = client; wc.id = counter; wc.tunneling = false; @@ -240,6 +240,8 @@ namespace MeshCentralRouter wc.onBinaryData += Wc_onBinaryData; wc.onStringData += Wc_onStringData; wc.onSendOk += Wc_onSendOk; + wc.TLSCertCheck = webSocketClient.TLSCertificateCheck.Fingerprint; + wc.Start(wsurl, certhash, null); } private void Wc_onSendOk(webSocketClient sender) diff --git a/WebSocketClient.cs b/WebSocketClient.cs index b504478..c3d8534 100644 --- a/WebSocketClient.cs +++ b/WebSocketClient.cs @@ -18,9 +18,12 @@ using System; using System.IO; using System.Net; using System.Text; +using System.Threading; using System.Net.Sockets; using System.Net.Security; +using System.Net.WebSockets; using System.IO.Compression; +using System.Threading.Tasks; using System.Collections.Generic; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; @@ -30,6 +33,9 @@ namespace MeshCentralRouter { public class webSocketClient : IDisposable { + private ClientWebSocket ws = null; // Native Windows WebSocket + private CancellationTokenSource CTS; + public bool AllowCompression = true; private TcpClient wsclient = null; private SslStream wsstream = null; @@ -44,10 +50,11 @@ namespace MeshCentralRouter private int acclen = 0; private bool proxyInUse = false; private string tlsCertFingerprint = null; + private string tlsCertFingerprint2 = null; //private ConnectionErrors lastError = ConnectionErrors.NoError; public bool debug = false; - public bool xignoreCert = false; - public string extraHeaders = null; + public bool tlsdump = false; + public Dictionary extraHeaders = null; private MemoryStream inflateMemory; private DeflateStream inflate; private MemoryStream deflateMemory; @@ -64,6 +71,10 @@ namespace MeshCentralRouter private bool shouldRead = false; private RNGCryptoServiceProvider CryptoRandom = new RNGCryptoServiceProvider(); private object mainLock = new object(); + public TLSCertificateCheck TLSCertCheck = TLSCertificateCheck.Verify; + public X509Certificate2 failedTlsCert = null; + static public bool nativeWebSocketFirst = true; + // Outside variables public object tag = null; @@ -79,11 +90,20 @@ namespace MeshCentralRouter Connected = 2 } + public enum TLSCertificateCheck + { + Ignore = 0, + Fingerprint = 1, + Verify = 2 + } + public enum ConnectionErrors { NoError = 0 } + private void TlsDump(string direction, byte[] data, int offset, int len) { if (tlsdump) { try { File.AppendAllText("debug.log", direction + ": " + BitConverter.ToString(data, offset, len).Replace("-", string.Empty) + "\r\n"); } catch (Exception) { } } } + public delegate void onBinaryDataHandler(webSocketClient sender, byte[] data, int offset, int length, int orglen); public event onBinaryDataHandler onBinaryData; public delegate void onStringDataHandler(webSocketClient sender, string data, int orglen); @@ -97,12 +117,7 @@ namespace MeshCentralRouter public ConnectionStates State { get { return state; } } - public X509Certificate RemoteCertificate { - get - { - try { return wsstream.RemoteCertificate; } catch (Exception) { return null; } - } - } + public X509Certificate RemoteCertificate { get { try { return wsstream.RemoteCertificate; } catch (Exception) { return null; } } } private void SetState(ConnectionStates newstate) { @@ -113,6 +128,17 @@ namespace MeshCentralRouter public void Dispose() { + if (ws != null) + { + if (ws.State == WebSocketState.Open) + { + CTS.CancelAfter(TimeSpan.FromSeconds(2)); + ws.CloseOutputAsync(WebSocketCloseStatus.Empty, "", CancellationToken.None); + ws.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); + } + try { if (ws != null) { ws.Dispose(); ws = null; } } catch (Exception) { } + try { if (CTS != null) { CTS.Dispose(); CTS = null; } } catch (Exception) { } + } if (pingTimer != null) { pingTimer.Dispose(); pingTimer = null; } if (pongTimer != null) { pongTimer.Dispose(); pongTimer = null; } if (wsstream != null) { try { wsstream.Close(); } catch (Exception) { } try { wsstream.Dispose(); } catch (Exception) { } wsstream = null; } @@ -128,56 +154,91 @@ namespace MeshCentralRouter if (debug) { try { File.AppendAllText("debug.log", DateTime.Now.ToString("HH:mm:tt.ffff") + ": WebSocket: " + msg + "\r\n"); } catch (Exception) { } } } - public bool Start(Uri url, string tlsCertFingerprint) + private async Task ConnectAsync(Uri url) + { + if (CTS != null) CTS.Dispose(); + CTS = new CancellationTokenSource(); + try { await ws.ConnectAsync(url, CTS.Token); } catch (Exception) { SetState(0); return; } + await Task.Factory.StartNew(ReceiveLoop, CTS.Token, TaskCreationOptions.LongRunning, TaskScheduler.Default); + } + + public async Task DisconnectAsync() + { + if (ws == null) return; + if (ws.State == WebSocketState.Open) + { + CTS.CancelAfter(TimeSpan.FromSeconds(2)); + await ws.CloseOutputAsync(WebSocketCloseStatus.Empty, "", CancellationToken.None); + await ws.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); + } + ws.Dispose(); + ws = null; + CTS.Dispose(); + CTS = null; + } + + public bool Start(Uri url, string tlsCertFingerprint, string tlsCertFingerprint2) { if (state != ConnectionStates.Disconnected) return false; SetState(ConnectionStates.Connecting); this.url = url; if (tlsCertFingerprint != null) { this.tlsCertFingerprint = tlsCertFingerprint.ToUpper(); } - Uri proxyUri = null; + if (tlsCertFingerprint2 != null) { this.tlsCertFingerprint2 = tlsCertFingerprint2.ToUpper(); } - Log("Websocket Start, URL=" + ((url == null) ? "(NULL)" : url.ToString())); - - // Check if we need to use a HTTP proxy (Auto-proxy way) - try + if (nativeWebSocketFirst) { try { ws = new ClientWebSocket(); } catch (Exception) { } } + if (ws != null) { - RegistryKey registryKey = Registry.CurrentUser.OpenSubKey("Software\\Microsoft\\Windows\\CurrentVersion\\Internet Settings", true); - Object x = registryKey.GetValue("AutoConfigURL", null); - if ((x != null) && (x.GetType() == typeof(string))) - { - string proxyStr = GetProxyForUrlUsingPac("http" + ((url.Port == 80) ? "" : "s") + "://" + url.Host + ":" + url.Port, x.ToString()); - if (proxyStr != null) { proxyUri = new Uri("http://" + proxyStr); } - } - } - catch (Exception) { proxyUri = null; } - - // Check if we need to use a HTTP proxy (Normal way) - if (proxyUri == null) - { - var proxy = System.Net.HttpWebRequest.GetSystemWebProxy(); - proxyUri = proxy.GetProxy(url); - if ((url.Host.ToLower() == proxyUri.Host.ToLower()) && (url.Port == proxyUri.Port)) { proxyUri = null; } - } - - if (proxyUri != null) - { - // Proxy in use - Log("Websocket proxyUri: " + proxyUri.ToString()); - proxyInUse = true; - wsclient = new TcpClient(); - wsclient.BeginConnect(proxyUri.Host, proxyUri.Port, new AsyncCallback(OnConnectSink), this); + // Use Windows native websockets + Log("Websocket (native) Start, URL=" + ((url == null) ? "(NULL)" : url.ToString())); + if (extraHeaders != null) { foreach (var key in extraHeaders.Keys) { ws.Options.SetRequestHeader(key, extraHeaders[key]); } } + Task t = ConnectAsync(url); } else { - // No proxy in use - Log("Websocket noProxy"); - proxyInUse = false; - wsclient = new TcpClient(); - string h = url.Host; - if (h.StartsWith("[") && h.EndsWith("]")) { h = h.Substring(1, h.Length - 2); } - wsclient.BeginConnect(h, url.Port, new AsyncCallback(OnConnectSink), this); - } + // Use C# coded websockets + Uri proxyUri = null; + Log("Websocket Start, URL=" + ((url == null) ? "(NULL)" : url.ToString())); + // Check if we need to use a HTTP proxy (Auto-proxy way) + try + { + RegistryKey registryKey = Registry.CurrentUser.OpenSubKey("Software\\Microsoft\\Windows\\CurrentVersion\\Internet Settings", true); + Object x = registryKey.GetValue("AutoConfigURL", null); + if ((x != null) && (x.GetType() == typeof(string))) + { + string proxyStr = GetProxyForUrlUsingPac("http" + ((url.Port == 80) ? "" : "s") + "://" + url.Host + ":" + url.Port, x.ToString()); + if (proxyStr != null) { proxyUri = new Uri("http://" + proxyStr); } + } + } + catch (Exception) { proxyUri = null; } + + // Check if we need to use a HTTP proxy (Normal way) + if (proxyUri == null) + { + var proxy = System.Net.HttpWebRequest.GetSystemWebProxy(); + proxyUri = proxy.GetProxy(url); + if ((url.Host.ToLower() == proxyUri.Host.ToLower()) && (url.Port == proxyUri.Port)) { proxyUri = null; } + } + + if (proxyUri != null) + { + // Proxy in use + Log("Websocket proxyUri: " + proxyUri.ToString()); + proxyInUse = true; + wsclient = new TcpClient(); + wsclient.BeginConnect(proxyUri.Host, proxyUri.Port, new AsyncCallback(OnConnectSink), this); + } + else + { + // No proxy in use + Log("Websocket noProxy"); + proxyInUse = false; + wsclient = new TcpClient(); + string h = url.Host; + if (h.StartsWith("[") && h.EndsWith("]")) { h = h.Substring(1, h.Length - 2); } + wsclient.BeginConnect(h, url.Port, new AsyncCallback(OnConnectSink), this); + } + } return true; } @@ -296,13 +357,20 @@ namespace MeshCentralRouter pendingSendBuffer = new MemoryStream(); pendingSendCall = false; + // Build extra headers + string extraHeadersStr = ""; + if (extraHeaders != null) + { + foreach (string key in extraHeaders.Keys) { extraHeadersStr += key + ": " + extraHeaders[key] + "\r\n"; } + } + // Send the HTTP headers Log("Websocket TLS setup, sending HTTP header..."); string header; if (AllowCompression) { - header = "GET " + url.PathAndQuery + " HTTP/1.1\r\nHost: " + url.Host + "\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover\r\n" + extraHeaders + "\r\n"; + header = "GET " + url.PathAndQuery + " HTTP/1.1\r\nHost: " + url.Host + "\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover\r\n" + extraHeadersStr + "\r\n"; } else { - header = "GET " + url.PathAndQuery + " HTTP/1.1\r\nHost: " + url.Host + "\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n" + extraHeaders + "\r\n"; + header = "GET " + url.PathAndQuery + " HTTP/1.1\r\nHost: " + url.Host + "\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n" + extraHeadersStr + "\r\n"; } SendData(UTF8Encoding.UTF8.GetBytes(header)); @@ -397,6 +465,7 @@ namespace MeshCentralRouter private int ProcessBuffer(byte[] buffer, int offset, int len) { + TlsDump("InRaw", buffer, offset, len); string ss = UTF8Encoding.UTF8.GetString(buffer, offset, len); if (state == ConnectionStates.Connecting) @@ -503,12 +572,14 @@ namespace MeshCentralRouter case 0x01: // This is a text frame { Log("Websocket got string data, len = " + len); + TlsDump("InStr", data, offset, len); if (onStringData != null) { onStringData(this, UTF8Encoding.UTF8.GetString(data, offset, len), orglen); } break; } case 0x02: // This is a birnay frame { Log("Websocket got binary data, len = " + len); + TlsDump("InBin", data, offset, len); if (onBinaryData != null) { onBinaryData(this, data, offset, len, orglen); } break; } @@ -564,20 +635,42 @@ namespace MeshCentralRouter private bool VerifyServerCertificate(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) { - if (tlsCertFingerprint == null) return true; - if ((tlsCertFingerprint.Length == 32) && (certificate.GetCertHashString().Equals(tlsCertFingerprint))) { return true; } - if (tlsCertFingerprint.Length == 96) - { - if (GetMeshCertHash(certificate).Equals(tlsCertFingerprint)) { return true; } - if (GetMeshKeyHash(certificate).Equals(tlsCertFingerprint)) { return true; } - } - string hash1 = GetMeshCertHash(certificate); string hash2 = certificate.GetCertHashString(); - Log("VerifyServerCertificate: tlsCertFingerprint = " + tlsCertFingerprint); - Log("VerifyServerCertificate: Hash1 = " + hash1); - Log("VerifyServerCertificate: Hash2 = " + hash2); - return ((tlsCertFingerprint == GetMeshKeyHash(certificate)) || (tlsCertFingerprint == certificate.GetCertHashString())); + //Debug("Verify cert: " + hash1); + + if (TLSCertCheck == TLSCertificateCheck.Ignore) + { + // Ignore certificate check + return true; + } + else if (TLSCertCheck == TLSCertificateCheck.Fingerprint) + { + // Fingerprint certificate check + if (tlsCertFingerprint == null) return true; + if ((tlsCertFingerprint.Length == 32) && (certificate.GetCertHashString().Equals(tlsCertFingerprint))) { return true; } + if (tlsCertFingerprint.Length == 96) + { + if (GetMeshCertHash(certificate).Equals(tlsCertFingerprint)) { return true; } + if (GetMeshKeyHash(certificate).Equals(tlsCertFingerprint)) { return true; } + } + + Log("VerifyServerCertificate: tlsCertFingerprint = " + tlsCertFingerprint); + Log("VerifyServerCertificate: Hash1 = " + hash1); + Log("VerifyServerCertificate: Hash2 = " + hash2); + return ((tlsCertFingerprint == GetMeshKeyHash(certificate)) || (tlsCertFingerprint == certificate.GetCertHashString())); + } + else + { + // Normal certificate check + if (chain.Build(new X509Certificate2(certificate)) == true) return true; + + // Check that the remote certificate is the expected one + if ((tlsCertFingerprint != null) && ((tlsCertFingerprint == certificate.GetCertHashString()) || (tlsCertFingerprint == GetMeshKeyHash(certificate)) || (tlsCertFingerprint == GetMeshCertHash(certificate)))) { return true; } + if ((tlsCertFingerprint2 != null) && ((tlsCertFingerprint2 == certificate.GetCertHashString()) || (tlsCertFingerprint2 == GetMeshKeyHash(certificate)) || (tlsCertFingerprint2 == GetMeshCertHash(certificate)))) { return true; } + failedTlsCert = new X509Certificate2(certificate); + return false; + } } public int SendString(string data) @@ -611,84 +704,103 @@ namespace MeshCentralRouter return SendFragment(null, 0, 0, 138); } + Task pendingSend = null; + // Fragment op code (129 = text, 130 = binary) public int SendFragment(byte[] data, int offset, int len, byte op) { - lock (mainLock) + TlsDump("Out(" + op + ")", data, offset, len); + if (ws != null) { - if (state != ConnectionStates.Connected) return 0; - byte[] buf; - - // If deflate is active, attempt to compress the data here. - if ((deflateMemory != null) && (len > 32) && (AllowCompression)) + // Using native websocket + lock (this) { - deflateMemory.SetLength(0); - deflateMemory.Write(inflateStart, 0, 14); - DeflateStream deflate = new DeflateStream(deflateMemory, CompressionMode.Compress, true); - deflate.Write(data, offset, len); - deflate.Dispose(); - deflate = null; - if (deflateMemory.Length < len) + if ((pendingSend != null) && (pendingSend.IsCompleted == false)) { pendingSend.Wait(); } + ArraySegment arr = new ArraySegment(data, offset, len); + WebSocketMessageType msgType = ((op == 129) ? WebSocketMessageType.Text : WebSocketMessageType.Binary); + pendingSend = ws.SendAsync(arr, msgType, true, CTS.Token); + } + return len; + } + else + { + // Using C# websocket + lock (mainLock) + { + if (state != ConnectionStates.Connected) return 0; + byte[] buf; + + // If deflate is active, attempt to compress the data here. + if ((deflateMemory != null) && (len > 32) && (AllowCompression)) { - // Use the compressed data - int newlen = (int)deflateMemory.Length; - buf = deflateMemory.GetBuffer(); - len = newlen - 14; - op |= 0x40; // Add compression op + deflateMemory.SetLength(0); + deflateMemory.Write(inflateStart, 0, 14); + DeflateStream deflate = new DeflateStream(deflateMemory, CompressionMode.Compress, true); + deflate.Write(data, offset, len); + deflate.Dispose(); + deflate = null; + if (deflateMemory.Length < len) + { + // Use the compressed data + int newlen = (int)deflateMemory.Length; + buf = deflateMemory.GetBuffer(); + len = newlen - 14; + op |= 0x40; // Add compression op + } + else + { + // Don't use the compress data + // Convert the string into a buffer with 4 byte of header space. + buf = new byte[14 + len]; + Array.Copy(data, offset, buf, 14, len); + } } else { - // Don't use the compress data // Convert the string into a buffer with 4 byte of header space. buf = new byte[14 + len]; - Array.Copy(data, offset, buf, 14, len); + if (len > 0) { Array.Copy(data, offset, buf, 14, len); } } - } - else - { - // Convert the string into a buffer with 4 byte of header space. - buf = new byte[14 + len]; - if (len > 0) { Array.Copy(data, offset, buf, 14, len); } - } - // Check that everything is ok - if (len < 0) { Dispose(); return 0; } + // Check that everything is ok + if (len < 0) { Dispose(); return 0; } - // Set the mask to a cryptographic random value and XOR the data - byte[] rand = new byte[4]; - CryptoRandom.GetBytes(rand); - Array.Copy(rand, 0, buf, 10, 4); - for (int x = 0; x < len; x++) { buf[x + 14] ^= rand[x % 4]; } + // Set the mask to a cryptographic random value and XOR the data + byte[] rand = new byte[4]; + CryptoRandom.GetBytes(rand); + Array.Copy(rand, 0, buf, 10, 4); + for (int x = 0; x < len; x++) { buf[x + 14] ^= rand[x % 4]; } - if (len < 126) - { - // Small fragment - buf[8] = op; - buf[9] = (byte)((len & 0x7F) + 128); // Add 128 to indicate the mask is present - SendData(buf, 8, len + 6); - } - else if (len < 65535) - { - // Medium fragment - buf[6] = op; - buf[7] = 126 + 128; // Add 128 to indicate the mask is present - buf[8] = (byte)((len >> 8) & 0xFF); - buf[9] = (byte)(len & 0xFF); - SendData(buf, 6, len + 8); - } - else - { - // Large fragment - buf[0] = op; - buf[1] = 127 + 128; // Add 128 to indicate the mask is present - buf[6] = (byte)((len >> 24) & 0xFF); - buf[7] = (byte)((len >> 16) & 0xFF); - buf[8] = (byte)((len >> 8) & 0xFF); - buf[9] = (byte)(len & 0xFF); - SendData(buf, 0, len + 14); - } + if (len < 126) + { + // Small fragment + buf[8] = op; + buf[9] = (byte)((len & 0x7F) + 128); // Add 128 to indicate the mask is present + SendData(buf, 8, len + 6); + } + else if (len < 65535) + { + // Medium fragment + buf[6] = op; + buf[7] = 126 + 128; // Add 128 to indicate the mask is present + buf[8] = (byte)((len >> 8) & 0xFF); + buf[9] = (byte)(len & 0xFF); + SendData(buf, 6, len + 8); + } + else + { + // Large fragment + buf[0] = op; + buf[1] = 127 + 128; // Add 128 to indicate the mask is present + buf[6] = (byte)((len >> 24) & 0xFF); + buf[7] = (byte)((len >> 16) & 0xFF); + buf[8] = (byte)((len >> 8) & 0xFF); + buf[9] = (byte)(len & 0xFF); + SendData(buf, 0, len + 14); + } - return len; + return len; + } } } @@ -696,6 +808,7 @@ namespace MeshCentralRouter private void SendData(byte[] buf, int off, int len) { + TlsDump("OutRaw", buf, off, len); if (pendingSendCall) { lock (pendingSendBuffer) { pendingSendBuffer.Write(buf, off, len); } @@ -752,6 +865,50 @@ namespace MeshCentralRouter } } + private async Task ReceiveLoop() + { + SetState(ConnectionStates.Connected); + var loopToken = CTS.Token; + MemoryStream outputStream = null; + WebSocketReceiveResult receiveResult = null; + var buffer = new byte[8192]; + ArraySegment bufferEx = new ArraySegment(buffer); + try + { + while (!loopToken.IsCancellationRequested) + { + outputStream = new MemoryStream(8192); + do + { + receiveResult = await ws.ReceiveAsync(bufferEx, CTS.Token); + if (receiveResult.MessageType != WebSocketMessageType.Close) + outputStream.Write(buffer, 0, receiveResult.Count); + } + while (!receiveResult.EndOfMessage); + if (receiveResult.MessageType == WebSocketMessageType.Close) break; + outputStream.Position = 0; + if (receiveResult.MessageType == WebSocketMessageType.Text) + { + Log("Websocket got string data, len = " + (int)outputStream.Length); + TlsDump("InStr", outputStream.GetBuffer(), 0, (int)outputStream.Length); + if (onStringData != null) { onStringData(this, UTF8Encoding.UTF8.GetString(outputStream.GetBuffer(), 0, (int)outputStream.Length), (int)outputStream.Length); } + } + else if (receiveResult.MessageType == WebSocketMessageType.Binary) + { + Log("Websocket got binary data, len = " + (int)outputStream.Length); + TlsDump("InBin", outputStream.GetBuffer(), 0, (int)outputStream.Length); + if (onBinaryData != null) { onBinaryData(this, outputStream.GetBuffer(), 0, (int)outputStream.Length, (int)outputStream.Length); } + } + } + } + catch (TaskCanceledException) { } + finally + { + outputStream?.Dispose(); + SetState(0); + } + } + } } diff --git a/app.config b/app.config index e56d694..37bb3f3 100644 --- a/app.config +++ b/app.config @@ -39,4 +39,4 @@ - +