Fix due to the modified WebSocket.cs

master
sta 13 years ago
parent ae5d461a42
commit 94385ea2bc

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

@ -302,35 +302,45 @@ namespace WebSocketSharp {
if (length <= 0) if (length <= 0)
return new byte[]{}; return new byte[]{};
var buffer = new byte[length]; var buffer = new byte[length];
stream.Read(buffer, 0, length); var readLen = stream.Read(buffer, 0, length);
return buffer;
return readLen == length ? buffer : null;
} }
public static byte[] ReadBytes(this Stream stream, long length, int bufferLength) public static byte[] ReadBytes(this Stream stream, long length, int bufferLength)
{ {
var count = length / bufferLength; var count = length / bufferLength;
var rem = length % bufferLength; var rem = length % bufferLength;
var readData = new List<byte>(); var readData = new List<byte>();
var readLen = 0; var readBuffer = new byte[bufferLength];
var buffer = new byte[bufferLength]; long readLen = 0;
var tmpLen = 0;
Action<byte[]> read = (buffer) =>
{
tmpLen = stream.Read(buffer, 0, buffer.Length);
if (tmpLen > 0)
{
readLen += tmpLen;
readData.AddRange(buffer.SubArray(0, tmpLen));
}
};
count.Times(() => count.Times(() =>
{ {
readLen = stream.Read(buffer, 0, bufferLength); read(readBuffer);
if (readLen > 0)
readData.AddRange(buffer.SubArray(0, readLen));
}); });
if (rem > 0) if (rem > 0)
{ {
buffer = new byte[rem]; readBuffer = new byte[rem];
readLen = stream.Read(buffer, 0, (int)rem); read(readBuffer);
if (readLen > 0)
readData.AddRange(buffer.SubArray(0, readLen));
} }
return readData.ToArray(); return readLen == length
? readData.ToArray()
: null;
} }
public static T[] SubArray<T>(this T[] array, int startIndex, int length) public static T[] SubArray<T>(this T[] array, int startIndex, int length)

