AyCode.Core/AyCode.Services/SignalRs/AcSignalRClientBase.cs

472 lines
23 KiB
C#

using System.Collections.Concurrent;
using AyCode.Core;
using AyCode.Core.Extensions;
using AyCode.Core.Helpers;
using AyCode.Core.Loggers;
using AyCode.Core.Serializers;
using AyCode.Core.Serializers.Binaries;
using AyCode.Interfaces.Entities;
using Microsoft.AspNetCore.Http.Connections;
using Microsoft.AspNetCore.SignalR.Client;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
namespace AyCode.Services.SignalRs
{
public abstract class AcSignalRClientBase : IAcSignalRHubClient
{
private readonly ConcurrentDictionary<int, SignalRRequestModel> _responseByRequestId = new();
protected readonly HubConnection? HubConnection;
protected readonly AcLoggerBase Logger;
/// <summary>
/// Enable diagnostic logging for binary serialization debugging.
/// </summary>
public static bool EnableBinaryDiagnostics { get; set; }
protected abstract Task MessageReceived(int messageTag, SignalParams signalParams, object data);
public int MsDelay = 25;
public int MsFirstDelay = 50;
public int ConnectionTimeout = 10000;
public int TransportSendTimeout = 60000;
private const string TagsName = "SignalRTags";
/// <summary>
/// Primary constructor. The <paramref name="hubBuilder"/> is expected to be fully configured
/// (URL, transport, reconnect, keep-alive, protocol) — typically via a transient DI registration
/// in the consuming project's <c>Program.cs</c>. This class only calls <c>Build()</c> and wires
/// the dispatch callback; no connection parameters are hard-coded here.
/// </summary>
protected AcSignalRClientBase(IHubConnectionBuilder hubBuilder, AcLoggerBase logger)
{
Logger = logger;
HubConnection = hubBuilder.Build();
HubConnection.Closed += HubConnection_Closed;
_ = HubConnection.On<int, int?, SignalParams, object>(nameof(IAcSignalRHubClient.OnReceiveMessage), OnReceiveMessage);
}
/// <summary>
/// Connection-less constructor — used by derived classes that manage their own connection lifecycle
/// or run in test / offline scenarios where <see cref="HubConnection"/> stays <c>null</c>.
/// </summary>
protected AcSignalRClientBase(AcLoggerBase logger)
{
Logger = logger;
HubConnection = null;
}
private Task HubConnection_Closed(Exception? arg)
{
if (_responseByRequestId.IsEmpty) Logger.DebugConditional("Client HubConnection_Closed");
else Logger.Warning($"Client HubConnection_Closed; {nameof(_responseByRequestId)} count: {_responseByRequestId.Count}");
ClearPendingRequests();
return Task.CompletedTask;
}
#region Connection State Methods
protected virtual HubConnectionState GetConnectionState()
=> HubConnection?.State ?? HubConnectionState.Disconnected;
protected virtual bool IsConnected()
=> GetConnectionState() == HubConnectionState.Connected;
protected virtual Task StartConnectionInternal()
=> HubConnection?.StartAsync() ?? Task.CompletedTask;
protected virtual Task StopConnectionInternal()
=> HubConnection?.StopAsync() ?? Task.CompletedTask;
protected virtual ValueTask DisposeConnectionInternal()
=> HubConnection?.DisposeAsync() ?? ValueTask.CompletedTask;
protected virtual Task SendToHubAsync(int messageTag, int? requestId, SignalParams signalParams, object? data)
=> HubConnection?.SendAsync(nameof(IAcSignalRHubClient.OnReceiveMessage), messageTag, requestId, signalParams, data) ?? Task.CompletedTask;
#endregion
#region Protected Test Helpers
protected ConcurrentDictionary<int, SignalRRequestModel> GetPendingRequests() => _responseByRequestId;
protected void ClearPendingRequests() => _responseByRequestId.Clear();
protected void RegisterPendingRequest(int requestId, SignalRRequestModel model) => _responseByRequestId[requestId] = model;
#endregion
public async Task StartConnection()
{
if (GetConnectionState() == HubConnectionState.Disconnected)
await StartConnectionInternal();
if (!IsConnected())
await TaskHelper.WaitToAsync(IsConnected, ConnectionTimeout, 10, 25);
}
public async Task StopConnection()
{
await StopConnectionInternal();
await DisposeConnectionInternal();
}
public virtual Task SendMessageToServerAsync(int messageTag)
=> SendMessageToServerAsync(messageTag, (object[]?)null, GetNextRequestId());
public virtual Task SendMessageToServerAsync(int messageTag, object[]? parameters, int? requestId)
=> SendCoreAsync(messageTag, parameters, requestId, new SignalParams { Status = SignalResponseStatus.Success });
/// <summary>
/// Core send: takes a pre-built SignalParams (caller controls IsRawBytesData etc.)
/// </summary>
protected async Task SendCoreAsync(int messageTag, object[]? parameters, int? requestId, SignalParams signalParams)
{
Logger.DebugConditional($"Client SendMessageToServerAsync sending; {nameof(requestId)}: {requestId}; ConnectionState: {GetConnectionState()}; {ConstHelper.NameByValue(TagsName, messageTag)}");
await StartConnection();
if (!IsConnected())
{
Logger.Error($"Client SendMessageToServerAsync error! ConnectionState: {GetConnectionState()};");
return;
}
if (parameters is { Length: > 0 })
signalParams.SetParameterValues(parameters);
await SendToHubAsync(messageTag, requestId, signalParams, null);
}
#region CRUD
public virtual Task<TResponseData?> PostAsync<TResponseData>(int messageTag, object parameter)
=> SendMessageToServerAsync<TResponseData>(messageTag, [parameter], GetNextRequestId());
public virtual Task<TResponseData?> PostAsync<TResponseData>(int messageTag, object[] parameters)
=> SendMessageToServerAsync<TResponseData>(messageTag, parameters, GetNextRequestId());
public virtual Task<TResponseData?> GetByIdAsync<TResponseData>(int messageTag, object id)
=> PostAsync<TResponseData?>(messageTag, id);
public virtual Task<TResponseData?> GetByIdAsync<TResponseData>(int messageTag, object[] ids)
=> PostAsync<TResponseData?>(messageTag, ids);
/// <summary>
/// Gets data by ID with async callback response. Callback is second parameter.
/// </summary>
public virtual Task GetByIdAsync<TResponseData>(int messageTag, Func<SignalResponseDataMessage, Task> responseCallback, object id)
=> SendMessageToServerAsync(messageTag, [id], responseCallback);
/// <summary>
/// Gets data by IDs with async callback response. Callback is second parameter.
/// </summary>
public virtual Task GetByIdAsync<TResponseData>(int messageTag, Func<SignalResponseDataMessage, Task> responseCallback, object[] ids)
=> SendMessageToServerAsync(messageTag, ids, responseCallback);
public virtual Task<TResponseData?> GetAllAsync<TResponseData>(int messageTag)
=> SendMessageToServerAsync<TResponseData>(messageTag);
public virtual Task<TResponseData?> GetAllAsync<TResponseData>(int messageTag, object[]? contextParams)
=> SendMessageToServerAsync<TResponseData>(messageTag, contextParams is { Length: > 0 } ? contextParams : null, GetNextRequestId());
/// <summary>
/// Gets all data with async callback response. Callback is second parameter.
/// </summary>
public virtual Task GetAllAsync<TResponseData>(int messageTag, Func<SignalResponseDataMessage, Task> responseCallback)
=> SendMessageToServerAsync(messageTag, null, responseCallback);
/// <summary>
/// Gets all data with context params and async callback response.
/// </summary>
public virtual Task GetAllAsync<TResponseData>(int messageTag, Func<SignalResponseDataMessage, Task> responseCallback, object[]? contextParams)
=> SendMessageToServerAsync(messageTag, contextParams is { Length: > 0 } ? contextParams : null, responseCallback);
public virtual async IAsyncEnumerable<TResponseData?> StreamAllAsync<TResponseData>(int messageTag, object[]? contextParams = null, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
{
await StartConnection();
if (HubConnection == null || !IsConnected())
{
Logger.Error($"Client StreamAllAsync error! ConnectionState: {GetConnectionState()};");
yield break;
}
var msgBytes = contextParams is { Length: > 0 }
? SignalRSerializationHelper.SerializeToBinary(
contextParams.Select(p => SignalRSerializationHelper.SerializeToBinary(p)).ToArray())
: null;
var stream = HubConnection.StreamAsync<byte[]>(
"OnReceiveStreamMessage",
messageTag,
msgBytes,
cancellationToken);
await foreach (var bytes in stream.WithCancellation(cancellationToken))
{
if (bytes == null) continue;
if (typeof(TResponseData) == typeof(byte[]))
{
yield return (TResponseData)(object)bytes;
continue;
}
var responseMessage = SignalRSerializationHelper.DeserializeFromBinary<SignalResponseDataMessage>(bytes);
if (responseMessage != null)
{
if (responseMessage.Status == SignalResponseStatus.Error)
{
var errorText = $"Client StreamAllAsync error; tag: {messageTag}; Status: {responseMessage.Status}";
Logger.Error(errorText);
throw new Exception(errorText);
}
yield return responseMessage.GetResponseData<TResponseData>();
}
}
}
public virtual Task<TPostData?> PostDataAsync<TPostData>(int messageTag, TPostData postData) where TPostData : class
=> SendMessageToServerAsync<TPostData>(messageTag, [postData!], GetNextRequestId());
public virtual Task<TResponseData?> PostDataAsync<TPostData, TResponseData>(int messageTag, TPostData postData)
=> SendMessageToServerAsync<TResponseData>(messageTag, [postData!], GetNextRequestId());
/// <summary>
/// Posts data with async callback response.
/// </summary>
public virtual Task PostDataAsync<TPostData>(int messageTag, TPostData postData, Func<SignalResponseDataMessage, Task> responseCallback)
=> SendMessageToServerAsync(messageTag, [postData!], responseCallback);
/// <summary>
/// Posts data with typed async callback response.
/// </summary>
public virtual Task PostDataAsync<TPostData, TResponseData>(int messageTag, TPostData postData, Func<SignalResponseDataMessage, Task> responseCallback)
=> SendMessageToServerAsync(messageTag, [postData!], responseCallback);
/// <summary>
/// Posts data and invokes callback with response. Fire-and-forget friendly for background saves.
/// </summary>
public virtual Task PostDataAsync<TPostData>(int messageTag, TPostData postData, Action<SignalResponseDataMessage> responseCallback)
{
var requestId = GetNextRequestId();
var requestModel = SignalRRequestModelPool.Get(responseCallback);
_responseByRequestId[requestId] = requestModel;
return SendMessageToServerAsync(messageTag, [postData!], requestId);
}
public virtual async IAsyncEnumerable<TResponseData?> StreamPostDataAsync<TPostData, TResponseData>(int messageTag, TPostData postData, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
{
await StartConnection();
if (HubConnection == null || !IsConnected())
{
Logger.Error($"Client StreamPostDataAsync error! ConnectionState: {GetConnectionState()};");
yield break;
}
var msgBytes = SignalRSerializationHelper.SerializeToBinary(
new[] { SignalRSerializationHelper.SerializeToBinary(postData!) });
var stream = HubConnection.StreamAsync<byte[]>(
"OnReceiveStreamMessage",
messageTag,
msgBytes,
cancellationToken);
await foreach (var bytes in stream.WithCancellation(cancellationToken))
{
if (bytes == null) continue;
if (typeof(TResponseData) == typeof(byte[]))
{
yield return (TResponseData)(object)bytes;
continue;
}
var responseMessage = SignalRSerializationHelper.DeserializeFromBinary<SignalResponseDataMessage>(bytes);
if (responseMessage != null)
{
if (responseMessage.Status == SignalResponseStatus.Error)
{
var errorText = $"Client StreamPostDataAsync error; tag: {messageTag}; Status: {responseMessage.Status}";
Logger.Error(errorText);
throw new Exception(errorText);
}
yield return responseMessage.GetResponseData<TResponseData>();
}
}
}
public Task GetAllIntoAsync<TResponseItem>(List<TResponseItem> intoList, int messageTag, object[]? contextParams = null, Action? callback = null) where TResponseItem : IEntityGuid
{
return GetAllAsync<List<TResponseItem>>(messageTag, contextParams).ContinueWith(task =>
{
var logText = $"GetAllIntoAsync<{typeof(TResponseItem).Name}>(); dataCount: {task.Result?.Count}; {ConstHelper.NameByValue(TagsName, messageTag)};";
intoList.Clear();
if (task.Result != null)
{
Logger.Debug(logText);
intoList.AddRange(task.Result);
}
else Logger.Error(logText);
callback?.Invoke();
}, TaskScheduler.Default);
}
#endregion
public virtual Task<TResponse?> SendMessageToServerAsync<TResponse>(int messageTag)
=> SendMessageToServerAsync<TResponse>(messageTag, (object[]?)null, GetNextRequestId());
public virtual Task<TResponse?> SendMessageToServerAsync<TResponse>(int messageTag, object[]? parameters)
=> SendMessageToServerAsync<TResponse>(messageTag, parameters, GetNextRequestId());
/// <summary>
/// Sends message to server with async callback response.
/// </summary>
public virtual async Task SendMessageToServerAsync(int messageTag, object[]? parameters, Func<SignalResponseDataMessage, Task> responseCallback)
{
var requestId = GetNextRequestId();
var requestModel = SignalRRequestModelPool.Get(responseCallback);
_responseByRequestId[requestId] = requestModel;
await SendMessageToServerAsync(messageTag, parameters, requestId);
}
protected virtual async Task<TResponse?> SendMessageToServerAsync<TResponse>(int messageTag, object[]? parameters, int requestId)
{
Logger.DebugConditional($"Client SendMessageToServerAsync<TResult>; {nameof(requestId)}: {requestId}; {ConstHelper.NameByValue(TagsName, messageTag)}");
var startTime = DateTime.Now;
var requestModel = SignalRRequestModelPool.Get();
_responseByRequestId[requestId] = requestModel;
await SendCoreAsync(messageTag, parameters, requestId, new SignalParams
{
Status = SignalResponseStatus.Success,
IsRawBytesData = typeof(TResponse) == typeof(byte[])
});
try
{
if (await TaskHelper.WaitToAsync(() => _responseByRequestId[requestId].ResponseByRequestId != null, TransportSendTimeout, MsDelay, MsFirstDelay) &&
_responseByRequestId.TryRemove(requestId, out var obj) && obj.ResponseByRequestId is SignalResponseDataMessage responseMessage)
{
startTime = obj.RequestDateTime;
SignalRRequestModelPool.Return(obj);
if (responseMessage.Status == SignalResponseStatus.Error)
{
var errorText = $"Client SendMessageToServerAsync<TResponseData> response error; await; tag: {messageTag}; Status: {responseMessage.Status}; ConnectionState: {GetConnectionState()}; requestId: {requestId}";
Logger.Error(errorText);
return await Task.FromException<TResponse>(new Exception(errorText));
}
// Special case: when TResponse is SignalResponseDataMessage, return the message itself
// instead of trying to deserialize ResponseData (which would cause InvalidCastException)
if (typeof(TResponse) == typeof(SignalResponseDataMessage))
{
var serializerType = responseMessage.DataSerializerType == AcSerializerType.Binary ? "Binary" : "JSON";
Logger.Info($"Client returning raw SignalResponseDataMessage ({serializerType}). Total: {(DateTime.UtcNow.Subtract(startTime)).TotalMilliseconds} ms! requestId: {requestId}; tag: {messageTag} [{ConstHelper.NameByValue(TagsName, messageTag)}]");
return (TResponse)(object)responseMessage;
}
var responseData = responseMessage.GetResponseData<TResponse>();
if (responseData == null && responseMessage.Status == SignalResponseStatus.Success)
{
Logger.Info($"Client received null response. Total: {(DateTime.UtcNow.Subtract(startTime)).TotalMilliseconds} ms! requestId: {requestId}; tag: {messageTag} [{ConstHelper.NameByValue(TagsName, messageTag)}]");
return default;
}
var serializerType2 = responseMessage.DataSerializerType == AcSerializerType.Binary ? "Binary" : "JSON";
Logger.Info($"Client deserialized response ({serializerType2}). Total: {(DateTime.UtcNow.Subtract(startTime)).TotalMilliseconds} ms! requestId: {requestId}; tag: {messageTag} [{ConstHelper.NameByValue(TagsName, messageTag)}]");
return responseData;
}
Logger.Error($"Client timeout after: {(DateTime.Now - startTime).TotalSeconds} sec! ConnectionState: {GetConnectionState()}; requestId: {requestId}; tag: {messageTag} [{ConstHelper.NameByValue(TagsName, messageTag)}]");
}
catch (Exception ex)
{
Logger.Error($"Client SendMessageToServerAsync; requestId: {requestId}; ConnectionState: {GetConnectionState()}; {ex.Message}; {ConstHelper.NameByValue(TagsName, messageTag)}", ex);
}
if (_responseByRequestId.TryRemove(requestId, out var removedModel))
SignalRRequestModelPool.Return(removedModel);
return default;
}
protected virtual int GetNextRequestId() => AcDomain.NextUniqueInt32;
public virtual Task OnReceiveMessage(int messageTag, int? requestId, SignalParams signalParams, object data)
{
var logText = $"Client OnReceiveMessage; {nameof(requestId)}: {requestId}; {ConstHelper.NameByValue(TagsName, messageTag)}";
try
{
if (requestId.HasValue && _responseByRequestId.TryGetValue(requestId.Value, out var requestModel))
{
var reqId = requestId.Value;
requestModel.ResponseDateTime = DateTime.UtcNow;
Logger.Debug($"[{requestModel.ResponseDateTime.Subtract(requestModel.RequestDateTime).TotalMilliseconds:N0}ms]{logText}");
// Protocol already deserialized data to typed object or byte[]
var responseMessage = new SignalResponseDataMessage
{
Status = signalParams.Status,
DataSerializerType = signalParams.DataSerializerType,
RawResponseData = data
};
switch (requestModel.ResponseByRequestId)
{
case null:
requestModel.ResponseByRequestId = responseMessage;
return Task.CompletedTask;
case Action<SignalResponseDataMessage> actionCallback:
if (_responseByRequestId.TryRemove(reqId, out var actionModel))
SignalRRequestModelPool.Return(actionModel);
actionCallback.Invoke(responseMessage);
return Task.CompletedTask;
case Func<SignalResponseDataMessage, Task> funcCallback:
if (_responseByRequestId.TryRemove(reqId, out var funcModel))
SignalRRequestModelPool.Return(funcModel);
return funcCallback.Invoke(responseMessage);
default:
Logger.Error($"Client OnReceiveMessage switch; unknown message type: {requestModel.ResponseByRequestId?.GetType().Name}; {ConstHelper.NameByValue(TagsName, messageTag)}");
break;
}
if (_responseByRequestId.TryRemove(reqId, out var removedModel))
SignalRRequestModelPool.Return(removedModel);
return Task.CompletedTask;
}
Logger.Info(logText);
MessageReceived(messageTag, signalParams, data).Forget();
}
catch (Exception ex)
{
if (requestId.HasValue && _responseByRequestId.TryRemove(requestId.Value, out var exModel))
SignalRRequestModelPool.Return(exModel);
Logger.Error($"Client OnReceiveMessage; requestId: {requestId}; ConnectionState: {GetConnectionState()}; {ex.Message}; {ConstHelper.NameByValue(TagsName, messageTag)}", ex);
throw;
}
return Task.CompletedTask;
}
}
}