1
0
mirror of https://github.com/bitwarden/server synced 2025-12-19 17:53:44 +00:00

use DI to determine whether to track play Ids

This commit is contained in:
Matt Gibson
2025-11-18 05:07:03 -08:00
parent d6eaafb308
commit 439bf37b7f
9 changed files with 179 additions and 101 deletions

View File

@@ -54,6 +54,7 @@ public class GlobalSettings : IGlobalSettings
public virtual bool EnableCloudCommunication { get; set; } = false;
public virtual int OrganizationInviteExpirationHours { get; set; } = 120; // 5 days
public virtual string EventGridKey { get; set; }
public virtual bool TestPlayIdTrackingEnabled { get; set; } = false;
public virtual IInstallationSettings Installation { get; set; } = new InstallationSettings();
public virtual IBaseServiceUriSettings BaseServiceUri { get; set; }
public virtual string DatabaseProvider { get; set; }

View File

@@ -18,37 +18,16 @@ namespace Bit.Infrastructure.Dapper.Repositories;
public class OrganizationRepository : Repository<Organization, Guid>, IOrganizationRepository
{
private readonly IPlayIdService _playIdService;
private readonly IPlayDataRepository _playDataRepository;
private readonly ILogger<OrganizationRepository> _logger;
protected readonly ILogger<OrganizationRepository> _logger;
public OrganizationRepository(
IPlayIdService playIdService,
IPlayDataRepository playDataRepository,
GlobalSettings globalSettings,
ILogger<OrganizationRepository> logger)
: base(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString)
{
_playIdService = playIdService;
_playDataRepository = playDataRepository;
_logger = logger;
}
public override async Task<Organization> CreateAsync(Organization obj)
{
await base.CreateAsync(obj);
if (_playIdService.InPlay(out var playId))
{
_logger.LogInformation("Associating organization {OrganizationId} with Play ID {PlayId}",
obj.Id, playId);
await _playDataRepository.CreateAsync(PlayData.Create(obj, playId));
}
return obj;
}
public async Task<Organization?> GetByIdentifierAsync(string identifier)
{
using (var connection = new SqlConnection(ConnectionString))
@@ -274,3 +253,35 @@ public class OrganizationRepository : Repository<Organization, Guid>, IOrganizat
commandType: CommandType.StoredProcedure);
}
}
public class TestOrganizationTrackingOrganizationRepository : OrganizationRepository
{
private readonly IPlayIdService _playIdService;
private readonly IPlayDataRepository _playDataRepository;
public TestOrganizationTrackingOrganizationRepository(
IPlayIdService playIdService,
IPlayDataRepository playDataRepository,
GlobalSettings globalSettings,
ILogger<OrganizationRepository> logger)
: base(globalSettings, logger)
{
_playIdService = playIdService;
_playDataRepository = playDataRepository;
}
public override async Task<Organization> CreateAsync(Organization obj)
{
var createdOrganization = await base.CreateAsync(obj);
if (_playIdService.InPlay(out var playId))
{
_logger.LogInformation("Associating organization {OrganizationId} with Play ID {PlayId}",
createdOrganization.Id, playId);
await _playDataRepository.CreateAsync(PlayData.Create(createdOrganization, playId));
}
return createdOrganization;
}
}

View File

