Compare commits

...

3 Commits

Author SHA1 Message Date
Tom
61151bef0c Changed line endings for Startup 2024-10-29 12:14:55 +00:00
Tom
aac73563d0 Added delete policy request 2024-10-29 12:14:27 +00:00
Tom
e7b06f1634 Added & using prepared statements for stores 2024-10-29 12:13:38 +00:00
10 changed files with 414 additions and 220 deletions

27
Requests/DeletePolicy.cs Normal file
View File

@ -0,0 +1,27 @@
using HermesSocketServer.Services;
using ILogger = Serilog.ILogger;
namespace HermesSocketServer.Requests
{
public class DeletePolicy : IRequest
{
public string Name => "delete_policy";
public string[] RequiredKeys => ["id"];
private ChannelManager _channels;
private ILogger _logger;
public DeletePolicy(ChannelManager channels, ILogger logger)
{
_channels = channels;
_logger = logger;
}
public async Task<RequestResult> Grant(string sender, IDictionary<string, object>? data)
{
var channel = _channels.Get(sender);
channel.Policies.Remove(data!["id"].ToString());
_logger.Information($"Deleted a policy by id [policy id: {data["id"]}]");
return RequestResult.Successful(null);
}
}
}

View File

@ -1,160 +1,161 @@
using System.Net; using System.Net;
using System.Text.Json; using System.Text.Json;
using HermesSocketLibrary; using HermesSocketLibrary;
using HermesSocketLibrary.db; using HermesSocketLibrary.db;
using HermesSocketLibrary.Requests; using HermesSocketLibrary.Requests;
using HermesSocketServer; using HermesSocketServer;
using HermesSocketServer.Requests; using HermesSocketServer.Requests;
using HermesSocketServer.Socket; using HermesSocketServer.Socket;
using HermesSocketServer.Socket.Handlers; using HermesSocketServer.Socket.Handlers;
using Microsoft.AspNetCore.HttpOverrides; using Microsoft.AspNetCore.HttpOverrides;
using Serilog; using Serilog;
using Serilog.Events; using Serilog.Events;
using YamlDotNet.Serialization; using YamlDotNet.Serialization;
using YamlDotNet.Serialization.NamingConventions; using YamlDotNet.Serialization.NamingConventions;
using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections;
using HermesSocketServer.Validators; using HermesSocketServer.Validators;
using HermesSocketServer.Store; using HermesSocketServer.Store;
using HermesSocketServer.Services; using HermesSocketServer.Services;
var yamlDeserializer = new DeserializerBuilder() var yamlDeserializer = new DeserializerBuilder()
.WithNamingConvention(HyphenatedNamingConvention.Instance) .WithNamingConvention(HyphenatedNamingConvention.Instance)
.Build(); .Build();
var configFileName = "server.config.yml"; var configFileName = "server.config.yml";
var environment = Environment.GetEnvironmentVariable("TTS_ENV")!.ToLower(); var environment = Environment.GetEnvironmentVariable("TTS_ENV")!.ToLower();
if (File.Exists("server.config." + environment + ".yml")) if (File.Exists("server.config." + environment + ".yml"))
configFileName = "server.config." + environment + ".yml"; configFileName = "server.config." + environment + ".yml";
var configContent = File.ReadAllText(configFileName); var configContent = File.ReadAllText(configFileName);
var configuration = yamlDeserializer.Deserialize<ServerConfiguration>(configContent); var configuration = yamlDeserializer.Deserialize<ServerConfiguration>(configContent);
if (configuration.Environment.ToUpper() != "QA" && configuration.Environment.ToUpper() != "PROD") if (configuration.Environment.ToUpper() != "QA" && configuration.Environment.ToUpper() != "PROD")
throw new Exception("Invalid environment set."); throw new Exception("Invalid environment set.");
var builder = WebApplication.CreateBuilder(); var builder = WebApplication.CreateBuilder();
builder.Logging.ClearProviders(); builder.Logging.ClearProviders();
builder.Services.Configure<ForwardedHeadersOptions>(options => builder.Services.Configure<ForwardedHeadersOptions>(options =>
{ {
options.ForwardedHeaders = ForwardedHeaders.XForwardedFor | ForwardedHeaders.XForwardedProto; options.ForwardedHeaders = ForwardedHeaders.XForwardedFor | ForwardedHeaders.XForwardedProto;
}); });
builder.WebHost.UseUrls($"http://{configuration.WebsocketServer.Host}:{configuration.WebsocketServer.Port}"); builder.WebHost.UseUrls($"http://{configuration.WebsocketServer.Host}:{configuration.WebsocketServer.Port}");
var loggerConfiguration = new LoggerConfiguration(); var loggerConfiguration = new LoggerConfiguration();
if (configuration.Environment.ToUpper() == "QA") if (configuration.Environment.ToUpper() == "QA")
loggerConfiguration.MinimumLevel.Verbose(); loggerConfiguration.MinimumLevel.Verbose();
else else
loggerConfiguration.MinimumLevel.Debug(); loggerConfiguration.MinimumLevel.Debug();
loggerConfiguration.Enrich.FromLogContext() loggerConfiguration.Enrich.FromLogContext()
.WriteTo.File($"logs/{configuration.Environment.ToUpper()}/serverlog-.log", rollingInterval: RollingInterval.Day, retainedFileCountLimit: 7); .WriteTo.File($"logs/{configuration.Environment.ToUpper()}/serverlog-.log", rollingInterval: RollingInterval.Day, retainedFileCountLimit: 7);
if (configuration.Environment.ToUpper() == "QA") if (configuration.Environment.ToUpper() == "QA")
loggerConfiguration.WriteTo.Console(restrictedToMinimumLevel: LogEventLevel.Debug); loggerConfiguration.WriteTo.Console(restrictedToMinimumLevel: LogEventLevel.Debug);
else else
loggerConfiguration.WriteTo.Console(restrictedToMinimumLevel: LogEventLevel.Information); loggerConfiguration.WriteTo.Console(restrictedToMinimumLevel: LogEventLevel.Information);
var logger = loggerConfiguration.CreateLogger(); var logger = loggerConfiguration.CreateLogger();
builder.Host.UseSerilog(logger); builder.Host.UseSerilog(logger);
builder.Logging.AddSerilog(logger); builder.Logging.AddSerilog(logger);
var s = builder.Services; var s = builder.Services;
s.AddSerilog(logger); s.AddSerilog(logger);
s.AddSingleton<ServerConfiguration>(configuration); s.AddSingleton<ServerConfiguration>(configuration);
s.AddSingleton<Database>(); s.AddSingleton<Database>();
// Socket message handlers // Socket message handlers
s.AddSingleton<Serilog.ILogger>(logger); s.AddSingleton<Serilog.ILogger>(logger);
s.AddSingleton<ISocketHandler, HeartbeatHandler>(); s.AddSingleton<ISocketHandler, HeartbeatHandler>();
s.AddSingleton<ISocketHandler, HermesLoginHandler>(); s.AddSingleton<ISocketHandler, HermesLoginHandler>();
s.AddSingleton<ISocketHandler, RequestHandler>(); s.AddSingleton<ISocketHandler, RequestHandler>();
s.AddSingleton<ISocketHandler, LoggingHandler>(); s.AddSingleton<ISocketHandler, LoggingHandler>();
s.AddSingleton<ISocketHandler, ChatterHandler>(); s.AddSingleton<ISocketHandler, ChatterHandler>();
s.AddSingleton<ISocketHandler, EmoteDetailsHandler>(); s.AddSingleton<ISocketHandler, EmoteDetailsHandler>();
s.AddSingleton<ISocketHandler, EmoteUsageHandler>(); s.AddSingleton<ISocketHandler, EmoteUsageHandler>();
// Validators // Validators
s.AddSingleton<VoiceIdValidator>(); s.AddSingleton<VoiceIdValidator>();
s.AddSingleton<VoiceNameValidator>(); s.AddSingleton<VoiceNameValidator>();
// Stores // Stores
s.AddSingleton<VoiceStore>(); s.AddSingleton<VoiceStore>();
s.AddSingleton<UserStore>(); s.AddSingleton<UserStore>();
// Request handlers // Request handlers
s.AddSingleton<IRequest, GetChatterIds>(); s.AddSingleton<IRequest, CreatePolicy>();
s.AddSingleton<IRequest, GetConnections>(); s.AddSingleton<IRequest, CreateTTSUser>();
s.AddSingleton<IRequest, GetDefaultTTSVoice>(); s.AddSingleton<IRequest, CreateTTSVoice>();
s.AddSingleton<IRequest, GetEmotes>(); s.AddSingleton<IRequest, DeletePolicy>();
s.AddSingleton<IRequest, GetEnabledTTSVoices>(); s.AddSingleton<IRequest, DeleteTTSVoice>();
s.AddSingleton<IRequest, GetPermissions>(); s.AddSingleton<IRequest, GetChatterIds>();
s.AddSingleton<IRequest, GetRedemptions>(); s.AddSingleton<IRequest, GetConnections>();
s.AddSingleton<IRequest, GetRedeemableActions>(); s.AddSingleton<IRequest, GetDefaultTTSVoice>();
s.AddSingleton<IRequest, GetPolicies>(); s.AddSingleton<IRequest, GetEmotes>();
s.AddSingleton<IRequest, GetTTSUsers>(); s.AddSingleton<IRequest, GetEnabledTTSVoices>();
s.AddSingleton<IRequest, GetTTSVoices>(); s.AddSingleton<IRequest, GetPermissions>();
s.AddSingleton<IRequest, GetTTSWordFilters>(); s.AddSingleton<IRequest, GetRedemptions>();
s.AddSingleton<IRequest, CreatePolicy>(); s.AddSingleton<IRequest, GetRedeemableActions>();
s.AddSingleton<IRequest, CreateTTSUser>(); s.AddSingleton<IRequest, GetPolicies>();
s.AddSingleton<IRequest, CreateTTSVoice>(); s.AddSingleton<IRequest, GetTTSUsers>();
s.AddSingleton<IRequest, DeleteTTSVoice>(); s.AddSingleton<IRequest, GetTTSVoices>();
s.AddSingleton<IRequest, UpdateTTSUser>(); s.AddSingleton<IRequest, GetTTSWordFilters>();
s.AddSingleton<IRequest, UpdateTTSVoice>(); s.AddSingleton<IRequest, UpdateTTSUser>();
s.AddSingleton<IRequest, UpdateDefaultTTSVoice>(); s.AddSingleton<IRequest, UpdateTTSVoice>();
s.AddSingleton<IRequest, UpdateTTSVoiceState>(); s.AddSingleton<IRequest, UpdateDefaultTTSVoice>();
s.AddSingleton<IRequest, UpdatePolicy>(); s.AddSingleton<IRequest, UpdateTTSVoiceState>();
s.AddSingleton<IRequest, UpdatePolicy>();
// Managers
s.AddSingleton<ChannelManager>(); // Managers
s.AddSingleton<HermesSocketManager>(); s.AddSingleton<ChannelManager>();
s.AddSingleton<SocketHandlerManager>(); s.AddSingleton<HermesSocketManager>();
s.AddSingleton<IRequestManager, RequestManager>(); s.AddSingleton<SocketHandlerManager>();
s.AddSingleton(new JsonSerializerOptions() s.AddSingleton<IRequestManager, RequestManager>();
{ s.AddSingleton(new JsonSerializerOptions()
PropertyNameCaseInsensitive = false, {
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower PropertyNameCaseInsensitive = false,
}); PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower
s.AddSingleton<Server>(); });
s.AddSingleton<Server>();
// Background services
s.AddHostedService<DatabaseService>(); // Background services
s.AddHostedService<DatabaseService>();
var app = builder.Build();
app.UseForwardedHeaders(); var app = builder.Build();
app.UseSerilogRequestLogging(); app.UseForwardedHeaders();
app.UseSerilogRequestLogging();
var wsOptions = new WebSocketOptions()
{ var wsOptions = new WebSocketOptions()
KeepAliveInterval = TimeSpan.FromSeconds(30) {
}; KeepAliveInterval = TimeSpan.FromSeconds(30)
// wsOptions.AllowedOrigins.Add("wss://tomtospeech.com"); };
//wsOptions.AllowedOrigins.Add("ws.tomtospeech.com"); // wsOptions.AllowedOrigins.Add("wss://tomtospeech.com");
//wsOptions.AllowedOrigins.Add("hermes-ws.goblincaves.com"); //wsOptions.AllowedOrigins.Add("ws.tomtospeech.com");
app.UseWebSockets(wsOptions); //wsOptions.AllowedOrigins.Add("hermes-ws.goblincaves.com");
app.UseWebSockets(wsOptions);
var options = app.Services.GetRequiredService<JsonSerializerOptions>();
var server = app.Services.GetRequiredService<Server>(); var options = app.Services.GetRequiredService<JsonSerializerOptions>();
var server = app.Services.GetRequiredService<Server>();
app.Use(async (HttpContext context, RequestDelegate next) =>
{ app.Use(async (HttpContext context, RequestDelegate next) =>
if (context.Request.Path != "/") {
{ if (context.Request.Path != "/")
context.Response.StatusCode = StatusCodes.Status401Unauthorized; {
return; context.Response.StatusCode = StatusCodes.Status401Unauthorized;
} return;
}
if (context.WebSockets.IsWebSocketRequest)
{ if (context.WebSockets.IsWebSocketRequest)
using var webSocket = await context.WebSockets.AcceptWebSocketAsync(); {
await server.Handle(new WebSocketUser(webSocket, IPAddress.Parse(context.Request.Headers["X-Forwarded-For"].ToString()), options, logger), context); using var webSocket = await context.WebSockets.AcceptWebSocketAsync();
} await server.Handle(new WebSocketUser(webSocket, IPAddress.Parse(context.Request.Headers["X-Forwarded-For"].ToString()), options, logger), context);
else }
{ else
context.Response.StatusCode = StatusCodes.Status400BadRequest; {
} context.Response.StatusCode = StatusCodes.Status400BadRequest;
await next(context); }
}); await next(context);
});
await app.RunAsync(); await app.RunAsync();

View File

@ -1,3 +1,4 @@
using System.Collections.Immutable;
using HermesSocketLibrary.db; using HermesSocketLibrary.db;
using HermesSocketServer.Models; using HermesSocketServer.Models;
@ -23,7 +24,7 @@ namespace HermesSocketServer.Store
{ "ttsVoiceId", "VoiceId" }, { "ttsVoiceId", "VoiceId" },
{ "userId", "UserId" }, { "userId", "UserId" },
}; };
_generator = new GroupSaveSqlGenerator<ChatterVoice>(ctp); _generator = new GroupSaveSqlGenerator<ChatterVoice>(ctp, _logger);
} }
public override async Task Load() public override async Task Load()
@ -55,46 +56,53 @@ namespace HermesSocketServer.Store
{ {
} }
public override async Task<bool> Save() public override async Task Save()
{ {
int count = 0; int count = 0;
string sql = string.Empty; string sql = string.Empty;
ImmutableList<string>? list = null;
if (_added.Any()) if (_added.Any())
{ {
lock (_lock) lock (_lock)
{ {
count = _added.Count; list = _added.ToImmutableList();
sql = _generator.GenerateInsertSql("TtsChatVoice", _added.Select(a => _store[a]), ["userId", "chatterId", "ttsVoiceId"]);
_added.Clear(); _added.Clear();
} }
count = list.Count;
sql = _generator.GeneratePreparedInsertSql("TtsChatVoice", count, ["userId", "chatterId", "ttsVoiceId"]);
_logger.Debug($"TtsChatVoice - Adding {count} rows to database: {sql}"); _logger.Debug($"User - Adding {count} rows to database: {sql}");
await _database.ExecuteScalarTransaction(sql); var values = list.Select(id => _store[id]).Where(v => v != null);
await _generator.DoPreparedStatement(_database, sql, values, ["id", "name", "email", "role", "ttsDefaultVoice"]);
} }
if (_modified.Any()) if (_modified.Any())
{ {
lock (_lock) lock (_lock)
{ {
count = _modified.Count; list = _modified.ToImmutableList();
sql = _generator.GenerateUpdateSql("TtsChatVoice", _modified.Select(m => _store[m]), ["userId", "chatterId"], ["ttsVoiceId"]);
_modified.Clear(); _modified.Clear();
} }
_logger.Debug($"TtsChatVoice - Modifying {count} rows in database: {sql}"); count = list.Count;
await _database.ExecuteScalarTransaction(sql); sql = _generator.GeneratePreparedUpdateSql("TtsChatVoice", count, ["userId", "chatterId"], ["ttsVoiceId"]);
_logger.Debug($"User - Modifying {count} rows in database: {sql}");
var values = list.Select(id => _store[id]).Where(v => v != null);
await _generator.DoPreparedStatement(_database, sql, values, ["id", "name", "email", "role", "ttsDefaultVoice"]);
} }
if (_deleted.Any()) if (_deleted.Any())
{ {
lock (_lock) lock (_lock)
{ {
count = _deleted.Count; list = _deleted.ToImmutableList();
sql = _generator.GenerateDeleteSql("TtsChatVoice", _deleted, ["userId", "chatterId"]);
_deleted.Clear(); _deleted.Clear();
} }
_logger.Debug($"TtsChatVoice - Deleting {count} rows from database: {sql}"); count = list.Count;
await _database.ExecuteScalarTransaction(sql); sql = _generator.GeneratePreparedDeleteSql("TtsChatVoice", count, ["userId", "chatterId"]);
_logger.Debug($"User - Deleting {count} rows from database: {sql}");
await _generator.DoPreparedStatement(_database, sql, list, ["id"]);
} }
return true;
} }
} }
} }

