using Serilog; using System.Net.WebSockets; using System.Text; using System.Text.Json; namespace CommonSocketLibrary.Abstract { public abstract class SocketClient : IDisposable { private ClientWebSocket? _socket; private CancellationTokenSource? _cts; protected readonly ILogger _logger; protected readonly JsonSerializerOptions _options; public bool Connected { get; set; } public int ReceiveBufferSize { get; } = 8192; public SocketClient(ILogger logger, JsonSerializerOptions options) { _logger = logger; _options = options; Connected = false; } public async Task ConnectAsync(string url) { if (_socket != null) { if (_socket.State == WebSocketState.Open) return; else _socket.Dispose(); } _socket = new ClientWebSocket(); _socket.Options.RemoteCertificateValidationCallback = (o, c, ch, er) => true; _socket.Options.UseDefaultCredentials = false; if (_cts != null) _cts.Dispose(); _cts = new CancellationTokenSource(); await _socket.ConnectAsync(new Uri(url), _cts.Token); await Task.Factory.StartNew(ReceiveLoop, _cts.Token, TaskCreationOptions.LongRunning, TaskScheduler.Default); await OnConnection(); } public async Task DisconnectAsync() { if (_socket == null || _cts == null) return; // TODO: requests cleanup code, sub-protocol dependent. if (_socket.State == WebSocketState.Open) { _cts.CancelAfter(TimeSpan.FromMilliseconds(500)); await _socket.CloseOutputAsync(WebSocketCloseStatus.Empty, "", CancellationToken.None); await _socket.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); } Connected = false; _socket.Dispose(); _socket = null; _cts.Dispose(); _cts = null; } public void Dispose() => DisconnectAsync().Wait(); private async Task ReceiveLoop() { if (_socket == null || _cts == null) return; var loopToken = _cts.Token; MemoryStream? outputStream = null; WebSocketReceiveResult? receiveResult = null; var buffer = new byte[ReceiveBufferSize]; try { while (!loopToken.IsCancellationRequested) { outputStream = new MemoryStream(ReceiveBufferSize); do { receiveResult = await _socket.ReceiveAsync(buffer, _cts.Token); if (receiveResult.MessageType != WebSocketMessageType.Close) outputStream.Write(buffer, 0, receiveResult.Count); } while (!receiveResult.EndOfMessage); if (receiveResult.MessageType == WebSocketMessageType.Close) break; outputStream.Position = 0; await ResponseReceived(outputStream); } } catch (TaskCanceledException) { } finally { outputStream?.Dispose(); } } public async Task SendRaw(string content) { if (!Connected) return; var bytes = new byte[1024 * 4]; var array = new ArraySegment(bytes); var total = Encoding.UTF8.GetBytes(content).Length; var current = 0; while (current < total) { var size = Encoding.UTF8.GetBytes(content.Substring(current), array); await _socket.SendAsync(array, WebSocketMessageType.Text, true, _cts.Token); current += size; } await OnMessageSend(-1, content); } public async Task Send(int opcode, T data) { try { var message = GenerateMessage(opcode, data); var content = JsonSerializer.Serialize(message, _options); var bytes = Encoding.UTF8.GetBytes(content); var array = new ArraySegment(bytes); var total = bytes.Length; var current = 0; while (current < total) { var size = Encoding.UTF8.GetBytes(content.Substring(current), array); await _socket.SendAsync(array, WebSocketMessageType.Text, current + size >= total, _cts.Token); current += size; } await OnMessageSend(opcode, content); } catch (Exception e) { Connected = false; _logger.Error(e, "Failed to send a message: " + opcode); } } private async Task ResponseReceived(Stream stream) { try { var data = await JsonSerializer.DeserializeAsync(stream); await OnResponseReceived(data); } catch (Exception ex) { _logger.Error(ex, "Failed to read or execute a websocket message."); } finally { stream.Dispose(); } } protected abstract Message GenerateMessage(int opcode, T data); protected abstract Task OnResponseReceived(Message? content); protected abstract Task OnMessageSend(int opcode, string? content); protected abstract Task OnConnection(); } }