@@ -9,6 +9,7 @@ using Bit.Core.NotificationCenter.Repositories;
using Bit.Core.Platform.Installations;
using Bit.Core.Repositories;
using Bit.Core.SecretsManager.Repositories;
using Bit.Core.Settings;
using Bit.Core.Tools.Repositories;
using Bit.Core.Vault.Repositories;
using Bit.Infrastructure.Dapper.AdminConsole.Repositories;
@@ -28,8 +29,19 @@ namespace Bit.Infrastructure.Dapper;
public static class DapperServiceCollectionExtensions
{
public static void AddDapperRepositories(this IServiceCollection services, bool selfHosted)
public static void AddDapperRepositories(this IServiceCollection services, bool selfHosted, GlobalSettings globalSettings)
{
if (globalSettings.TestPlayIdTrackingEnabled)
{
services.AddSingleton<IOrganizationRepository, TestOrganizationTrackingOrganizationRepository>();
services.AddSingleton<IUserRepository, TestUserTrackingUserRepository>();
}
else
{
services.AddSingleton<IOrganizationRepository, OrganizationRepository>();
services.AddSingleton<IUserRepository, UserRepository>();
}
services.AddSingleton<IApiKeyRepository, ApiKeyRepository>();
services.AddSingleton<IAuthRequestRepository, AuthRequestRepository>();
services.AddSingleton<ICipherRepository, CipherRepository>();
@@ -47,7 +59,6 @@ public static class DapperServiceCollectionExtensions
services.AddSingleton<IOrganizationConnectionRepository, OrganizationConnectionRepository>();
services.AddSingleton<IOrganizationIntegrationConfigurationRepository, OrganizationIntegrationConfigurationRepository>();
services.AddSingleton<IOrganizationIntegrationRepository, OrganizationIntegrationRepository>();
services.AddSingleton<IOrganizationRepository, OrganizationRepository>();
services.AddSingleton<IOrganizationSponsorshipRepository, OrganizationSponsorshipRepository>();
services.AddSingleton<IOrganizationUserRepository, OrganizationUserRepository>();
services.AddSingleton<IPlayDataRepository, PlayDataRepository>();
@@ -59,7 +70,6 @@ public static class DapperServiceCollectionExtensions
services.AddSingleton<ISsoConfigRepository, SsoConfigRepository>();
services.AddSingleton<ISsoUserRepository, SsoUserRepository>();
services.AddSingleton<ITransactionRepository, TransactionRepository>();
services.AddSingleton<IUserRepository, UserRepository>();
services.AddSingleton<IOrganizationDomainRepository, OrganizationDomainRepository>();
services.AddSingleton<IWebAuthnCredentialRepository, WebAuthnCredentialRepository>();
services.AddSingleton<IProviderPlanRepository, ProviderPlanRepository>();

View File

@@ -18,21 +18,15 @@ namespace Bit.Infrastructure.Dapper.Repositories;
public class UserRepository : Repository<User, Guid>, IUserRepository
{
private readonly IPlayIdService _playIdService;
private readonly IPlayDataRepository _playDataRepository;
private readonly IDataProtector _dataProtector;
private readonly ILogger<UserRepository> _logger;
protected readonly ILogger<UserRepository> _logger;
public UserRepository(
IPlayIdService playIdService,
GlobalSettings globalSettings,
IPlayDataRepository playDataRepository,
IDataProtectionProvider dataProtectionProvider,
GlobalSettings globalSettings,
ILogger<UserRepository> logger)
: base(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString)
{
_playIdService = playIdService;
_playDataRepository = playDataRepository;
_dataProtector = dataProtectionProvider.CreateProtector(Constants.DatabaseFieldProtectorPurpose);
_logger = logger;
}
@@ -165,14 +159,6 @@ public class UserRepository : Repository<User, Guid>, IUserRepository
{
await ProtectDataAndSaveAsync(user, async () => await base.CreateAsync(user));
if (_playIdService.InPlay(out var playId))
{
_logger.LogInformation("Associating user {UserId} with Play ID {PlayId}",
user.Id, playId);
await _playDataRepository.CreateAsync(PlayData.Create(user, playId));
}
return user;
}
@@ -415,3 +401,35 @@ public class UserRepository : Repository<User, Guid>, IUserRepository
}
}
}
public class TestUserTrackingUserRepository : UserRepository
{
private readonly IPlayIdService _playIdService;
private readonly IPlayDataRepository _playDataRepository;
public TestUserTrackingUserRepository(
IPlayIdService playIdService,
GlobalSettings globalSettings,
IPlayDataRepository playDataRepository,
IDataProtectionProvider dataProtectionProvider,
ILogger<UserRepository> logger)
: base(dataProtectionProvider, globalSettings, logger)
{
_playIdService = playIdService;
_playDataRepository = playDataRepository;
}
public override async Task<User> CreateAsync(User user)
{
var createdUser = await base.CreateAsync(user);
if (_playIdService.InPlay(out var playId))
{
_logger.LogInformation("Associating user {UserId} with Play ID {PlayId}",
user.Id, playId);
await _playDataRepository.CreateAsync(PlayData.Create(createdUser, playId));
}
return createdUser;
}
}