View File

@ -1,15 +1,18 @@
using System.Reflection; using System.Reflection;
using System.Text; using System.Text;
using HermesSocketLibrary.db;
namespace HermesSocketServer.Store namespace HermesSocketServer.Store
{ {
public class GroupSaveSqlGenerator<T> public class GroupSaveSqlGenerator<T>
{ {
private readonly IDictionary<string, PropertyInfo?> columnPropertyRelations; private readonly IDictionary<string, PropertyInfo?> columnPropertyRelations;
private readonly Serilog.ILogger _logger;
public GroupSaveSqlGenerator(IDictionary<string, string> columnsToProperties) public GroupSaveSqlGenerator(IDictionary<string, string> columnsToProperties, Serilog.ILogger logger)
{ {
columnPropertyRelations = columnsToProperties.ToDictionary(p => p.Key, p => typeof(T).GetProperty(p.Value)); columnPropertyRelations = columnsToProperties.ToDictionary(p => p.Key, p => typeof(T).GetProperty(p.Value));
_logger = logger;
var nullProperties = columnPropertyRelations.Where(p => p.Value == null) var nullProperties = columnPropertyRelations.Where(p => p.Value == null)
.Select(p => columnsToProperties[p.Key]); .Select(p => columnsToProperties[p.Key]);
@ -17,6 +20,24 @@ namespace HermesSocketServer.Store
throw new ArgumentException("Some properties do not exist on the values given: " + string.Join(", ", nullProperties)); throw new ArgumentException("Some properties do not exist on the values given: " + string.Join(", ", nullProperties));
} }
public async Task DoPreparedStatement<V>(Database database, string sql, IEnumerable<V> values, string[] columns)
{
await database.Execute(sql, (c) =>
{
var valueCounter = 0;
foreach (var value in values)
{
foreach (var column in columns)
{
var propValue = columnPropertyRelations[column]!.GetValue(value);
var propType = columnPropertyRelations[column]!.PropertyType;
c.Parameters.AddWithValue(column.ToLower() + valueCounter, propValue ?? DBNull.Value);
}
valueCounter++;
}
});
}
public string GenerateInsertSql(string table, IEnumerable<T> values, IEnumerable<string> columns) public string GenerateInsertSql(string table, IEnumerable<T> values, IEnumerable<string> columns)
{ {
if (string.IsNullOrWhiteSpace(table)) if (string.IsNullOrWhiteSpace(table))
@ -40,11 +61,42 @@ namespace HermesSocketServer.Store
{ {
var propValue = columnPropertyRelations[column]!.GetValue(value); var propValue = columnPropertyRelations[column]!.GetValue(value);
var propType = columnPropertyRelations[column]!.PropertyType; var propType = columnPropertyRelations[column]!.PropertyType;
WriteValue(sb, propValue, propType); WriteValue(sb, propValue ?? DBNull.Value, propType);
sb.Append(","); sb.Append(",");
} }
sb.Remove(sb.Length - 1, 1) sb.Remove(sb.Length - 1, 1)
.Append("),"); .Append("),");
}
sb.Remove(sb.Length - 1, 1)
.Append(';');
return sb.ToString();
}
public string GeneratePreparedInsertSql(string table, int rows, IEnumerable<string> columns)
{
if (string.IsNullOrWhiteSpace(table))
throw new ArgumentException("Value is either null or whitespace-filled.", nameof(table));
if (columns == null)
throw new ArgumentNullException(nameof(columns));
if (!columns.Any())
throw new ArgumentException("Empty list given.", nameof(columns));
var ctp = columns.ToDictionary(c => c, c => columnPropertyRelations[c]);
var sb = new StringBuilder();
var columnsLower = columns.Select(c => c.ToLower());
sb.Append($"INSERT INTO \"{table}\" (\"{string.Join("\", \"", columns)}\") VALUES ");
for (var row = 0; row < rows; row++)
{
sb.Append("(");
foreach (var column in columnsLower)
{
sb.Append('@')
.Append(column)
.Append(row)
.Append(", ");
}
sb.Remove(sb.Length - 2, 2)
.Append("),");
} }
sb.Remove(sb.Length - 1, 1) sb.Remove(sb.Length - 1, 1)
.Append(';'); .Append(';');
@ -93,6 +145,44 @@ namespace HermesSocketServer.Store
return sb.ToString(); return sb.ToString();
} }
public string GeneratePreparedUpdateSql(string table, int rows, IEnumerable<string> keyColumns, IEnumerable<string> updateColumns)
{
if (string.IsNullOrWhiteSpace(table))
throw new ArgumentException("Value is either null or whitespace-filled.", nameof(table));
if (keyColumns == null)
throw new ArgumentNullException(nameof(keyColumns));
if (!keyColumns.Any())
throw new ArgumentException("Empty list given.", nameof(keyColumns));
if (updateColumns == null)
throw new ArgumentNullException(nameof(updateColumns));
if (!updateColumns.Any())
throw new ArgumentException("Empty list given.", nameof(updateColumns));
var columns = keyColumns.Union(updateColumns);
var ctp = columns.ToDictionary(c => c, c => columnPropertyRelations[c]);
var sb = new StringBuilder();
sb.Append($"UPDATE \"{table}\" as t SET {string.Join(", ", updateColumns.Select(c => "\"" + c + "\" = c.\"" + c + "\""))} FROM (VALUES ");
for (var row = 0; row < rows; row++)
{
sb.Append("(");
foreach (var column in columns)
{
sb.Append('@')
.Append(column)
.Append(row)
.Append(", ");
}
sb.Remove(sb.Length - 2, 2)
.Append("),");
}
sb.Remove(sb.Length - 1, 1)
.Append($") AS c(\"{string.Join("\", \"", columns)}\") WHERE ")
.Append(string.Join(" AND ", keyColumns.Select(c => "t.\"" + c + "\" = c.\"" + c + "\"")))
.Append(";");
return sb.ToString();
}
public string GenerateDeleteSql(string table, IEnumerable<string> keys, IEnumerable<string> keyColumns) public string GenerateDeleteSql(string table, IEnumerable<string> keys, IEnumerable<string> keyColumns)
{ {
if (string.IsNullOrWhiteSpace(table)) if (string.IsNullOrWhiteSpace(table))
@ -127,6 +217,37 @@ namespace HermesSocketServer.Store
return sb.ToString(); return sb.ToString();
} }
public string GeneratePreparedDeleteSql(string table, int rows, IEnumerable<string> keyColumns)
{
if (string.IsNullOrWhiteSpace(table))
throw new ArgumentException("Value is either null or whitespace-filled.", nameof(table));
if (keyColumns == null)
throw new ArgumentNullException(nameof(keyColumns));
if (!keyColumns.Any())
throw new ArgumentException("Empty list given.", nameof(keyColumns));
var ctp = keyColumns.ToDictionary(c => c, c => columnPropertyRelations[c]);
var sb = new StringBuilder();
sb.Append($"DELETE FROM \"{table}\" WHERE (\"{string.Join("\", \"", keyColumns)}\") IN (");
for (var row = 0; row < rows; row++)
{
sb.Append("(");
foreach (var column in keyColumns)
{
sb.Append('@')
.Append(column)
.Append(row)
.Append(", ");
}
sb.Remove(sb.Length - 2, 2)
.Append("),");
}
sb.Remove(sb.Length - 1, 1)
.Append(");");
return sb.ToString();
}
private void WriteValue(StringBuilder sb, object? value, Type type) private void WriteValue(StringBuilder sb, object? value, Type type)
{ {
if (type == typeof(string)) if (type == typeof(string))
@ -134,9 +255,9 @@ namespace HermesSocketServer.Store
.Append(value) .Append(value)
.Append("'"); .Append("'");
else if (type == typeof(Guid)) else if (type == typeof(Guid))
sb.Append("'") sb.Append("uuid('")
.Append(value?.ToString()) .Append(value?.ToString())
.Append("'"); .Append("')");
else if (type == typeof(TimeSpan)) else if (type == typeof(TimeSpan))
sb.Append(((TimeSpan)value).TotalMilliseconds); sb.Append(((TimeSpan)value).TotalMilliseconds);
else else

View File

@ -28,7 +28,7 @@ namespace HermesSocketServer.Store
protected abstract void OnInitialAdd(K key, V value); protected abstract void OnInitialAdd(K key, V value);
protected abstract void OnInitialModify(K key, V value); protected abstract void OnInitialModify(K key, V value);
protected abstract void OnInitialRemove(K key); protected abstract void OnInitialRemove(K key);
public abstract Task<bool> Save(); public abstract Task Save();
public V? Get(K key) public V? Get(K key)
{ {

View File

@ -7,7 +7,7 @@ namespace HermesSocketServer.Store
Task Load(); Task Load();
bool Modify(K? key, Action<V> action); bool Modify(K? key, Action<V> action);
void Remove(K? key); void Remove(K? key);
Task<bool> Save(); Task Save();
bool Set(K? key, V? value); bool Set(K? key, V? value);
} }
} }

