diff --git a/websocket-sharp/Net/HttpConnection.cs b/websocket-sharp/Net/HttpConnection.cs index a1886568..43f3a007 100644 --- a/websocket-sharp/Net/HttpConnection.cs +++ b/websocket-sharp/Net/HttpConnection.cs @@ -65,19 +65,19 @@ namespace WebSocketSharp.Net private HttpListenerContext _context; private bool _contextWasBound; private StringBuilder _currentLine; - private EndPointListener _epListener; private InputState _inputState; private RequestStream _inputStream; private HttpListener _lastListener; private LineState _lineState; + private EndPointListener _listener; private ResponseStream _outputStream; private int _position; private ListenerPrefix _prefix; private MemoryStream _requestBuffer; private int _reuses; - private bool _secure; private Socket _socket; private Stream _stream; + private object _sync; private int _timeout; private Timer _timer; @@ -88,11 +88,10 @@ namespace WebSocketSharp.Net public HttpConnection (Socket socket, EndPointListener listener) { _socket = socket; - _epListener = listener; - _secure = listener.IsSecure; + _listener = listener; var netStream = new NetworkStream (socket, false); - if (_secure) { + if (listener.IsSecure) { var sslStream = new SslStream (netStream, false); sslStream.AuthenticateAsServer (listener.Certificate); _stream = sslStream; @@ -101,6 +100,7 @@ namespace WebSocketSharp.Net _stream = netStream; } + _sync = new object (); _timeout = 90000; // 90k ms for first request, 15k ms from then on. _timer = new Timer (onTimeout, this, Timeout.Infinite, Timeout.Infinite); @@ -119,7 +119,7 @@ namespace WebSocketSharp.Net public bool IsSecure { get { - return _secure; + return _listener.IsSecure; } } @@ -161,21 +161,53 @@ namespace WebSocketSharp.Net #region Private Methods - private void closeSocket () + private void close () { if (_socket == null) return; + lock (_sync) { + if (_socket == null) + return; + + disposeTimer (); + disposeRequestBuffer (); + disposeStream (); + closeSocket (); + } + + unbind (); + removeConnection (); + } + + private void closeSocket () + { try { - _socket.Close (); + _socket.Shutdown (SocketShutdown.Both); } catch { } - finally { - _socket = null; - } - removeConnection (); + _socket.Close (); + _socket = null; + } + + private void disposeRequestBuffer () + { + if (_requestBuffer == null) + return; + + _requestBuffer.Dispose (); + _requestBuffer = null; + } + + private void disposeStream () + { + if (_stream == null) + return; + + _stream.Dispose (); + _stream = null; } private void disposeTimer () @@ -183,17 +215,14 @@ namespace WebSocketSharp.Net if (_timer == null) return; - var timer = _timer; - _timer = null; - try { - timer.Change (Timeout.Infinite, Timeout.Infinite); + _timer.Change (Timeout.Infinite, Timeout.Infinite); } catch { } - if (timer != null) - timer.Dispose (); + _timer.Dispose (); + _timer = null; } private void init () @@ -212,80 +241,78 @@ namespace WebSocketSharp.Net private static void onRead (IAsyncResult asyncResult) { var conn = (HttpConnection) asyncResult.AsyncState; + if (conn._socket == null) + return; - var read = -1; - try { - conn._timer.Change (Timeout.Infinite, Timeout.Infinite); - read = conn._stream.EndRead (asyncResult); - conn._requestBuffer.Write (conn._buffer, 0, read); - if (conn._requestBuffer.Length > 32768) { - conn.SendError ("Bad request", 400); - conn.Close (true); + lock (conn._sync) { + if (conn._socket == null) + return; + var read = -1; + try { + conn._timer.Change (Timeout.Infinite, Timeout.Infinite); + read = conn._stream.EndRead (asyncResult); + conn._requestBuffer.Write (conn._buffer, 0, read); + if (conn._requestBuffer.Length > 32768) { + conn.SendError ("Bad request", 400); + conn.Close (true); + + return; + } + } + catch { + var requestBuffer = conn._requestBuffer; + if (requestBuffer != null && requestBuffer.Length > 0) + conn.SendError (); + + conn.close (); return; } - } - catch { - if (conn._requestBuffer != null && conn._requestBuffer.Length > 0) - conn.SendError (); - if (conn._socket != null) { - conn.disposeTimer (); - conn.closeSocket (); - conn.unbind (); + if (read <= 0) { + conn.close (); + return; } - return; - } + if (conn.processInput (conn._requestBuffer.GetBuffer ())) { + if (!conn._context.HaveError) { + conn._context.Request.FinishInitialization (); + } + else { + conn.SendError (); + conn.Close (true); - if (read <= 0) { - conn.disposeTimer (); - conn.closeSocket (); - conn.unbind (); + return; + } - return; - } + if (!conn._listener.BindContext (conn._context)) { + conn.SendError ("Invalid host", 400); + conn.Close (true); - if (conn.processInput (conn._requestBuffer.GetBuffer ())) { - if (!conn._context.HaveError) { - conn._context.Request.FinishInitialization (); - } - else { - conn.SendError (); - conn.Close (true); + return; + } - return; - } + var listener = conn._context.Listener; + if (conn._lastListener != listener) { + conn.removeConnection (); + listener.AddConnection (conn); + conn._lastListener = listener; + } - if (!conn._epListener.BindContext (conn._context)) { - conn.SendError ("Invalid host", 400); - conn.Close (true); + conn._contextWasBound = true; + listener.RegisterContext (conn._context); return; } - var listener = conn._context.Listener; - if (conn._lastListener != listener) { - conn.removeConnection (); - listener.AddConnection (conn); - conn._lastListener = listener; - } - - conn._contextWasBound = true; - listener.RegisterContext (conn._context); - - return; + conn._stream.BeginRead (conn._buffer, 0, _bufferSize, onRead, conn); } - - conn._stream.BeginRead (conn._buffer, 0, _bufferSize, onRead, conn); } private static void onTimeout (object state) { var conn = (HttpConnection) state; - conn.disposeTimer (); - conn.closeSocket (); - conn.unbind (); + conn.close (); } // true -> Done processing. @@ -363,7 +390,7 @@ namespace WebSocketSharp.Net private void removeConnection () { if (_lastListener == null) - _epListener.RemoveConnection (this); + _listener.RemoveConnection (this); else _lastListener.RemoveConnection (this); } @@ -371,7 +398,7 @@ namespace WebSocketSharp.Net private void unbind () { if (_contextWasBound) { - _epListener.UnbindContext (_context); + _listener.UnbindContext (_context); _contextWasBound = false; } } @@ -382,7 +409,13 @@ namespace WebSocketSharp.Net internal void Close (bool force) { - if (_socket != null) { + if (_socket == null) + return; + + lock (_sync) { + if (_socket == null) + return; + if (_outputStream != null) { _outputStream.Close (); _outputStream = null; @@ -400,6 +433,7 @@ namespace WebSocketSharp.Net (!_chunked || (_chunked && !res.ForceCloseChunked))) { // Don't close. Keep working. _reuses++; + disposeRequestBuffer (); unbind (); init (); BeginReadRequest (); @@ -407,24 +441,7 @@ namespace WebSocketSharp.Net return; } - var socket = _socket; - _socket = null; - - disposeTimer (); - - try { - socket.Shutdown (SocketShutdown.Both); - } - catch { - } - - if (socket != null) - socket.Close (); - - unbind (); - removeConnection (); - - return; + close (); } } @@ -437,17 +454,15 @@ namespace WebSocketSharp.Net if (_buffer == null) _buffer = new byte [_bufferSize]; - try { - if (_reuses == 1) - _timeout = 15000; + if (_reuses == 1) + _timeout = 15000; + try { _timer.Change (_timeout, Timeout.Infinite); _stream.BeginRead (_buffer, 0, _bufferSize, onRead, this); } catch { - disposeTimer (); - closeSocket (); - unbind (); + close (); } } @@ -458,11 +473,16 @@ namespace WebSocketSharp.Net public RequestStream GetRequestStream (bool chunked, long contentlength) { - if (_inputStream == null) { + if (_inputStream != null || _socket == null) + return _inputStream; + + lock (_sync) { + if (_socket == null) + return _inputStream; + var buffer = _requestBuffer.GetBuffer (); var length = buffer.Length; - - _requestBuffer = null; + disposeRequestBuffer (); if (chunked) { _chunked = true; _context.Response.SendChunked = true; @@ -473,21 +493,28 @@ namespace WebSocketSharp.Net _inputStream = new RequestStream ( _stream, buffer, _position, length - _position, contentlength); } - } - return _inputStream; + return _inputStream; + } } public ResponseStream GetResponseStream () { // TODO: Can we get this stream before reading the input? - if (_outputStream == null) { + + if (_outputStream != null || _socket == null) + return _outputStream; + + lock (_sync) { + if (_socket == null) + return _outputStream; + var listener = _context.Listener; var ignore = listener == null ? true : listener.IgnoreWriteExceptions; _outputStream = new ResponseStream (_stream, _context.Response, ignore); - } - return _outputStream; + return _outputStream; + } } public void SendError () @@ -497,21 +524,29 @@ namespace WebSocketSharp.Net public void SendError (string message, int status) { - try { - var res = _context.Response; - res.StatusCode = status; - res.ContentType = "text/html"; + if (_socket == null) + return; - var description = status.GetStatusDescription (); - var error = message != null && message.Length > 0 - ? String.Format ("