View File

@@ -21,9 +21,7 @@ namespace Bit.Infrastructure.EntityFramework.Repositories;
public class OrganizationRepository : Repository<Core.AdminConsole.Entities.Organization, Organization, Guid>, IOrganizationRepository
{
private readonly ILogger<OrganizationRepository> _logger;
private readonly IPlayIdService _playIdService;
private readonly IPlayDataRepository _playDataRepository;
protected readonly ILogger<OrganizationRepository> _logger;
public OrganizationRepository(
IServiceScopeFactory serviceScopeFactory,
@@ -34,24 +32,9 @@ public class OrganizationRepository : Repository<Core.AdminConsole.Entities.Orga
: base(serviceScopeFactory, mapper, context => context.Organizations)
{
_logger = logger;
_playIdService = playIdService;
_playDataRepository = playDataRepository;
}
public override async Task<Core.AdminConsole.Entities.Organization> CreateAsync(Core.AdminConsole.Entities.Organization organization)
{
var createdOrganization = await base.CreateAsync(organization);
if (_playIdService.InPlay(out var playId))
{
_logger.LogInformation("Associating organization {OrganizationId} with Play ID {PlayId}",
organization.Id, playId);
await _playDataRepository.CreateAsync(Core.Entities.PlayData.Create(organization, playId));
}
return createdOrganization;
}
public async Task<Core.AdminConsole.Entities.Organization> GetByIdentifierAsync(string identifier)
{
@@ -459,3 +442,37 @@ public class OrganizationRepository : Repository<Core.AdminConsole.Entities.Orga
.SetProperty(o => o.RevisionDate, requestDate));
}
}
public class TestOrganizationTrackingOrganizationRepository : OrganizationRepository
{
private readonly IPlayIdService _playIdService;
private readonly IPlayDataRepository _playDataRepository;
public TestOrganizationTrackingOrganizationRepository(
IServiceScopeFactory serviceScopeFactory,
IMapper mapper,
ILogger<OrganizationRepository> logger,
IPlayIdService playIdService,
IPlayDataRepository playDataRepository)
: base(serviceScopeFactory, mapper, logger, playIdService, playDataRepository)
{
_playIdService = playIdService;
_playDataRepository = playDataRepository;
}
public override async Task<Core.AdminConsole.Entities.Organization> CreateAsync(Core.AdminConsole.Entities.Organization organization)
{
var createdOrganization = await base.CreateAsync(organization);
if (_playIdService.InPlay(out var playId))
{
_logger.LogInformation("Associating organization {OrganizationId} with Play ID {PlayId}",
organization.Id, playId);
await _playDataRepository.CreateAsync(Core.Entities.PlayData.Create(organization, playId));
}
return createdOrganization;
}
}

View File

