diff --git a/src/Core/Settings/GlobalSettings.cs b/src/Core/Settings/GlobalSettings.cs index c467d1e652..34e6a6276a 100644 --- a/src/Core/Settings/GlobalSettings.cs +++ b/src/Core/Settings/GlobalSettings.cs @@ -54,6 +54,7 @@ public class GlobalSettings : IGlobalSettings public virtual bool EnableCloudCommunication { get; set; } = false; public virtual int OrganizationInviteExpirationHours { get; set; } = 120; // 5 days public virtual string EventGridKey { get; set; } + public virtual bool TestPlayIdTrackingEnabled { get; set; } = false; public virtual IInstallationSettings Installation { get; set; } = new InstallationSettings(); public virtual IBaseServiceUriSettings BaseServiceUri { get; set; } public virtual string DatabaseProvider { get; set; } diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationRepository.cs b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationRepository.cs index fae20bd8fb..4d8557194d 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationRepository.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationRepository.cs @@ -18,37 +18,16 @@ namespace Bit.Infrastructure.Dapper.Repositories; public class OrganizationRepository : Repository, IOrganizationRepository { - private readonly IPlayIdService _playIdService; - private readonly IPlayDataRepository _playDataRepository; - private readonly ILogger _logger; + protected readonly ILogger _logger; public OrganizationRepository( - IPlayIdService playIdService, - IPlayDataRepository playDataRepository, GlobalSettings globalSettings, ILogger logger) : base(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) { - _playIdService = playIdService; - _playDataRepository = playDataRepository; _logger = logger; } - public override async Task CreateAsync(Organization obj) - { - await base.CreateAsync(obj); - - if (_playIdService.InPlay(out var playId)) - { - _logger.LogInformation("Associating organization {OrganizationId} with Play ID {PlayId}", - obj.Id, playId); - - await _playDataRepository.CreateAsync(PlayData.Create(obj, playId)); - } - - return obj; - } - public async Task GetByIdentifierAsync(string identifier) { using (var connection = new SqlConnection(ConnectionString)) @@ -274,3 +253,35 @@ public class OrganizationRepository : Repository, IOrganizat commandType: CommandType.StoredProcedure); } } + +public class TestOrganizationTrackingOrganizationRepository : OrganizationRepository +{ + private readonly IPlayIdService _playIdService; + private readonly IPlayDataRepository _playDataRepository; + + public TestOrganizationTrackingOrganizationRepository( + IPlayIdService playIdService, + IPlayDataRepository playDataRepository, + GlobalSettings globalSettings, + ILogger logger) + : base(globalSettings, logger) + { + _playIdService = playIdService; + _playDataRepository = playDataRepository; + } + + public override async Task CreateAsync(Organization obj) + { + var createdOrganization = await base.CreateAsync(obj); + + if (_playIdService.InPlay(out var playId)) + { + _logger.LogInformation("Associating organization {OrganizationId} with Play ID {PlayId}", + createdOrganization.Id, playId); + + await _playDataRepository.CreateAsync(PlayData.Create(createdOrganization, playId)); + } + + return createdOrganization; + } +} diff --git a/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs b/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs index cddf94caf9..c2e5d011c7 100644 --- a/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs +++ b/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs @@ -9,6 +9,7 @@ using Bit.Core.NotificationCenter.Repositories; using Bit.Core.Platform.Installations; using Bit.Core.Repositories; using Bit.Core.SecretsManager.Repositories; +using Bit.Core.Settings; using Bit.Core.Tools.Repositories; using Bit.Core.Vault.Repositories; using Bit.Infrastructure.Dapper.AdminConsole.Repositories; @@ -28,8 +29,19 @@ namespace Bit.Infrastructure.Dapper; public static class DapperServiceCollectionExtensions { - public static void AddDapperRepositories(this IServiceCollection services, bool selfHosted) + public static void AddDapperRepositories(this IServiceCollection services, bool selfHosted, GlobalSettings globalSettings) { + if (globalSettings.TestPlayIdTrackingEnabled) + { + services.AddSingleton(); + services.AddSingleton(); + } + else + { + services.AddSingleton(); + services.AddSingleton(); + } + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); @@ -47,7 +59,6 @@ public static class DapperServiceCollectionExtensions services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); - services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); @@ -59,7 +70,6 @@ public static class DapperServiceCollectionExtensions services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); - services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); diff --git a/src/Infrastructure.Dapper/Repositories/UserRepository.cs b/src/Infrastructure.Dapper/Repositories/UserRepository.cs index 9e63f9e39c..fc25ef65f3 100644 --- a/src/Infrastructure.Dapper/Repositories/UserRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/UserRepository.cs @@ -18,21 +18,15 @@ namespace Bit.Infrastructure.Dapper.Repositories; public class UserRepository : Repository, IUserRepository { - private readonly IPlayIdService _playIdService; - private readonly IPlayDataRepository _playDataRepository; private readonly IDataProtector _dataProtector; - private readonly ILogger _logger; + protected readonly ILogger _logger; public UserRepository( - IPlayIdService playIdService, - GlobalSettings globalSettings, - IPlayDataRepository playDataRepository, IDataProtectionProvider dataProtectionProvider, + GlobalSettings globalSettings, ILogger logger) : base(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) { - _playIdService = playIdService; - _playDataRepository = playDataRepository; _dataProtector = dataProtectionProvider.CreateProtector(Constants.DatabaseFieldProtectorPurpose); _logger = logger; } @@ -165,14 +159,6 @@ public class UserRepository : Repository, IUserRepository { await ProtectDataAndSaveAsync(user, async () => await base.CreateAsync(user)); - if (_playIdService.InPlay(out var playId)) - { - _logger.LogInformation("Associating user {UserId} with Play ID {PlayId}", - user.Id, playId); - - await _playDataRepository.CreateAsync(PlayData.Create(user, playId)); - } - return user; } @@ -415,3 +401,35 @@ public class UserRepository : Repository, IUserRepository } } } + +public class TestUserTrackingUserRepository : UserRepository +{ + private readonly IPlayIdService _playIdService; + private readonly IPlayDataRepository _playDataRepository; + + public TestUserTrackingUserRepository( + IPlayIdService playIdService, + GlobalSettings globalSettings, + IPlayDataRepository playDataRepository, + IDataProtectionProvider dataProtectionProvider, + ILogger logger) + : base(dataProtectionProvider, globalSettings, logger) + { + _playIdService = playIdService; + _playDataRepository = playDataRepository; + } + + public override async Task CreateAsync(User user) + { + var createdUser = await base.CreateAsync(user); + + if (_playIdService.InPlay(out var playId)) + { + _logger.LogInformation("Associating user {UserId} with Play ID {PlayId}", + user.Id, playId); + + await _playDataRepository.CreateAsync(PlayData.Create(createdUser, playId)); + } + return createdUser; + } +} diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs index abdb941f4e..956282b15e 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs @@ -21,9 +21,7 @@ namespace Bit.Infrastructure.EntityFramework.Repositories; public class OrganizationRepository : Repository, IOrganizationRepository { - private readonly ILogger _logger; - private readonly IPlayIdService _playIdService; - private readonly IPlayDataRepository _playDataRepository; + protected readonly ILogger _logger; public OrganizationRepository( IServiceScopeFactory serviceScopeFactory, @@ -34,24 +32,9 @@ public class OrganizationRepository : Repository context.Organizations) { _logger = logger; - _playIdService = playIdService; - _playDataRepository = playDataRepository; } - public override async Task CreateAsync(Core.AdminConsole.Entities.Organization organization) - { - var createdOrganization = await base.CreateAsync(organization); - if (_playIdService.InPlay(out var playId)) - { - _logger.LogInformation("Associating organization {OrganizationId} with Play ID {PlayId}", - organization.Id, playId); - - await _playDataRepository.CreateAsync(Core.Entities.PlayData.Create(organization, playId)); - } - - return createdOrganization; - } public async Task GetByIdentifierAsync(string identifier) { @@ -459,3 +442,37 @@ public class OrganizationRepository : Repository o.RevisionDate, requestDate)); } } + +public class TestOrganizationTrackingOrganizationRepository : OrganizationRepository +{ + private readonly IPlayIdService _playIdService; + private readonly IPlayDataRepository _playDataRepository; + + public TestOrganizationTrackingOrganizationRepository( + IServiceScopeFactory serviceScopeFactory, + IMapper mapper, + ILogger logger, + IPlayIdService playIdService, + IPlayDataRepository playDataRepository) + : base(serviceScopeFactory, mapper, logger, playIdService, playDataRepository) + { + _playIdService = playIdService; + _playDataRepository = playDataRepository; + + } + + public override async Task CreateAsync(Core.AdminConsole.Entities.Organization organization) + { + var createdOrganization = await base.CreateAsync(organization); + + if (_playIdService.InPlay(out var playId)) + { + _logger.LogInformation("Associating organization {OrganizationId} with Play ID {PlayId}", + organization.Id, playId); + + await _playDataRepository.CreateAsync(Core.Entities.PlayData.Create(organization, playId)); + } + + return createdOrganization; + } +} diff --git a/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs b/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs index d92be9113b..74f7499417 100644 --- a/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs +++ b/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs @@ -10,6 +10,7 @@ using Bit.Core.NotificationCenter.Repositories; using Bit.Core.Platform.Installations; using Bit.Core.Repositories; using Bit.Core.SecretsManager.Repositories; +using Bit.Core.Settings; using Bit.Core.Tools.Repositories; using Bit.Core.Vault.Repositories; using Bit.Infrastructure.EntityFramework.AdminConsole.Repositories; @@ -67,8 +68,19 @@ public static class EntityFrameworkServiceCollectionExtensions }); } - public static void AddPasswordManagerEFRepositories(this IServiceCollection services, bool selfHosted) + public static void AddPasswordManagerEFRepositories(this IServiceCollection services, bool selfHosted, GlobalSettings globalSettings) { + if (globalSettings.TestPlayIdTrackingEnabled) + { + services.AddSingleton(); + services.AddSingleton(); + } + else + { + services.AddSingleton(); + services.AddSingleton(); + } + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); @@ -85,7 +97,6 @@ public static class EntityFrameworkServiceCollectionExtensions services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); - services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); @@ -97,7 +108,6 @@ public static class EntityFrameworkServiceCollectionExtensions services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); - services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); diff --git a/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs b/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs index ce30182346..b820814e3c 100644 --- a/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs @@ -14,38 +14,17 @@ namespace Bit.Infrastructure.EntityFramework.Repositories; public class UserRepository : Repository, IUserRepository { - private readonly IPlayIdService _playIdService; - private readonly IPlayDataRepository _playDataRepository; - private readonly ILogger _logger; + protected readonly ILogger _logger; public UserRepository( IServiceScopeFactory serviceScopeFactory, IMapper mapper, - IPlayIdService playIdService, - IPlayDataRepository playDataRepository, ILogger logger) : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Users) { - _playIdService = playIdService; - _playDataRepository = playDataRepository; _logger = logger; } - public override async Task CreateAsync(Core.Entities.User user) - { - var createdUser = await base.CreateAsync(user); - - if (_playIdService.InPlay(out var playId)) - { - _logger.LogInformation("Associating user {UserId} with Play ID {PlayId}", - user.Id, playId); - - await _playDataRepository.CreateAsync(Core.Entities.PlayData.Create(user, playId)); - } - - return createdUser; - } - public async Task GetByEmailAsync(string email) { using (var scope = ServiceScopeFactory.CreateScope()) @@ -422,3 +401,36 @@ public class UserRepository : Repository, IUserR } } } + +public class TestUserTrackingUserRepository : UserRepository +{ + private readonly IPlayIdService _playIdService; + private readonly IPlayDataRepository _playDataRepository; + + public TestUserTrackingUserRepository( + IPlayIdService playIdService, + IPlayDataRepository playDataRepository, + IServiceScopeFactory serviceScopeFactory, + IMapper mapper, + ILogger logger) + : base(serviceScopeFactory, mapper, logger) + { + _playIdService = playIdService; + _playDataRepository = playDataRepository; + } + + public override async Task CreateAsync(Core.Entities.User user) + { + var createdUser = await base.CreateAsync(user); + + if (_playIdService.InPlay(out var playId)) + { + _logger.LogInformation("Associating user {UserId} with Play ID {PlayId}", + user.Id, playId); + + await _playDataRepository.CreateAsync(Core.Entities.PlayData.Create(user, playId)); + } + + return createdUser; + } +} diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs index 2ee8d5b572..720477f073 100644 --- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs +++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs @@ -100,11 +100,11 @@ public static class ServiceCollectionExtensions if (provider != SupportedDatabaseProviders.SqlServer && !forceEf) { - services.AddPasswordManagerEFRepositories(globalSettings.SelfHosted); + services.AddPasswordManagerEFRepositories(globalSettings.SelfHosted, globalSettings); } else { - services.AddDapperRepositories(globalSettings.SelfHosted); + services.AddDapperRepositories(globalSettings.SelfHosted, globalSettings); } if (globalSettings.SelfHosted) @@ -118,12 +118,19 @@ public static class ServiceCollectionExtensions services.AddKeyedSingleton("cosmos"); } - // Include PlayIdService for tracking Play Ids in repositories - // We need the http context accessor to use the Singleton version, which pulls from the scoped version - services.AddHttpContextAccessor(); + if (globalSettings.TestPlayIdTrackingEnabled) + { + // Include PlayIdService for tracking Play Ids in repositories + // We need the http context accessor to use the Singleton version, which pulls from the scoped version + services.AddHttpContextAccessor(); - services.AddSingleton(); - services.AddScoped(); + services.AddSingleton(); + services.AddScoped(); + } + else + { + services.AddSingleton(); + } return provider; } diff --git a/test/Infrastructure.IntegrationTest/DatabaseDataAttribute.cs b/test/Infrastructure.IntegrationTest/DatabaseDataAttribute.cs index 7dadf654dd..6aeae91019 100644 --- a/test/Infrastructure.IntegrationTest/DatabaseDataAttribute.cs +++ b/test/Infrastructure.IntegrationTest/DatabaseDataAttribute.cs @@ -1,6 +1,5 @@ using System.Reflection; using Bit.Core.Enums; -using Bit.Core.Services; using Bit.Core.Settings; using Bit.Infrastructure.Dapper; using Bit.Infrastructure.EntityFramework; @@ -125,18 +124,10 @@ public class DatabaseDataAttribute : DataAttribute { services.AddSingleton(); } - - // Include PlayIdService for tracking Play Ids in repositories - // We need the http context accessor to use the Singleton version, which pulls from the scoped version - services.AddHttpContextAccessor(); - - services.AddSingleton(); - services.AddScoped(); } private void AddDapperServices(IServiceCollection services, Database database) { - services.AddDapperRepositories(SelfHosted); var globalSettings = new GlobalSettings { DatabaseProvider = "sqlServer", @@ -149,6 +140,7 @@ public class DatabaseDataAttribute : DataAttribute UserRequestExpiration = TimeSpan.FromMinutes(15), } }; + services.AddDapperRepositories(SelfHosted, globalSettings); services.AddSingleton(globalSettings); services.AddSingleton(globalSettings); services.AddSingleton(database); @@ -168,7 +160,6 @@ public class DatabaseDataAttribute : DataAttribute private void AddEfServices(IServiceCollection services, Database database) { services.SetupEntityFramework(database.ConnectionString, database.Type); - services.AddPasswordManagerEFRepositories(SelfHosted); var globalSettings = new GlobalSettings { @@ -177,6 +168,7 @@ public class DatabaseDataAttribute : DataAttribute UserRequestExpiration = TimeSpan.FromMinutes(15), }, }; + services.AddPasswordManagerEFRepositories(SelfHosted, globalSettings); services.AddSingleton(globalSettings); services.AddSingleton(globalSettings);