using System.Buffers; using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; using AyCode.Core.Serializers.Binaries; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.SignalR; using Microsoft.AspNetCore.SignalR.Protocol; namespace AyCode.Services.SignalRs; /// /// Custom SignalR hub protocol using AcBinarySerializer for wire format. /// Eliminates JSON+Base64 overhead by serializing all HubMessages directly to binary. /// /// Wire format per message: /// [4 bytes: payload length (little-endian)] [payload bytes] /// /// Payload structure: /// [1 byte: message type] [message-specific fields serialized via AcBinary] /// /// Message types map 1:1 to SignalR HubMessageType values. /// Arguments are serialized individually with a VarUInt length prefix each, /// enabling deferred deserialization via IHubProtocol's binder pattern. /// /// All writes go directly to the IBufferWriter provided by SignalR via BufferWriterBinaryOutput. /// Length prefix is patched in-place after payload is written. /// public sealed class AcBinaryHubProtocol : IHubProtocol { private const int LengthPrefixSize = 4; // Message type markers (matching HubMessageType enum values) private const byte MsgInvocation = 1; private const byte MsgStreamItem = 2; private const byte MsgCompletion = 3; private const byte MsgStreamInvocation = 4; private const byte MsgCancelInvocation = 5; private const byte MsgPing = 6; private const byte MsgClose = 7; private const byte MsgAck = 8; private const byte MsgSequence = 9; private volatile AcBinarySerializerOptions _options; public AcBinaryHubProtocol() : this(AcBinarySerializerOptions.Default) { } public AcBinaryHubProtocol(AcBinarySerializerOptions options) { _options = options; } /// /// Runtime-replaceable serializer options. /// Thread-safe: uses volatile field, callers see the new options on next message. /// public AcBinarySerializerOptions Options { get => _options; set => _options = value; } public string Name => "acbinary"; public int Version => 1; public TransferFormat TransferFormat => TransferFormat.Binary; [MethodImpl(MethodImplOptions.AggressiveInlining)] public bool IsVersionSupported(int version) => version <= Version; #region WriteMessage public ReadOnlyMemory GetMessageBytes(HubMessage message) { var writer = new ArrayBufferWriter(256); WriteMessage(message, writer); return writer.WrittenMemory; } public void WriteMessage(HubMessage message, IBufferWriter output) { // Reserve 4 bytes for the length prefix — we'll patch it after writing the payload. // GetMemory returns a contiguous block; we keep a reference to write the length later. var lengthMemory = output.GetMemory(LengthPrefixSize); output.Advance(LengthPrefixSize); // Wrap the IBufferWriter in BufferWriterBinaryOutput for optimized writes. var w = new BufferWriterBinaryOutput(output); switch (message) { case InvocationMessage m: WriteInvocation(w, m); break; case StreamInvocationMessage m: WriteStreamInvocation(w, m); break; case StreamItemMessage m: WriteStreamItem(w, m); break; case CompletionMessage m: WriteCompletion(w, m); break; case CancelInvocationMessage m: WriteCancelInvocation(w, m); break; case PingMessage: w.WriteByte(MsgPing); break; case CloseMessage m: WriteClose(w, m); break; case AckMessage m: WriteAck(w, m); break; case SequenceMessage m: WriteSequence(w, m); break; default: throw new HubException($"Unexpected message type: {message.GetType().Name}"); } // Flush pending chunk bytes to the underlying IBufferWriter, then patch length prefix. w.Flush(); Unsafe.WriteUnaligned(ref lengthMemory.Span[0], w.Position); } private void WriteInvocation(BufferWriterBinaryOutput w, InvocationMessage m) { w.WriteByte(MsgInvocation); WriteNullableString(w, m.InvocationId); WriteString(w, m.Target); WriteArguments(w, m.Arguments); WriteStringArray(w, m.StreamIds); WriteHeaders(w, m.Headers); } private void WriteStreamInvocation(BufferWriterBinaryOutput w, StreamInvocationMessage m) { w.WriteByte(MsgStreamInvocation); WriteString(w, m.InvocationId!); WriteString(w, m.Target); WriteArguments(w, m.Arguments); WriteStringArray(w, m.StreamIds); WriteHeaders(w, m.Headers); } private void WriteStreamItem(BufferWriterBinaryOutput w, StreamItemMessage m) { w.WriteByte(MsgStreamItem); WriteString(w, m.InvocationId!); WriteArgument(w, m.Item); WriteHeaders(w, m.Headers); } private void WriteCompletion(BufferWriterBinaryOutput w, CompletionMessage m) { w.WriteByte(MsgCompletion); WriteString(w, m.InvocationId!); WriteNullableString(w, m.Error); // Result presence flags: 0 = no result, 1 = has result var hasResult = m.HasResult; w.WriteByte(hasResult ? (byte)1 : (byte)0); if (hasResult) WriteArgument(w, m.Result); WriteHeaders(w, m.Headers); } private static void WriteCancelInvocation(BufferWriterBinaryOutput w, CancelInvocationMessage m) { w.WriteByte(MsgCancelInvocation); WriteString(w, m.InvocationId!); WriteHeaders(w, m.Headers); } private static void WriteClose(BufferWriterBinaryOutput w, CloseMessage m) { w.WriteByte(MsgClose); WriteNullableString(w, m.Error); w.WriteByte(m.AllowReconnect ? (byte)1 : (byte)0); } private static void WriteAck(BufferWriterBinaryOutput w, AckMessage m) { w.WriteByte(MsgAck); w.WriteRaw(m.SequenceId); } private static void WriteSequence(BufferWriterBinaryOutput w, SequenceMessage m) { w.WriteByte(MsgSequence); w.WriteRaw(m.SequenceId); } #endregion #region TryParseMessage public bool TryParseMessage(ref ReadOnlySequence input, IInvocationBinder binder, [NotNullWhen(true)] out HubMessage? message) { message = null; if (input.Length < LengthPrefixSize) return false; // Read length prefix int payloadLength; if (input.FirstSpan.Length >= LengthPrefixSize) { payloadLength = Unsafe.ReadUnaligned(ref Unsafe.AsRef(in input.FirstSpan[0])); } else { Span lenBuf = stackalloc byte[LengthPrefixSize]; input.Slice(0, LengthPrefixSize).CopyTo(lenBuf); payloadLength = Unsafe.ReadUnaligned(ref lenBuf[0]); } var totalLength = LengthPrefixSize + payloadLength; if (input.Length < totalLength) return false; var payload = input.Slice(LengthPrefixSize, payloadLength); // Linearize payload for span-based reading ReadOnlySpan span; byte[]? rentedBuffer = null; if (payload.IsSingleSegment) { span = payload.FirstSpan; } else { rentedBuffer = ArrayPool.Shared.Rent(payloadLength); payload.CopyTo(rentedBuffer); span = rentedBuffer.AsSpan(0, payloadLength); } try { message = ParseMessage(span, binder); } finally { if (rentedBuffer != null) ArrayPool.Shared.Return(rentedBuffer); } input = input.Slice(totalLength); return message != null; } private HubMessage? ParseMessage(ReadOnlySpan span, IInvocationBinder binder) { if (span.Length == 0) return null; var reader = new SpanReader(span); var msgType = reader.ReadByte(); return msgType switch { MsgInvocation => ParseInvocation(ref reader, binder), MsgStreamInvocation => ParseStreamInvocation(ref reader, binder), MsgStreamItem => ParseStreamItem(ref reader, binder), MsgCompletion => ParseCompletion(ref reader, binder), MsgCancelInvocation => ParseCancelInvocation(ref reader), MsgPing => PingMessage.Instance, MsgClose => ParseClose(ref reader), MsgAck => new AckMessage(reader.ReadInt64()), MsgSequence => new SequenceMessage(reader.ReadInt64()), _ => null }; } private HubMessage ParseInvocation(ref SpanReader r, IInvocationBinder binder) { var invocationId = r.ReadNullableString(); var target = r.ReadString(); var paramTypes = binder.GetParameterTypes(target); var args = ReadArguments(ref r, paramTypes); var streamIds = r.ReadStringArray(); var headers = ReadHeaders(ref r); var msg = streamIds is { Length: > 0 } ? new InvocationMessage(invocationId, target, args, streamIds) : ApplyInvocationId(new InvocationMessage(target, args), invocationId); if (headers != null) SetHeaders(msg, headers); return msg; } private HubMessage ParseStreamInvocation(ref SpanReader r, IInvocationBinder binder) { var invocationId = r.ReadString(); var target = r.ReadString(); var paramTypes = binder.GetParameterTypes(target); var args = ReadArguments(ref r, paramTypes); var streamIds = r.ReadStringArray(); var headers = ReadHeaders(ref r); var msg = new StreamInvocationMessage(invocationId, target, args, streamIds); if (headers != null) SetHeaders(msg, headers); return msg; } private HubMessage ParseStreamItem(ref SpanReader r, IInvocationBinder binder) { var invocationId = r.ReadString(); var itemType = binder.GetStreamItemType(invocationId); var item = ReadSingleArgument(ref r, itemType); var headers = ReadHeaders(ref r); var msg = new StreamItemMessage(invocationId, item); if (headers != null) SetHeaders(msg, headers); return msg; } private HubMessage ParseCompletion(ref SpanReader r, IInvocationBinder binder) { var invocationId = r.ReadString(); var error = r.ReadNullableString(); var hasResult = r.ReadByte() == 1; object? result = null; if (hasResult) { var resultType = binder.GetReturnType(invocationId); result = ReadSingleArgument(ref r, resultType); } var headers = ReadHeaders(ref r); CompletionMessage msg; if (error != null) msg = CompletionMessage.WithError(invocationId, error); else if (hasResult) msg = CompletionMessage.WithResult(invocationId, result); else msg = CompletionMessage.Empty(invocationId); if (headers != null) SetHeaders(msg, headers); return msg; } private static HubMessage ParseCancelInvocation(ref SpanReader r) { var invocationId = r.ReadString(); var headers = ReadHeaders(ref r); var msg = new CancelInvocationMessage(invocationId); if (headers != null) SetHeaders(msg, headers); return msg; } private static HubMessage ParseClose(ref SpanReader r) { var error = r.ReadNullableString(); var allowReconnect = r.Remaining > 0 && r.ReadByte() == 1; return new CloseMessage(error, allowReconnect); } #endregion #region Argument Serialization (AcBinary payload per argument) private void WriteArguments(BufferWriterBinaryOutput w, object?[] arguments) { w.WriteVarUInt((uint)arguments.Length); for (var i = 0; i < arguments.Length; i++) WriteArgument(w, arguments[i]); } private void WriteArgument(BufferWriterBinaryOutput w, object? value) { if (value == null) { w.WriteVarUInt(1); w.WriteByte(0); // BinaryTypeCode.Null return; } // AcBinarySerializer needs the full payload size upfront (2-pass), // so we serialize to a pooled byte[] first, then copy length-prefixed. var serialized = AcBinarySerializer.Serialize(value, _options); w.WriteVarUInt((uint)serialized.Length); w.WriteBytes(serialized); } private object?[] ReadArguments(ref SpanReader r, IReadOnlyList paramTypes) { var count = (int)r.ReadVarUInt(); var args = new object?[count]; for (var i = 0; i < count; i++) { var targetType = i < paramTypes.Count ? paramTypes[i] : typeof(object); args[i] = ReadSingleArgument(ref r, targetType); } return args; } private object? ReadSingleArgument(ref SpanReader r, Type targetType) { var argLength = (int)r.ReadVarUInt(); if (argLength == 0) return null; var argSpan = r.ReadSpan(argLength); if (argLength == 1 && argSpan[0] == 0) // BinaryTypeCode.Null return null; return AcBinaryDeserializer.Deserialize(argSpan, targetType, _options); } #endregion #region Framing Helpers (string, nullable string, string array, headers) [MethodImpl(MethodImplOptions.AggressiveInlining)] private static void WriteString(BufferWriterBinaryOutput w, string value) { w.WriteStringUtf8(value); } [MethodImpl(MethodImplOptions.AggressiveInlining)] private static void WriteNullableString(BufferWriterBinaryOutput w, string? value) { if (value == null) { w.WriteByte(0); // null marker return; } w.WriteByte(1); // present marker w.WriteStringUtf8(value); } private static void WriteStringArray(BufferWriterBinaryOutput w, string[]? array) { if (array == null || array.Length == 0) { w.WriteVarUInt(0); return; } w.WriteVarUInt((uint)array.Length); for (var i = 0; i < array.Length; i++) w.WriteStringUtf8(array[i]); } private static void WriteHeaders(BufferWriterBinaryOutput w, IDictionary? headers) { if (headers == null || headers.Count == 0) { w.WriteVarUInt(0); return; } w.WriteVarUInt((uint)headers.Count); foreach (var kv in headers) { w.WriteStringUtf8(kv.Key); w.WriteStringUtf8(kv.Value); } } #endregion #region Helpers private static InvocationMessage ApplyInvocationId(InvocationMessage msg, string? invocationId) { if (invocationId != null) return new InvocationMessage(invocationId, msg.Target, msg.Arguments); return msg; } private static void SetHeaders(HubMessage msg, Dictionary headers) { if (msg is HubInvocationMessage invMsg) invMsg.Headers = headers; } private static Dictionary? ReadHeaders(ref SpanReader r) { if (r.Remaining == 0) return null; var count = (int)r.ReadVarUInt(); if (count == 0) return null; var headers = new Dictionary(count, StringComparer.Ordinal); for (var i = 0; i < count; i++) { var key = r.ReadString(); var value = r.ReadString(); headers[key] = value; } return headers; } #endregion #region SpanReader /// /// Lightweight ref struct for sequential reading from a ReadOnlySpan. /// private ref struct SpanReader { private readonly ReadOnlySpan _span; private int _pos; [MethodImpl(MethodImplOptions.AggressiveInlining)] public SpanReader(ReadOnlySpan span) { _span = span; _pos = 0; } public int Remaining { [MethodImpl(MethodImplOptions.AggressiveInlining)] get => _span.Length - _pos; } [MethodImpl(MethodImplOptions.AggressiveInlining)] public byte ReadByte() => _span[_pos++]; [MethodImpl(MethodImplOptions.AggressiveInlining)] public long ReadInt64() { var value = Unsafe.ReadUnaligned(ref Unsafe.AsRef(in _span[_pos])); _pos += 8; return value; } [MethodImpl(MethodImplOptions.AggressiveInlining)] public uint ReadVarUInt() { uint value = 0; var shift = 0; while (true) { var b = _span[_pos++]; value |= (uint)(b & 0x7F) << shift; if ((b & 0x80) == 0) return value; shift += 7; } } [MethodImpl(MethodImplOptions.AggressiveInlining)] public ReadOnlySpan ReadSpan(int length) { var result = _span.Slice(_pos, length); _pos += length; return result; } public string ReadString() { var byteCount = (int)ReadVarUInt(); if (byteCount == 0) return string.Empty; var bytes = ReadSpan(byteCount); return System.Text.Encoding.UTF8.GetString(bytes); } public string? ReadNullableString() { var marker = ReadByte(); return marker == 0 ? null : ReadString(); } public string[]? ReadStringArray() { var count = (int)ReadVarUInt(); if (count == 0) return null; var array = new string[count]; for (var i = 0; i < count; i++) array[i] = ReadString(); return array; } } #endregion }