@@ -10,6 +10,7 @@ using Bit.Core.NotificationCenter.Repositories;
using Bit.Core.Platform.Installations;
using Bit.Core.Repositories;
using Bit.Core.SecretsManager.Repositories;
using Bit.Core.Settings;
using Bit.Core.Tools.Repositories;
using Bit.Core.Vault.Repositories;
using Bit.Infrastructure.EntityFramework.AdminConsole.Repositories;
@@ -67,8 +68,19 @@ public static class EntityFrameworkServiceCollectionExtensions
});
}
public static void AddPasswordManagerEFRepositories(this IServiceCollection services, bool selfHosted)
public static void AddPasswordManagerEFRepositories(this IServiceCollection services, bool selfHosted, GlobalSettings globalSettings)
{
if (globalSettings.TestPlayIdTrackingEnabled)
{
services.AddSingleton<IOrganizationRepository, TestOrganizationTrackingOrganizationRepository>();
services.AddSingleton<IUserRepository, TestUserTrackingUserRepository>();
}
else
{
services.AddSingleton<IOrganizationRepository, OrganizationRepository>();
services.AddSingleton<IUserRepository, UserRepository>();
}
services.AddSingleton<IApiKeyRepository, ApiKeyRepository>();
services.AddSingleton<IAuthRequestRepository, AuthRequestRepository>();
services.AddSingleton<ICipherRepository, CipherRepository>();
@@ -85,7 +97,6 @@ public static class EntityFrameworkServiceCollectionExtensions
services.AddSingleton<IOrganizationConnectionRepository, OrganizationConnectionRepository>();
services.AddSingleton<IOrganizationIntegrationRepository, OrganizationIntegrationRepository>();
services.AddSingleton<IOrganizationIntegrationConfigurationRepository, OrganizationIntegrationConfigurationRepository>();
services.AddSingleton<IOrganizationRepository, OrganizationRepository>();
services.AddSingleton<IOrganizationSponsorshipRepository, OrganizationSponsorshipRepository>();
services.AddSingleton<IOrganizationUserRepository, OrganizationUserRepository>();
services.AddSingleton<IPlayDataRepository, PlayDataRepository>();
@@ -97,7 +108,6 @@ public static class EntityFrameworkServiceCollectionExtensions
services.AddSingleton<ISsoConfigRepository, SsoConfigRepository>();
services.AddSingleton<ISsoUserRepository, SsoUserRepository>();
services.AddSingleton<ITransactionRepository, TransactionRepository>();
services.AddSingleton<IUserRepository, UserRepository>();
services.AddSingleton<IOrganizationDomainRepository, OrganizationDomainRepository>();
services.AddSingleton<IWebAuthnCredentialRepository, WebAuthnCredentialRepository>();
services.AddSingleton<IProviderPlanRepository, ProviderPlanRepository>();

View File

@@ -14,38 +14,17 @@ namespace Bit.Infrastructure.EntityFramework.Repositories;
public class UserRepository : Repository<Core.Entities.User, User, Guid>, IUserRepository
{
private readonly IPlayIdService _playIdService;
private readonly IPlayDataRepository _playDataRepository;
private readonly ILogger<UserRepository> _logger;
protected readonly ILogger<UserRepository> _logger;
public UserRepository(
IServiceScopeFactory serviceScopeFactory,
IMapper mapper,
IPlayIdService playIdService,
IPlayDataRepository playDataRepository,
ILogger<UserRepository> logger)
: base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Users)
{
_playIdService = playIdService;
_playDataRepository = playDataRepository;
_logger = logger;
}
public override async Task<Core.Entities.User> CreateAsync(Core.Entities.User user)
{
var createdUser = await base.CreateAsync(user);
if (_playIdService.InPlay(out var playId))
{
_logger.LogInformation("Associating user {UserId} with Play ID {PlayId}",
user.Id, playId);
await _playDataRepository.CreateAsync(Core.Entities.PlayData.Create(user, playId));
}
return createdUser;
}
public async Task<Core.Entities.User?> GetByEmailAsync(string email)
{
using (var scope = ServiceScopeFactory.CreateScope())
@@ -422,3 +401,36 @@ public class UserRepository : Repository<Core.Entities.User, User, Guid>, IUserR
}
}
}
public class TestUserTrackingUserRepository : UserRepository
{
private readonly IPlayIdService _playIdService;
private readonly IPlayDataRepository _playDataRepository;
public TestUserTrackingUserRepository(
IPlayIdService playIdService,
IPlayDataRepository playDataRepository,
IServiceScopeFactory serviceScopeFactory,
IMapper mapper,
ILogger<UserRepository> logger)
: base(serviceScopeFactory, mapper, logger)
{
_playIdService = playIdService;
_playDataRepository = playDataRepository;
}
public override async Task<Core.Entities.User> CreateAsync(Core.Entities.User user)
{
var createdUser = await base.CreateAsync(user);
if (_playIdService.InPlay(out var playId))
{
_logger.LogInformation("Associating user {UserId} with Play ID {PlayId}",
user.Id, playId);
await _playDataRepository.CreateAsync(Core.Entities.PlayData.Create(user, playId));
}
return createdUser;
}
}

