AyCode.Core/AyCode.Services/SignalRs/AcSignalRClientBase.cs

541 lines
26 KiB
C#

using System.Collections.Concurrent;
using AyCode.Core;
using AyCode.Core.Extensions;
using AyCode.Core.Helpers;
using AyCode.Core.Loggers;
using AyCode.Interfaces.Entities;
using MessagePack.Resolvers;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http.Connections;
using Microsoft.AspNetCore.SignalR.Client;
using Microsoft.Extensions.Logging;
using static AyCode.Core.Extensions.JsonUtilities;
namespace AyCode.Services.SignalRs
{
public abstract class AcSignalRClientBase : IAcSignalRHubClient
{
private readonly ConcurrentDictionary<int, SignalRRequestModel> _responseByRequestId = new();
protected readonly HubConnection? HubConnection;
protected readonly AcLoggerBase Logger;
//protected event Action<int, byte[], int?> OnMessageReceived = null!;
protected abstract Task MessageReceived(int messageTag, byte[] messageBytes);
public int MsDelay = 25;
public int MsFirstDelay = 50;
public int ConnectionTimeout = 10000;
public int TransportSendTimeout = 60000;
private const string TagsName = "SignalRTags";
/// <summary>
/// Production constructor - creates and starts HubConnection.
/// </summary>
protected AcSignalRClientBase(string fullHubName, AcLoggerBase logger)
{
Logger = logger;
Logger.Detail(fullHubName);
//TODO: HubConnectionBuilder constructor!!! - J.
HubConnection = new HubConnectionBuilder()
//.WithUrl(fullHubName)
.WithUrl(fullHubName, HttpTransportType.WebSockets,
options =>
{
options.TransportMaxBufferSize = 30_000_000; //Increasing this value allows the client to receive larger messages. default: 65KB; unlimited: 0;;
options.ApplicationMaxBufferSize = 30_000_000; //Increasing this value allows the client to send larger messages. default: 65KB; unlimited: 0;
options.CloseTimeout = TimeSpan.FromSeconds(10); //default: 5 sec.
options.SkipNegotiation = true; // Skip HTTP negotiation when using WebSockets only
//options.AccessTokenProvider = null;
//options.HttpMessageHandlerFactory = null;
//options.Headers["CustomData"] = "value";
//options.SkipNegotiation = true;
//options.ClientCertificates = new System.Security.Cryptography.X509Certificates.X509CertificateCollection();
//options.Cookies = new System.Net.CookieContainer();
//options.DefaultTransferFormat = TransferFormat.Text;
//options.Credentials = null;
//options.Proxy = null;
//options.UseDefaultCredentials = true;
//options.WebSocketConfiguration = null;
//options.WebSocketFactory = null;
})
//.ConfigureLogging(logging =>
//{
// logging.SetMinimumLevel(Microsoft.Extensions.Logging.LogLevel.Trace);
// logging.AddConsole();
//})
.WithAutomaticReconnect()
.WithStatefulReconnect()
.WithKeepAliveInterval(TimeSpan.FromSeconds(60))
.WithServerTimeout(TimeSpan.FromSeconds(180))
//.AddMessagePackProtocol(options => {
// options.SerializerOptions = MessagePackSerializerOptions.Standard
// .WithResolver(MessagePack.Resolvers.StandardResolver.Instance)
// .WithSecurity(MessagePackSecurity.UntrustedData)
// .WithCompression(MessagePackCompression.Lz4Block)
// .WithCompressionMinLength(256);})
.Build();
HubConnection.Closed += HubConnection_Closed;
_ = HubConnection.On<int, byte[], int?>(nameof(IAcSignalRHubClient.OnReceiveMessage), OnReceiveMessage);
//_ = HubConnection.On<int, int>(nameof(IAcSignalRHubClient.OnRequestMessage), OnRequestMessage);
//HubConnection.StartAsync().Forget();
}
/// <summary>
/// Test constructor - allows testing without real HubConnection.
/// Override virtual methods to control behavior in tests.
/// </summary>
protected AcSignalRClientBase(AcLoggerBase logger)
{
Logger = logger;
HubConnection = null;
}
private Task HubConnection_Closed(Exception? arg)
{
if (_responseByRequestId.IsEmpty) Logger.DebugConditional($"Client HubConnection_Closed");
else Logger.Warning($"Client HubConnection_Closed; {nameof(_responseByRequestId)} count: {_responseByRequestId.Count}");
ClearPendingRequests();
return Task.CompletedTask;
}
#region Connection State Methods (virtual for testing)
/// <summary>
/// Gets the current connection state. Override in tests.
/// </summary>
protected virtual HubConnectionState GetConnectionState()
=> HubConnection?.State ?? HubConnectionState.Disconnected;
/// <summary>
/// Checks if the connection is connected. Override in tests.
/// </summary>
protected virtual bool IsConnected()
=> GetConnectionState() == HubConnectionState.Connected;
/// <summary>
/// Starts the connection. Override in tests to avoid real connection.
/// </summary>
protected virtual Task StartConnectionInternal()
{
if (HubConnection == null) return Task.CompletedTask;
return HubConnection.StartAsync();
}
/// <summary>
/// Stops the connection. Override in tests.
/// </summary>
protected virtual Task StopConnectionInternal()
{
if (HubConnection == null) return Task.CompletedTask;
return HubConnection.StopAsync();
}
/// <summary>
/// Disposes the connection. Override in tests.
/// </summary>
protected virtual ValueTask DisposeConnectionInternal()
{
if (HubConnection == null) return ValueTask.CompletedTask;
return HubConnection.DisposeAsync();
}
/// <summary>
/// Sends a message to the server via HubConnection. Override in tests.
/// </summary>
protected virtual Task SendToHubAsync(int messageTag, byte[]? messageBytes, int? requestId)
{
if (HubConnection == null) return Task.CompletedTask;
return HubConnection.SendAsync(nameof(IAcSignalRHubClient.OnReceiveMessage), messageTag, messageBytes, requestId);
}
#endregion
#region Protected Test Helpers
/// <summary>
/// Gets the pending requests dictionary for testing.
/// </summary>
protected ConcurrentDictionary<int, SignalRRequestModel> GetPendingRequests()
=> _responseByRequestId;
/// <summary>
/// Clears all pending requests.
/// </summary>
protected void ClearPendingRequests()
=> _responseByRequestId.Clear();
/// <summary>
/// Registers a pending request for testing.
/// </summary>
protected void RegisterPendingRequest(int requestId, SignalRRequestModel model)
=> _responseByRequestId[requestId] = model;
/// <summary>
/// Simulates receiving a response for testing.
/// </summary>
protected void SimulateResponse(int requestId, ISignalResponseMessage<string> response)
{
if (_responseByRequestId.TryGetValue(requestId, out var model))
{
model.ResponseByRequestId = response;
model.ResponseDateTime = DateTime.UtcNow;
}
}
#endregion
public async Task StartConnection()
{
if (GetConnectionState() == HubConnectionState.Disconnected)
await StartConnectionInternal();
if (!IsConnected())
await TaskHelper.WaitToAsync(IsConnected, ConnectionTimeout, 10, 25);
}
public async Task StopConnection()
{
await StopConnectionInternal();
await DisposeConnectionInternal();
}
public virtual Task SendMessageToServerAsync(int messageTag)
=> SendMessageToServerAsync(messageTag, null, GetNextRequestId());
public virtual async Task SendMessageToServerAsync(int messageTag, ISignalRMessage? message, int? requestId)
{
Logger.DebugConditional($"Client SendMessageToServerAsync sending; {nameof(requestId)}: {requestId}; ConnectionState: {GetConnectionState()}; {ConstHelper.NameByValue(TagsName, messageTag)}");
await StartConnection();
var msgp = message?.ToMessagePack(ContractlessStandardResolver.Options);
if (!IsConnected())
{
Logger.Error($"Client SendMessageToServerAsync error! ConnectionState: {GetConnectionState()};");
return;
}
await SendToHubAsync(messageTag, msgp, requestId);
}
#region CRUD
public virtual Task<TResponseData?> PostAsync<TResponseData>(int messageTag, object parameter) //where TResponseData : class
=> SendMessageToServerAsync<TResponseData>(messageTag, new SignalPostJsonDataMessage<IdMessage>(new IdMessage(parameter)), GetNextRequestId());
public virtual Task<TResponseData?> PostAsync<TResponseData>(int messageTag, object[] parameters) //where TResponseData : class
=> SendMessageToServerAsync<TResponseData>(messageTag, new SignalPostJsonDataMessage<IdMessage>(new IdMessage(parameters)), GetNextRequestId());
public virtual Task<TResponseData?> GetByIdAsync<TResponseData>(int messageTag, object id) //where TResponseData : class
=> PostAsync<TResponseData?>(messageTag, id);
public virtual Task GetByIdAsync<TResponseData>(int messageTag, Func<ISignalResponseMessage<TResponseData?>, Task> responseCallback, object id)
=> SendMessageToServerAsync(messageTag, new SignalPostJsonDataMessage<IdMessage>(new IdMessage(id)), responseCallback);
public virtual Task<TResponseData?> GetByIdAsync<TResponseData>(int messageTag, object[] ids) //where TResponseData : class
=> PostAsync<TResponseData?>(messageTag, ids);
public virtual Task GetByIdAsync<TResponseData>(int messageTag, Func<ISignalResponseMessage<TResponseData?>, Task> responseCallback, object[] ids)
=> SendMessageToServerAsync(messageTag, new SignalPostJsonDataMessage<IdMessage>(new IdMessage(ids)), responseCallback);
public virtual Task<TResponseData?> GetAllAsync<TResponseData>(int messageTag) //where TResponseData : class
=> SendMessageToServerAsync<TResponseData>(messageTag);
public virtual Task GetAllAsync<TResponseData>(int messageTag, Func<ISignalResponseMessage<TResponseData?>, Task> responseCallback)
=> SendMessageToServerAsync(messageTag, null, responseCallback);
public virtual Task GetAllAsync<TResponseData>(int messageTag, Func<ISignalResponseMessage<TResponseData?>, Task> responseCallback, object[]? contextParams)
=> SendMessageToServerAsync(messageTag, (contextParams == null || contextParams.Length == 0 ? null : new SignalPostJsonDataMessage<IdMessage>(new IdMessage(contextParams))), responseCallback);
public virtual Task<TResponseData?> GetAllAsync<TResponseData>(int messageTag, object[]? contextParams) //where TResponseData : class
=> SendMessageToServerAsync<TResponseData>(messageTag, contextParams == null || contextParams.Length == 0 ? null : new SignalPostJsonDataMessage<IdMessage>(new IdMessage(contextParams)), GetNextRequestId());
public virtual Task<TPostData?> PostDataAsync<TPostData>(int messageTag, TPostData postData) where TPostData : class
=> SendMessageToServerAsync<TPostData>(messageTag, CreatePostMessage(postData), GetNextRequestId());
public virtual Task<TResponseData?> PostDataAsync<TPostData, TResponseData>(int messageTag, TPostData postData) //where TPostData : class where TResponseData : class
=> SendMessageToServerAsync<TResponseData>(messageTag, CreatePostMessage(postData), GetNextRequestId());
public virtual Task PostDataAsync<TPostData>(int messageTag, TPostData postData, Func<ISignalResponseMessage<TPostData?>, Task> responseCallback) //where TPostData : class
=> SendMessageToServerAsync(messageTag, CreatePostMessage(postData), responseCallback);
public virtual Task PostDataAsync<TPostData, TResponseData>(int messageTag, TPostData postData, Func<ISignalResponseMessage<TResponseData?>, Task> responseCallback) //where TPostData : class where TResponseData : class
=> SendMessageToServerAsync(messageTag, CreatePostMessage(postData), responseCallback);
/// <summary>
/// Creates the appropriate message wrapper for the post data.
/// Primitives, strings, enums, and value types are wrapped in IdMessage.
/// Complex objects are sent directly in SignalPostJsonDataMessage.
/// </summary>
private static ISignalRMessage CreatePostMessage<TPostData>(TPostData postData)
{
var type = typeof(TPostData);
// Primitives, strings, enums, and value types should use IdMessage format
if (IsPrimitiveOrStringOrEnum(type))
{
return new SignalPostJsonDataMessage<IdMessage>(new IdMessage(postData!));
}
// Complex objects use direct serialization
return new SignalPostJsonDataMessage<TPostData>(postData);
}
/// <summary>
/// Determines if a type should use IdMessage format (primitives, strings, enums, value types).
/// Must match the logic in AcWebSignalRHubBase.IsPrimitiveOrStringOrEnum.
/// NOTE: Arrays and collections are NOT included here - they are complex objects for PostDataAsync.
/// </summary>
private static bool IsPrimitiveOrStringOrEnum(Type type)
{
return type == typeof(string) ||
type.IsEnum ||
type.IsValueType ||
type == typeof(DateTime);
}
public Task GetAllIntoAsync<TResponseItem>(List<TResponseItem> intoList, int messageTag, object[]? contextParams = null, Action? callback = null) where TResponseItem : IEntityGuid
{
return GetAllAsync<List<TResponseItem>>(messageTag, response =>
{
var logText = $"GetAllIntoAsync<{typeof(TResponseItem).Name}>(); status: {response.Status}; dataCount: {response.ResponseData?.Count}; {ConstHelper.NameByValue(TagsName, messageTag)};";
intoList.Clear();
if (response.Status == SignalResponseStatus.Success && response.ResponseData != null)
{
Logger.Debug(logText);
intoList.AddRange(response.ResponseData);
}
else Logger.Error(logText);
callback?.Invoke();
return Task.CompletedTask;
}, contextParams);
}
#endregion CRUD
public virtual Task<TResponse?> SendMessageToServerAsync<TResponse>(int messageTag) //where TResponse : class
=> SendMessageToServerAsync<TResponse>(messageTag, null, GetNextRequestId());
public virtual Task<TResponse?> SendMessageToServerAsync<TResponse>(int messageTag, ISignalRMessage? message) //where TResponse : class
=> SendMessageToServerAsync<TResponse>(messageTag, message, GetNextRequestId());
protected virtual async Task<TResponse?> SendMessageToServerAsync<TResponse>(int messageTag, ISignalRMessage? message, int requestId) //where TResponse : class
{
Logger.DebugConditional($"Client SendMessageToServerAsync<TResult>; {nameof(requestId)}: {requestId}; {ConstHelper.NameByValue(TagsName, messageTag)}");
var startTime = DateTime.Now;
var requestModel = SignalRRequestModelPool.Get();
_responseByRequestId[requestId] = requestModel;
await SendMessageToServerAsync(messageTag, message, requestId);
try
{
if (await TaskHelper.WaitToAsync(() => _responseByRequestId[requestId].ResponseByRequestId != null, TransportSendTimeout, MsDelay, MsFirstDelay) &&
_responseByRequestId.TryRemove(requestId, out var obj) && obj.ResponseByRequestId is ISignalResponseMessage responseMessage)
{
startTime = obj.RequestDateTime;
SignalRRequestModelPool.Return(obj);
if (responseMessage.Status == SignalResponseStatus.Error)
{
var errorText = $"Client SendMessageToServerAsync<TResponseData> response error; await; tag: {messageTag}; Status: {responseMessage.Status}; ConnectionState: {GetConnectionState()}; requestId: {requestId}";
Logger.Error(errorText);
return await Task.FromException<TResponse>(new Exception(errorText));
}
var responseData = DeserializeResponseData<TResponse>(responseMessage);
if (responseData == null && responseMessage.Status == SignalResponseStatus.Success)
{
// Null response is valid for Success status
Logger.Info($"Client received null response. Total: {(DateTime.UtcNow.Subtract(startTime)).TotalMilliseconds} ms! requestId: {requestId}; tag: {messageTag} [{ConstHelper.NameByValue(TagsName, messageTag)}]");
return default;
}
var serializerType = responseMessage switch
{
SignalResponseBinaryMessage => "Binary",
_ => "JSON"
};
Logger.Info($"Client deserialized response ({serializerType}). Total: {(DateTime.UtcNow.Subtract(startTime)).TotalMilliseconds} ms! requestId: {requestId}; tag: {messageTag} [{ConstHelper.NameByValue(TagsName, messageTag)}]");
return responseData;
}
Logger.Error($"Client timeout after: {(DateTime.Now - startTime).TotalSeconds} sec! ConnectionState: {GetConnectionState()}; requestId: {requestId}; tag: {messageTag} [{ConstHelper.NameByValue(TagsName, messageTag)}]");
}
catch (Exception ex)
{
Logger.Error($"Client SendMessageToServerAsync; requestId: {requestId}; ConnectionState: {GetConnectionState()}; {ex.Message}; {ConstHelper.NameByValue(TagsName, messageTag)}", ex);
}
if (_responseByRequestId.TryRemove(requestId, out var removedModel))
{
SignalRRequestModelPool.Return(removedModel);
}
return default;
}
/// <summary>
/// Deserializes response data from either JSON or Binary format.
/// Automatically detects the format based on the response message type.
/// </summary>
private static TResponse? DeserializeResponseData<TResponse>(ISignalResponseMessage responseMessage)
{
return responseMessage switch
{
SignalResponseBinaryMessage binaryMsg when binaryMsg.ResponseData != null
=> binaryMsg.ResponseData.BinaryTo<TResponse>(),
SignalResponseJsonMessage jsonMsg when !string.IsNullOrEmpty(jsonMsg.ResponseData)
=> jsonMsg.ResponseData.JsonTo<TResponse>(),
ISignalResponseMessage<string> stringMsg when !string.IsNullOrEmpty(stringMsg.ResponseData)
=> stringMsg.ResponseData.JsonTo<TResponse>(),
_ => default
};
}
public virtual Task SendMessageToServerAsync<TResponseData>(int messageTag, Func<ISignalResponseMessage<TResponseData?>, Task> responseCallback)
=> SendMessageToServerAsync(messageTag, null, responseCallback);
public virtual Task SendMessageToServerAsync<TResponseData>(int messageTag, ISignalRMessage? message, Func<ISignalResponseMessage<TResponseData?>, Task> responseCallback)
{
if (messageTag == 0) Logger.Error($"SendMessageToServerAsync; messageTag == 0");
var requestId = GetNextRequestId();
var requestModel = SignalRRequestModelPool.Get(new Action<ISignalResponseMessage>(responseMessage =>
{
TResponseData? responseData = default;
if (responseMessage.Status == SignalResponseStatus.Success)
{
responseData = DeserializeResponseData<TResponseData>(responseMessage);
}
else Logger.Error($"Client SendMessageToServerAsync<TResponseData> response error; callback; Status: {responseMessage.Status}; ConnectionState: {GetConnectionState()}; requestId: {requestId}; {ConstHelper.NameByValue(TagsName, messageTag)}");
responseCallback(new SignalResponseMessage<TResponseData?>(messageTag, responseMessage.Status, responseData));
}));
_responseByRequestId[requestId] = requestModel;
return SendMessageToServerAsync(messageTag, message, requestId);
}
/// <summary>
/// Gets the next unique request ID.
/// </summary>
protected virtual int GetNextRequestId() => AcDomain.NextUniqueInt32;
public virtual Task OnReceiveMessage(int messageTag, byte[] messageBytes, int? requestId)
{
var logText = $"Client OnReceiveMessage; {nameof(requestId)}: {requestId}; {ConstHelper.NameByValue(TagsName, messageTag)}";
if (messageBytes.Length == 0) Logger.Warning($"message.Length == 0! {logText}");
try
{
if (requestId.HasValue && _responseByRequestId.ContainsKey(requestId.Value))
{
var reqId = requestId.Value;
_responseByRequestId[reqId].ResponseDateTime = DateTime.UtcNow;
Logger.Debug($"[{_responseByRequestId[reqId].ResponseDateTime.Subtract(_responseByRequestId[reqId].RequestDateTime).TotalMilliseconds:N0}ms][{(messageBytes.Length / 1024)}kb]{logText}");
var responseMessage = DeserializeResponseMessage(messageBytes);
switch (_responseByRequestId[reqId].ResponseByRequestId)
{
case null:
_responseByRequestId[reqId].ResponseByRequestId = responseMessage;
return Task.CompletedTask;
case Action<ISignalResponseMessage> messageCallback:
if (_responseByRequestId.TryRemove(reqId, out var callbackModel))
{
SignalRRequestModelPool.Return(callbackModel);
}
messageCallback.Invoke(responseMessage);
return Task.CompletedTask;
// Legacy support for string-based callbacks
case Action<ISignalResponseMessage<string>> stringCallback when responseMessage is SignalResponseJsonMessage jsonMsg:
if (_responseByRequestId.TryRemove(reqId, out var legacyModel))
{
SignalRRequestModelPool.Return(legacyModel);
}
stringCallback.Invoke(jsonMsg);
return Task.CompletedTask;
default:
Logger.Error($"Client OnReceiveMessage switch; unknown message type: {_responseByRequestId[reqId].ResponseByRequestId?.GetType().Name}; {ConstHelper.NameByValue(TagsName, messageTag)}");
break;
}
if (_responseByRequestId.TryRemove(reqId, out var removedModel))
{
SignalRRequestModelPool.Return(removedModel);
}
// Request-response hibás eset - ne hívjuk meg a MessageReceived-et
return Task.CompletedTask;
}
// Csak broadcast/notification üzeneteknél hívjuk meg a MessageReceived-et
Logger.Info(logText);
MessageReceived(messageTag, messageBytes).Forget();
}
catch (Exception ex)
{
if (requestId.HasValue && _responseByRequestId.TryRemove(requestId.Value, out var exModel))
{
SignalRRequestModelPool.Return(exModel);
}
Logger.Error($"Client OnReceiveMessage; ConnectionState: {GetConnectionState()}; requestId: {requestId}; {ex.Message}; {ConstHelper.NameByValue(TagsName, messageTag)}", ex);
throw;
}
return Task.CompletedTask;
}
/// <summary>
/// Deserializes a MessagePack response to the appropriate message type (JSON or Binary).
/// First tries to deserialize as Binary, then falls back to JSON if that fails.
/// </summary>
protected virtual ISignalResponseMessage DeserializeResponseMessage(byte[] messageBytes)
{
// Try Binary format first (SignalResponseBinaryMessage)
try
{
var binaryMsg = messageBytes.MessagePackTo<SignalResponseBinaryMessage>(ContractlessStandardResolver.Options);
if (binaryMsg.ResponseData != null && binaryMsg.ResponseData.Length > 0)
{
// Verify it's actually binary data by checking the format
if (DetectSerializerTypeFromBytes(binaryMsg.ResponseData) == AcSerializerType.Binary)
{
return binaryMsg;
}
}
}
catch
{
// Not a binary message, try JSON
}
// Fall back to JSON format
return messageBytes.MessagePackTo<SignalResponseJsonMessage>(ContractlessStandardResolver.Options);
}
}
}