hermes-common-library/Abstract/SocketClient.cs

164 lines
5.7 KiB
C#
Raw Normal View History

2024-06-24 18:28:40 -04:00
using Serilog;
using System.Net.WebSockets;
using System.Text;
using System.Text.Json;
namespace CommonSocketLibrary.Abstract
{
public abstract class SocketClient<Message> : IDisposable
{
private ClientWebSocket? _socket;
private CancellationTokenSource? _cts;
protected readonly ILogger _logger;
protected readonly JsonSerializerOptions _options;
public bool Connected { get; set; }
public int ReceiveBufferSize { get; } = 8192;
public SocketClient(ILogger logger, JsonSerializerOptions options)
{
_logger = logger;
_options = options;
Connected = false;
}
public async Task ConnectAsync(string url)
{
if (_socket != null)
{
if (_socket.State == WebSocketState.Open) return;
else _socket.Dispose();
}
_socket = new ClientWebSocket();
_socket.Options.RemoteCertificateValidationCallback = (o, c, ch, er) => true;
_socket.Options.UseDefaultCredentials = false;
if (_cts != null) _cts.Dispose();
_cts = new CancellationTokenSource();
await _socket.ConnectAsync(new Uri(url), _cts.Token);
await Task.Factory.StartNew(ReceiveLoop, _cts.Token, TaskCreationOptions.LongRunning, TaskScheduler.Default);
await OnConnection();
}
public async Task DisconnectAsync()
{
if (_socket == null || _cts == null) return;
// TODO: requests cleanup code, sub-protocol dependent.
if (_socket.State == WebSocketState.Open)
{
_cts.CancelAfter(TimeSpan.FromMilliseconds(500));
await _socket.CloseOutputAsync(WebSocketCloseStatus.Empty, "", CancellationToken.None);
await _socket.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
}
Connected = false;
_socket.Dispose();
_socket = null;
_cts.Dispose();
_cts = null;
}
public void Dispose() => DisconnectAsync().Wait();
private async Task ReceiveLoop()
{
if (_socket == null || _cts == null) return;
var loopToken = _cts.Token;
MemoryStream? outputStream = null;
WebSocketReceiveResult? receiveResult = null;
var buffer = new byte[ReceiveBufferSize];
try
{
while (!loopToken.IsCancellationRequested)
{
outputStream = new MemoryStream(ReceiveBufferSize);
do
{
receiveResult = await _socket.ReceiveAsync(buffer, _cts.Token);
if (receiveResult.MessageType != WebSocketMessageType.Close)
outputStream.Write(buffer, 0, receiveResult.Count);
}
while (!receiveResult.EndOfMessage);
if (receiveResult.MessageType == WebSocketMessageType.Close) break;
outputStream.Position = 0;
await ResponseReceived(outputStream);
}
}
catch (TaskCanceledException) { }
finally
{
outputStream?.Dispose();
}
}
public async Task SendRaw(string content)
{
if (!Connected) return;
var bytes = new byte[1024 * 4];
var array = new ArraySegment<byte>(bytes);
var total = Encoding.UTF8.GetBytes(content).Length;
var current = 0;
while (current < total)
{
var size = Encoding.UTF8.GetBytes(content.Substring(current), array);
await _socket.SendAsync(array, WebSocketMessageType.Text, true, _cts.Token);
current += size;
}
await OnMessageSend(-1, content);
}
public async Task Send<T>(int opcode, T data)
{
try
{
var message = GenerateMessage(opcode, data);
var content = JsonSerializer.Serialize(message, _options);
var bytes = Encoding.UTF8.GetBytes(content);
var array = new ArraySegment<byte>(bytes);
var total = bytes.Length;
var current = 0;
while (current < total)
{
var size = Encoding.UTF8.GetBytes(content.Substring(current), array);
await _socket.SendAsync(array, WebSocketMessageType.Text, current + size >= total, _cts.Token);
current += size;
}
await OnMessageSend(opcode, content);
}
catch (Exception e)
{
Connected = false;
_logger.Error(e, "Failed to send a message: " + opcode);
}
}
private async Task ResponseReceived(Stream stream)
{
try
{
var data = await JsonSerializer.DeserializeAsync<Message>(stream);
await OnResponseReceived(data);
}
catch (Exception ex)
{
_logger.Error(ex, "Failed to read or execute a websocket message.");
}
finally
{
stream.Dispose();
}
}
protected abstract Message GenerateMessage<T>(int opcode, T data);
protected abstract Task OnResponseReceived(Message? content);
protected abstract Task OnMessageSend(int opcode, string? content);
protected abstract Task OnConnection();
}
}