View File

@ -1,3 +1,4 @@
using System.Collections.Immutable;
using HermesSocketLibrary.db; using HermesSocketLibrary.db;
using HermesSocketServer.Models; using HermesSocketServer.Models;
@ -26,7 +27,7 @@ namespace HermesSocketServer.Store
{ "count", "Usage" }, { "count", "Usage" },
{ "timespan", "Span" }, { "timespan", "Span" },
}; };
_generator = new GroupSaveSqlGenerator<PolicyMessage>(ctp); _generator = new GroupSaveSqlGenerator<PolicyMessage>(ctp, _logger);
} }
public override async Task Load() public override async Task Load()
@ -61,46 +62,53 @@ namespace HermesSocketServer.Store
{ {
} }
public override async Task<bool> Save() public override async Task Save()
{ {
int count = 0; int count = 0;
string sql = string.Empty; string sql = string.Empty;
ImmutableList<string>? list = null;
if (_added.Any()) if (_added.Any())
{ {
lock (_lock) lock (_lock)
{ {
count = _added.Count; list = _added.ToImmutableList();
sql = _generator.GenerateInsertSql("GroupPermissionPolicy", _added.Select(a => _store[a]), ["id", "userId", "groupId", "path", "count", "timespan"]);
_added.Clear(); _added.Clear();
} }
count = list.Count;
sql = _generator.GeneratePreparedInsertSql("GroupPermissionPolicy", count, ["id", "userId", "groupId", "path", "count", "timespan"]);
_logger.Debug($"GroupPermissionPolicy - Adding {count} rows to database: {sql}"); _logger.Debug($"GroupPermissionPolicy - Adding {count} rows to database: {sql}");
await _database.ExecuteScalarTransaction(sql); var values = list.Select(id => _store[id]).Where(v => v != null);
await _generator.DoPreparedStatement(_database, sql, values, ["id", "userId", "groupId", "path", "count", "timespan"]);
} }
if (_modified.Any()) if (_modified.Any())
{ {
lock (_lock) lock (_lock)
{ {
count = _modified.Count; list = _modified.ToImmutableList();
sql = _generator.GenerateUpdateSql("GroupPermissionPolicy", _modified.Select(m => _store[m]), ["id"], ["userId", "groupId", "path", "count", "timespan"]);
_modified.Clear(); _modified.Clear();
} }
count = list.Count;
sql = _generator.GeneratePreparedUpdateSql("GroupPermissionPolicy", count, ["id"], ["userId", "groupId", "path", "count", "timespan"]);
_logger.Debug($"GroupPermissionPolicy - Modifying {count} rows in database: {sql}"); _logger.Debug($"GroupPermissionPolicy - Modifying {count} rows in database: {sql}");
await _database.ExecuteScalarTransaction(sql); var values = list.Select(id => _store[id]).Where(v => v != null);
await _generator.DoPreparedStatement(_database, sql, values, ["id", "userId", "groupId", "path", "count", "timespan"]);
} }
if (_deleted.Any()) if (_deleted.Any())
{ {
lock (_lock) lock (_lock)
{ {
count = _deleted.Count; list = _deleted.ToImmutableList();
sql = _generator.GenerateDeleteSql("GroupPermissionPolicy", _deleted, ["id"]);
_deleted.Clear(); _deleted.Clear();
} }
count = list.Count;
sql = _generator.GeneratePreparedDeleteSql("GroupPermissionPolicy", count, ["id"]);
_logger.Debug($"GroupPermissionPolicy - Deleting {count} rows from database: {sql}"); _logger.Debug($"GroupPermissionPolicy - Deleting {count} rows from database: {sql}");
await _database.ExecuteScalarTransaction(sql); await _generator.DoPreparedStatement(_database, sql, list, ["id"]);
} }
return true;
} }
} }
} }

