diff --git a/MainForm.cs b/MainForm.cs index 172920d..fc1eaf0 100644 --- a/MainForm.cs +++ b/MainForm.cs @@ -236,7 +236,7 @@ namespace MeshCentralRouter if (arg.ToLower() == "-all") { inaddrany = true; } if (arg.ToLower() == "-inaddrany") { inaddrany = true; } if (arg.ToLower() == "-tray") { notifyIcon.Visible = true; this.ShowInTaskbar = false; this.MinimizeBox = false; } - if (arg.ToLower() == "-nonative") { webSocketClient.nativeWebSocketFirst = false; } + if (arg.ToLower() == "-native") { webSocketClient.nativeWebSocketFirst = true; } if (arg.Length > 6 && arg.Substring(0, 6).ToLower() == "-host:") { serverNameComboBox.Text = arg.Substring(6); argflags |= 1; } if (arg.Length > 6 && arg.Substring(0, 6).ToLower() == "-user:") { userNameTextBox.Text = arg.Substring(6); argflags |= 2; } if (arg.Length > 6 && arg.Substring(0, 6).ToLower() == "-pass:") { passwordTextBox.Text = arg.Substring(6); argflags |= 4; } diff --git a/WebSocketClient.cs b/WebSocketClient.cs index ccaa446..0e3fd73 100644 --- a/WebSocketClient.cs +++ b/WebSocketClient.cs @@ -64,6 +64,7 @@ namespace MeshCentralRouter public int pongTimeSeconds = 0; private System.Threading.Timer pingTimer = null; private System.Threading.Timer pongTimer = null; + private System.Threading.Timer connectTimer = null; private bool pendingSendCall = false; private MemoryStream pendingSendBuffer = null; private bool readPaused = false; @@ -73,7 +74,7 @@ namespace MeshCentralRouter public TLSCertificateCheck TLSCertCheck = TLSCertificateCheck.Verify; public X509Certificate2 tlsCert = null; public X509Certificate2 failedTlsCert = null; - static public bool nativeWebSocketFirst = true; + static public bool nativeWebSocketFirst = false; private SemaphoreSlim receiveLock = new SemaphoreSlim(1, 1); // Outside variables @@ -143,6 +144,7 @@ namespace MeshCentralRouter } if (pingTimer != null) { pingTimer.Dispose(); pingTimer = null; } if (pongTimer != null) { pongTimer.Dispose(); pongTimer = null; } + if (connectTimer != null) { try { connectTimer.Dispose(); } catch (Exception) { } connectTimer = null; } if (wsstream != null) { try { wsstream.Close(); } catch (Exception) { } try { wsstream.Dispose(); } catch (Exception) { } wsstream = null; } if (wsclient != null) { wsclient = null; } if (pendingSendBuffer != null) { pendingSendBuffer.Dispose(); pendingSendBuffer = null; } @@ -179,15 +181,15 @@ namespace MeshCentralRouter CTS = null; } - public bool Start(Uri url, string tlsCertFingerprint, string tlsCertFingerprint2) + public bool Start(Uri url, string tlsCertFingerprint, string tlsCertFingerprint2, bool force = false) { - if (state != ConnectionStates.Disconnected) return false; + if ((force == false) && (state != ConnectionStates.Disconnected)) return false; SetState(ConnectionStates.Connecting); this.url = url; if (tlsCertFingerprint != null) { this.tlsCertFingerprint = tlsCertFingerprint.ToUpper(); } if (tlsCertFingerprint2 != null) { this.tlsCertFingerprint2 = tlsCertFingerprint2.ToUpper(); } - //if (nativeWebSocketFirst) { try { ws = new ClientWebSocket(); } catch (Exception) { } } + if (nativeWebSocketFirst) { try { ws = new ClientWebSocket(); } catch (Exception) { } } if (ws != null) { // Use Windows native websockets @@ -240,12 +242,18 @@ namespace MeshCentralRouter if (h.StartsWith("[") && h.EndsWith("]")) { h = h.Substring(1, h.Length - 2); } wsclient.BeginConnect(h, url.Port, new AsyncCallback(OnConnectSink), this); } + + // Start a timer that will fallback to native sockets automatically. + // For some proxy types, native websockets are the only way to connect. + if (connectTimer != null) { try { connectTimer.Dispose(); } catch (Exception) { } connectTimer = null; } + connectTimer = new System.Threading.Timer(new System.Threading.TimerCallback(ConnectTimerCallback), null, 3000, 3000); } return true; } private void OnConnectSink(IAsyncResult ar) { + if (connectTimer != null) { try { connectTimer.Dispose(); } catch (Exception) { } connectTimer = null; } if (wsclient == null) return; // Accept the connection @@ -256,7 +264,14 @@ namespace MeshCentralRouter catch (Exception ex) { Log("Websocket TCP failed to connect: " + ex.ToString()); - Dispose(); + if (nativeWebSocketFirst == false) + { + ConnectTimerCallback(null); + } + else + { + Dispose(); + } return; } @@ -468,6 +483,22 @@ namespace MeshCentralRouter private void PongTimerCallback(object state) { SendPong(null, 0, 0); } + private void ConnectTimerCallback(object state) { + // Switch from C# sockets to native sockets + if ((nativeWebSocketFirst == false) && (this.state == ConnectionStates.Connecting)) + { + Log("Switching to native Websocket"); + if (pingTimer != null) { try { pingTimer.Dispose(); } catch (Exception) { } pingTimer = null; } + if (pongTimer != null) { try { pongTimer.Dispose(); } catch (Exception) { } pongTimer = null; } + if (connectTimer != null) { try { connectTimer.Dispose(); } catch (Exception) { } connectTimer = null; } + if (wsstream != null) { try { wsstream.Close(); } catch (Exception) { } try { wsstream.Dispose(); } catch (Exception) { } wsstream = null; } + if (wsclient != null) { wsclient = null; } + if (pendingSendBuffer != null) { pendingSendBuffer.Dispose(); pendingSendBuffer = null; } + nativeWebSocketFirst = true; + Start(this.url, this.tlsCertFingerprint, this.tlsCertFingerprint2, true); + } + } + private int ProcessBuffer(byte[] buffer, int offset, int len) { TlsDump("InRaw", buffer, offset, len);