using System.Security.Claims; using AyCode.Core; using AyCode.Core.Extensions; using AyCode.Core.Helpers; using AyCode.Core.Loggers; using AyCode.Models.Server.DynamicMethods; using AyCode.Services.SignalRs; using MessagePack; using MessagePack.Resolvers; using Microsoft.AspNetCore.SignalR; using Microsoft.Extensions.Configuration; namespace AyCode.Services.Server.SignalRs; public abstract class AcWebSignalRHubBase(IConfiguration configuration, TLogger logger) : Hub, IAcSignalRHubServer where TSignalRTags : AcSignalRTags where TLogger : AcLoggerBase { protected readonly List> DynamicMethodCallModels = []; protected TLogger Logger = logger; protected IConfiguration Configuration = configuration; protected AcSerializerOptions SerializerOptions = new AcBinarySerializerOptions(); #region Connection Lifecycle public override async Task OnConnectedAsync() { Logger.Debug($"Server OnConnectedAsync; ConnectionId: {GetConnectionId()}; UserIdentifier: {GetUserIdentifier()}"); LogContextUserNameAndId(); await base.OnConnectedAsync(); } public override async Task OnDisconnectedAsync(Exception? exception) { var connectionId = GetConnectionId(); var userIdentifier = GetUserIdentifier(); if (exception == null) Logger.Debug($"Server OnDisconnectedAsync; ConnectionId: {connectionId}; UserIdentifier: {userIdentifier}"); else Logger.Error($"Server OnDisconnectedAsync; ConnectionId: {connectionId}; UserIdentifier: {userIdentifier}", exception); LogContextUserNameAndId(); await base.OnDisconnectedAsync(exception); } #endregion #region Message Processing public virtual Task OnReceiveMessage(int messageTag, byte[]? messageBytes, int? requestId) { return ProcessOnReceiveMessage(messageTag, messageBytes, requestId, null); } protected virtual async Task ProcessOnReceiveMessage(int messageTag, byte[]? message, int? requestId, Func? notFoundCallback) { var tagName = ConstHelper.NameByValue(messageTag); if (message is { Length: 0 }) { Logger.Warning($"message.Length == 0! Server OnReceiveMessage; requestId: {requestId}; ConnectionId: {GetConnectionId()}; {tagName}"); } else { Logger.Debug($"[{message?.Length:N0}b] Server OnReceiveMessage; requestId: {requestId}; ConnectionId: {GetConnectionId()}; {tagName}"); } try { if (AcDomain.IsDeveloperVersion) LogContextUserNameAndId(); if (TryFindAndInvokeMethod(messageTag, message, tagName, out var responseData)) { var responseDataJson = new SignalResponseJsonMessage(messageTag, SignalResponseStatus.Success, responseData); if (Logger.LogLevel <= LogLevel.Debug) { var responseDataJsonKiloBytes = System.Text.Encoding.Unicode.GetByteCount(responseDataJson.ResponseData ?? "") / 1024; Logger.Debug($"[{responseDataJsonKiloBytes}kb] responseData serialized to json"); } await ResponseToCaller(messageTag, responseDataJson, requestId); return; } Logger.Debug($"Not found dynamic method for the tag! {tagName}"); notFoundCallback?.Invoke(tagName); } catch (Exception ex) { Logger.Error($"Server OnReceiveMessage; {ex.Message}; {tagName}", ex); } await ResponseToCaller(messageTag, new SignalResponseJsonMessage(messageTag, SignalResponseStatus.Error), requestId); } /// /// Finds and invokes the method registered for the given message tag. /// private bool TryFindAndInvokeMethod(int messageTag, byte[]? message, string tagName, out object? responseData) { responseData = null; foreach (var methodsByDeclaringObject in DynamicMethodCallModels) { if (!methodsByDeclaringObject.MethodsByMessageTag.TryGetValue(messageTag, out var methodInfoModel)) continue; var methodName = $"{methodsByDeclaringObject.InstanceObject.GetType().Name}.{methodInfoModel.MethodInfo.Name}"; var paramValues = DeserializeParameters(message, methodInfoModel, tagName, methodName); if (paramValues == null) Logger.Debug($"Found dynamic method for the tag! method: {methodName}(); {tagName}"); else Logger.Debug($"Found dynamic method for the tag! method: {methodName}({string.Join(", ", methodInfoModel.ParamInfos.Select(x => x.Name))}); {tagName}"); responseData = methodInfoModel.MethodInfo.InvokeMethod(methodsByDeclaringObject.InstanceObject, paramValues); if (methodInfoModel.Attribute.SendToOtherClientType != SendToClientType.None) SendMessageToOthers(methodInfoModel.Attribute.SendToOtherClientTag, responseData).Forget(); return true; } return false; } /// /// Deserializes parameters from the message based on method signature. /// Returns null if no parameters needed, or throws if message is invalid. /// private static object[]? DeserializeParameters(byte[]? message, AcMethodInfoModel methodInfoModel, string tagName, string methodName) { if (methodInfoModel.ParamInfos is not { Length: > 0 }) return null; // Validate message - required when method has parameters if (message is null or { Length: 0 }) throw new ArgumentException($"Message is null or empty but method '{methodName}' requires {methodInfoModel.ParamInfos.Length} parameter(s); {tagName}"); var paramValues = new object[methodInfoModel.ParamInfos.Length]; var firstParamType = methodInfoModel.ParamInfos[0].ParameterType; // Use IdMessage format for: multiple params OR primitives/strings/enums/value types if (methodInfoModel.ParamInfos.Length > 1 || IsPrimitiveOrStringOrEnum(firstParamType)) { // Use ContractlessStandardResolver to match client serialization var msg = message.MessagePackTo>(ContractlessStandardResolver.Options); for (var i = 0; i < msg.PostData.Ids.Count; i++) { var paramType = methodInfoModel.ParamInfos[i].ParameterType; // Direct JSON deserialization using AcJsonDeserializer (supports primitives) paramValues[i] = AcJsonDeserializer.Deserialize(msg.PostData.Ids[i], paramType)!; } } else { // Single complex object - try to detect format by checking if it's an IdMessage var msgJson = message.MessagePackTo>(ContractlessStandardResolver.Options); var json = msgJson.PostDataJson; // Check if the JSON is an IdMessage format (has "Ids" property) if (json.Contains("\"Ids\"")) { // It's IdMessage format - deserialize as IdMessage and get first Id var idMsg = message.MessagePackTo>(ContractlessStandardResolver.Options); if (idMsg.PostData.Ids.Count > 0) { paramValues[0] = AcJsonDeserializer.Deserialize(idMsg.PostData.Ids[0], firstParamType)!; return paramValues; } } // Direct complex object format paramValues[0] = json.JsonTo(firstParamType)!; } return paramValues; } /// /// Determines if a type should use IdMessage format (primitives, strings, enums, value types). /// NOTE: Arrays and collections are NOT included - they use PostDataJson format when sent as single parameter. /// private static bool IsPrimitiveOrStringOrEnum(Type type) { return type == typeof(string) || type.IsEnum || type.IsValueType || type == typeof(DateTime); } #endregion #region Response Methods protected virtual Task ResponseToCallerWithContent(int messageTag, object? content) => ResponseToCaller(messageTag, new SignalResponseJsonMessage(messageTag, SignalResponseStatus.Success, content), null); protected virtual Task ResponseToCaller(int messageTag, ISignalRMessage message, int? requestId) => SendMessageToClient(Clients.Caller, messageTag, message, requestId); protected virtual Task SendMessageToUserIdWithContent(string userId, int messageTag, object? content) => SendMessageToUserIdInternal(userId, messageTag, new SignalResponseJsonMessage(messageTag, SignalResponseStatus.Success, content), null); protected virtual Task SendMessageToUserIdInternal(string userId, int messageTag, ISignalRMessage message, int? requestId) => SendMessageToClient(Clients.User(userId), messageTag, message, requestId); protected virtual Task SendMessageToConnectionIdWithContent(string connectionId, int messageTag, object? content) => SendMessageToConnectionIdInternal(connectionId, messageTag, new SignalResponseJsonMessage(messageTag, SignalResponseStatus.Success, content), null); protected virtual Task SendMessageToConnectionIdInternal(string connectionId, int messageTag, ISignalRMessage message, int? requestId) => SendMessageToClient(Clients.Client(connectionId), messageTag, message, requestId); protected virtual Task SendMessageToOthers(int messageTag, object? content) => SendMessageToClient(Clients.Others, messageTag, new SignalResponseJsonMessage(messageTag, SignalResponseStatus.Success, content), null); protected virtual Task SendMessageToAll(int messageTag, object? content) => SendMessageToClient(Clients.All, messageTag, new SignalResponseJsonMessage(messageTag, SignalResponseStatus.Success, content), null); protected virtual async Task SendMessageToClient(IAcSignalRHubItemServer sendTo, int messageTag, ISignalRMessage message, int? requestId = null) { var responseDataMessagePack = message.ToMessagePack(ContractlessStandardResolver.Options); var tagName = ConstHelper.NameByValue(messageTag); Logger.Debug($"[{responseDataMessagePack.Length / 1024}kb] Server sending message to client; requestId: {requestId}; Aborted: {IsConnectionAborted()}; ConnectionId: {GetConnectionId()}; {tagName}"); await sendTo.OnReceiveMessage(messageTag, responseDataMessagePack, requestId); Logger.Debug($"Server sent message to client; requestId: {requestId}; ConnectionId: {GetConnectionId()}; {tagName}"); } #endregion #region Context Accessor Methods (virtual for testing) /// /// Gets the connection ID. Override in tests to avoid Context dependency. /// protected virtual string GetConnectionId() => Context.ConnectionId; /// /// Gets whether the connection is aborted. Override in tests to avoid Context dependency. /// protected virtual bool IsConnectionAborted() => Context.ConnectionAborted.IsCancellationRequested; /// /// Gets the user identifier. Override in tests to avoid Context dependency. /// protected virtual string? GetUserIdentifier() => Context.UserIdentifier; /// /// Gets the ClaimsPrincipal user. Override in tests to avoid Context dependency. /// protected virtual ClaimsPrincipal? GetUser() => Context.User; #endregion #region Logging protected virtual void LogContextUserNameAndId() { var user = GetUser(); if (user == null) return; var userName = user.Identity?.Name; Guid.TryParse(user.FindFirstValue(ClaimTypes.NameIdentifier), out var userId); if (AcDomain.IsDeveloperVersion) Logger.WarningConditional($"SignalR.Context; userName: {userName}; userId: {userId}"); else Logger.Debug($"SignalR.Context; userName: {userName}; userId: {userId}"); } #endregion }