1
0
mirror of https://github.com/bitwarden/server synced 2025-12-28 06:03:29 +00:00

Resolve review feedback

This commit is contained in:
Hinton
2025-12-19 13:48:57 +01:00
parent afce475833
commit ba879fd872
12 changed files with 243 additions and 71 deletions

View File

@@ -29,6 +29,14 @@ public class NeverPlayIdServices : IPlayIdService
}
}
/// <summary>
/// Singleton wrapper service that bridges singleton-scoped service boundaries for PlayId tracking.
/// This allows singleton services to access the scoped PlayIdService via HttpContext.RequestServices.
///
/// Uses IHttpContextAccessor to retrieve the current request's scoped PlayIdService instance, enabling
/// singleton services to participate in Play session tracking without violating DI lifetime rules.
/// Falls back to NeverPlayIdServices when no HttpContext is available (e.g., background jobs).
/// </summary>
public class PlayIdSingletonService(IHttpContextAccessor httpContextAccessor, IHostEnvironment hostEnvironment) : IPlayIdService
{
private IPlayIdService Current

View File

@@ -9,7 +9,6 @@ using Bit.Core.NotificationCenter.Repositories;
using Bit.Core.Platform.Installations;
using Bit.Core.Repositories;
using Bit.Core.SecretsManager.Repositories;
using Bit.Core.Settings;
using Bit.Core.Tools.Repositories;
using Bit.Core.Vault.Repositories;
using Bit.Infrastructure.Dapper.AdminConsole.Repositories;
@@ -29,19 +28,8 @@ namespace Bit.Infrastructure.Dapper;
public static class DapperServiceCollectionExtensions
{
public static void AddDapperRepositories(this IServiceCollection services, bool selfHosted, GlobalSettings globalSettings)
public static void AddDapperRepositories(this IServiceCollection services, bool selfHosted)
{
if (globalSettings.TestPlayIdTrackingEnabled)
{
services.AddSingleton<IOrganizationRepository, TestOrganizationTrackingOrganizationRepository>();
services.AddSingleton<IUserRepository, TestUserTrackingUserRepository>();
}
else
{
services.AddSingleton<IOrganizationRepository, OrganizationRepository>();
services.AddSingleton<IUserRepository, UserRepository>();
}
services.AddSingleton<IApiKeyRepository, ApiKeyRepository>();
services.AddSingleton<IAuthRequestRepository, AuthRequestRepository>();
services.AddSingleton<ICipherRepository, CipherRepository>();
@@ -59,6 +47,7 @@ public static class DapperServiceCollectionExtensions
services.AddSingleton<IOrganizationConnectionRepository, OrganizationConnectionRepository>();
services.AddSingleton<IOrganizationIntegrationConfigurationRepository, OrganizationIntegrationConfigurationRepository>();
services.AddSingleton<IOrganizationIntegrationRepository, OrganizationIntegrationRepository>();
services.AddSingleton<IOrganizationRepository, OrganizationRepository>();
services.AddSingleton<IOrganizationSponsorshipRepository, OrganizationSponsorshipRepository>();
services.AddSingleton<IOrganizationUserRepository, OrganizationUserRepository>();
services.AddSingleton<IPlayDataRepository, PlayDataRepository>();
@@ -70,6 +59,7 @@ public static class DapperServiceCollectionExtensions
services.AddSingleton<ISsoConfigRepository, SsoConfigRepository>();
services.AddSingleton<ISsoUserRepository, SsoUserRepository>();
services.AddSingleton<ITransactionRepository, TransactionRepository>();
services.AddSingleton<IUserRepository, UserRepository>();
services.AddSingleton<IOrganizationDomainRepository, OrganizationDomainRepository>();
services.AddSingleton<IWebAuthnCredentialRepository, WebAuthnCredentialRepository>();
services.AddSingleton<IProviderPlanRepository, ProviderPlanRepository>();
@@ -92,4 +82,15 @@ public static class DapperServiceCollectionExtensions
services.AddSingleton<IEventRepository, EventRepository>();
}
}
/// <summary>
/// Adds PlayId tracking decorators for User and Organization repositories.
/// This replaces the standard repository implementations with tracking versions
/// that record created entities for test data cleanup. Only call when TestPlayIdTrackingEnabled is true.
/// </summary>
public static void AddPlayIdTrackingRepositories(this IServiceCollection services)
{
services.AddSingleton<IOrganizationRepository, TestOrganizationTrackingOrganizationRepository>();
services.AddSingleton<IUserRepository, TestUserTrackingUserRepository>();
}
}