@ -28,6 +28,7 @@
using System; using System;
using System.IO; using System.IO;
using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
@ -35,7 +36,7 @@ namespace WebSocketSharp.Frame
{ {
public class WsFrame : IEnumerable<byte> public class WsFrame : IEnumerable<byte>
{ {
#region Private Static Fields #region Field
private static readonly int _readBufferLen; private static readonly int _readBufferLen;
@ -81,7 +82,7 @@ namespace WebSocketSharp.Frame
#endregion #endregion
#region Private Constructors #region Private Constructor
private WsFrame() private WsFrame()
{ {
@ -109,223 +110,148 @@ namespace WebSocketSharp.Frame
public WsFrame(Fin fin, Opcode opcode, Mask mask, PayloadData payloadData) public WsFrame(Fin fin, Opcode opcode, Mask mask, PayloadData payloadData)
: this() : this()
{ {
Fin = fin; Fin = fin;
Opcode = opcode; Opcode = opcode;
Masked = payloadData.Length != 0 ? mask : Mask.UNMASK;
ulong dataLength = payloadData.Length;
if (dataLength == 0)
{
Masked = Mask.UNMASK;
}
else
{
Masked = mask;
}
if (dataLength < 126)
{
PayloadLen = (byte)dataLength;
}
else if (dataLength < 0x010000)
{
PayloadLen = (byte)126;
ExtPayloadLen = ((ushort)dataLength).ToBytes(ByteOrder.BIG);
}
else
{
PayloadLen = (byte)127;
ExtPayloadLen = dataLength.ToBytes(ByteOrder.BIG);
}
PayloadData = payloadData; PayloadData = payloadData;
if (Masked == Mask.MASK) init();
{
MaskingKey = new byte[4];
var rand = new Random();
rand.NextBytes(MaskingKey);
PayloadData.Mask(MaskingKey);
}
} }
#endregion #endregion
#region Public Static Methods #region Private Methods
public static WsFrame Parse(byte[] src) IEnumerator IEnumerable.GetEnumerator()
{ {
return Parse(src, true); return GetEnumerator();
} }
public static WsFrame Parse(byte[] src, bool unmask) private void init()
{ {
using (MemoryStream ms = new MemoryStream(src)) setPayloadLen(PayloadLength);
{ if (Masked == Mask.MASK)
return Parse(ms, unmask); maskPayloadData();
}
} }
public static WsFrame Parse<TStream>(TStream stream) private void maskPayloadData()
where TStream : System.IO.Stream
{ {
return Parse(stream, true); var key = new byte[4];
var rand = new Random();
rand.NextBytes(key);
MaskingKey = key;
PayloadData.Mask(key);
} }
public static WsFrame Parse<TStream>(TStream stream, bool unmask) private static void readExtPayloadLen(Stream stream, WsFrame frame)
where TStream : System.IO.Stream
{ {
Fin fin; var length = frame.PayloadLen <= 125
Rsv rsv1, rsv2, rsv3; ? 0
Opcode opcode; : frame.PayloadLen == 126 ? 2 : 8;
Mask masked;
byte payloadLen;
byte[] extPayloadLen = new byte[]{};
byte[] maskingKey = new byte[]{};
PayloadData payloadData;
byte[] buffer1, buffer2, buffer3;
int buffer1Len = 2;
int buffer2Len = 0;
ulong buffer3Len = 0;
int maskingKeyLen = 4;
int readLen = 0;
buffer1 = new byte[buffer1Len]; if (length > 0)
readLen = stream.Read(buffer1, 0, buffer1Len);
if (readLen < buffer1Len)
{ {
return null; var extLength = stream.ReadBytes(length);
if (extLength == null)
throw new IOException();
frame.ExtPayloadLen = extLength;
} }
}
private static WsFrame readHeader(Stream stream)
{
var header = stream.ReadBytes(2);
if (header == null)
return null;
// FIN // FIN
fin = (buffer1[0] & 0x80) == 0x80 Fin fin = (header[0] & 0x80) == 0x80 ? Fin.FINAL : Fin.MORE;
? Fin.FINAL
: Fin.MORE;
// RSV1 // RSV1
rsv1 = (buffer1[0] & 0x40) == 0x40 Rsv rsv1 = (header[0] & 0x40) == 0x40 ? Rsv.ON : Rsv.OFF;
? Rsv.ON
: Rsv.OFF;
// RSV2 // RSV2
rsv2 = (buffer1[0] & 0x20) == 0x20 Rsv rsv2 = (header[0] & 0x20) == 0x20 ? Rsv.ON : Rsv.OFF;
? Rsv.ON
: Rsv.OFF;
// RSV3 // RSV3
rsv3 = (buffer1[0] & 0x10) == 0x10 Rsv rsv3 = (header[0] & 0x10) == 0x10 ? Rsv.ON : Rsv.OFF;
? Rsv.ON // Opcode
: Rsv.OFF; Opcode opcode = (Opcode)(header[0] & 0x0f);
// opcode
opcode = (Opcode)(buffer1[0] & 0x0f);
// MASK // MASK
masked = (buffer1[1] & 0x80) == 0x80 Mask masked = (header[1] & 0x80) == 0x80 ? Mask.MASK : Mask.UNMASK;
? Mask.MASK
: Mask.UNMASK;
// Payload len // Payload len
payloadLen = (byte)(buffer1[1] & 0x7f); byte payloadLen = (byte)(header[1] & 0x7f);
// Extended payload length
if (payloadLen <= 125) return new WsFrame {
{ Fin = fin,
buffer3Len = payloadLen; Rsv1 = rsv1,
} Rsv2 = rsv2,
else if (payloadLen == 126) Rsv3 = rsv3,
{ Opcode = opcode,
buffer2Len = 2; Masked = masked,
} PayloadLen = payloadLen};
else }
private static void readMaskingKey(Stream stream, WsFrame frame)
{
if (frame.Masked == Mask.MASK)
{ {
buffer2Len = 8; var maskingKey = stream.ReadBytes(4);
if (maskingKey == null)
throw new IOException();
frame.MaskingKey = maskingKey;
} }
}
if (buffer2Len > 0) private static void readPayloadData(Stream stream, WsFrame frame, bool unmask)
{ {
buffer2 = new byte[buffer2Len]; ulong length = frame.PayloadLen <= 125
readLen = stream.Read(buffer2, 0, buffer2Len); ? frame.PayloadLen
: frame.PayloadLen == 126
? frame.ExtPayloadLen.To<ushort>(ByteOrder.BIG)
: frame.ExtPayloadLen.To<ulong>(ByteOrder.BIG);
if (readLen < buffer2Len) var buffer = length <= (ulong)_readBufferLen
{ ? stream.ReadBytes((int)length)
return null; : stream.ReadBytes((long)length, _readBufferLen);
}
extPayloadLen = buffer2; if (buffer == null)
switch (buffer2Len) throw new IOException();
{
case 2:
buffer3Len = extPayloadLen.To<ushort>(ByteOrder.BIG);
break;
case 8:
buffer3Len = extPayloadLen.To<ulong>(ByteOrder.BIG);
break;
}
}
if (buffer3Len > PayloadData.MaxLength) PayloadData payloadData;
{ if (frame.Masked == Mask.MASK)
throw new WsReceivedTooBigMessageException();
}
// Masking-key
if (masked == Mask.MASK)
{
maskingKey = new byte[maskingKeyLen];
readLen = stream.Read(maskingKey, 0, maskingKeyLen);
if (readLen < maskingKeyLen)
{
return null;
}
}
// Payload Data
if (buffer3Len == 0)
{
buffer3 = new byte[]{};
}
else if (buffer3Len <= (ulong)_readBufferLen)
{ {
buffer3 = new byte[buffer3Len]; payloadData = new PayloadData(buffer, true);
readLen = stream.Read(buffer3, 0, (int)buffer3Len); if (unmask == true)
if (readLen < (int)buffer3Len)
{ {
return null; payloadData.Mask(frame.MaskingKey);
frame.Masked = Mask.UNMASK;
frame.MaskingKey = new byte[]{};
} }
} }
else else
{ {
buffer3 = stream.ReadBytes((long)buffer3Len, _readBufferLen); payloadData = new PayloadData(buffer);
if ((ulong)buffer3.LongLength < buffer3Len)
{
return null;
}
} }
if (masked == Mask.MASK) frame.PayloadData = payloadData;
}
private void setPayloadLen(ulong length)
{
if (length < 126)
{ {
payloadData = new PayloadData(buffer3, true); PayloadLen = (byte)length;
if (unmask == true) return;
{
payloadData.Mask(maskingKey);
masked = Mask.UNMASK;
maskingKey = new byte[]{};
}
} }
else
if (length < 0x010000)
{ {
payloadData = new PayloadData(buffer3); PayloadLen = (byte)126;
ExtPayloadLen = ((ushort)length).ToBytes(ByteOrder.BIG);
return;
} }
return new WsFrame PayloadLen = (byte)127;
{ ExtPayloadLen = length.ToBytes(ByteOrder.BIG);
Fin = fin,
Rsv1 = rsv1,
Rsv2 = rsv2,
Rsv3 = rsv3,
Opcode = opcode,
Masked = masked,
PayloadLen = payloadLen,
ExtPayloadLen = extPayloadLen,
MaskingKey = maskingKey,
PayloadData = payloadData
};
} }
#endregion #endregion
@ -335,14 +261,38 @@ namespace WebSocketSharp.Frame
public IEnumerator<byte> GetEnumerator() public IEnumerator<byte> GetEnumerator()
{ {
foreach (byte b in ToBytes()) foreach (byte b in ToBytes())
{
yield return b; yield return b;
}
public static WsFrame Parse(byte[] src)
{
return Parse(src, true);
}
public static WsFrame Parse(byte[] src, bool unmask)
{
using (MemoryStream ms = new MemoryStream(src))
{
return Parse(ms, unmask);
} }
} }
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() public static WsFrame Parse(Stream stream)
{ {
return GetEnumerator(); return Parse(stream, true);
}
public static WsFrame Parse(Stream stream, bool unmask)
{
var frame = readHeader(stream);
if (frame == null)
return null;
readExtPayloadLen(stream, frame);
readMaskingKey(stream, frame);
readPayloadData(stream, frame, unmask);
return frame;
} }
public void Print() public void Print()
@ -462,33 +412,27 @@ namespace WebSocketSharp.Frame
public byte[] ToBytes() public byte[] ToBytes()
{ {
var bytes = new List<byte>(); var buffer = new List<byte>();
int first16 = (int)Fin; int header = (int)Fin;
first16 = (first16 << 1) + (int)Rsv1; header = (header << 1) + (int)Rsv1;
first16 = (first16 << 1) + (int)Rsv2; header = (header << 1) + (int)Rsv2;
first16 = (first16 << 1) + (int)Rsv3; header = (header << 1) + (int)Rsv3;
first16 = (first16 << 4) + (int)Opcode; header = (header << 4) + (int)Opcode;
first16 = (first16 << 1) + (int)Masked; header = (header << 1) + (int)Masked;
first16 = (first16 << 7) + (int)PayloadLen; header = (header << 7) + (int)PayloadLen;
bytes.AddRange(((ushort)first16).ToBytes(ByteOrder.BIG)); buffer.AddRange(((ushort)header).ToBytes(ByteOrder.BIG));
if (PayloadLen >= 126) if (PayloadLen >= 126)
{ buffer.AddRange(ExtPayloadLen);
bytes.AddRange(ExtPayloadLen);
}
if (Masked == Mask.MASK) if (Masked == Mask.MASK)
{ buffer.AddRange(MaskingKey);
bytes.AddRange(MaskingKey);
}
if (PayloadLen > 0) if (PayloadLen > 0)
{ buffer.AddRange(PayloadData.ToBytes());
bytes.AddRange(PayloadData.ToBytes());
}
return bytes.ToArray(); return buffer.ToArray();
} }
public override string ToString() public override string ToString()

@ -31,13 +31,13 @@
using System; using System;
using System.IO; using System.IO;
using System.Net; using System.Net;
using System.Net.Security;
using System.Net.Sockets; using System.Net.Sockets;
using System.Reflection; using System.Reflection;
using System.Security.Cryptography; using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates; using System.Security.Cryptography.X509Certificates;
using System.Text; using System.Text;
using System.Threading; using System.Threading;
using WebSocketSharp.Net.Security;
namespace WebSocketSharp.Net { namespace WebSocketSharp.Net {
@ -122,7 +122,7 @@ namespace WebSocketSharp.Net {
if (!secure) { if (!secure) {
stream = net_stream; stream = net_stream;
} else { } else {
var ssl_stream = new SslStream(net_stream); var ssl_stream = new SslStream(net_stream, false);
ssl_stream.AuthenticateAsServer(cert); ssl_stream.AuthenticateAsServer(cert);
stream = ssl_stream; stream = ssl_stream;
} }

@ -0,0 +1,83 @@
#region MIT License
/**
* SslStream.cs
*
* The MIT License
*
* Copyright (c) 2012 sta.blockhead
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#endregion
using System;
using System.Net.Security;
using System.Net.Sockets;
namespace WebSocketSharp.Net.Security {
public class SslStream : System.Net.Security.SslStream
{
#region Constructors
public SslStream(NetworkStream innerStream)
: base(innerStream)
{
}
public SslStream(NetworkStream innerStream, bool leaveInnerStreamOpen)
: base(innerStream, leaveInnerStreamOpen)
{
}
public SslStream(
NetworkStream innerStream,
bool leaveInnerStreamOpen,
RemoteCertificateValidationCallback userCertificateValidationCallback
) : base(innerStream, leaveInnerStreamOpen, userCertificateValidationCallback)
{
}
public SslStream(
NetworkStream innerStream,
bool leaveInnerStreamOpen,
RemoteCertificateValidationCallback userCertificateValidationCallback,
LocalCertificateSelectionCallback userCertificateSelectionCallback
) : base(
innerStream,
leaveInnerStreamOpen,
userCertificateValidationCallback,
userCertificateSelectionCallback
)
{
}
#endregion
#region Property
public bool DataAvailable {
get {
return ((NetworkStream)InnerStream).DataAvailable;
}
}
#endregion
}
}

@ -643,6 +643,19 @@ namespace WebSocketSharp {
return frame; return frame;
} }
private WsFrame readFrameWithTimeout()
{
if (!_wsStream.DataAvailable)
{
var timeout = 1 * 100;
Thread.Sleep(timeout);
if (!_wsStream.DataAvailable)
return null;
}
return readFrame();
}
private string[] readHandshake() private string[] readHandshake()
{ {
return _wsStream.ReadHandshake(); return _wsStream.ReadHandshake();
@ -650,7 +663,7 @@ namespace WebSocketSharp {
private MessageEventArgs receive() private MessageEventArgs receive()
{ {
var frame = readFrame(); var frame = _isClient ? readFrame() : readFrameWithTimeout();
if (frame == null) if (frame == null)
return null; return null;

@ -31,11 +31,11 @@ using System.Collections.Generic;
using System.Configuration; using System.Configuration;
using System.IO; using System.IO;
using System.Net; using System.Net;
using System.Net.Security;
using System.Net.Sockets; using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates; using System.Security.Cryptography.X509Certificates;
using System.Text; using System.Text;
using WebSocketSharp.Frame; using WebSocketSharp.Frame;
using WebSocketSharp.Net.Security;
namespace WebSocketSharp namespace WebSocketSharp
{ {
@ -51,7 +51,7 @@ namespace WebSocketSharp
#endregion #endregion
#region Constructor #region Constructors
public WsStream(NetworkStream innerStream) public WsStream(NetworkStream innerStream)
{ {
@ -65,7 +65,16 @@ namespace WebSocketSharp
#endregion #endregion
#region Public Property #region Properties
public bool DataAvailable {
get {
if (_innerStreamType == typeof(SslStream))
return ((SslStream)_innerStream).DataAvailable;
return ((NetworkStream)_innerStream).DataAvailable;
}
}
public bool IsSecure { public bool IsSecure {
get { return _isSecure; } get { return _isSecure; }
@ -73,7 +82,7 @@ namespace WebSocketSharp
#endregion #endregion
#region Private Methods #region Private Method
private void init(Stream innerStream) private void init(Stream innerStream)
{ {
@ -98,7 +107,7 @@ namespace WebSocketSharp
if (port == 443) if (port == 443)
{ {
RemoteCertificateValidationCallback validationCb = (sender, certificate, chain, sslPolicyErrors) => System.Net.Security.RemoteCertificateValidationCallback validationCb = (sender, certificate, chain, sslPolicyErrors) =>
{ {
// FIXME: Always returns true // FIXME: Always returns true
return true; return true;
@ -120,7 +129,7 @@ namespace WebSocketSharp
var port = ((IPEndPoint)client.Client.LocalEndPoint).Port; var port = ((IPEndPoint)client.Client.LocalEndPoint).Port;
if (port == 443) if (port == 443)
{ {
var sslStream = new SslStream(netStream); var sslStream = new SslStream(netStream, false);
var certPath = ConfigurationManager.AppSettings["ServerCertPath"]; var certPath = ConfigurationManager.AppSettings["ServerCertPath"];
sslStream.AuthenticateAsServer(new X509Certificate2(certPath)); sslStream.AuthenticateAsServer(new X509Certificate2(certPath));

@ -110,6 +110,7 @@
<Compile Include="Server\IWebSocketServer.cs" /> <Compile Include="Server\IWebSocketServer.cs" />
<Compile Include="Net\Sockets\TcpListenerWebSocketContext.cs" /> <Compile Include="Net\Sockets\TcpListenerWebSocketContext.cs" />
<Compile Include="Server\WebSocketServerBase.cs" /> <Compile Include="Server\WebSocketServerBase.cs" />
<Compile Include="Net\Security\SslStream.cs" />
</ItemGroup> </ItemGroup>
<Import Project="$(MSBuildBinPath)\Microsoft.CSharp.targets" /> <Import Project="$(MSBuildBinPath)\Microsoft.CSharp.targets" />
<ItemGroup> <ItemGroup>
@ -117,5 +118,6 @@
<Folder Include="Server\" /> <Folder Include="Server\" />
<Folder Include="Net\" /> <Folder Include="Net\" />
<Folder Include="Net\Sockets\" /> <Folder Include="Net\Sockets\" />
<Folder Include="Net\Security\" />
</ItemGroup> </ItemGroup>
</Project> </Project>
Loading…
Cancel
Save