View File

@ -1,3 +1,4 @@
using System.Collections.Immutable;
using HermesSocketLibrary.db; using HermesSocketLibrary.db;
using HermesSocketServer.Models; using HermesSocketServer.Models;
@ -23,7 +24,7 @@ namespace HermesSocketServer.Store
{ "role", "Role" }, { "role", "Role" },
{ "ttsDefaultVoice", "DefaultVoice" } { "ttsDefaultVoice", "DefaultVoice" }
}; };
_generator = new GroupSaveSqlGenerator<User>(ctp); _generator = new GroupSaveSqlGenerator<User>(ctp, _logger);
} }
public override async Task Load() public override async Task Load()
@ -56,45 +57,53 @@ namespace HermesSocketServer.Store
{ {
} }
public override async Task<bool> Save() public override async Task Save()
{ {
int count = 0; int count = 0;
string sql = string.Empty; string sql = string.Empty;
ImmutableList<string>? list = null;
if (_added.Any()) if (_added.Any())
{ {
lock (_lock) lock (_lock)
{ {
count = _added.Count; list = _added.ToImmutableList();
sql = _generator.GenerateInsertSql("User", _added.Select(a => _store[a]), ["id", "name", "email", "role", "ttsDefaultVoice"]);
_added.Clear(); _added.Clear();
} }
count = list.Count;
sql = _generator.GeneratePreparedInsertSql("User", count, ["id", "name", "email", "role", "ttsDefaultVoice"]);
_logger.Debug($"User - Adding {count} rows to database: {sql}"); _logger.Debug($"User - Adding {count} rows to database: {sql}");
await _database.ExecuteScalarTransaction(sql); var values = list.Select(id => _store[id]).Where(v => v != null);
await _generator.DoPreparedStatement(_database, sql, values, ["id", "name", "email", "role", "ttsDefaultVoice"]);
} }
if (_modified.Any()) if (_modified.Any())
{ {
lock (_lock) lock (_lock)
{ {
count = _modified.Count; list = _modified.ToImmutableList();
sql = _generator.GenerateUpdateSql("User", _modified.Select(m => _store[m]), ["id"], ["name", "email", "role", "ttsDefaultVoice"]);
_modified.Clear(); _modified.Clear();
} }
count = list.Count;
sql = _generator.GeneratePreparedUpdateSql("User", count, ["id"], ["name", "email", "role", "ttsDefaultVoice"]);
_logger.Debug($"User - Modifying {count} rows in database: {sql}"); _logger.Debug($"User - Modifying {count} rows in database: {sql}");
await _database.ExecuteScalarTransaction(sql); var values = list.Select(id => _store[id]).Where(v => v != null);
await _generator.DoPreparedStatement(_database, sql, values, ["id", "name", "email", "role", "ttsDefaultVoice"]);
} }
if (_deleted.Any()) if (_deleted.Any())
{ {
lock (_lock) lock (_lock)
{ {
count = _deleted.Count; list = _deleted.ToImmutableList();
sql = _generator.GenerateDeleteSql("User", _deleted, ["id"]);
_deleted.Clear(); _deleted.Clear();
} }
count = list.Count;
sql = _generator.GeneratePreparedDeleteSql("User", count, ["id"]);
_logger.Debug($"User - Deleting {count} rows from database: {sql}"); _logger.Debug($"User - Deleting {count} rows from database: {sql}");
await _database.ExecuteScalarTransaction(sql); await _generator.DoPreparedStatement(_database, sql, list, ["id"]);
} }
return true;
} }
} }
} }