View File

@@ -68,19 +68,8 @@ public static class EntityFrameworkServiceCollectionExtensions
});
}
public static void AddPasswordManagerEFRepositories(this IServiceCollection services, bool selfHosted, GlobalSettings globalSettings)
public static void AddPasswordManagerEFRepositories(this IServiceCollection services, bool selfHosted)
{
if (globalSettings.TestPlayIdTrackingEnabled)
{
services.AddSingleton<IOrganizationRepository, TestOrganizationTrackingOrganizationRepository>();
services.AddSingleton<IUserRepository, TestUserTrackingUserRepository>();
}
else
{
services.AddSingleton<IOrganizationRepository, OrganizationRepository>();
services.AddSingleton<IUserRepository, UserRepository>();
}
services.AddSingleton<IApiKeyRepository, ApiKeyRepository>();
services.AddSingleton<IAuthRequestRepository, AuthRequestRepository>();
services.AddSingleton<ICipherRepository, CipherRepository>();
@@ -97,6 +86,7 @@ public static class EntityFrameworkServiceCollectionExtensions
services.AddSingleton<IOrganizationConnectionRepository, OrganizationConnectionRepository>();
services.AddSingleton<IOrganizationIntegrationRepository, OrganizationIntegrationRepository>();
services.AddSingleton<IOrganizationIntegrationConfigurationRepository, OrganizationIntegrationConfigurationRepository>();
services.AddSingleton<IOrganizationRepository, OrganizationRepository>();
services.AddSingleton<IOrganizationSponsorshipRepository, OrganizationSponsorshipRepository>();
services.AddSingleton<IOrganizationUserRepository, OrganizationUserRepository>();
services.AddSingleton<IPlayDataRepository, PlayDataRepository>();
@@ -108,6 +98,7 @@ public static class EntityFrameworkServiceCollectionExtensions
services.AddSingleton<ISsoConfigRepository, SsoConfigRepository>();
services.AddSingleton<ISsoUserRepository, SsoUserRepository>();
services.AddSingleton<ITransactionRepository, TransactionRepository>();
services.AddSingleton<IUserRepository, UserRepository>();
services.AddSingleton<IOrganizationDomainRepository, OrganizationDomainRepository>();
services.AddSingleton<IWebAuthnCredentialRepository, WebAuthnCredentialRepository>();
services.AddSingleton<IProviderPlanRepository, ProviderPlanRepository>();
@@ -130,4 +121,15 @@ public static class EntityFrameworkServiceCollectionExtensions
services.AddSingleton<IEventRepository, EventRepository>();
}
}
/// <summary>
/// Adds PlayId tracking decorators for User and Organization repositories.
/// This replaces the standard repository implementations with tracking versions
/// that record created entities for test data cleanup. Only call when TestPlayIdTrackingEnabled is true.
/// </summary>
public static void AddPlayIdTrackingEFRepositories(this IServiceCollection services)
{
services.AddSingleton<IOrganizationRepository, TestOrganizationTrackingOrganizationRepository>();
services.AddSingleton<IUserRepository, TestUserTrackingUserRepository>();
}
}

View File

@@ -11,13 +11,31 @@ namespace Bit.SharedWeb.Utilities;
/// <param name="next"></param>
public sealed class PlayIdMiddleware(RequestDelegate next)
{
public Task Invoke(HttpContext context, PlayIdService playIdService)
private const int MaxPlayIdLength = 256;
public async Task Invoke(HttpContext context, PlayIdService playIdService)
{
if (context.Request.Headers.TryGetValue("x-play-id", out var playId))
{
playIdService.PlayId = playId;
var playIdValue = playId.ToString();
if (string.IsNullOrWhiteSpace(playIdValue))
{
context.Response.StatusCode = StatusCodes.Status400BadRequest;
await context.Response.WriteAsJsonAsync(new { Error = "x-play-id header cannot be empty or whitespace" });
return;
}
if (playIdValue.Length > MaxPlayIdLength)
{
context.Response.StatusCode = StatusCodes.Status400BadRequest;
await context.Response.WriteAsJsonAsync(new { Error = $"x-play-id header cannot exceed {MaxPlayIdLength} characters" });
return;
}
playIdService.PlayId = playIdValue;
}
return next(context);
await next(context);
}
}

