From ba879fd872945659816abefb2539e30affd12a37 Mon Sep 17 00:00:00 2001 From: Hinton Date: Fri, 19 Dec 2025 13:48:57 +0100 Subject: [PATCH] Resolve review feedback --- .../Services/Implementations/PlayIdService.cs | 8 ++ .../DapperServiceCollectionExtensions.cs | 27 ++--- ...ityFrameworkServiceCollectionExtensions.cs | 26 ++--- src/SharedWeb/Utilities/PlayIdMiddleware.cs | 24 ++++- .../Utilities/ServiceCollectionExtensions.cs | 14 ++- .../PlayData_ReadByPlayId.sql | 6 +- test/SharedWeb.Test/PlayIdMiddlewareTests.cs | 102 ++++++++++++++++++ .../DbScripts/2025-11-04_00_PlayData.sql | 6 +- util/Seeder/MangleId.cs | 4 +- util/Seeder/SceneResult.cs | 9 ++ util/SeederApi/Services/QueryService.cs | 82 +++++++------- util/SeederApi/Startup.cs | 6 ++ 12 files changed, 243 insertions(+), 71 deletions(-) create mode 100644 test/SharedWeb.Test/PlayIdMiddlewareTests.cs diff --git a/src/Core/Services/Implementations/PlayIdService.cs b/src/Core/Services/Implementations/PlayIdService.cs index 0c2a25291b..7a5a046142 100644 --- a/src/Core/Services/Implementations/PlayIdService.cs +++ b/src/Core/Services/Implementations/PlayIdService.cs @@ -29,6 +29,14 @@ public class NeverPlayIdServices : IPlayIdService } } +/// +/// Singleton wrapper service that bridges singleton-scoped service boundaries for PlayId tracking. +/// This allows singleton services to access the scoped PlayIdService via HttpContext.RequestServices. +/// +/// Uses IHttpContextAccessor to retrieve the current request's scoped PlayIdService instance, enabling +/// singleton services to participate in Play session tracking without violating DI lifetime rules. +/// Falls back to NeverPlayIdServices when no HttpContext is available (e.g., background jobs). +/// public class PlayIdSingletonService(IHttpContextAccessor httpContextAccessor, IHostEnvironment hostEnvironment) : IPlayIdService { private IPlayIdService Current diff --git a/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs b/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs index c2e5d011c7..e22e203ac5 100644 --- a/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs +++ b/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs @@ -9,7 +9,6 @@ using Bit.Core.NotificationCenter.Repositories; using Bit.Core.Platform.Installations; using Bit.Core.Repositories; using Bit.Core.SecretsManager.Repositories; -using Bit.Core.Settings; using Bit.Core.Tools.Repositories; using Bit.Core.Vault.Repositories; using Bit.Infrastructure.Dapper.AdminConsole.Repositories; @@ -29,19 +28,8 @@ namespace Bit.Infrastructure.Dapper; public static class DapperServiceCollectionExtensions { - public static void AddDapperRepositories(this IServiceCollection services, bool selfHosted, GlobalSettings globalSettings) + public static void AddDapperRepositories(this IServiceCollection services, bool selfHosted) { - if (globalSettings.TestPlayIdTrackingEnabled) - { - services.AddSingleton(); - services.AddSingleton(); - } - else - { - services.AddSingleton(); - services.AddSingleton(); - } - services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); @@ -59,6 +47,7 @@ public static class DapperServiceCollectionExtensions services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); @@ -70,6 +59,7 @@ public static class DapperServiceCollectionExtensions services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); @@ -92,4 +82,15 @@ public static class DapperServiceCollectionExtensions services.AddSingleton(); } } + + /// + /// Adds PlayId tracking decorators for User and Organization repositories. + /// This replaces the standard repository implementations with tracking versions + /// that record created entities for test data cleanup. Only call when TestPlayIdTrackingEnabled is true. + /// + public static void AddPlayIdTrackingRepositories(this IServiceCollection services) + { + services.AddSingleton(); + services.AddSingleton(); + } } diff --git a/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs b/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs index 74f7499417..14cfbab79c 100644 --- a/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs +++ b/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs @@ -68,19 +68,8 @@ public static class EntityFrameworkServiceCollectionExtensions }); } - public static void AddPasswordManagerEFRepositories(this IServiceCollection services, bool selfHosted, GlobalSettings globalSettings) + public static void AddPasswordManagerEFRepositories(this IServiceCollection services, bool selfHosted) { - if (globalSettings.TestPlayIdTrackingEnabled) - { - services.AddSingleton(); - services.AddSingleton(); - } - else - { - services.AddSingleton(); - services.AddSingleton(); - } - services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); @@ -97,6 +86,7 @@ public static class EntityFrameworkServiceCollectionExtensions services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); @@ -108,6 +98,7 @@ public static class EntityFrameworkServiceCollectionExtensions services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); @@ -130,4 +121,15 @@ public static class EntityFrameworkServiceCollectionExtensions services.AddSingleton(); } } + + /// + /// Adds PlayId tracking decorators for User and Organization repositories. + /// This replaces the standard repository implementations with tracking versions + /// that record created entities for test data cleanup. Only call when TestPlayIdTrackingEnabled is true. + /// + public static void AddPlayIdTrackingEFRepositories(this IServiceCollection services) + { + services.AddSingleton(); + services.AddSingleton(); + } } diff --git a/src/SharedWeb/Utilities/PlayIdMiddleware.cs b/src/SharedWeb/Utilities/PlayIdMiddleware.cs index 3f692e6ae9..c00ab2b657 100644 --- a/src/SharedWeb/Utilities/PlayIdMiddleware.cs +++ b/src/SharedWeb/Utilities/PlayIdMiddleware.cs @@ -11,13 +11,31 @@ namespace Bit.SharedWeb.Utilities; /// public sealed class PlayIdMiddleware(RequestDelegate next) { - public Task Invoke(HttpContext context, PlayIdService playIdService) + private const int MaxPlayIdLength = 256; + + public async Task Invoke(HttpContext context, PlayIdService playIdService) { if (context.Request.Headers.TryGetValue("x-play-id", out var playId)) { - playIdService.PlayId = playId; + var playIdValue = playId.ToString(); + + if (string.IsNullOrWhiteSpace(playIdValue)) + { + context.Response.StatusCode = StatusCodes.Status400BadRequest; + await context.Response.WriteAsJsonAsync(new { Error = "x-play-id header cannot be empty or whitespace" }); + return; + } + + if (playIdValue.Length > MaxPlayIdLength) + { + context.Response.StatusCode = StatusCodes.Status400BadRequest; + await context.Response.WriteAsJsonAsync(new { Error = $"x-play-id header cannot exceed {MaxPlayIdLength} characters" }); + return; + } + + playIdService.PlayId = playIdValue; } - return next(context); + await next(context); } } diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs index e40a849b82..d354c3a408 100644 --- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs +++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs @@ -96,11 +96,11 @@ public static class ServiceCollectionExtensions if (provider != SupportedDatabaseProviders.SqlServer) { - services.AddPasswordManagerEFRepositories(globalSettings.SelfHosted, globalSettings); + services.AddPasswordManagerEFRepositories(globalSettings.SelfHosted); } else { - services.AddDapperRepositories(globalSettings.SelfHosted, globalSettings); + services.AddDapperRepositories(globalSettings.SelfHosted); } if (globalSettings.SelfHosted) @@ -123,6 +123,16 @@ public static class ServiceCollectionExtensions services.AddSingleton(); services.AddSingleton(); services.AddScoped(); + + // Replace standard repositories with PlayId tracking decorators + if (provider == SupportedDatabaseProviders.SqlServer) + { + services.AddPlayIdTrackingRepositories(); + } + else + { + services.AddPlayIdTrackingEFRepositories(); + } } else { diff --git a/src/Sql/dbo/Stored Procedures/PlayData_ReadByPlayId.sql b/src/Sql/dbo/Stored Procedures/PlayData_ReadByPlayId.sql index af1e15701d..77d3f7df4f 100644 --- a/src/Sql/dbo/Stored Procedures/PlayData_ReadByPlayId.sql +++ b/src/Sql/dbo/Stored Procedures/PlayData_ReadByPlayId.sql @@ -5,7 +5,11 @@ BEGIN SET NOCOUNT ON SELECT - * + [Id], + [PlayId], + [UserId], + [OrganizationId], + [CreationDate] FROM [dbo].[PlayData] WHERE diff --git a/test/SharedWeb.Test/PlayIdMiddlewareTests.cs b/test/SharedWeb.Test/PlayIdMiddlewareTests.cs new file mode 100644 index 0000000000..c2f6e0522d --- /dev/null +++ b/test/SharedWeb.Test/PlayIdMiddlewareTests.cs @@ -0,0 +1,102 @@ +using Bit.Core.Services; +using Bit.SharedWeb.Utilities; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Hosting; +using NSubstitute; + +namespace SharedWeb.Test; + +public class PlayIdMiddlewareTests +{ + private readonly PlayIdService _playIdService; + private readonly RequestDelegate _next; + private readonly PlayIdMiddleware _middleware; + + public PlayIdMiddlewareTests() + { + var hostEnvironment = Substitute.For(); + hostEnvironment.EnvironmentName.Returns(Environments.Development); + + _playIdService = new PlayIdService(hostEnvironment); + _next = Substitute.For(); + _middleware = new PlayIdMiddleware(_next); + } + + [Fact] + public async Task Invoke_WithValidPlayId_SetsPlayIdAndCallsNext() + { + var context = new DefaultHttpContext(); + context.Request.Headers["x-play-id"] = "test-play-id"; + + await _middleware.Invoke(context, _playIdService); + + Assert.Equal("test-play-id", _playIdService.PlayId); + await _next.Received(1).Invoke(context); + } + + [Fact] + public async Task Invoke_WithoutPlayIdHeader_CallsNext() + { + var context = new DefaultHttpContext(); + + await _middleware.Invoke(context, _playIdService); + + Assert.Null(_playIdService.PlayId); + await _next.Received(1).Invoke(context); + } + + [Theory] + [InlineData("")] + [InlineData(" ")] + [InlineData("\t")] + public async Task Invoke_WithEmptyOrWhitespacePlayId_Returns400(string playId) + { + var context = new DefaultHttpContext(); + context.Response.Body = new MemoryStream(); + context.Request.Headers["x-play-id"] = playId; + + await _middleware.Invoke(context, _playIdService); + + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + await _next.DidNotReceive().Invoke(context); + } + + [Fact] + public async Task Invoke_WithPlayIdExceedingMaxLength_Returns400() + { + var context = new DefaultHttpContext(); + context.Response.Body = new MemoryStream(); + var longPlayId = new string('a', 257); // Exceeds 256 character limit + context.Request.Headers["x-play-id"] = longPlayId; + + await _middleware.Invoke(context, _playIdService); + + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + await _next.DidNotReceive().Invoke(context); + } + + [Fact] + public async Task Invoke_WithPlayIdAtMaxLength_SetsPlayIdAndCallsNext() + { + var context = new DefaultHttpContext(); + var maxLengthPlayId = new string('a', 256); // Exactly 256 characters + context.Request.Headers["x-play-id"] = maxLengthPlayId; + + await _middleware.Invoke(context, _playIdService); + + Assert.Equal(maxLengthPlayId, _playIdService.PlayId); + await _next.Received(1).Invoke(context); + } + + [Fact] + public async Task Invoke_WithSpecialCharactersInPlayId_SetsPlayIdAndCallsNext() + { + var context = new DefaultHttpContext(); + context.Request.Headers["x-play-id"] = "test-play_id.123"; + + await _middleware.Invoke(context, _playIdService); + + Assert.Equal("test-play_id.123", _playIdService.PlayId); + await _next.Received(1).Invoke(context); + } +} diff --git a/util/Migrator/DbScripts/2025-11-04_00_PlayData.sql b/util/Migrator/DbScripts/2025-11-04_00_PlayData.sql index 4c9468ed29..492e8a5143 100644 --- a/util/Migrator/DbScripts/2025-11-04_00_PlayData.sql +++ b/util/Migrator/DbScripts/2025-11-04_00_PlayData.sql @@ -62,7 +62,11 @@ BEGIN SET NOCOUNT ON SELECT - * + [Id], + [PlayId], + [UserId], + [OrganizationId], + [CreationDate] FROM [dbo].[PlayData] WHERE diff --git a/util/Seeder/MangleId.cs b/util/Seeder/MangleId.cs index 0f7c70b2bd..e1be47f4d2 100644 --- a/util/Seeder/MangleId.cs +++ b/util/Seeder/MangleId.cs @@ -1,7 +1,9 @@ namespace Bit.Seeder; /// -/// Helper for mangling IDs +/// Helper for generating unique identifier suffixes to prevent collisions in test data. +/// "Mangling" adds a random suffix to test data identifiers (usernames, emails, org names, etc.) +/// to ensure uniqueness across multiple test runs and parallel test executions. /// public class MangleId { diff --git a/util/Seeder/SceneResult.cs b/util/Seeder/SceneResult.cs index 5c543c9004..7ac004f55e 100644 --- a/util/Seeder/SceneResult.cs +++ b/util/Seeder/SceneResult.cs @@ -1,8 +1,17 @@ namespace Bit.Seeder; +/// +/// Helper for exposing a interface with a SeedAsync method. +/// public class SceneResult(Dictionary mangleMap) : SceneResult(result: null, mangleMap: mangleMap); +/// +/// Generic result from executing a Scene. +/// Contains custom scene-specific data and a mangle map that maps magic strings from the +/// request to their mangled (collision-free) values inserted into the database. +/// +/// The type of custom result data returned by the scene. public class SceneResult(TResult result, Dictionary mangleMap) { public TResult Result { get; init; } = result; diff --git a/util/SeederApi/Services/QueryService.cs b/util/SeederApi/Services/QueryService.cs index 11e4564349..7314f1066a 100644 --- a/util/SeederApi/Services/QueryService.cs +++ b/util/SeederApi/Services/QueryService.cs @@ -23,44 +23,7 @@ public class QueryService( var requestType = query.GetRequestType(); - // Deserialize the arguments into the request model - object? requestModel; - if (arguments == null) - { - // Try to create an instance with default values - try - { - requestModel = Activator.CreateInstance(requestType); - if (requestModel == null) - { - throw new QueryExecutionException( - $"Arguments are required for query '{queryName}'"); - } - } - catch - { - throw new QueryExecutionException( - $"Arguments are required for query '{queryName}'"); - } - } - else - { - try - { - requestModel = JsonSerializer.Deserialize(arguments.Value.GetRawText(), requestType, _jsonOptions); - if (requestModel == null) - { - throw new QueryExecutionException( - $"Failed to deserialize request model for query '{queryName}'"); - } - } - catch (JsonException ex) - { - throw new QueryExecutionException( - $"Failed to deserialize request model for query '{queryName}': {ex.Message}", ex); - } - } - + var requestModel = DeserializeRequestModel(queryName, requestType, arguments); var result = query.Execute(requestModel); logger.LogInformation("Successfully executed query: {QueryName}", queryName); @@ -74,4 +37,47 @@ public class QueryService( ex.InnerException ?? ex); } } + + private object DeserializeRequestModel(string queryName, Type requestType, JsonElement? arguments) + { + if (arguments == null) + { + return CreateDefaultRequestModel(queryName, requestType); + } + + try + { + var requestModel = JsonSerializer.Deserialize(arguments.Value.GetRawText(), requestType, _jsonOptions); + if (requestModel == null) + { + throw new QueryExecutionException( + $"Failed to deserialize request model for query '{queryName}'"); + } + return requestModel; + } + catch (JsonException ex) + { + throw new QueryExecutionException( + $"Failed to deserialize request model for query '{queryName}': {ex.Message}", ex); + } + } + + private object CreateDefaultRequestModel(string queryName, Type requestType) + { + try + { + var requestModel = Activator.CreateInstance(requestType); + if (requestModel == null) + { + throw new QueryExecutionException( + $"Arguments are required for query '{queryName}'"); + } + return requestModel; + } + catch + { + throw new QueryExecutionException( + $"Arguments are required for query '{queryName}'"); + } + } } diff --git a/util/SeederApi/Startup.cs b/util/SeederApi/Startup.cs index d19c7193c6..f27b3a006a 100644 --- a/util/SeederApi/Startup.cs +++ b/util/SeederApi/Startup.cs @@ -54,6 +54,12 @@ public class Startup IHostApplicationLifetime appLifetime, GlobalSettings globalSettings) { + if (env.IsProduction()) + { + throw new InvalidOperationException( + "SeederApi cannot be run in production environments. This service is intended for test data generation only."); + } + if (globalSettings.TestPlayIdTrackingEnabled) { app.UseMiddleware();