mirror of
https://github.com/bitwarden/server
synced 2025-12-28 06:03:29 +00:00
Resolve review feedback
This commit is contained in:
@@ -29,6 +29,14 @@ public class NeverPlayIdServices : IPlayIdService
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Singleton wrapper service that bridges singleton-scoped service boundaries for PlayId tracking.
|
||||
/// This allows singleton services to access the scoped PlayIdService via HttpContext.RequestServices.
|
||||
///
|
||||
/// Uses IHttpContextAccessor to retrieve the current request's scoped PlayIdService instance, enabling
|
||||
/// singleton services to participate in Play session tracking without violating DI lifetime rules.
|
||||
/// Falls back to NeverPlayIdServices when no HttpContext is available (e.g., background jobs).
|
||||
/// </summary>
|
||||
public class PlayIdSingletonService(IHttpContextAccessor httpContextAccessor, IHostEnvironment hostEnvironment) : IPlayIdService
|
||||
{
|
||||
private IPlayIdService Current
|
||||
|
||||
@@ -9,7 +9,6 @@ using Bit.Core.NotificationCenter.Repositories;
|
||||
using Bit.Core.Platform.Installations;
|
||||
using Bit.Core.Repositories;
|
||||
using Bit.Core.SecretsManager.Repositories;
|
||||
using Bit.Core.Settings;
|
||||
using Bit.Core.Tools.Repositories;
|
||||
using Bit.Core.Vault.Repositories;
|
||||
using Bit.Infrastructure.Dapper.AdminConsole.Repositories;
|
||||
@@ -29,19 +28,8 @@ namespace Bit.Infrastructure.Dapper;
|
||||
|
||||
public static class DapperServiceCollectionExtensions
|
||||
{
|
||||
public static void AddDapperRepositories(this IServiceCollection services, bool selfHosted, GlobalSettings globalSettings)
|
||||
public static void AddDapperRepositories(this IServiceCollection services, bool selfHosted)
|
||||
{
|
||||
if (globalSettings.TestPlayIdTrackingEnabled)
|
||||
{
|
||||
services.AddSingleton<IOrganizationRepository, TestOrganizationTrackingOrganizationRepository>();
|
||||
services.AddSingleton<IUserRepository, TestUserTrackingUserRepository>();
|
||||
}
|
||||
else
|
||||
{
|
||||
services.AddSingleton<IOrganizationRepository, OrganizationRepository>();
|
||||
services.AddSingleton<IUserRepository, UserRepository>();
|
||||
}
|
||||
|
||||
services.AddSingleton<IApiKeyRepository, ApiKeyRepository>();
|
||||
services.AddSingleton<IAuthRequestRepository, AuthRequestRepository>();
|
||||
services.AddSingleton<ICipherRepository, CipherRepository>();
|
||||
@@ -59,6 +47,7 @@ public static class DapperServiceCollectionExtensions
|
||||
services.AddSingleton<IOrganizationConnectionRepository, OrganizationConnectionRepository>();
|
||||
services.AddSingleton<IOrganizationIntegrationConfigurationRepository, OrganizationIntegrationConfigurationRepository>();
|
||||
services.AddSingleton<IOrganizationIntegrationRepository, OrganizationIntegrationRepository>();
|
||||
services.AddSingleton<IOrganizationRepository, OrganizationRepository>();
|
||||
services.AddSingleton<IOrganizationSponsorshipRepository, OrganizationSponsorshipRepository>();
|
||||
services.AddSingleton<IOrganizationUserRepository, OrganizationUserRepository>();
|
||||
services.AddSingleton<IPlayDataRepository, PlayDataRepository>();
|
||||
@@ -70,6 +59,7 @@ public static class DapperServiceCollectionExtensions
|
||||
services.AddSingleton<ISsoConfigRepository, SsoConfigRepository>();
|
||||
services.AddSingleton<ISsoUserRepository, SsoUserRepository>();
|
||||
services.AddSingleton<ITransactionRepository, TransactionRepository>();
|
||||
services.AddSingleton<IUserRepository, UserRepository>();
|
||||
services.AddSingleton<IOrganizationDomainRepository, OrganizationDomainRepository>();
|
||||
services.AddSingleton<IWebAuthnCredentialRepository, WebAuthnCredentialRepository>();
|
||||
services.AddSingleton<IProviderPlanRepository, ProviderPlanRepository>();
|
||||
@@ -92,4 +82,15 @@ public static class DapperServiceCollectionExtensions
|
||||
services.AddSingleton<IEventRepository, EventRepository>();
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds PlayId tracking decorators for User and Organization repositories.
|
||||
/// This replaces the standard repository implementations with tracking versions
|
||||
/// that record created entities for test data cleanup. Only call when TestPlayIdTrackingEnabled is true.
|
||||
/// </summary>
|
||||
public static void AddPlayIdTrackingRepositories(this IServiceCollection services)
|
||||
{
|
||||
services.AddSingleton<IOrganizationRepository, TestOrganizationTrackingOrganizationRepository>();
|
||||
services.AddSingleton<IUserRepository, TestUserTrackingUserRepository>();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,19 +68,8 @@ public static class EntityFrameworkServiceCollectionExtensions
|
||||
});
|
||||
}
|
||||
|
||||
public static void AddPasswordManagerEFRepositories(this IServiceCollection services, bool selfHosted, GlobalSettings globalSettings)
|
||||
public static void AddPasswordManagerEFRepositories(this IServiceCollection services, bool selfHosted)
|
||||
{
|
||||
if (globalSettings.TestPlayIdTrackingEnabled)
|
||||
{
|
||||
services.AddSingleton<IOrganizationRepository, TestOrganizationTrackingOrganizationRepository>();
|
||||
services.AddSingleton<IUserRepository, TestUserTrackingUserRepository>();
|
||||
}
|
||||
else
|
||||
{
|
||||
services.AddSingleton<IOrganizationRepository, OrganizationRepository>();
|
||||
services.AddSingleton<IUserRepository, UserRepository>();
|
||||
}
|
||||
|
||||
services.AddSingleton<IApiKeyRepository, ApiKeyRepository>();
|
||||
services.AddSingleton<IAuthRequestRepository, AuthRequestRepository>();
|
||||
services.AddSingleton<ICipherRepository, CipherRepository>();
|
||||
@@ -97,6 +86,7 @@ public static class EntityFrameworkServiceCollectionExtensions
|
||||
services.AddSingleton<IOrganizationConnectionRepository, OrganizationConnectionRepository>();
|
||||
services.AddSingleton<IOrganizationIntegrationRepository, OrganizationIntegrationRepository>();
|
||||
services.AddSingleton<IOrganizationIntegrationConfigurationRepository, OrganizationIntegrationConfigurationRepository>();
|
||||
services.AddSingleton<IOrganizationRepository, OrganizationRepository>();
|
||||
services.AddSingleton<IOrganizationSponsorshipRepository, OrganizationSponsorshipRepository>();
|
||||
services.AddSingleton<IOrganizationUserRepository, OrganizationUserRepository>();
|
||||
services.AddSingleton<IPlayDataRepository, PlayDataRepository>();
|
||||
@@ -108,6 +98,7 @@ public static class EntityFrameworkServiceCollectionExtensions
|
||||
services.AddSingleton<ISsoConfigRepository, SsoConfigRepository>();
|
||||
services.AddSingleton<ISsoUserRepository, SsoUserRepository>();
|
||||
services.AddSingleton<ITransactionRepository, TransactionRepository>();
|
||||
services.AddSingleton<IUserRepository, UserRepository>();
|
||||
services.AddSingleton<IOrganizationDomainRepository, OrganizationDomainRepository>();
|
||||
services.AddSingleton<IWebAuthnCredentialRepository, WebAuthnCredentialRepository>();
|
||||
services.AddSingleton<IProviderPlanRepository, ProviderPlanRepository>();
|
||||
@@ -130,4 +121,15 @@ public static class EntityFrameworkServiceCollectionExtensions
|
||||
services.AddSingleton<IEventRepository, EventRepository>();
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds PlayId tracking decorators for User and Organization repositories.
|
||||
/// This replaces the standard repository implementations with tracking versions
|
||||
/// that record created entities for test data cleanup. Only call when TestPlayIdTrackingEnabled is true.
|
||||
/// </summary>
|
||||
public static void AddPlayIdTrackingEFRepositories(this IServiceCollection services)
|
||||
{
|
||||
services.AddSingleton<IOrganizationRepository, TestOrganizationTrackingOrganizationRepository>();
|
||||
services.AddSingleton<IUserRepository, TestUserTrackingUserRepository>();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,13 +11,31 @@ namespace Bit.SharedWeb.Utilities;
|
||||
/// <param name="next"></param>
|
||||
public sealed class PlayIdMiddleware(RequestDelegate next)
|
||||
{
|
||||
public Task Invoke(HttpContext context, PlayIdService playIdService)
|
||||
private const int MaxPlayIdLength = 256;
|
||||
|
||||
public async Task Invoke(HttpContext context, PlayIdService playIdService)
|
||||
{
|
||||
if (context.Request.Headers.TryGetValue("x-play-id", out var playId))
|
||||
{
|
||||
playIdService.PlayId = playId;
|
||||
var playIdValue = playId.ToString();
|
||||
|
||||
if (string.IsNullOrWhiteSpace(playIdValue))
|
||||
{
|
||||
context.Response.StatusCode = StatusCodes.Status400BadRequest;
|
||||
await context.Response.WriteAsJsonAsync(new { Error = "x-play-id header cannot be empty or whitespace" });
|
||||
return;
|
||||
}
|
||||
|
||||
if (playIdValue.Length > MaxPlayIdLength)
|
||||
{
|
||||
context.Response.StatusCode = StatusCodes.Status400BadRequest;
|
||||
await context.Response.WriteAsJsonAsync(new { Error = $"x-play-id header cannot exceed {MaxPlayIdLength} characters" });
|
||||
return;
|
||||
}
|
||||
|
||||
playIdService.PlayId = playIdValue;
|
||||
}
|
||||
|
||||
return next(context);
|
||||
await next(context);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,11 +96,11 @@ public static class ServiceCollectionExtensions
|
||||
|
||||
if (provider != SupportedDatabaseProviders.SqlServer)
|
||||
{
|
||||
services.AddPasswordManagerEFRepositories(globalSettings.SelfHosted, globalSettings);
|
||||
services.AddPasswordManagerEFRepositories(globalSettings.SelfHosted);
|
||||
}
|
||||
else
|
||||
{
|
||||
services.AddDapperRepositories(globalSettings.SelfHosted, globalSettings);
|
||||
services.AddDapperRepositories(globalSettings.SelfHosted);
|
||||
}
|
||||
|
||||
if (globalSettings.SelfHosted)
|
||||
@@ -123,6 +123,16 @@ public static class ServiceCollectionExtensions
|
||||
services.AddSingleton<IPlayDataService, PlayDataService>();
|
||||
services.AddSingleton<IPlayIdService, PlayIdSingletonService>();
|
||||
services.AddScoped<PlayIdService>();
|
||||
|
||||
// Replace standard repositories with PlayId tracking decorators
|
||||
if (provider == SupportedDatabaseProviders.SqlServer)
|
||||
{
|
||||
services.AddPlayIdTrackingRepositories();
|
||||
}
|
||||
else
|
||||
{
|
||||
services.AddPlayIdTrackingEFRepositories();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -5,7 +5,11 @@ BEGIN
|
||||
SET NOCOUNT ON
|
||||
|
||||
SELECT
|
||||
*
|
||||
[Id],
|
||||
[PlayId],
|
||||
[UserId],
|
||||
[OrganizationId],
|
||||
[CreationDate]
|
||||
FROM
|
||||
[dbo].[PlayData]
|
||||
WHERE
|
||||
|
||||
102
test/SharedWeb.Test/PlayIdMiddlewareTests.cs
Normal file
102
test/SharedWeb.Test/PlayIdMiddlewareTests.cs
Normal file
@@ -0,0 +1,102 @@
|
||||
using Bit.Core.Services;
|
||||
using Bit.SharedWeb.Utilities;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.Extensions.Hosting;
|
||||
using NSubstitute;
|
||||
|
||||
namespace SharedWeb.Test;
|
||||
|
||||
public class PlayIdMiddlewareTests
|
||||
{
|
||||
private readonly PlayIdService _playIdService;
|
||||
private readonly RequestDelegate _next;
|
||||
private readonly PlayIdMiddleware _middleware;
|
||||
|
||||
public PlayIdMiddlewareTests()
|
||||
{
|
||||
var hostEnvironment = Substitute.For<IHostEnvironment>();
|
||||
hostEnvironment.EnvironmentName.Returns(Environments.Development);
|
||||
|
||||
_playIdService = new PlayIdService(hostEnvironment);
|
||||
_next = Substitute.For<RequestDelegate>();
|
||||
_middleware = new PlayIdMiddleware(_next);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Invoke_WithValidPlayId_SetsPlayIdAndCallsNext()
|
||||
{
|
||||
var context = new DefaultHttpContext();
|
||||
context.Request.Headers["x-play-id"] = "test-play-id";
|
||||
|
||||
await _middleware.Invoke(context, _playIdService);
|
||||
|
||||
Assert.Equal("test-play-id", _playIdService.PlayId);
|
||||
await _next.Received(1).Invoke(context);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Invoke_WithoutPlayIdHeader_CallsNext()
|
||||
{
|
||||
var context = new DefaultHttpContext();
|
||||
|
||||
await _middleware.Invoke(context, _playIdService);
|
||||
|
||||
Assert.Null(_playIdService.PlayId);
|
||||
await _next.Received(1).Invoke(context);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData("")]
|
||||
[InlineData(" ")]
|
||||
[InlineData("\t")]
|
||||
public async Task Invoke_WithEmptyOrWhitespacePlayId_Returns400(string playId)
|
||||
{
|
||||
var context = new DefaultHttpContext();
|
||||
context.Response.Body = new MemoryStream();
|
||||
context.Request.Headers["x-play-id"] = playId;
|
||||
|
||||
await _middleware.Invoke(context, _playIdService);
|
||||
|
||||
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
|
||||
await _next.DidNotReceive().Invoke(context);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Invoke_WithPlayIdExceedingMaxLength_Returns400()
|
||||
{
|
||||
var context = new DefaultHttpContext();
|
||||
context.Response.Body = new MemoryStream();
|
||||
var longPlayId = new string('a', 257); // Exceeds 256 character limit
|
||||
context.Request.Headers["x-play-id"] = longPlayId;
|
||||
|
||||
await _middleware.Invoke(context, _playIdService);
|
||||
|
||||
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
|
||||
await _next.DidNotReceive().Invoke(context);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Invoke_WithPlayIdAtMaxLength_SetsPlayIdAndCallsNext()
|
||||
{
|
||||
var context = new DefaultHttpContext();
|
||||
var maxLengthPlayId = new string('a', 256); // Exactly 256 characters
|
||||
context.Request.Headers["x-play-id"] = maxLengthPlayId;
|
||||
|
||||
await _middleware.Invoke(context, _playIdService);
|
||||
|
||||
Assert.Equal(maxLengthPlayId, _playIdService.PlayId);
|
||||
await _next.Received(1).Invoke(context);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Invoke_WithSpecialCharactersInPlayId_SetsPlayIdAndCallsNext()
|
||||
{
|
||||
var context = new DefaultHttpContext();
|
||||
context.Request.Headers["x-play-id"] = "test-play_id.123";
|
||||
|
||||
await _middleware.Invoke(context, _playIdService);
|
||||
|
||||
Assert.Equal("test-play_id.123", _playIdService.PlayId);
|
||||
await _next.Received(1).Invoke(context);
|
||||
}
|
||||
}
|
||||
@@ -62,7 +62,11 @@ BEGIN
|
||||
SET NOCOUNT ON
|
||||
|
||||
SELECT
|
||||
*
|
||||
[Id],
|
||||
[PlayId],
|
||||
[UserId],
|
||||
[OrganizationId],
|
||||
[CreationDate]
|
||||
FROM
|
||||
[dbo].[PlayData]
|
||||
WHERE
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
namespace Bit.Seeder;
|
||||
|
||||
/// <summary>
|
||||
/// Helper for mangling IDs
|
||||
/// Helper for generating unique identifier suffixes to prevent collisions in test data.
|
||||
/// "Mangling" adds a random suffix to test data identifiers (usernames, emails, org names, etc.)
|
||||
/// to ensure uniqueness across multiple test runs and parallel test executions.
|
||||
/// </summary>
|
||||
public class MangleId
|
||||
{
|
||||
|
||||
@@ -1,8 +1,17 @@
|
||||
namespace Bit.Seeder;
|
||||
|
||||
/// <summary>
|
||||
/// Helper for exposing a <see cref="IScene" /> interface with a SeedAsync method.
|
||||
/// </summary>
|
||||
public class SceneResult(Dictionary<string, string?> mangleMap)
|
||||
: SceneResult<object?>(result: null, mangleMap: mangleMap);
|
||||
|
||||
/// <summary>
|
||||
/// Generic result from executing a Scene.
|
||||
/// Contains custom scene-specific data and a mangle map that maps magic strings from the
|
||||
/// request to their mangled (collision-free) values inserted into the database.
|
||||
/// </summary>
|
||||
/// <typeparam name="TResult">The type of custom result data returned by the scene.</typeparam>
|
||||
public class SceneResult<TResult>(TResult result, Dictionary<string, string?> mangleMap)
|
||||
{
|
||||
public TResult Result { get; init; } = result;
|
||||
|
||||
@@ -23,44 +23,7 @@ public class QueryService(
|
||||
|
||||
var requestType = query.GetRequestType();
|
||||
|
||||
// Deserialize the arguments into the request model
|
||||
object? requestModel;
|
||||
if (arguments == null)
|
||||
{
|
||||
// Try to create an instance with default values
|
||||
try
|
||||
{
|
||||
requestModel = Activator.CreateInstance(requestType);
|
||||
if (requestModel == null)
|
||||
{
|
||||
throw new QueryExecutionException(
|
||||
$"Arguments are required for query '{queryName}'");
|
||||
}
|
||||
}
|
||||
catch
|
||||
{
|
||||
throw new QueryExecutionException(
|
||||
$"Arguments are required for query '{queryName}'");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
try
|
||||
{
|
||||
requestModel = JsonSerializer.Deserialize(arguments.Value.GetRawText(), requestType, _jsonOptions);
|
||||
if (requestModel == null)
|
||||
{
|
||||
throw new QueryExecutionException(
|
||||
$"Failed to deserialize request model for query '{queryName}'");
|
||||
}
|
||||
}
|
||||
catch (JsonException ex)
|
||||
{
|
||||
throw new QueryExecutionException(
|
||||
$"Failed to deserialize request model for query '{queryName}': {ex.Message}", ex);
|
||||
}
|
||||
}
|
||||
|
||||
var requestModel = DeserializeRequestModel(queryName, requestType, arguments);
|
||||
var result = query.Execute(requestModel);
|
||||
|
||||
logger.LogInformation("Successfully executed query: {QueryName}", queryName);
|
||||
@@ -74,4 +37,47 @@ public class QueryService(
|
||||
ex.InnerException ?? ex);
|
||||
}
|
||||
}
|
||||
|
||||
private object DeserializeRequestModel(string queryName, Type requestType, JsonElement? arguments)
|
||||
{
|
||||
if (arguments == null)
|
||||
{
|
||||
return CreateDefaultRequestModel(queryName, requestType);
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
var requestModel = JsonSerializer.Deserialize(arguments.Value.GetRawText(), requestType, _jsonOptions);
|
||||
if (requestModel == null)
|
||||
{
|
||||
throw new QueryExecutionException(
|
||||
$"Failed to deserialize request model for query '{queryName}'");
|
||||
}
|
||||
return requestModel;
|
||||
}
|
||||
catch (JsonException ex)
|
||||
{
|
||||
throw new QueryExecutionException(
|
||||
$"Failed to deserialize request model for query '{queryName}': {ex.Message}", ex);
|
||||
}
|
||||
}
|
||||
|
||||
private object CreateDefaultRequestModel(string queryName, Type requestType)
|
||||
{
|
||||
try
|
||||
{
|
||||
var requestModel = Activator.CreateInstance(requestType);
|
||||
if (requestModel == null)
|
||||
{
|
||||
throw new QueryExecutionException(
|
||||
$"Arguments are required for query '{queryName}'");
|
||||
}
|
||||
return requestModel;
|
||||
}
|
||||
catch
|
||||
{
|
||||
throw new QueryExecutionException(
|
||||
$"Arguments are required for query '{queryName}'");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,6 +54,12 @@ public class Startup
|
||||
IHostApplicationLifetime appLifetime,
|
||||
GlobalSettings globalSettings)
|
||||
{
|
||||
if (env.IsProduction())
|
||||
{
|
||||
throw new InvalidOperationException(
|
||||
"SeederApi cannot be run in production environments. This service is intended for test data generation only.");
|
||||
}
|
||||
|
||||
if (globalSettings.TestPlayIdTrackingEnabled)
|
||||
{
|
||||
app.UseMiddleware<PlayIdMiddleware>();
|
||||
|
||||
Reference in New Issue
Block a user