View File

@@ -96,11 +96,11 @@ public static class ServiceCollectionExtensions
if (provider != SupportedDatabaseProviders.SqlServer)
{
services.AddPasswordManagerEFRepositories(globalSettings.SelfHosted, globalSettings);
services.AddPasswordManagerEFRepositories(globalSettings.SelfHosted);
}
else
{
services.AddDapperRepositories(globalSettings.SelfHosted, globalSettings);
services.AddDapperRepositories(globalSettings.SelfHosted);
}
if (globalSettings.SelfHosted)
@@ -123,6 +123,16 @@ public static class ServiceCollectionExtensions
services.AddSingleton<IPlayDataService, PlayDataService>();
services.AddSingleton<IPlayIdService, PlayIdSingletonService>();
services.AddScoped<PlayIdService>();
// Replace standard repositories with PlayId tracking decorators
if (provider == SupportedDatabaseProviders.SqlServer)
{
services.AddPlayIdTrackingRepositories();
}
else
{
services.AddPlayIdTrackingEFRepositories();
}
}
else
{

View File

@@ -5,7 +5,11 @@ BEGIN
SET NOCOUNT ON
SELECT
*
[Id],
[PlayId],
[UserId],
[OrganizationId],
[CreationDate]
FROM
[dbo].[PlayData]
WHERE

View File

@@ -0,0 +1,102 @@
using Bit.Core.Services;
using Bit.SharedWeb.Utilities;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Hosting;
using NSubstitute;
namespace SharedWeb.Test;
public class PlayIdMiddlewareTests
{
private readonly PlayIdService _playIdService;
private readonly RequestDelegate _next;
private readonly PlayIdMiddleware _middleware;
public PlayIdMiddlewareTests()
{
var hostEnvironment = Substitute.For<IHostEnvironment>();
hostEnvironment.EnvironmentName.Returns(Environments.Development);
_playIdService = new PlayIdService(hostEnvironment);
_next = Substitute.For<RequestDelegate>();
_middleware = new PlayIdMiddleware(_next);
}
[Fact]
public async Task Invoke_WithValidPlayId_SetsPlayIdAndCallsNext()
{
var context = new DefaultHttpContext();
context.Request.Headers["x-play-id"] = "test-play-id";
await _middleware.Invoke(context, _playIdService);
Assert.Equal("test-play-id", _playIdService.PlayId);
await _next.Received(1).Invoke(context);
}
[Fact]
public async Task Invoke_WithoutPlayIdHeader_CallsNext()
{
var context = new DefaultHttpContext();
await _middleware.Invoke(context, _playIdService);
Assert.Null(_playIdService.PlayId);
await _next.Received(1).Invoke(context);
}
[Theory]
[InlineData("")]
[InlineData(" ")]
[InlineData("\t")]
public async Task Invoke_WithEmptyOrWhitespacePlayId_Returns400(string playId)
{
var context = new DefaultHttpContext();
context.Response.Body = new MemoryStream();
context.Request.Headers["x-play-id"] = playId;
await _middleware.Invoke(context, _playIdService);
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
await _next.DidNotReceive().Invoke(context);
}
[Fact]
public async Task Invoke_WithPlayIdExceedingMaxLength_Returns400()
{
var context = new DefaultHttpContext();
context.Response.Body = new MemoryStream();
var longPlayId = new string('a', 257); // Exceeds 256 character limit
context.Request.Headers["x-play-id"] = longPlayId;
await _middleware.Invoke(context, _playIdService);
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
await _next.DidNotReceive().Invoke(context);
}
[Fact]
public async Task Invoke_WithPlayIdAtMaxLength_SetsPlayIdAndCallsNext()
{
var context = new DefaultHttpContext();
var maxLengthPlayId = new string('a', 256); // Exactly 256 characters
context.Request.Headers["x-play-id"] = maxLengthPlayId;
await _middleware.Invoke(context, _playIdService);
Assert.Equal(maxLengthPlayId, _playIdService.PlayId);
await _next.Received(1).Invoke(context);
}
[Fact]
public async Task Invoke_WithSpecialCharactersInPlayId_SetsPlayIdAndCallsNext()
{
var context = new DefaultHttpContext();
context.Request.Headers["x-play-id"] = "test-play_id.123";
await _middleware.Invoke(context, _playIdService);
Assert.Equal("test-play_id.123", _playIdService.PlayId);
await _next.Received(1).Invoke(context);
}
}

View File

@@ -62,7 +62,11 @@ BEGIN
SET NOCOUNT ON
SELECT
*
[Id],
[PlayId],
[UserId],
[OrganizationId],
[CreationDate]
FROM
[dbo].[PlayData]
WHERE

View File

@@ -1,7 +1,9 @@
namespace Bit.Seeder;
/// <summary>
/// Helper for mangling IDs
/// Helper for generating unique identifier suffixes to prevent collisions in test data.
/// "Mangling" adds a random suffix to test data identifiers (usernames, emails, org names, etc.)
/// to ensure uniqueness across multiple test runs and parallel test executions.
/// </summary>
public class MangleId
{

View File

@@ -1,8 +1,17 @@
namespace Bit.Seeder;
/// <summary>
/// Helper for exposing a <see cref="IScene" /> interface with a SeedAsync method.
/// </summary>
public class SceneResult(Dictionary<string, string?> mangleMap)
: SceneResult<object?>(result: null, mangleMap: mangleMap);
/// <summary>
/// Generic result from executing a Scene.
/// Contains custom scene-specific data and a mangle map that maps magic strings from the
/// request to their mangled (collision-free) values inserted into the database.
/// </summary>
/// <typeparam name="TResult">The type of custom result data returned by the scene.</typeparam>
public class SceneResult<TResult>(TResult result, Dictionary<string, string?> mangleMap)
{
public TResult Result { get; init; } = result;

View File

@@ -23,44 +23,7 @@ public class QueryService(
var requestType = query.GetRequestType();
// Deserialize the arguments into the request model
object? requestModel;
if (arguments == null)
{
// Try to create an instance with default values
try
{
requestModel = Activator.CreateInstance(requestType);
if (requestModel == null)
{
throw new QueryExecutionException(
$"Arguments are required for query '{queryName}'");
}
}
catch
{
throw new QueryExecutionException(
$"Arguments are required for query '{queryName}'");
}
}
else
{
try
{
requestModel = JsonSerializer.Deserialize(arguments.Value.GetRawText(), requestType, _jsonOptions);
if (requestModel == null)
{
throw new QueryExecutionException(
$"Failed to deserialize request model for query '{queryName}'");
}
}
catch (JsonException ex)
{
throw new QueryExecutionException(
$"Failed to deserialize request model for query '{queryName}': {ex.Message}", ex);
}
}
var requestModel = DeserializeRequestModel(queryName, requestType, arguments);
var result = query.Execute(requestModel);
logger.LogInformation("Successfully executed query: {QueryName}", queryName);
@@ -74,4 +37,47 @@ public class QueryService(
ex.InnerException ?? ex);
}
}
private object DeserializeRequestModel(string queryName, Type requestType, JsonElement? arguments)
{
if (arguments == null)
{
return CreateDefaultRequestModel(queryName, requestType);
}
try
{
var requestModel = JsonSerializer.Deserialize(arguments.Value.GetRawText(), requestType, _jsonOptions);
if (requestModel == null)
{
throw new QueryExecutionException(
$"Failed to deserialize request model for query '{queryName}'");
}
return requestModel;
}
catch (JsonException ex)
{
throw new QueryExecutionException(
$"Failed to deserialize request model for query '{queryName}': {ex.Message}", ex);
}
}
private object CreateDefaultRequestModel(string queryName, Type requestType)
{
try
{
var requestModel = Activator.CreateInstance(requestType);
if (requestModel == null)
{
throw new QueryExecutionException(
$"Arguments are required for query '{queryName}'");
}
return requestModel;
}
catch
{
throw new QueryExecutionException(
$"Arguments are required for query '{queryName}'");
}
}
}

View File

@@ -54,6 +54,12 @@ public class Startup
IHostApplicationLifetime appLifetime,
GlobalSettings globalSettings)
{
if (env.IsProduction())
{
throw new InvalidOperationException(
"SeederApi cannot be run in production environments. This service is intended for test data generation only.");
}
if (globalSettings.TestPlayIdTrackingEnabled)
{
app.UseMiddleware<PlayIdMiddleware>();