AyCode.Core/AyCode.Services/SignalRs/AcBinaryHubProtocol.cs

616 lines
18 KiB
C#

using System.Buffers;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using AyCode.Core.Serializers.Binaries;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Protocol;
namespace AyCode.Services.SignalRs;
/// <summary>
/// Custom SignalR hub protocol using AcBinarySerializer for wire format.
/// Eliminates JSON+Base64 overhead by serializing all HubMessages directly to binary.
///
/// Wire format per message:
/// [4 bytes: payload length (little-endian)] [payload bytes]
///
/// Payload structure:
/// [1 byte: message type] [message-specific fields serialized via AcBinary]
///
/// Message types map 1:1 to SignalR HubMessageType values.
/// Arguments are serialized individually with a VarUInt length prefix each,
/// enabling deferred deserialization via IHubProtocol's binder pattern.
///
/// All writes go directly to the IBufferWriter provided by SignalR via BufferWriterBinaryOutput.
/// Length prefix is patched in-place after payload is written.
/// </summary>
public sealed class AcBinaryHubProtocol : IHubProtocol
{
private const int LengthPrefixSize = 4;
// Message type markers (matching HubMessageType enum values)
private const byte MsgInvocation = 1;
private const byte MsgStreamItem = 2;
private const byte MsgCompletion = 3;
private const byte MsgStreamInvocation = 4;
private const byte MsgCancelInvocation = 5;
private const byte MsgPing = 6;
private const byte MsgClose = 7;
private const byte MsgAck = 8;
private const byte MsgSequence = 9;
private volatile AcBinarySerializerOptions _options;
public AcBinaryHubProtocol() : this(AcBinarySerializerOptions.Default) { }
public AcBinaryHubProtocol(AcBinarySerializerOptions options)
{
_options = options;
}
/// <summary>
/// Runtime-replaceable serializer options.
/// Thread-safe: uses volatile field, callers see the new options on next message.
/// </summary>
public AcBinarySerializerOptions Options
{
get => _options;
set => _options = value;
}
public string Name => "acbinary";
public int Version => 1;
public TransferFormat TransferFormat => TransferFormat.Binary;
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public bool IsVersionSupported(int version) => version <= Version;
#region WriteMessage
public ReadOnlyMemory<byte> GetMessageBytes(HubMessage message)
{
var writer = new ArrayBufferWriter<byte>(256);
WriteMessage(message, writer);
return writer.WrittenMemory;
}
public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
{
// Reserve 4 bytes for the length prefix — we'll patch it after writing the payload.
// GetMemory returns a contiguous block; we keep a reference to write the length later.
var lengthMemory = output.GetMemory(LengthPrefixSize);
output.Advance(LengthPrefixSize);
// Wrap the IBufferWriter in BufferWriterBinaryOutput for optimized writes.
var w = new BufferWriterBinaryOutput(output);
switch (message)
{
case InvocationMessage m:
WriteInvocation(w, m);
break;
case StreamInvocationMessage m:
WriteStreamInvocation(w, m);
break;
case StreamItemMessage m:
WriteStreamItem(w, m);
break;
case CompletionMessage m:
WriteCompletion(w, m);
break;
case CancelInvocationMessage m:
WriteCancelInvocation(w, m);
break;
case PingMessage:
w.WriteByte(MsgPing);
break;
case CloseMessage m:
WriteClose(w, m);
break;
case AckMessage m:
WriteAck(w, m);
break;
case SequenceMessage m:
WriteSequence(w, m);
break;
default:
throw new HubException($"Unexpected message type: {message.GetType().Name}");
}
// Flush pending chunk bytes to the underlying IBufferWriter, then patch length prefix.
w.Flush();
Unsafe.WriteUnaligned(ref lengthMemory.Span[0], w.Position);
}
private void WriteInvocation(BufferWriterBinaryOutput w, InvocationMessage m)
{
w.WriteByte(MsgInvocation);
WriteNullableString(w, m.InvocationId);
WriteString(w, m.Target);
WriteArguments(w, m.Arguments);
WriteStringArray(w, m.StreamIds);
WriteHeaders(w, m.Headers);
}
private void WriteStreamInvocation(BufferWriterBinaryOutput w, StreamInvocationMessage m)
{
w.WriteByte(MsgStreamInvocation);
WriteString(w, m.InvocationId!);
WriteString(w, m.Target);
WriteArguments(w, m.Arguments);
WriteStringArray(w, m.StreamIds);
WriteHeaders(w, m.Headers);
}
private void WriteStreamItem(BufferWriterBinaryOutput w, StreamItemMessage m)
{
w.WriteByte(MsgStreamItem);
WriteString(w, m.InvocationId!);
WriteArgument(w, m.Item);
WriteHeaders(w, m.Headers);
}
private void WriteCompletion(BufferWriterBinaryOutput w, CompletionMessage m)
{
w.WriteByte(MsgCompletion);
WriteString(w, m.InvocationId!);
WriteNullableString(w, m.Error);
// Result presence flags: 0 = no result, 1 = has result
var hasResult = m.HasResult;
w.WriteByte(hasResult ? (byte)1 : (byte)0);
if (hasResult)
WriteArgument(w, m.Result);
WriteHeaders(w, m.Headers);
}
private static void WriteCancelInvocation(BufferWriterBinaryOutput w, CancelInvocationMessage m)
{
w.WriteByte(MsgCancelInvocation);
WriteString(w, m.InvocationId!);
WriteHeaders(w, m.Headers);
}
private static void WriteClose(BufferWriterBinaryOutput w, CloseMessage m)
{
w.WriteByte(MsgClose);
WriteNullableString(w, m.Error);
w.WriteByte(m.AllowReconnect ? (byte)1 : (byte)0);
}
private static void WriteAck(BufferWriterBinaryOutput w, AckMessage m)
{
w.WriteByte(MsgAck);
w.WriteRaw(m.SequenceId);
}
private static void WriteSequence(BufferWriterBinaryOutput w, SequenceMessage m)
{
w.WriteByte(MsgSequence);
w.WriteRaw(m.SequenceId);
}
#endregion
#region TryParseMessage
public bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, [NotNullWhen(true)] out HubMessage? message)
{
message = null;
if (input.Length < LengthPrefixSize)
return false;
// Read length prefix
int payloadLength;
if (input.FirstSpan.Length >= LengthPrefixSize)
{
payloadLength = Unsafe.ReadUnaligned<int>(ref Unsafe.AsRef(in input.FirstSpan[0]));
}
else
{
Span<byte> lenBuf = stackalloc byte[LengthPrefixSize];
input.Slice(0, LengthPrefixSize).CopyTo(lenBuf);
payloadLength = Unsafe.ReadUnaligned<int>(ref lenBuf[0]);
}
var totalLength = LengthPrefixSize + payloadLength;
if (input.Length < totalLength)
return false;
var payload = input.Slice(LengthPrefixSize, payloadLength);
// Linearize payload for span-based reading
ReadOnlySpan<byte> span;
byte[]? rentedBuffer = null;
if (payload.IsSingleSegment)
{
span = payload.FirstSpan;
}
else
{
rentedBuffer = ArrayPool<byte>.Shared.Rent(payloadLength);
payload.CopyTo(rentedBuffer);
span = rentedBuffer.AsSpan(0, payloadLength);
}
try
{
message = ParseMessage(span, binder);
}
finally
{
if (rentedBuffer != null)
ArrayPool<byte>.Shared.Return(rentedBuffer);
}
input = input.Slice(totalLength);
return message != null;
}
private HubMessage? ParseMessage(ReadOnlySpan<byte> span, IInvocationBinder binder)
{
if (span.Length == 0)
return null;
var reader = new SpanReader(span);
var msgType = reader.ReadByte();
return msgType switch
{
MsgInvocation => ParseInvocation(ref reader, binder),
MsgStreamInvocation => ParseStreamInvocation(ref reader, binder),
MsgStreamItem => ParseStreamItem(ref reader, binder),
MsgCompletion => ParseCompletion(ref reader, binder),
MsgCancelInvocation => ParseCancelInvocation(ref reader),
MsgPing => PingMessage.Instance,
MsgClose => ParseClose(ref reader),
MsgAck => new AckMessage(reader.ReadInt64()),
MsgSequence => new SequenceMessage(reader.ReadInt64()),
_ => null
};
}
private HubMessage ParseInvocation(ref SpanReader r, IInvocationBinder binder)
{
var invocationId = r.ReadNullableString();
var target = r.ReadString();
var paramTypes = binder.GetParameterTypes(target);
var args = ReadArguments(ref r, paramTypes);
var streamIds = r.ReadStringArray();
var headers = ReadHeaders(ref r);
var msg = streamIds is { Length: > 0 }
? new InvocationMessage(invocationId, target, args, streamIds)
: ApplyInvocationId(new InvocationMessage(target, args), invocationId);
if (headers != null)
SetHeaders(msg, headers);
return msg;
}
private HubMessage ParseStreamInvocation(ref SpanReader r, IInvocationBinder binder)
{
var invocationId = r.ReadString();
var target = r.ReadString();
var paramTypes = binder.GetParameterTypes(target);
var args = ReadArguments(ref r, paramTypes);
var streamIds = r.ReadStringArray();
var headers = ReadHeaders(ref r);
var msg = new StreamInvocationMessage(invocationId, target, args, streamIds);
if (headers != null)
SetHeaders(msg, headers);
return msg;
}
private HubMessage ParseStreamItem(ref SpanReader r, IInvocationBinder binder)
{
var invocationId = r.ReadString();
var itemType = binder.GetStreamItemType(invocationId);
var item = ReadSingleArgument(ref r, itemType);
var headers = ReadHeaders(ref r);
var msg = new StreamItemMessage(invocationId, item);
if (headers != null)
SetHeaders(msg, headers);
return msg;
}
private HubMessage ParseCompletion(ref SpanReader r, IInvocationBinder binder)
{
var invocationId = r.ReadString();
var error = r.ReadNullableString();
var hasResult = r.ReadByte() == 1;
object? result = null;
if (hasResult)
{
var resultType = binder.GetReturnType(invocationId);
result = ReadSingleArgument(ref r, resultType);
}
var headers = ReadHeaders(ref r);
CompletionMessage msg;
if (error != null)
msg = CompletionMessage.WithError(invocationId, error);
else if (hasResult)
msg = CompletionMessage.WithResult(invocationId, result);
else
msg = CompletionMessage.Empty(invocationId);
if (headers != null)
SetHeaders(msg, headers);
return msg;
}
private static HubMessage ParseCancelInvocation(ref SpanReader r)
{
var invocationId = r.ReadString();
var headers = ReadHeaders(ref r);
var msg = new CancelInvocationMessage(invocationId);
if (headers != null)
SetHeaders(msg, headers);
return msg;
}
private static HubMessage ParseClose(ref SpanReader r)
{
var error = r.ReadNullableString();
var allowReconnect = r.Remaining > 0 && r.ReadByte() == 1;
return new CloseMessage(error, allowReconnect);
}
#endregion
#region Argument Serialization (AcBinary payload per argument)
private void WriteArguments(BufferWriterBinaryOutput w, object?[] arguments)
{
w.WriteVarUInt((uint)arguments.Length);
for (var i = 0; i < arguments.Length; i++)
WriteArgument(w, arguments[i]);
}
private void WriteArgument(BufferWriterBinaryOutput w, object? value)
{
if (value == null)
{
w.WriteVarUInt(1);
w.WriteByte(0); // BinaryTypeCode.Null
return;
}
// AcBinarySerializer needs the full payload size upfront (2-pass),
// so we serialize to a pooled byte[] first, then copy length-prefixed.
var serialized = AcBinarySerializer.Serialize(value, _options);
w.WriteVarUInt((uint)serialized.Length);
w.WriteBytes(serialized);
}
private object?[] ReadArguments(ref SpanReader r, IReadOnlyList<Type> paramTypes)
{
var count = (int)r.ReadVarUInt();
var args = new object?[count];
for (var i = 0; i < count; i++)
{
var targetType = i < paramTypes.Count ? paramTypes[i] : typeof(object);
args[i] = ReadSingleArgument(ref r, targetType);
}
return args;
}
private object? ReadSingleArgument(ref SpanReader r, Type targetType)
{
var argLength = (int)r.ReadVarUInt();
if (argLength == 0)
return null;
var argSpan = r.ReadSpan(argLength);
if (argLength == 1 && argSpan[0] == 0) // BinaryTypeCode.Null
return null;
return AcBinaryDeserializer.Deserialize(argSpan, targetType, _options);
}
#endregion
#region Framing Helpers (string, nullable string, string array, headers)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void WriteString(BufferWriterBinaryOutput w, string value)
{
w.WriteStringUtf8(value);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void WriteNullableString(BufferWriterBinaryOutput w, string? value)
{
if (value == null)
{
w.WriteByte(0); // null marker
return;
}
w.WriteByte(1); // present marker
w.WriteStringUtf8(value);
}
private static void WriteStringArray(BufferWriterBinaryOutput w, string[]? array)
{
if (array == null || array.Length == 0)
{
w.WriteVarUInt(0);
return;
}
w.WriteVarUInt((uint)array.Length);
for (var i = 0; i < array.Length; i++)
w.WriteStringUtf8(array[i]);
}
private static void WriteHeaders(BufferWriterBinaryOutput w, IDictionary<string, string>? headers)
{
if (headers == null || headers.Count == 0)
{
w.WriteVarUInt(0);
return;
}
w.WriteVarUInt((uint)headers.Count);
foreach (var kv in headers)
{
w.WriteStringUtf8(kv.Key);
w.WriteStringUtf8(kv.Value);
}
}
#endregion
#region Helpers
private static InvocationMessage ApplyInvocationId(InvocationMessage msg, string? invocationId)
{
if (invocationId != null)
return new InvocationMessage(invocationId, msg.Target, msg.Arguments);
return msg;
}
private static void SetHeaders(HubMessage msg, Dictionary<string, string> headers)
{
if (msg is HubInvocationMessage invMsg)
invMsg.Headers = headers;
}
private static Dictionary<string, string>? ReadHeaders(ref SpanReader r)
{
if (r.Remaining == 0)
return null;
var count = (int)r.ReadVarUInt();
if (count == 0)
return null;
var headers = new Dictionary<string, string>(count, StringComparer.Ordinal);
for (var i = 0; i < count; i++)
{
var key = r.ReadString();
var value = r.ReadString();
headers[key] = value;
}
return headers;
}
#endregion
#region SpanReader
/// <summary>
/// Lightweight ref struct for sequential reading from a ReadOnlySpan.
/// </summary>
private ref struct SpanReader
{
private readonly ReadOnlySpan<byte> _span;
private int _pos;
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public SpanReader(ReadOnlySpan<byte> span)
{
_span = span;
_pos = 0;
}
public int Remaining
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
get => _span.Length - _pos;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public byte ReadByte() => _span[_pos++];
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public long ReadInt64()
{
var value = Unsafe.ReadUnaligned<long>(ref Unsafe.AsRef(in _span[_pos]));
_pos += 8;
return value;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public uint ReadVarUInt()
{
uint value = 0;
var shift = 0;
while (true)
{
var b = _span[_pos++];
value |= (uint)(b & 0x7F) << shift;
if ((b & 0x80) == 0)
return value;
shift += 7;
}
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public ReadOnlySpan<byte> ReadSpan(int length)
{
var result = _span.Slice(_pos, length);
_pos += length;
return result;
}
public string ReadString()
{
var byteCount = (int)ReadVarUInt();
if (byteCount == 0)
return string.Empty;
var bytes = ReadSpan(byteCount);
return System.Text.Encoding.UTF8.GetString(bytes);
}
public string? ReadNullableString()
{
var marker = ReadByte();
return marker == 0 ? null : ReadString();
}
public string[]? ReadStringArray()
{
var count = (int)ReadVarUInt();
if (count == 0)
return null;
var array = new string[count];
for (var i = 0; i < count; i++)
array[i] = ReadString();
return array;
}
}
#endregion
}