View File

@ -1,3 +1,4 @@
using System.Collections.Immutable;
using HermesSocketLibrary.db; using HermesSocketLibrary.db;
using HermesSocketServer.Models; using HermesSocketServer.Models;
using HermesSocketServer.Validators; using HermesSocketServer.Validators;
@ -25,7 +26,7 @@ namespace HermesSocketServer.Store
{ "id", "Id" }, { "id", "Id" },
{ "name", "Name" } { "name", "Name" }
}; };
_generator = new GroupSaveSqlGenerator<Voice>(ctp); _generator = new GroupSaveSqlGenerator<Voice>(ctp, _logger);
} }
public override async Task Load() public override async Task Load()
@ -58,46 +59,53 @@ namespace HermesSocketServer.Store
{ {
} }
public override async Task<bool> Save() public override async Task Save()
{ {
int count = 0; int count = 0;
string sql = string.Empty; string sql = string.Empty;
ImmutableList<string>? list = null;
if (_added.Any()) if (_added.Any())
{ {
lock (_lock) lock (_lock)
{ {
count = _added.Count; list = _added.ToImmutableList();
sql = _generator.GenerateInsertSql("TtsVoice", _added.Select(a => _store[a]), ["id", "name"]);
_added.Clear(); _added.Clear();
} }
count = list.Count;
sql = _generator.GeneratePreparedInsertSql("TtsVoice", count, ["id", "name"]);
_logger.Debug($"TtsVoice - Adding {count} rows to database: {sql}"); _logger.Debug($"User - Adding {count} rows to database: {sql}");
await _database.ExecuteScalarTransaction(sql); var values = list.Select(id => _store[id]).Where(v => v != null);
await _generator.DoPreparedStatement(_database, sql, values, ["id", "name", "email", "role", "ttsDefaultVoice"]);
} }
if (_modified.Any()) if (_modified.Any())
{ {
lock (_lock) lock (_lock)
{ {
count = _modified.Count; list = _modified.ToImmutableList();
sql = _generator.GenerateUpdateSql("TtsVoice", _modified.Select(m => _store[m]), ["id"], ["name"]);
_modified.Clear(); _modified.Clear();
} }
_logger.Debug($"TtsVoice - Modifying {count} rows in database: {sql}"); count = list.Count;
await _database.ExecuteScalarTransaction(sql); sql = _generator.GeneratePreparedUpdateSql("TtsVoice", count, ["id"], ["name"]);
_logger.Debug($"User - Modifying {count} rows in database: {sql}");
var values = list.Select(id => _store[id]).Where(v => v != null);
await _generator.DoPreparedStatement(_database, sql, values, ["id", "name", "email", "role", "ttsDefaultVoice"]);
} }
if (_deleted.Any()) if (_deleted.Any())
{ {
lock (_lock) lock (_lock)
{ {
count = _deleted.Count; list = _deleted.ToImmutableList();
sql = _generator.GenerateDeleteSql("TtsVoice", _deleted, ["id"]);
_deleted.Clear(); _deleted.Clear();
} }
_logger.Debug($"TtsVoice - Deleting {count} rows from database: {sql}"); count = list.Count;
await _database.ExecuteScalarTransaction(sql); sql = _generator.GeneratePreparedDeleteSql("TtsVoice", count, ["id"]);
_logger.Debug($"User - Deleting {count} rows from database: {sql}");
await _generator.DoPreparedStatement(_database, sql, list, ["id"]);
} }
return true;
} }
} }
} }

View File

@ -69,6 +69,18 @@ namespace HermesSocketLibrary.db
return await command.ExecuteNonQueryAsync(); return await command.ExecuteNonQueryAsync();
} }
public async Task<int> ExecuteTransaction(string sql, Action<NpgsqlCommand> prepare)
{
await using var connection = await _source.OpenConnectionAsync();
await using var transaction = await connection.BeginTransactionAsync();
await using var command = new NpgsqlCommand(sql, connection, transaction);
prepare(command);
await command.PrepareAsync();
var results = await command.ExecuteNonQueryAsync();
await transaction.CommitAsync();
return results;
}
public async Task<object?> ExecuteScalar(string sql, IDictionary<string, object>? values = null) public async Task<object?> ExecuteScalar(string sql, IDictionary<string, object>? values = null)
{ {
await using var connection = await _source.OpenConnectionAsync(); await using var connection = await _source.OpenConnectionAsync();