View File

@@ -100,11 +100,11 @@ public static class ServiceCollectionExtensions
if (provider != SupportedDatabaseProviders.SqlServer && !forceEf)
{
services.AddPasswordManagerEFRepositories(globalSettings.SelfHosted);
services.AddPasswordManagerEFRepositories(globalSettings.SelfHosted, globalSettings);
}
else
{
services.AddDapperRepositories(globalSettings.SelfHosted);
services.AddDapperRepositories(globalSettings.SelfHosted, globalSettings);
}
if (globalSettings.SelfHosted)
@@ -118,12 +118,19 @@ public static class ServiceCollectionExtensions
services.AddKeyedSingleton<IGrantRepository, Core.Auth.Repositories.Cosmos.GrantRepository>("cosmos");
}
if (globalSettings.TestPlayIdTrackingEnabled)
{
// Include PlayIdService for tracking Play Ids in repositories
// We need the http context accessor to use the Singleton version, which pulls from the scoped version
services.AddHttpContextAccessor();
services.AddSingleton<IPlayIdService, PlayIdSingletonService>();
services.AddScoped<PlayIdService>();
}
else
{
services.AddSingleton<IPlayIdService, NeverPlayIdServices>();
}
return provider;
}

View File

@@ -1,6 +1,5 @@
using System.Reflection;
using Bit.Core.Enums;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Infrastructure.Dapper;
using Bit.Infrastructure.EntityFramework;
@@ -125,18 +124,10 @@ public class DatabaseDataAttribute : DataAttribute
{
services.AddSingleton<TimeProvider, FakeTimeProvider>();
}
// Include PlayIdService for tracking Play Ids in repositories
// We need the http context accessor to use the Singleton version, which pulls from the scoped version
services.AddHttpContextAccessor();
services.AddSingleton<IPlayIdService, PlayIdSingletonService>();
services.AddScoped<PlayIdService>();
}
private void AddDapperServices(IServiceCollection services, Database database)
{
services.AddDapperRepositories(SelfHosted);
var globalSettings = new GlobalSettings
{
DatabaseProvider = "sqlServer",
@@ -149,6 +140,7 @@ public class DatabaseDataAttribute : DataAttribute
UserRequestExpiration = TimeSpan.FromMinutes(15),
}
};
services.AddDapperRepositories(SelfHosted, globalSettings);
services.AddSingleton(globalSettings);
services.AddSingleton<IGlobalSettings>(globalSettings);
services.AddSingleton(database);
@@ -168,7 +160,6 @@ public class DatabaseDataAttribute : DataAttribute
private void AddEfServices(IServiceCollection services, Database database)
{
services.SetupEntityFramework(database.ConnectionString, database.Type);
services.AddPasswordManagerEFRepositories(SelfHosted);
var globalSettings = new GlobalSettings
{
@@ -177,6 +168,7 @@ public class DatabaseDataAttribute : DataAttribute
UserRequestExpiration = TimeSpan.FromMinutes(15),
},
};
services.AddPasswordManagerEFRepositories(SelfHosted, globalSettings);
services.AddSingleton(globalSettings);
services.AddSingleton<IGlobalSettings>(globalSettings);