From 101e29b3541248982ad18f60c56bd790dfaca949 Mon Sep 17 00:00:00 2001 From: Brandon Treston Date: Tue, 2 Sep 2025 10:52:23 -0400 Subject: [PATCH 01/13] [PM-15354] fix EF implementation to match dapper (missing null check) (#6261) * fix EF implementation to match dapper (missing null check) * cleanup --- .../Repositories/OrganizationDomainRepository.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Infrastructure.EntityFramework/Repositories/OrganizationDomainRepository.cs b/src/Infrastructure.EntityFramework/Repositories/OrganizationDomainRepository.cs index e7bee0cdfd..0ddf80130e 100644 --- a/src/Infrastructure.EntityFramework/Repositories/OrganizationDomainRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/OrganizationDomainRepository.cs @@ -152,7 +152,7 @@ public class OrganizationDomainRepository : Repository x.LastCheckedDate < DateTime.UtcNow.AddDays(-expirationPeriod)) + .Where(x => x.LastCheckedDate < DateTime.UtcNow.AddDays(-expirationPeriod) && x.VerifiedDate == null) .ToListAsync(); dbContext.OrganizationDomains.RemoveRange(expiredDomains); return await dbContext.SaveChangesAsync() > 0; From cb1db262cacc7a5c95e5ed50ba94e4b9c1ad3ae6 Mon Sep 17 00:00:00 2001 From: Todd Martin <106564991+trmartin4@users.noreply.github.com> Date: Tue, 2 Sep 2025 11:18:36 -0400 Subject: [PATCH 02/13] chore(feature-flag): [PM-18179] Remove pm-17128-recovery-code-login feature flag * Rmoved feature flag and obsolete endpoint * Removed obsolete method. --- .../Auth/Controllers/TwoFactorController.cs | 15 --------- src/Core/Constants.cs | 1 - src/Core/Services/IUserService.cs | 3 -- .../Services/Implementations/UserService.cs | 33 ------------------- 4 files changed, 52 deletions(-) diff --git a/src/Api/Auth/Controllers/TwoFactorController.cs b/src/Api/Auth/Controllers/TwoFactorController.cs index 96b64f16fc..4155489daa 100644 --- a/src/Api/Auth/Controllers/TwoFactorController.cs +++ b/src/Api/Auth/Controllers/TwoFactorController.cs @@ -409,21 +409,6 @@ public class TwoFactorController : Controller return response; } - /// - /// To be removed when the feature flag pm-17128-recovery-code-login is removed PM-18175. - /// - [Obsolete("Two Factor recovery is handled in the TwoFactorAuthenticationValidator.")] - [HttpPost("recover")] - [AllowAnonymous] - public async Task PostRecover([FromBody] TwoFactorRecoveryRequestModel model) - { - if (!await _userService.RecoverTwoFactorAsync(model.Email, model.MasterPasswordHash, model.RecoveryCode)) - { - await Task.Delay(2000); - throw new BadRequestException(string.Empty, "Invalid information. Try again."); - } - } - [Obsolete("Leaving this for backwards compatibility on clients")] [HttpGet("get-device-verification-settings")] public Task GetDeviceVerificationSettings() diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 39bd3fea5d..352daee862 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -121,7 +121,6 @@ public static class FeatureFlagKeys public const string BrowserExtensionLoginApproval = "pm-14938-browser-extension-login-approvals"; public const string SetInitialPasswordRefactor = "pm-16117-set-initial-password-refactor"; public const string ChangeExistingPasswordRefactor = "pm-16117-change-existing-password-refactor"; - public const string RecoveryCodeLogin = "pm-17128-recovery-code-login"; public const string Otp6Digits = "pm-18612-otp-6-digits"; public const string FailedTwoFactorEmail = "pm-24425-send-2fa-failed-email"; diff --git a/src/Core/Services/IUserService.cs b/src/Core/Services/IUserService.cs index 8457a9c128..ef602be93a 100644 --- a/src/Core/Services/IUserService.cs +++ b/src/Core/Services/IUserService.cs @@ -90,9 +90,6 @@ public interface IUserService void SetTwoFactorProvider(User user, TwoFactorProviderType type, bool setEnabled = true); - [Obsolete("To be removed when the feature flag pm-17128-recovery-code-login is removed PM-18175.")] - Task RecoverTwoFactorAsync(string email, string masterPassword, string recoveryCode); - /// /// This method is used by the TwoFactorAuthenticationValidator to recover two /// factor for a user. This allows users to be logged in after a successful recovery diff --git a/src/Core/Services/Implementations/UserService.cs b/src/Core/Services/Implementations/UserService.cs index 0da565c4ba..16e298d177 100644 --- a/src/Core/Services/Implementations/UserService.cs +++ b/src/Core/Services/Implementations/UserService.cs @@ -865,39 +865,6 @@ public class UserService : UserManager, IUserService } } - /// - /// To be removed when the feature flag pm-17128-recovery-code-login is removed PM-18175. - /// - [Obsolete("Two Factor recovery is handled in the TwoFactorAuthenticationValidator.")] - public async Task RecoverTwoFactorAsync(string email, string secret, string recoveryCode) - { - var user = await _userRepository.GetByEmailAsync(email); - if (user == null) - { - // No user exists. Do we want to send an email telling them this in the future? - return false; - } - - if (!await VerifySecretAsync(user, secret)) - { - return false; - } - - if (!CoreHelpers.FixedTimeEquals(user.TwoFactorRecoveryCode, recoveryCode)) - { - return false; - } - - user.TwoFactorProviders = null; - user.TwoFactorRecoveryCode = CoreHelpers.SecureRandomString(32, upper: false, special: false); - await SaveUserAsync(user); - await _mailService.SendRecoverTwoFactorEmail(user.Email, DateTime.UtcNow, _currentContext.IpAddress); - await _eventService.LogUserEventAsync(user.Id, EventType.User_Recovered2fa); - await CheckPoliciesOnTwoFactorRemovalAsync(user); - - return true; - } - public async Task RecoverTwoFactorAsync(User user, string recoveryCode) { if (!CoreHelpers.FixedTimeEquals( From a180317509b4b200185d2ffbbfdccf9f49560a79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Garc=C3=ADa?= Date: Tue, 2 Sep 2025 18:30:53 +0200 Subject: [PATCH 03/13] [PM-25182] Improve swagger OperationIDs: Part 1 (#6229) * Improve swagger OperationIDs: Part 1 * Fix tests and fmt * Improve docs and add more tests * Fmt * Improve Swagger OperationIDs for Auth * Fix review feedback * Use generic getcustomattributes * Format * replace swaggerexclude by split+obsolete * Format * Some remaining excludes --- dev/generate_openapi_files.ps1 | 9 ++ .../Auth/Controllers/AccountsController.cs | 32 ++++++- .../Controllers/AuthRequestsController.cs | 2 +- .../Controllers/EmergencyAccessController.cs | 18 +++- .../Auth/Controllers/TwoFactorController.cs | 67 +++++++++++++-- src/Api/Controllers/CollectionsController.cs | 26 +++++- src/Api/Controllers/DevicesController.cs | 55 ++++++++++-- src/Api/Controllers/InfoController.cs | 8 +- src/Api/Controllers/SettingsController.cs | 8 +- .../Utilities/ServiceCollectionExtensions.cs | 10 +-- src/Identity/Controllers/InfoController.cs | 8 +- src/Identity/Startup.cs | 10 +-- .../Swagger/ActionNameOperationFilter.cs | 25 ++++++ ...heckDuplicateOperationIdsDocumentFilter.cs | 80 +++++++++++++++++ src/SharedWeb/Swagger/SwaggerGenOptionsExt.cs | 33 +++++++ .../AuthRequestsControllerTests.cs | 2 +- .../Controllers/DevicesControllerTests.cs | 4 +- .../Controllers/CollectionsControllerTests.cs | 4 +- .../ActionNameOperationFilterTest.cs | 67 +++++++++++++++ ...DuplicateOperationIdsDocumentFilterTest.cs | 84 ++++++++++++++++++ test/SharedWeb.Test/SharedWeb.Test.csproj | 1 + test/SharedWeb.Test/SwaggerDocUtil.cs | 85 +++++++++++++++++++ 22 files changed, 583 insertions(+), 55 deletions(-) create mode 100644 src/SharedWeb/Swagger/ActionNameOperationFilter.cs create mode 100644 src/SharedWeb/Swagger/CheckDuplicateOperationIdsDocumentFilter.cs create mode 100644 src/SharedWeb/Swagger/SwaggerGenOptionsExt.cs create mode 100644 test/SharedWeb.Test/ActionNameOperationFilterTest.cs create mode 100644 test/SharedWeb.Test/CheckDuplicateOperationIdsDocumentFilterTest.cs create mode 100644 test/SharedWeb.Test/SwaggerDocUtil.cs diff --git a/dev/generate_openapi_files.ps1 b/dev/generate_openapi_files.ps1 index 02470a0b1d..9eca7dc734 100644 --- a/dev/generate_openapi_files.ps1 +++ b/dev/generate_openapi_files.ps1 @@ -11,9 +11,18 @@ dotnet tool restore Set-Location "./src/Identity" dotnet build dotnet swagger tofile --output "../../identity.json" --host "https://identity.bitwarden.com" "./bin/Debug/net8.0/Identity.dll" "v1" +if ($LASTEXITCODE -ne 0) { + exit $LASTEXITCODE +} # Api internal & public Set-Location "../../src/Api" dotnet build dotnet swagger tofile --output "../../api.json" --host "https://api.bitwarden.com" "./bin/Debug/net8.0/Api.dll" "internal" +if ($LASTEXITCODE -ne 0) { + exit $LASTEXITCODE +} dotnet swagger tofile --output "../../api.public.json" --host "https://api.bitwarden.com" "./bin/Debug/net8.0/Api.dll" "public" +if ($LASTEXITCODE -ne 0) { + exit $LASTEXITCODE +} diff --git a/src/Api/Auth/Controllers/AccountsController.cs b/src/Api/Auth/Controllers/AccountsController.cs index f197f1270b..0bed7c29c4 100644 --- a/src/Api/Auth/Controllers/AccountsController.cs +++ b/src/Api/Auth/Controllers/AccountsController.cs @@ -344,7 +344,6 @@ public class AccountsController : Controller } [HttpPut("profile")] - [HttpPost("profile")] public async Task PutProfile([FromBody] UpdateProfileRequestModel model) { var user = await _userService.GetUserByPrincipalAsync(User); @@ -363,8 +362,14 @@ public class AccountsController : Controller return response; } + [HttpPost("profile")] + [Obsolete("This endpoint is deprecated. Use PUT /profile instead.")] + public async Task PostProfile([FromBody] UpdateProfileRequestModel model) + { + return await PutProfile(model); + } + [HttpPut("avatar")] - [HttpPost("avatar")] public async Task PutAvatar([FromBody] UpdateAvatarRequestModel model) { var user = await _userService.GetUserByPrincipalAsync(User); @@ -382,6 +387,13 @@ public class AccountsController : Controller return response; } + [HttpPost("avatar")] + [Obsolete("This endpoint is deprecated. Use PUT /avatar instead.")] + public async Task PostAvatar([FromBody] UpdateAvatarRequestModel model) + { + return await PutAvatar(model); + } + [HttpGet("revision-date")] public async Task GetAccountRevisionDate() { @@ -430,7 +442,6 @@ public class AccountsController : Controller } [HttpDelete] - [HttpPost("delete")] public async Task Delete([FromBody] SecretVerificationRequestModel model) { var user = await _userService.GetUserByPrincipalAsync(User); @@ -467,6 +478,13 @@ public class AccountsController : Controller throw new BadRequestException(ModelState); } + [HttpPost("delete")] + [Obsolete("This endpoint is deprecated. Use DELETE / instead.")] + public async Task PostDelete([FromBody] SecretVerificationRequestModel model) + { + await Delete(model); + } + [AllowAnonymous] [HttpPost("delete-recover")] public async Task PostDeleteRecover([FromBody] DeleteRecoverRequestModel model) @@ -638,7 +656,6 @@ public class AccountsController : Controller await _twoFactorEmailService.SendNewDeviceVerificationEmailAsync(user); } - [HttpPost("verify-devices")] [HttpPut("verify-devices")] public async Task SetUserVerifyDevicesAsync([FromBody] SetVerifyDevicesRequestModel request) { @@ -654,6 +671,13 @@ public class AccountsController : Controller await _userService.SaveUserAsync(user); } + [HttpPost("verify-devices")] + [Obsolete("This endpoint is deprecated. Use PUT /verify-devices instead.")] + public async Task PostSetUserVerifyDevicesAsync([FromBody] SetVerifyDevicesRequestModel request) + { + await SetUserVerifyDevicesAsync(request); + } + private async Task> GetOrganizationIdsClaimingUserAsync(Guid userId) { var organizationsClaimingUser = await _userService.GetOrganizationsClaimingUserAsync(userId); diff --git a/src/Api/Auth/Controllers/AuthRequestsController.cs b/src/Api/Auth/Controllers/AuthRequestsController.cs index 3f91bd6eea..c62b817905 100644 --- a/src/Api/Auth/Controllers/AuthRequestsController.cs +++ b/src/Api/Auth/Controllers/AuthRequestsController.cs @@ -31,7 +31,7 @@ public class AuthRequestsController( private readonly IAuthRequestService _authRequestService = authRequestService; [HttpGet("")] - public async Task> Get() + public async Task> GetAll() { var userId = _userService.GetProperUserId(User).Value; var authRequests = await _authRequestRepository.GetManyByUserIdAsync(userId); diff --git a/src/Api/Auth/Controllers/EmergencyAccessController.cs b/src/Api/Auth/Controllers/EmergencyAccessController.cs index 53b57fe685..b849dc3e07 100644 --- a/src/Api/Auth/Controllers/EmergencyAccessController.cs +++ b/src/Api/Auth/Controllers/EmergencyAccessController.cs @@ -79,7 +79,6 @@ public class EmergencyAccessController : Controller } [HttpPut("{id}")] - [HttpPost("{id}")] public async Task Put(Guid id, [FromBody] EmergencyAccessUpdateRequestModel model) { var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); @@ -92,14 +91,27 @@ public class EmergencyAccessController : Controller await _emergencyAccessService.SaveAsync(model.ToEmergencyAccess(emergencyAccess), user); } + [HttpPost("{id}")] + [Obsolete("This endpoint is deprecated. Use PUT /{id} instead.")] + public async Task Post(Guid id, [FromBody] EmergencyAccessUpdateRequestModel model) + { + await Put(id, model); + } + [HttpDelete("{id}")] - [HttpPost("{id}/delete")] public async Task Delete(Guid id) { var userId = _userService.GetProperUserId(User); await _emergencyAccessService.DeleteAsync(id, userId.Value); } + [HttpPost("{id}/delete")] + [Obsolete("This endpoint is deprecated. Use DELETE /{id} instead.")] + public async Task PostDelete(Guid id) + { + await Delete(id); + } + [HttpPost("invite")] public async Task Invite([FromBody] EmergencyAccessInviteRequestModel model) { @@ -136,7 +148,7 @@ public class EmergencyAccessController : Controller } [HttpPost("{id}/approve")] - public async Task Accept(Guid id) + public async Task Approve(Guid id) { var user = await _userService.GetUserByPrincipalAsync(User); await _emergencyAccessService.ApproveAsync(id, user); diff --git a/src/Api/Auth/Controllers/TwoFactorController.cs b/src/Api/Auth/Controllers/TwoFactorController.cs index 4155489daa..886ed2cd20 100644 --- a/src/Api/Auth/Controllers/TwoFactorController.cs +++ b/src/Api/Auth/Controllers/TwoFactorController.cs @@ -110,7 +110,6 @@ public class TwoFactorController : Controller } [HttpPut("authenticator")] - [HttpPost("authenticator")] public async Task PutAuthenticator( [FromBody] UpdateTwoFactorAuthenticatorRequestModel model) { @@ -133,6 +132,14 @@ public class TwoFactorController : Controller return response; } + [HttpPost("authenticator")] + [Obsolete("This endpoint is deprecated. Use PUT /authenticator instead.")] + public async Task PostAuthenticator( + [FromBody] UpdateTwoFactorAuthenticatorRequestModel model) + { + return await PutAuthenticator(model); + } + [HttpDelete("authenticator")] public async Task DisableAuthenticator( [FromBody] TwoFactorAuthenticatorDisableRequestModel model) @@ -157,7 +164,6 @@ public class TwoFactorController : Controller } [HttpPut("yubikey")] - [HttpPost("yubikey")] public async Task PutYubiKey([FromBody] UpdateTwoFactorYubicoOtpRequestModel model) { var user = await CheckAsync(model, true); @@ -174,6 +180,13 @@ public class TwoFactorController : Controller return response; } + [HttpPost("yubikey")] + [Obsolete("This endpoint is deprecated. Use PUT /yubikey instead.")] + public async Task PostYubiKey([FromBody] UpdateTwoFactorYubicoOtpRequestModel model) + { + return await PutYubiKey(model); + } + [HttpPost("get-duo")] public async Task GetDuo([FromBody] SecretVerificationRequestModel model) { @@ -183,7 +196,6 @@ public class TwoFactorController : Controller } [HttpPut("duo")] - [HttpPost("duo")] public async Task PutDuo([FromBody] UpdateTwoFactorDuoRequestModel model) { var user = await CheckAsync(model, true); @@ -199,6 +211,13 @@ public class TwoFactorController : Controller return response; } + [HttpPost("duo")] + [Obsolete("This endpoint is deprecated. Use PUT /duo instead.")] + public async Task PostDuo([FromBody] UpdateTwoFactorDuoRequestModel model) + { + return await PutDuo(model); + } + [HttpPost("~/organizations/{id}/two-factor/get-duo")] public async Task GetOrganizationDuo(string id, [FromBody] SecretVerificationRequestModel model) @@ -217,7 +236,6 @@ public class TwoFactorController : Controller } [HttpPut("~/organizations/{id}/two-factor/duo")] - [HttpPost("~/organizations/{id}/two-factor/duo")] public async Task PutOrganizationDuo(string id, [FromBody] UpdateTwoFactorDuoRequestModel model) { @@ -243,6 +261,14 @@ public class TwoFactorController : Controller return response; } + [HttpPost("~/organizations/{id}/two-factor/duo")] + [Obsolete("This endpoint is deprecated. Use PUT /organizations/{id}/two-factor/duo instead.")] + public async Task PostOrganizationDuo(string id, + [FromBody] UpdateTwoFactorDuoRequestModel model) + { + return await PutOrganizationDuo(id, model); + } + [HttpPost("get-webauthn")] public async Task GetWebAuthn([FromBody] SecretVerificationRequestModel model) { @@ -261,7 +287,6 @@ public class TwoFactorController : Controller } [HttpPut("webauthn")] - [HttpPost("webauthn")] public async Task PutWebAuthn([FromBody] TwoFactorWebAuthnRequestModel model) { var user = await CheckAsync(model, false); @@ -277,6 +302,13 @@ public class TwoFactorController : Controller return response; } + [HttpPost("webauthn")] + [Obsolete("This endpoint is deprecated. Use PUT /webauthn instead.")] + public async Task PostWebAuthn([FromBody] TwoFactorWebAuthnRequestModel model) + { + return await PutWebAuthn(model); + } + [HttpDelete("webauthn")] public async Task DeleteWebAuthn( [FromBody] TwoFactorWebAuthnDeleteRequestModel model) @@ -349,7 +381,6 @@ public class TwoFactorController : Controller } [HttpPut("email")] - [HttpPost("email")] public async Task PutEmail([FromBody] UpdateTwoFactorEmailRequestModel model) { var user = await CheckAsync(model, false); @@ -367,8 +398,14 @@ public class TwoFactorController : Controller return response; } + [HttpPost("email")] + [Obsolete("This endpoint is deprecated. Use PUT /email instead.")] + public async Task PostEmail([FromBody] UpdateTwoFactorEmailRequestModel model) + { + return await PutEmail(model); + } + [HttpPut("disable")] - [HttpPost("disable")] public async Task PutDisable([FromBody] TwoFactorProviderRequestModel model) { var user = await CheckAsync(model, false); @@ -377,8 +414,14 @@ public class TwoFactorController : Controller return response; } + [HttpPost("disable")] + [Obsolete("This endpoint is deprecated. Use PUT /disable instead.")] + public async Task PostDisable([FromBody] TwoFactorProviderRequestModel model) + { + return await PutDisable(model); + } + [HttpPut("~/organizations/{id}/two-factor/disable")] - [HttpPost("~/organizations/{id}/two-factor/disable")] public async Task PutOrganizationDisable(string id, [FromBody] TwoFactorProviderRequestModel model) { @@ -401,6 +444,14 @@ public class TwoFactorController : Controller return response; } + [HttpPost("~/organizations/{id}/two-factor/disable")] + [Obsolete("This endpoint is deprecated. Use PUT /organizations/{id}/two-factor/disable instead.")] + public async Task PostOrganizationDisable(string id, + [FromBody] TwoFactorProviderRequestModel model) + { + return await PutOrganizationDisable(id, model); + } + [HttpPost("get-recover")] public async Task GetRecover([FromBody] SecretVerificationRequestModel model) { diff --git a/src/Api/Controllers/CollectionsController.cs b/src/Api/Controllers/CollectionsController.cs index 6d4e9c9fea..f037ab7034 100644 --- a/src/Api/Controllers/CollectionsController.cs +++ b/src/Api/Controllers/CollectionsController.cs @@ -102,7 +102,7 @@ public class CollectionsController : Controller } [HttpGet("")] - public async Task> Get(Guid orgId) + public async Task> GetAll(Guid orgId) { IEnumerable orgCollections; @@ -173,7 +173,6 @@ public class CollectionsController : Controller } [HttpPut("{id}")] - [HttpPost("{id}")] public async Task Put(Guid orgId, Guid id, [FromBody] UpdateCollectionRequestModel model) { var collection = await _collectionRepository.GetByIdAsync(id); @@ -198,6 +197,13 @@ public class CollectionsController : Controller return new CollectionAccessDetailsResponseModel(collectionWithPermissions); } + [HttpPost("{id}")] + [Obsolete("This endpoint is deprecated. Use PUT /{id} instead.")] + public async Task Post(Guid orgId, Guid id, [FromBody] UpdateCollectionRequestModel model) + { + return await Put(orgId, id, model); + } + [HttpPost("bulk-access")] public async Task PostBulkCollectionAccess(Guid orgId, [FromBody] BulkCollectionAccessRequestModel model) { @@ -222,7 +228,6 @@ public class CollectionsController : Controller } [HttpDelete("{id}")] - [HttpPost("{id}/delete")] public async Task Delete(Guid orgId, Guid id) { var collection = await _collectionRepository.GetByIdAsync(id); @@ -235,8 +240,14 @@ public class CollectionsController : Controller await _deleteCollectionCommand.DeleteAsync(collection); } + [HttpPost("{id}/delete")] + [Obsolete("This endpoint is deprecated. Use DELETE /{id} instead.")] + public async Task PostDelete(Guid orgId, Guid id) + { + await Delete(orgId, id); + } + [HttpDelete("")] - [HttpPost("delete")] public async Task DeleteMany(Guid orgId, [FromBody] CollectionBulkDeleteRequestModel model) { var collections = await _collectionRepository.GetManyByManyIdsAsync(model.Ids); @@ -248,4 +259,11 @@ public class CollectionsController : Controller await _deleteCollectionCommand.DeleteManyAsync(collections); } + + [HttpPost("delete")] + [Obsolete("This endpoint is deprecated. Use DELETE / instead.")] + public async Task PostDeleteMany(Guid orgId, [FromBody] CollectionBulkDeleteRequestModel model) + { + await DeleteMany(orgId, model); + } } diff --git a/src/Api/Controllers/DevicesController.cs b/src/Api/Controllers/DevicesController.cs index 07e8552268..1f2cda9cc4 100644 --- a/src/Api/Controllers/DevicesController.cs +++ b/src/Api/Controllers/DevicesController.cs @@ -75,7 +75,7 @@ public class DevicesController : Controller } [HttpGet("")] - public async Task> Get() + public async Task> GetAll() { var devicesWithPendingAuthData = await _deviceRepository.GetManyByUserIdWithDeviceAuth(_userService.GetProperUserId(User).Value); @@ -99,7 +99,6 @@ public class DevicesController : Controller } [HttpPut("{id}")] - [HttpPost("{id}")] public async Task Put(string id, [FromBody] DeviceRequestModel model) { var device = await _deviceRepository.GetByIdAsync(new Guid(id), _userService.GetProperUserId(User).Value); @@ -114,8 +113,14 @@ public class DevicesController : Controller return response; } + [HttpPost("{id}")] + [Obsolete("This endpoint is deprecated. Use PUT /{id} instead.")] + public async Task Post(string id, [FromBody] DeviceRequestModel model) + { + return await Put(id, model); + } + [HttpPut("{identifier}/keys")] - [HttpPost("{identifier}/keys")] public async Task PutKeys(string identifier, [FromBody] DeviceKeysRequestModel model) { var device = await _deviceRepository.GetByIdentifierAsync(identifier, _userService.GetProperUserId(User).Value); @@ -130,6 +135,13 @@ public class DevicesController : Controller return response; } + [HttpPost("{identifier}/keys")] + [Obsolete("This endpoint is deprecated. Use PUT /{identifier}/keys instead.")] + public async Task PostKeys(string identifier, [FromBody] DeviceKeysRequestModel model) + { + return await PutKeys(identifier, model); + } + [HttpPost("{identifier}/retrieve-keys")] [Obsolete("This endpoint is deprecated. The keys are on the regular device GET endpoints now.")] public async Task GetDeviceKeys(string identifier) @@ -187,7 +199,6 @@ public class DevicesController : Controller } [HttpPut("identifier/{identifier}/token")] - [HttpPost("identifier/{identifier}/token")] public async Task PutToken(string identifier, [FromBody] DeviceTokenRequestModel model) { var device = await _deviceRepository.GetByIdentifierAsync(identifier, _userService.GetProperUserId(User).Value); @@ -199,8 +210,14 @@ public class DevicesController : Controller await _deviceService.SaveAsync(model.ToDevice(device)); } + [HttpPost("identifier/{identifier}/token")] + [Obsolete("This endpoint is deprecated. Use PUT /identifier/{identifier}/token instead.")] + public async Task PostToken(string identifier, [FromBody] DeviceTokenRequestModel model) + { + await PutToken(identifier, model); + } + [HttpPut("identifier/{identifier}/web-push-auth")] - [HttpPost("identifier/{identifier}/web-push-auth")] public async Task PutWebPushAuth(string identifier, [FromBody] WebPushAuthRequestModel model) { var device = await _deviceRepository.GetByIdentifierAsync(identifier, _userService.GetProperUserId(User).Value); @@ -216,9 +233,15 @@ public class DevicesController : Controller ); } + [HttpPost("identifier/{identifier}/web-push-auth")] + [Obsolete("This endpoint is deprecated. Use PUT /identifier/{identifier}/web-push-auth instead.")] + public async Task PostWebPushAuth(string identifier, [FromBody] WebPushAuthRequestModel model) + { + await PutWebPushAuth(identifier, model); + } + [AllowAnonymous] [HttpPut("identifier/{identifier}/clear-token")] - [HttpPost("identifier/{identifier}/clear-token")] public async Task PutClearToken(string identifier) { var device = await _deviceRepository.GetByIdentifierAsync(identifier); @@ -230,8 +253,15 @@ public class DevicesController : Controller await _deviceService.ClearTokenAsync(device); } + [AllowAnonymous] + [HttpPost("identifier/{identifier}/clear-token")] + [Obsolete("This endpoint is deprecated. Use PUT /identifier/{identifier}/clear-token instead.")] + public async Task PostClearToken(string identifier) + { + await PutClearToken(identifier); + } + [HttpDelete("{id}")] - [HttpPost("{id}/deactivate")] public async Task Deactivate(string id) { var device = await _deviceRepository.GetByIdAsync(new Guid(id), _userService.GetProperUserId(User).Value); @@ -243,17 +273,24 @@ public class DevicesController : Controller await _deviceService.DeactivateAsync(device); } + [HttpPost("{id}/deactivate")] + [Obsolete("This endpoint is deprecated. Use DELETE /{id} instead.")] + public async Task PostDeactivate(string id) + { + await Deactivate(id); + } + [AllowAnonymous] [HttpGet("knowndevice")] public async Task GetByIdentifierQuery( [Required][FromHeader(Name = "X-Request-Email")] string Email, [Required][FromHeader(Name = "X-Device-Identifier")] string DeviceIdentifier) - => await GetByIdentifier(CoreHelpers.Base64UrlDecodeString(Email), DeviceIdentifier); + => await GetByEmailAndIdentifier(CoreHelpers.Base64UrlDecodeString(Email), DeviceIdentifier); [Obsolete("Path is deprecated due to encoding issues, use /knowndevice instead.")] [AllowAnonymous] [HttpGet("knowndevice/{email}/{identifier}")] - public async Task GetByIdentifier(string email, string identifier) + public async Task GetByEmailAndIdentifier(string email, string identifier) { if (string.IsNullOrWhiteSpace(email) || string.IsNullOrWhiteSpace(identifier)) { diff --git a/src/Api/Controllers/InfoController.cs b/src/Api/Controllers/InfoController.cs index edfd18c79e..590a3006c0 100644 --- a/src/Api/Controllers/InfoController.cs +++ b/src/Api/Controllers/InfoController.cs @@ -6,12 +6,18 @@ namespace Bit.Api.Controllers; public class InfoController : Controller { [HttpGet("~/alive")] - [HttpGet("~/now")] public DateTime GetAlive() { return DateTime.UtcNow; } + [HttpGet("~/now")] + [Obsolete("This endpoint is deprecated. Use GET /alive instead.")] + public DateTime GetNow() + { + return GetAlive(); + } + [HttpGet("~/version")] public JsonResult GetVersion() { diff --git a/src/Api/Controllers/SettingsController.cs b/src/Api/Controllers/SettingsController.cs index 8489b137e8..e872eeeeac 100644 --- a/src/Api/Controllers/SettingsController.cs +++ b/src/Api/Controllers/SettingsController.cs @@ -32,7 +32,6 @@ public class SettingsController : Controller } [HttpPut("domains")] - [HttpPost("domains")] public async Task PutDomains([FromBody] UpdateDomainsRequestModel model) { var user = await _userService.GetUserByPrincipalAsync(User); @@ -46,4 +45,11 @@ public class SettingsController : Controller var response = new DomainsResponseModel(user); return response; } + + [HttpPost("domains")] + [Obsolete("This endpoint is deprecated. Use PUT /domains instead.")] + public async Task PostDomains([FromBody] UpdateDomainsRequestModel model) + { + return await PutDomains(model); + } } diff --git a/src/Api/Utilities/ServiceCollectionExtensions.cs b/src/Api/Utilities/ServiceCollectionExtensions.cs index aa2710c42a..0d8c3dec38 100644 --- a/src/Api/Utilities/ServiceCollectionExtensions.cs +++ b/src/Api/Utilities/ServiceCollectionExtensions.cs @@ -82,15 +82,7 @@ public static class ServiceCollectionExtensions config.DescribeAllParametersInCamelCase(); // config.UseReferencedDefinitionsForEnums(); - config.SchemaFilter(); - config.SchemaFilter(); - - // These two filters require debug symbols/git, so only add them in development mode - if (environment.IsDevelopment()) - { - config.DocumentFilter(); - config.OperationFilter(); - } + config.InitializeSwaggerFilters(environment); var apiFilePath = Path.Combine(AppContext.BaseDirectory, "Api.xml"); config.IncludeXmlComments(apiFilePath, true); diff --git a/src/Identity/Controllers/InfoController.cs b/src/Identity/Controllers/InfoController.cs index 05cf3f2363..79dfd99c44 100644 --- a/src/Identity/Controllers/InfoController.cs +++ b/src/Identity/Controllers/InfoController.cs @@ -6,12 +6,18 @@ namespace Bit.Identity.Controllers; public class InfoController : Controller { [HttpGet("~/alive")] - [HttpGet("~/now")] public DateTime GetAlive() { return DateTime.UtcNow; } + [HttpGet("~/now")] + [Obsolete("This endpoint is deprecated. Use GET /alive instead.")] + public DateTime GetNow() + { + return GetAlive(); + } + [HttpGet("~/version")] public JsonResult GetVersion() { diff --git a/src/Identity/Startup.cs b/src/Identity/Startup.cs index ae628197e8..8da31d87d6 100644 --- a/src/Identity/Startup.cs +++ b/src/Identity/Startup.cs @@ -66,15 +66,7 @@ public class Startup services.AddSwaggerGen(config => { - config.SchemaFilter(); - config.SchemaFilter(); - - // These two filters require debug symbols/git, so only add them in development mode - if (Environment.IsDevelopment()) - { - config.DocumentFilter(); - config.OperationFilter(); - } + config.InitializeSwaggerFilters(Environment); config.SwaggerDoc("v1", new OpenApiInfo { Title = "Bitwarden Identity", Version = "v1" }); }); diff --git a/src/SharedWeb/Swagger/ActionNameOperationFilter.cs b/src/SharedWeb/Swagger/ActionNameOperationFilter.cs new file mode 100644 index 0000000000..b76e8864ba --- /dev/null +++ b/src/SharedWeb/Swagger/ActionNameOperationFilter.cs @@ -0,0 +1,25 @@ +using System.Text.Json; +using Microsoft.OpenApi.Any; +using Microsoft.OpenApi.Models; +using Swashbuckle.AspNetCore.SwaggerGen; + +namespace Bit.SharedWeb.Swagger; + +/// +/// Adds the action name (function name) as an extension to each operation in the Swagger document. +/// This can be useful for the code generation process, to generate more meaningful names for operations. +/// Note that we add both the original action name and a snake_case version, as the codegen templates +/// cannot do case conversions. +/// +public class ActionNameOperationFilter : IOperationFilter +{ + public void Apply(OpenApiOperation operation, OperationFilterContext context) + { + if (!context.ApiDescription.ActionDescriptor.RouteValues.TryGetValue("action", out var action)) return; + if (string.IsNullOrEmpty(action)) return; + + operation.Extensions.Add("x-action-name", new OpenApiString(action)); + // We can't do case changes in the codegen templates, so we also add the snake_case version of the action name + operation.Extensions.Add("x-action-name-snake-case", new OpenApiString(JsonNamingPolicy.SnakeCaseLower.ConvertName(action))); + } +} diff --git a/src/SharedWeb/Swagger/CheckDuplicateOperationIdsDocumentFilter.cs b/src/SharedWeb/Swagger/CheckDuplicateOperationIdsDocumentFilter.cs new file mode 100644 index 0000000000..3079a9171a --- /dev/null +++ b/src/SharedWeb/Swagger/CheckDuplicateOperationIdsDocumentFilter.cs @@ -0,0 +1,80 @@ +using Microsoft.OpenApi.Models; +using Swashbuckle.AspNetCore.SwaggerGen; + +namespace Bit.SharedWeb.Swagger; + +/// +/// Checks for duplicate operation IDs in the Swagger document, and throws an error if any are found. +/// Operation IDs must be unique across the entire Swagger document according to the OpenAPI specification, +/// but we use controller action names to generate them, which can lead to duplicates if a Controller function +/// has multiple HTTP methods or if a Controller has overloaded functions. +/// +public class CheckDuplicateOperationIdsDocumentFilter(bool printDuplicates = true) : IDocumentFilter +{ + public bool PrintDuplicates { get; } = printDuplicates; + + public void Apply(OpenApiDocument swaggerDoc, DocumentFilterContext context) + { + var operationIdMap = new Dictionary>(); + + foreach (var (path, pathItem) in swaggerDoc.Paths) + { + foreach (var operation in pathItem.Operations) + { + if (!operationIdMap.TryGetValue(operation.Value.OperationId, out var list)) + { + list = []; + operationIdMap[operation.Value.OperationId] = list; + } + + list.Add((path, pathItem, operation.Key, operation.Value)); + + } + } + + // Find duplicates + var duplicates = operationIdMap.Where((kvp) => kvp.Value.Count > 1).ToList(); + if (duplicates.Count > 0) + { + if (PrintDuplicates) + { + Console.WriteLine($"\n######## Duplicate operationIds found in the schema ({duplicates.Count} found) ########\n"); + + Console.WriteLine("## Common causes of duplicate operation IDs:"); + Console.WriteLine("- Multiple HTTP methods (GET, POST, etc.) on the same controller function"); + Console.WriteLine(" Solution: Split the methods into separate functions, and if appropiate, mark the deprecated ones with [Obsolete]"); + Console.WriteLine(); + Console.WriteLine("- Overloaded controller functions with the same name"); + Console.WriteLine(" Solution: Rename the overloaded functions to have unique names, or combine them into a single function with optional parameters"); + Console.WriteLine(); + + Console.WriteLine("## The duplicate operation IDs are:"); + + foreach (var (operationId, duplicate) in duplicates) + { + Console.WriteLine($"- operationId: {operationId}"); + foreach (var (path, pathItem, method, operation) in duplicate) + { + Console.Write($" {method.ToString().ToUpper()} {path}"); + + + if (operation.Extensions.TryGetValue("x-source-file", out var sourceFile) && operation.Extensions.TryGetValue("x-source-line", out var sourceLine)) + { + var sourceFileString = ((Microsoft.OpenApi.Any.OpenApiString)sourceFile).Value; + var sourceLineString = ((Microsoft.OpenApi.Any.OpenApiInteger)sourceLine).Value; + + Console.WriteLine($" {sourceFileString}:{sourceLineString}"); + } + else + { + Console.WriteLine(); + } + } + Console.WriteLine("\n"); + } + } + + throw new InvalidOperationException($"Duplicate operation IDs found in Swagger schema"); + } + } +} diff --git a/src/SharedWeb/Swagger/SwaggerGenOptionsExt.cs b/src/SharedWeb/Swagger/SwaggerGenOptionsExt.cs new file mode 100644 index 0000000000..60803705d6 --- /dev/null +++ b/src/SharedWeb/Swagger/SwaggerGenOptionsExt.cs @@ -0,0 +1,33 @@ +using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Swashbuckle.AspNetCore.SwaggerGen; + +namespace Bit.SharedWeb.Swagger; + +public static class SwaggerGenOptionsExt +{ + + public static void InitializeSwaggerFilters( + this SwaggerGenOptions config, IWebHostEnvironment environment) + { + config.SchemaFilter(); + config.SchemaFilter(); + + config.OperationFilter(); + + // Set the operation ID to the name of the controller followed by the name of the function. + // Note that the "Controller" suffix for the controllers, and the "Async" suffix for the actions + // are removed already, so we don't need to do that ourselves. + // TODO(Dani): This is disabled until we remove all the duplicate operation IDs. + // config.CustomOperationIds(e => $"{e.ActionDescriptor.RouteValues["controller"]}_{e.ActionDescriptor.RouteValues["action"]}"); + // config.DocumentFilter(); + + // These two filters require debug symbols/git, so only add them in development mode + if (environment.IsDevelopment()) + { + config.DocumentFilter(); + config.OperationFilter(); + } + } +} diff --git a/test/Api.Test/Auth/Controllers/AuthRequestsControllerTests.cs b/test/Api.Test/Auth/Controllers/AuthRequestsControllerTests.cs index 828911f6bd..1b8e7aba8e 100644 --- a/test/Api.Test/Auth/Controllers/AuthRequestsControllerTests.cs +++ b/test/Api.Test/Auth/Controllers/AuthRequestsControllerTests.cs @@ -43,7 +43,7 @@ public class AuthRequestsControllerTests .Returns([authRequest]); // Act - var result = await sutProvider.Sut.Get(); + var result = await sutProvider.Sut.GetAll(); // Assert Assert.NotNull(result); diff --git a/test/Api.Test/Auth/Controllers/DevicesControllerTests.cs b/test/Api.Test/Auth/Controllers/DevicesControllerTests.cs index 540d23f98b..bed483f83a 100644 --- a/test/Api.Test/Auth/Controllers/DevicesControllerTests.cs +++ b/test/Api.Test/Auth/Controllers/DevicesControllerTests.cs @@ -73,7 +73,7 @@ public class DevicesControllerTest _deviceRepositoryMock.GetManyByUserIdWithDeviceAuth(userId).Returns(devicesWithPendingAuthData); // Act - var result = await _sut.Get(); + var result = await _sut.GetAll(); // Assert Assert.NotNull(result); @@ -94,6 +94,6 @@ public class DevicesControllerTest _userServiceMock.GetProperUserId(Arg.Any()).Returns((Guid?)null); // Act & Assert - await Assert.ThrowsAsync(() => _sut.Get()); + await Assert.ThrowsAsync(() => _sut.GetAll()); } } diff --git a/test/Api.Test/Controllers/CollectionsControllerTests.cs b/test/Api.Test/Controllers/CollectionsControllerTests.cs index a3d34efb63..33b7e20327 100644 --- a/test/Api.Test/Controllers/CollectionsControllerTests.cs +++ b/test/Api.Test/Controllers/CollectionsControllerTests.cs @@ -177,7 +177,7 @@ public class CollectionsControllerTests .GetManySharedCollectionsByOrganizationIdAsync(organization.Id) .Returns(collections); - var response = await sutProvider.Sut.Get(organization.Id); + var response = await sutProvider.Sut.GetAll(organization.Id); await sutProvider.GetDependency().Received(1).GetManySharedCollectionsByOrganizationIdAsync(organization.Id); @@ -219,7 +219,7 @@ public class CollectionsControllerTests .GetManyByUserIdAsync(userId) .Returns(collections); - var result = await sutProvider.Sut.Get(organization.Id); + var result = await sutProvider.Sut.GetAll(organization.Id); await sutProvider.GetDependency().DidNotReceive().GetManyByOrganizationIdAsync(organization.Id); await sutProvider.GetDependency().Received(1).GetManyByUserIdAsync(userId); diff --git a/test/SharedWeb.Test/ActionNameOperationFilterTest.cs b/test/SharedWeb.Test/ActionNameOperationFilterTest.cs new file mode 100644 index 0000000000..c798adea8c --- /dev/null +++ b/test/SharedWeb.Test/ActionNameOperationFilterTest.cs @@ -0,0 +1,67 @@ +using Bit.SharedWeb.Swagger; +using Microsoft.AspNetCore.Mvc.Abstractions; +using Microsoft.AspNetCore.Mvc.ApiExplorer; +using Microsoft.OpenApi.Any; +using Microsoft.OpenApi.Models; +using Swashbuckle.AspNetCore.SwaggerGen; + +namespace SharedWeb.Test; + +public class ActionNameOperationFilterTest +{ + [Fact] + public void WithValidActionNameAddsActionNameExtensions() + { + // Arrange + var operation = new OpenApiOperation(); + var actionDescriptor = new ActionDescriptor(); + actionDescriptor.RouteValues["action"] = "GetUsers"; + + var apiDescription = new ApiDescription + { + ActionDescriptor = actionDescriptor + }; + + var context = new OperationFilterContext(apiDescription, null, null, null); + var filter = new ActionNameOperationFilter(); + + // Act + filter.Apply(operation, context); + + // Assert + Assert.True(operation.Extensions.ContainsKey("x-action-name")); + Assert.True(operation.Extensions.ContainsKey("x-action-name-snake-case")); + + var actionNameExt = operation.Extensions["x-action-name"] as OpenApiString; + var actionNameSnakeCaseExt = operation.Extensions["x-action-name-snake-case"] as OpenApiString; + + Assert.NotNull(actionNameExt); + Assert.NotNull(actionNameSnakeCaseExt); + Assert.Equal("GetUsers", actionNameExt.Value); + Assert.Equal("get_users", actionNameSnakeCaseExt.Value); + } + + [Fact] + public void WithMissingActionRouteValueDoesNotAddExtensions() + { + // Arrange + var operation = new OpenApiOperation(); + var actionDescriptor = new ActionDescriptor(); + // Not setting the "action" route value at all + + var apiDescription = new ApiDescription + { + ActionDescriptor = actionDescriptor + }; + + var context = new OperationFilterContext(apiDescription, null, null, null); + var filter = new ActionNameOperationFilter(); + + // Act + filter.Apply(operation, context); + + // Assert + Assert.False(operation.Extensions.ContainsKey("x-action-name")); + Assert.False(operation.Extensions.ContainsKey("x-action-name-snake-case")); + } +} diff --git a/test/SharedWeb.Test/CheckDuplicateOperationIdsDocumentFilterTest.cs b/test/SharedWeb.Test/CheckDuplicateOperationIdsDocumentFilterTest.cs new file mode 100644 index 0000000000..7b7c5771d3 --- /dev/null +++ b/test/SharedWeb.Test/CheckDuplicateOperationIdsDocumentFilterTest.cs @@ -0,0 +1,84 @@ +using Bit.SharedWeb.Swagger; +using Microsoft.AspNetCore.Mvc; +using Microsoft.OpenApi.Models; +using Swashbuckle.AspNetCore.SwaggerGen; + +namespace SharedWeb.Test; + +public class UniqueOperationIdsController : ControllerBase +{ + [HttpGet("unique-get")] + public void UniqueGetAction() { } + + [HttpPost("unique-post")] + public void UniquePostAction() { } +} + +public class OverloadedOperationIdsController : ControllerBase +{ + [HttpPut("another-duplicate")] + public void AnotherDuplicateAction() { } + + [HttpPatch("another-duplicate/{id}")] + public void AnotherDuplicateAction(int id) { } +} + +public class MultipleHttpMethodsController : ControllerBase +{ + [HttpGet("multi-method")] + [HttpPost("multi-method")] + [HttpPut("multi-method")] + public void MultiMethodAction() { } +} + +public class CheckDuplicateOperationIdsDocumentFilterTest +{ + [Fact] + public void UniqueOperationIdsDoNotThrowException() + { + // Arrange + var (swaggerDoc, context) = SwaggerDocUtil.CreateDocFromControllers(typeof(UniqueOperationIdsController)); + var filter = new CheckDuplicateOperationIdsDocumentFilter(); + filter.Apply(swaggerDoc, context); + // Act & Assert + var exception = Record.Exception(() => filter.Apply(swaggerDoc, context)); + Assert.Null(exception); + } + + [Fact] + public void DuplicateOperationIdsThrowInvalidOperationException() + { + // Arrange + var (swaggerDoc, context) = SwaggerDocUtil.CreateDocFromControllers(typeof(OverloadedOperationIdsController)); + var filter = new CheckDuplicateOperationIdsDocumentFilter(false); + + // Act & Assert + var exception = Assert.Throws(() => filter.Apply(swaggerDoc, context)); + Assert.Contains("Duplicate operation IDs found in Swagger schema", exception.Message); + } + + [Fact] + public void MultipleHttpMethodsThrowInvalidOperationException() + { + // Arrange + var (swaggerDoc, context) = SwaggerDocUtil.CreateDocFromControllers(typeof(MultipleHttpMethodsController)); + var filter = new CheckDuplicateOperationIdsDocumentFilter(false); + + // Act & Assert + var exception = Assert.Throws(() => filter.Apply(swaggerDoc, context)); + Assert.Contains("Duplicate operation IDs found in Swagger schema", exception.Message); + } + + [Fact] + public void EmptySwaggerDocDoesNotThrowException() + { + // Arrange + var swaggerDoc = new OpenApiDocument { Paths = [] }; + var context = new DocumentFilterContext([], null, null); + var filter = new CheckDuplicateOperationIdsDocumentFilter(false); + + // Act & Assert + var exception = Record.Exception(() => filter.Apply(swaggerDoc, context)); + Assert.Null(exception); + } +} diff --git a/test/SharedWeb.Test/SharedWeb.Test.csproj b/test/SharedWeb.Test/SharedWeb.Test.csproj index 8ae7a56a99..c631ac9227 100644 --- a/test/SharedWeb.Test/SharedWeb.Test.csproj +++ b/test/SharedWeb.Test/SharedWeb.Test.csproj @@ -9,6 +9,7 @@ all + diff --git a/test/SharedWeb.Test/SwaggerDocUtil.cs b/test/SharedWeb.Test/SwaggerDocUtil.cs new file mode 100644 index 0000000000..45a3033dec --- /dev/null +++ b/test/SharedWeb.Test/SwaggerDocUtil.cs @@ -0,0 +1,85 @@ +using System.Reflection; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.ApiExplorer; +using Microsoft.AspNetCore.Mvc.ApplicationParts; +using Microsoft.AspNetCore.Mvc.Controllers; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.OpenApi.Models; +using NSubstitute; +using Swashbuckle.AspNetCore.Swagger; +using Swashbuckle.AspNetCore.SwaggerGen; + +namespace SharedWeb.Test; + +public class SwaggerDocUtil +{ + /// + /// Creates an OpenApiDocument and DocumentFilterContext from the specified controller type by setting up + /// a minimal service collection and using the SwaggerProvider to generate the document. + /// + public static (OpenApiDocument, DocumentFilterContext) CreateDocFromControllers(params Type[] controllerTypes) + { + if (controllerTypes.Length == 0) + { + throw new ArgumentException("At least one controller type must be provided", nameof(controllerTypes)); + } + + var services = new ServiceCollection(); + services.AddLogging(); + services.AddSingleton(Substitute.For()); + services.AddControllers() + .ConfigureApplicationPartManager(manager => + { + // Clear existing parts and feature providers + manager.ApplicationParts.Clear(); + manager.FeatureProviders.Clear(); + + // Add a custom feature provider that only includes the specific controller types + manager.FeatureProviders.Add(new MultipleControllerFeatureProvider(controllerTypes)); + + // Add assembly parts for all unique assemblies containing the controllers + foreach (var assembly in controllerTypes.Select(t => t.Assembly).Distinct()) + { + manager.ApplicationParts.Add(new AssemblyPart(assembly)); + } + }); + services.AddSwaggerGen(config => + { + config.SwaggerDoc("v1", new OpenApiInfo { Title = "Test API", Version = "v1" }); + config.CustomOperationIds(e => $"{e.ActionDescriptor.RouteValues["controller"]}_{e.ActionDescriptor.RouteValues["action"]}"); + }); + var serviceProvider = services.BuildServiceProvider(); + + // Get API descriptions + var allApiDescriptions = serviceProvider.GetRequiredService() + .ApiDescriptionGroups.Items + .SelectMany(group => group.Items) + .ToList(); + + if (allApiDescriptions.Count == 0) + { + throw new InvalidOperationException("No API descriptions found for controller, ensure your controllers are defined correctly (public, not nested, inherit from ControllerBase, etc.)"); + } + + // Generate the swagger document and context + var document = serviceProvider.GetRequiredService().GetSwagger("v1"); + var schemaGenerator = serviceProvider.GetRequiredService(); + var context = new DocumentFilterContext(allApiDescriptions, schemaGenerator, new SchemaRepository()); + + return (document, context); + } +} + +public class MultipleControllerFeatureProvider(params Type[] controllerTypes) : ControllerFeatureProvider +{ + private readonly HashSet _allowedControllerTypes = [.. controllerTypes]; + + protected override bool IsController(TypeInfo typeInfo) + { + return _allowedControllerTypes.Contains(typeInfo.AsType()) + && typeInfo.IsClass + && !typeInfo.IsAbstract + && typeof(ControllerBase).IsAssignableFrom(typeInfo); + } +} From 53e5ddb1a719aa4ca23004196b569c12c0ac6722 Mon Sep 17 00:00:00 2001 From: Patrick-Pimentel-Bitwarden Date: Tue, 2 Sep 2025 12:44:28 -0400 Subject: [PATCH 04/13] fix(inactive-user-server-notification): [PM-25130] Inactive User Server Notify - Added feature flag. (#6270) --- src/Core/Constants.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 352daee862..ce18706bd4 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -187,6 +187,7 @@ public static class FeatureFlagKeys public const string PersistPopupView = "persist-popup-view"; public const string IpcChannelFramework = "ipc-channel-framework"; public const string PushNotificationsWhenLocked = "pm-19388-push-notifications-when-locked"; + public const string PushNotificationsWhenInactive = "pm-25130-receive-push-notifications-for-inactive-users"; /* Tools Team */ public const string DesktopSendUIRefresh = "desktop-send-ui-refresh"; From a5bed5dcaab3aba3aba588531561add60927b273 Mon Sep 17 00:00:00 2001 From: Thomas Avery <43214426+Thomas-Avery@users.noreply.github.com> Date: Tue, 2 Sep 2025 15:02:02 -0500 Subject: [PATCH 05/13] [PM-25384] Add feature flag (#6271) --- src/Core/Constants.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index ce18706bd4..393ab15e4c 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -163,6 +163,7 @@ public static class FeatureFlagKeys public const string UserSdkForDecryption = "use-sdk-for-decryption"; public const string PM17987_BlockType0 = "pm-17987-block-type-0"; public const string ForceUpdateKDFSettings = "pm-18021-force-update-kdf-settings"; + public const string UnlockWithMasterPasswordUnlockData = "pm-23246-unlock-with-master-password-unlock-data"; /* Mobile Team */ public const string NativeCarouselFlow = "native-carousel-flow"; From d2d3e0f11b6950b172dd3e16ac1f47f310e9ec50 Mon Sep 17 00:00:00 2001 From: Ike <137194738+ike-kottlowski@users.noreply.github.com> Date: Tue, 2 Sep 2025 16:48:57 -0400 Subject: [PATCH 06/13] [PM-22678] Send email otp authentication method (#6255) feat(auth): email OTP validation, and generalize authentication interface - Generalized send authentication method interface - Made validate method async - Added email mail support for Handlebars - Modified email templates to match future implementation fix(auth): update constants, naming conventions, and error handling - Renamed constants for clarity - Updated claims naming convention - Fixed error message generation - Added customResponse for Rust consumption test(auth): add and fix tests for validators and email - Added tests for SendEmailOtpRequestValidator - Updated tests for SendAccessGrantValidator chore: apply dotnet formatting --- .../SendAccessClaimsPrincipalExtensions.cs | 6 +- src/Core/Identity/Claims.cs | 7 +- .../Auth/SendAccessEmailOtpEmail.html.hbs | 28 ++ .../Auth/SendAccessEmailOtpEmail.text.hbs | 9 + .../Mail/Auth/DefaultEmailOtpViewModel.cs | 12 + src/Core/Services/IMailService.cs | 1 + .../Implementations/HandlebarsMailService.cs | 21 ++ .../NoopImplementations/NoopMailService.cs | 5 + src/Identity/IdentityServer/ApiResources.cs | 2 +- .../ISendAuthenticationMethodValidator.cs | 15 + .../ISendPasswordRequestValidator.cs | 16 - .../SendAccess/SendAccessConstants.cs | 37 ++- .../SendAccess/SendAccessGrantValidator.cs | 38 +-- .../SendEmailOtpRequestValidator.cs | 134 ++++++++ .../SendPasswordRequestValidator.cs | 16 +- .../Utilities/ServiceCollectionExtensions.cs | 4 +- ...endAccessClaimsPrincipalExtensionsTests.cs | 8 +- .../Services/HandlebarsMailServiceTests.cs | 15 +- ...endAccessGrantValidatorIntegrationTests.cs | 4 +- ...EmailOtpReqestValidatorIntegrationTests.cs | 256 +++++++++++++++ .../SendAccessGrantValidatorTests.cs | 30 +- .../SendEmailOtpRequestValidatorTests.cs | 310 ++++++++++++++++++ .../SendPasswordRequestValidatorTests.cs | 297 +++++++++++++++++ .../SendPasswordRequestValidatorTests.cs | 32 +- 24 files changed, 1213 insertions(+), 90 deletions(-) create mode 100644 src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmail.html.hbs create mode 100644 src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmail.text.hbs create mode 100644 src/Core/Models/Mail/Auth/DefaultEmailOtpViewModel.cs create mode 100644 src/Identity/IdentityServer/RequestValidators/SendAccess/ISendAuthenticationMethodValidator.cs delete mode 100644 src/Identity/IdentityServer/RequestValidators/SendAccess/ISendPasswordRequestValidator.cs create mode 100644 src/Identity/IdentityServer/RequestValidators/SendAccess/SendEmailOtpRequestValidator.cs create mode 100644 test/Identity.IntegrationTest/RequestValidation/SendEmailOtpReqestValidatorIntegrationTests.cs rename test/Identity.Test/IdentityServer/{ => SendAccess}/SendAccessGrantValidatorTests.cs (90%) create mode 100644 test/Identity.Test/IdentityServer/SendAccess/SendEmailOtpRequestValidatorTests.cs create mode 100644 test/Identity.Test/IdentityServer/SendAccess/SendPasswordRequestValidatorTests.cs diff --git a/src/Core/Auth/UserFeatures/SendAccess/SendAccessClaimsPrincipalExtensions.cs b/src/Core/Auth/UserFeatures/SendAccess/SendAccessClaimsPrincipalExtensions.cs index 1feadaf081..7ae7355ba4 100644 --- a/src/Core/Auth/UserFeatures/SendAccess/SendAccessClaimsPrincipalExtensions.cs +++ b/src/Core/Auth/UserFeatures/SendAccess/SendAccessClaimsPrincipalExtensions.cs @@ -9,12 +9,12 @@ public static class SendAccessClaimsPrincipalExtensions { ArgumentNullException.ThrowIfNull(user); - var sendIdClaim = user.FindFirst(Claims.SendId) - ?? throw new InvalidOperationException("Send ID claim not found."); + var sendIdClaim = user.FindFirst(Claims.SendAccessClaims.SendId) + ?? throw new InvalidOperationException("send_id claim not found."); if (!Guid.TryParse(sendIdClaim.Value, out var sendGuid)) { - throw new InvalidOperationException("Invalid Send ID claim value."); + throw new InvalidOperationException("Invalid send_id claim value."); } return sendGuid; diff --git a/src/Core/Identity/Claims.cs b/src/Core/Identity/Claims.cs index ef3d5e450c..39a036f3f9 100644 --- a/src/Core/Identity/Claims.cs +++ b/src/Core/Identity/Claims.cs @@ -39,6 +39,9 @@ public static class Claims public const string ManageResetPassword = "manageresetpassword"; public const string ManageScim = "managescim"; } - - public const string SendId = "send_id"; + public static class SendAccessClaims + { + public const string SendId = "send_id"; + public const string Email = "send_email"; + } } diff --git a/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmail.html.hbs b/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmail.html.hbs new file mode 100644 index 0000000000..5bf1f24218 --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmail.html.hbs @@ -0,0 +1,28 @@ +{{#>FullHtmlLayout}} + + + + + + + + + + + + + +
+ Verify your email to access this Bitwarden Send. +
+
+ Your verification code is: {{Token}} +
+
+ This code can only be used once and expires in 5 minutes. After that you'll need to verify your email again. +
+
+
+ {{TheDate}} at {{TheTime}} {{TimeZone}} +
+{{/FullHtmlLayout}} \ No newline at end of file diff --git a/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmail.text.hbs b/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmail.text.hbs new file mode 100644 index 0000000000..f83008c30b --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmail.text.hbs @@ -0,0 +1,9 @@ +{{#>BasicTextLayout}} +Verify your email to access this Bitwarden Send. + +Your verification code is: {{Token}} + +This code can only be used once and expires in 5 minutes. After that you'll need to verify your email again. + +Date : {{TheDate}} at {{TheTime}} {{TimeZone}} +{{/BasicTextLayout}} \ No newline at end of file diff --git a/src/Core/Models/Mail/Auth/DefaultEmailOtpViewModel.cs b/src/Core/Models/Mail/Auth/DefaultEmailOtpViewModel.cs new file mode 100644 index 0000000000..5faf550e60 --- /dev/null +++ b/src/Core/Models/Mail/Auth/DefaultEmailOtpViewModel.cs @@ -0,0 +1,12 @@ +namespace Bit.Core.Models.Mail.Auth; + +/// +/// Send email OTP view model +/// +public class DefaultEmailOtpViewModel : BaseMailModel +{ + public string? Token { get; set; } + public string? TheDate { get; set; } + public string? TheTime { get; set; } + public string? TimeZone { get; set; } +} diff --git a/src/Core/Services/IMailService.cs b/src/Core/Services/IMailService.cs index 32aaac84b7..a38328dc9d 100644 --- a/src/Core/Services/IMailService.cs +++ b/src/Core/Services/IMailService.cs @@ -30,6 +30,7 @@ public interface IMailService Task SendChangeEmailAlreadyExistsEmailAsync(string fromEmail, string toEmail); Task SendChangeEmailEmailAsync(string newEmailAddress, string token); Task SendTwoFactorEmailAsync(string email, string accountEmail, string token, string deviceIp, string deviceType, TwoFactorEmailPurpose purpose); + Task SendSendEmailOtpEmailAsync(string email, string token, string subject); Task SendFailedTwoFactorAttemptEmailAsync(string email, TwoFactorProviderType type, DateTime utcNow, string ip); Task SendNoMasterPasswordHintEmailAsync(string email); Task SendMasterPasswordHintEmailAsync(string email, string hint); diff --git a/src/Core/Services/Implementations/HandlebarsMailService.cs b/src/Core/Services/Implementations/HandlebarsMailService.cs index f06a37fa3b..394b5c5125 100644 --- a/src/Core/Services/Implementations/HandlebarsMailService.cs +++ b/src/Core/Services/Implementations/HandlebarsMailService.cs @@ -15,6 +15,7 @@ using Bit.Core.Billing.Models.Mail; using Bit.Core.Entities; using Bit.Core.Models.Data.Organizations; using Bit.Core.Models.Mail; +using Bit.Core.Models.Mail.Auth; using Bit.Core.Models.Mail.Billing; using Bit.Core.Models.Mail.FamiliesForEnterprise; using Bit.Core.Models.Mail.Provider; @@ -199,6 +200,26 @@ public class HandlebarsMailService : IMailService await _mailDeliveryService.SendEmailAsync(message); } + public async Task SendSendEmailOtpEmailAsync(string email, string token, string subject) + { + var message = CreateDefaultMessage(subject, email); + var requestDateTime = DateTime.UtcNow; + var model = new DefaultEmailOtpViewModel + { + Token = token, + TheDate = requestDateTime.ToLongDateString(), + TheTime = requestDateTime.ToShortTimeString(), + TimeZone = _utcTimeZoneDisplay, + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + }; + await AddMessageContentAsync(message, "Auth.SendAccessEmailOtpEmail", model); + message.MetaData.Add("SendGridBypassListManagement", true); + // TODO - PM-25380 change to string constant + message.Category = "SendEmailOtp"; + await _mailDeliveryService.SendEmailAsync(message); + } + public async Task SendFailedTwoFactorAttemptEmailAsync(string email, TwoFactorProviderType failedType, DateTime utcNow, string ip) { // Check if we've sent this email within the last hour diff --git a/src/Core/Services/NoopImplementations/NoopMailService.cs b/src/Core/Services/NoopImplementations/NoopMailService.cs index 5847aaf929..bc73fb5398 100644 --- a/src/Core/Services/NoopImplementations/NoopMailService.cs +++ b/src/Core/Services/NoopImplementations/NoopMailService.cs @@ -93,6 +93,11 @@ public class NoopMailService : IMailService return Task.FromResult(0); } + public Task SendSendEmailOtpEmailAsync(string email, string token, string subject) + { + return Task.FromResult(0); + } + public Task SendFailedTwoFactorAttemptEmailAsync(string email, TwoFactorProviderType failedType, DateTime utcNow, string ip) { return Task.FromResult(0); diff --git a/src/Identity/IdentityServer/ApiResources.cs b/src/Identity/IdentityServer/ApiResources.cs index eea53734cb..61f3dd10ba 100644 --- a/src/Identity/IdentityServer/ApiResources.cs +++ b/src/Identity/IdentityServer/ApiResources.cs @@ -29,7 +29,7 @@ public class ApiResources }), new(ApiScopes.ApiSendAccess, [ JwtClaimTypes.Subject, - Claims.SendId + Claims.SendAccessClaims.SendId ]), new(ApiScopes.Internal, new[] { JwtClaimTypes.Subject }), new(ApiScopes.ApiPush, new[] { JwtClaimTypes.Subject }), diff --git a/src/Identity/IdentityServer/RequestValidators/SendAccess/ISendAuthenticationMethodValidator.cs b/src/Identity/IdentityServer/RequestValidators/SendAccess/ISendAuthenticationMethodValidator.cs new file mode 100644 index 0000000000..1ffb68ceca --- /dev/null +++ b/src/Identity/IdentityServer/RequestValidators/SendAccess/ISendAuthenticationMethodValidator.cs @@ -0,0 +1,15 @@ +using Bit.Core.Tools.Models.Data; +using Duende.IdentityServer.Validation; + +namespace Bit.Identity.IdentityServer.RequestValidators.SendAccess; + +public interface ISendAuthenticationMethodValidator where T : SendAuthenticationMethod +{ + /// + /// + /// request context + /// SendAuthenticationRecord that contains the information to be compared against the context + /// the sendId being accessed + /// returns the result of the validation; A failed result will be an error a successful will contain the claims and a success + Task ValidateRequestAsync(ExtensionGrantValidationContext context, T authMethod, Guid sendId); +} diff --git a/src/Identity/IdentityServer/RequestValidators/SendAccess/ISendPasswordRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/SendAccess/ISendPasswordRequestValidator.cs deleted file mode 100644 index a6f33175bd..0000000000 --- a/src/Identity/IdentityServer/RequestValidators/SendAccess/ISendPasswordRequestValidator.cs +++ /dev/null @@ -1,16 +0,0 @@ -using Bit.Core.Tools.Models.Data; -using Duende.IdentityServer.Validation; - -namespace Bit.Identity.IdentityServer.RequestValidators.SendAccess; - -public interface ISendPasswordRequestValidator -{ - /// - /// Validates the send password hash against the client hashed password. - /// If this method fails then it will automatically set the context.Result to an invalid grant result. - /// - /// request context - /// resource password authentication method containing the hash of the Send being retrieved - /// returns the result of the validation; A failed result will be an error a successful will contain the claims and a success - GrantValidationResult ValidateSendPassword(ExtensionGrantValidationContext context, ResourcePassword resourcePassword, Guid sendId); -} diff --git a/src/Identity/IdentityServer/RequestValidators/SendAccess/SendAccessConstants.cs b/src/Identity/IdentityServer/RequestValidators/SendAccess/SendAccessConstants.cs index 952f4146ed..fae7ba4215 100644 --- a/src/Identity/IdentityServer/RequestValidators/SendAccess/SendAccessConstants.cs +++ b/src/Identity/IdentityServer/RequestValidators/SendAccess/SendAccessConstants.cs @@ -1,4 +1,5 @@ -using Duende.IdentityServer.Validation; +using Bit.Core.Auth.Identity.TokenProviders; +using Duende.IdentityServer.Validation; namespace Bit.Identity.IdentityServer.RequestValidators.SendAccess; @@ -34,7 +35,7 @@ public static class SendAccessConstants public static class GrantValidatorResults { /// - /// The sendId is valid and the request is well formed. + /// The sendId is valid and the request is well formed. Not returned in any response. /// public const string ValidSendGuid = "valid_send_guid"; /// @@ -66,8 +67,40 @@ public static class SendAccessConstants /// public const string EmailRequired = "email_required"; /// + /// Represents the error code indicating that an email address is invalid. + /// + public const string EmailInvalid = "email_invalid"; + /// /// Represents the status indicating that both email and OTP are required, and the OTP has been sent. /// public const string EmailOtpSent = "email_and_otp_required_otp_sent"; + /// + /// Represents the status indicating that both email and OTP are required, and the OTP is invalid. + /// + public const string EmailOtpInvalid = "otp_invalid"; + /// + /// For what ever reason the OTP was not able to be generated + /// + public const string OtpGenerationFailed = "otp_generation_failed"; + } + + /// + /// These are the constants for the OTP token that is generated during the email otp authentication process. + /// These items are required by to aid in the creation of a unique lookup key. + /// Look up key format is: {TokenProviderName}_{Purpose}_{TokenUniqueIdentifier} + /// + public static class OtpToken + { + public const string TokenProviderName = "send_access"; + public const string Purpose = "email_otp"; + /// + /// This will be send_id {0} and email {1} + /// + public const string TokenUniqueIdentifier = "{0}_{1}"; + } + + public static class OtpEmail + { + public const string Subject = "Your Bitwarden Send verification code is {0}"; } } diff --git a/src/Identity/IdentityServer/RequestValidators/SendAccess/SendAccessGrantValidator.cs b/src/Identity/IdentityServer/RequestValidators/SendAccess/SendAccessGrantValidator.cs index 7cfa2acd2a..5fe0b7b724 100644 --- a/src/Identity/IdentityServer/RequestValidators/SendAccess/SendAccessGrantValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/SendAccess/SendAccessGrantValidator.cs @@ -13,7 +13,8 @@ namespace Bit.Identity.IdentityServer.RequestValidators.SendAccess; public class SendAccessGrantValidator( ISendAuthenticationQuery _sendAuthenticationQuery, - ISendPasswordRequestValidator _sendPasswordRequestValidator, + ISendAuthenticationMethodValidator _sendPasswordRequestValidator, + ISendAuthenticationMethodValidator _sendEmailOtpRequestValidator, IFeatureService _featureService) : IExtensionGrantValidator { @@ -61,16 +62,14 @@ public class SendAccessGrantValidator( // automatically issue access token context.Result = BuildBaseSuccessResult(sendIdGuid); return; - case ResourcePassword rp: - // Validate if the password is correct, or if we need to respond with a 400 stating a password has is required - context.Result = _sendPasswordRequestValidator.ValidateSendPassword(context, rp, sendIdGuid); + // Validate if the password is correct, or if we need to respond with a 400 stating a password is invalid or required. + context.Result = await _sendPasswordRequestValidator.ValidateRequestAsync(context, rp, sendIdGuid); return; case EmailOtp eo: - // TODO PM-22678: We will either send the OTP here or validate it based on if otp exists in the request. - // SendOtpToEmail(eo.Emails) or ValidateOtp(eo.Emails); - // break; - + // Validate if the request has the correct email and OTP. If not, respond with a 400 and information about the failure. + context.Result = await _sendEmailOtpRequestValidator.ValidateRequestAsync(context, eo, sendIdGuid); + return; default: // shouldn’t ever hit this throw new InvalidOperationException($"Unknown auth method: {method.GetType()}"); @@ -114,28 +113,27 @@ public class SendAccessGrantValidator( /// /// Builds an error result for the specified error type. /// - /// The error type. + /// This error is a constant string from /// The error result. private static GrantValidationResult BuildErrorResult(string error) { + var customResponse = new Dictionary + { + { SendAccessConstants.SendAccessError, error } + }; + return error switch { // Request is the wrong shape SendAccessConstants.GrantValidatorResults.MissingSendId => new GrantValidationResult( TokenRequestErrors.InvalidRequest, - errorDescription: _sendGrantValidatorErrorDescriptions[SendAccessConstants.GrantValidatorResults.MissingSendId], - new Dictionary - { - { SendAccessConstants.SendAccessError, SendAccessConstants.GrantValidatorResults.MissingSendId} - }), + errorDescription: _sendGrantValidatorErrorDescriptions[error], + customResponse), // Request is correct shape but data is bad SendAccessConstants.GrantValidatorResults.InvalidSendId => new GrantValidationResult( TokenRequestErrors.InvalidGrant, - errorDescription: _sendGrantValidatorErrorDescriptions[SendAccessConstants.GrantValidatorResults.InvalidSendId], - new Dictionary - { - { SendAccessConstants.SendAccessError, SendAccessConstants.GrantValidatorResults.InvalidSendId } - }), + errorDescription: _sendGrantValidatorErrorDescriptions[error], + customResponse), // should never get here _ => new GrantValidationResult(TokenRequestErrors.InvalidRequest) }; @@ -145,7 +143,7 @@ public class SendAccessGrantValidator( { var claims = new List { - new(Claims.SendId, sendId.ToString()), + new(Claims.SendAccessClaims.SendId, sendId.ToString()), new(Claims.Type, IdentityClientType.Send.ToString()) }; diff --git a/src/Identity/IdentityServer/RequestValidators/SendAccess/SendEmailOtpRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/SendAccess/SendEmailOtpRequestValidator.cs new file mode 100644 index 0000000000..e26556eb80 --- /dev/null +++ b/src/Identity/IdentityServer/RequestValidators/SendAccess/SendEmailOtpRequestValidator.cs @@ -0,0 +1,134 @@ +using System.Security.Claims; +using Bit.Core.Auth.Identity.TokenProviders; +using Bit.Core.Identity; +using Bit.Core.Services; +using Bit.Core.Tools.Models.Data; +using Bit.Identity.IdentityServer.Enums; +using Duende.IdentityServer.Models; +using Duende.IdentityServer.Validation; + +namespace Bit.Identity.IdentityServer.RequestValidators.SendAccess; + +public class SendEmailOtpRequestValidator( + IOtpTokenProvider otpTokenProvider, + IMailService mailService) : ISendAuthenticationMethodValidator +{ + + /// + /// static object that contains the error messages for the SendEmailOtpRequestValidator. + /// + private static readonly Dictionary _sendEmailOtpValidatorErrorDescriptions = new() + { + { SendAccessConstants.EmailOtpValidatorResults.EmailRequired, $"{SendAccessConstants.TokenRequest.Email} is required." }, + { SendAccessConstants.EmailOtpValidatorResults.EmailOtpSent, "email otp sent." }, + { SendAccessConstants.EmailOtpValidatorResults.EmailInvalid, $"{SendAccessConstants.TokenRequest.Email} is invalid." }, + { SendAccessConstants.EmailOtpValidatorResults.EmailOtpInvalid, $"{SendAccessConstants.TokenRequest.Email} otp is invalid." }, + }; + + public async Task ValidateRequestAsync(ExtensionGrantValidationContext context, EmailOtp authMethod, Guid sendId) + { + var request = context.Request.Raw; + // get email + var email = request.Get(SendAccessConstants.TokenRequest.Email); + + // It is an invalid request if the email is missing which indicated bad shape. + if (string.IsNullOrEmpty(email)) + { + // Request is the wrong shape and doesn't contain an email field. + return BuildErrorResult(SendAccessConstants.EmailOtpValidatorResults.EmailRequired); + } + + // email must be in the list of emails in the EmailOtp array + if (!authMethod.Emails.Contains(email)) + { + return BuildErrorResult(SendAccessConstants.EmailOtpValidatorResults.EmailInvalid); + } + + // get otp from request + var requestOtp = request.Get(SendAccessConstants.TokenRequest.Otp); + var uniqueIdentifierForTokenCache = string.Format(SendAccessConstants.OtpToken.TokenUniqueIdentifier, sendId, email); + if (string.IsNullOrEmpty(requestOtp)) + { + // Since the request doesn't have an OTP, generate one + var token = await otpTokenProvider.GenerateTokenAsync( + SendAccessConstants.OtpToken.TokenProviderName, + SendAccessConstants.OtpToken.Purpose, + uniqueIdentifierForTokenCache); + + // Verify that the OTP is generated + if (string.IsNullOrEmpty(token)) + { + return BuildErrorResult(SendAccessConstants.EmailOtpValidatorResults.OtpGenerationFailed); + } + + await mailService.SendSendEmailOtpEmailAsync( + email, + token, + string.Format(SendAccessConstants.OtpEmail.Subject, token)); + return BuildErrorResult(SendAccessConstants.EmailOtpValidatorResults.EmailOtpSent); + } + + // validate request otp + var otpResult = await otpTokenProvider.ValidateTokenAsync( + requestOtp, + SendAccessConstants.OtpToken.TokenProviderName, + SendAccessConstants.OtpToken.Purpose, + uniqueIdentifierForTokenCache); + + // If OTP is invalid return error result + if (!otpResult) + { + return BuildErrorResult(SendAccessConstants.EmailOtpValidatorResults.EmailOtpInvalid); + } + + return BuildSuccessResult(sendId, email!); + } + + private static GrantValidationResult BuildErrorResult(string error) + { + switch (error) + { + case SendAccessConstants.EmailOtpValidatorResults.EmailRequired: + case SendAccessConstants.EmailOtpValidatorResults.EmailOtpSent: + return new GrantValidationResult(TokenRequestErrors.InvalidRequest, + errorDescription: _sendEmailOtpValidatorErrorDescriptions[error], + new Dictionary + { + { SendAccessConstants.SendAccessError, error } + }); + case SendAccessConstants.EmailOtpValidatorResults.EmailOtpInvalid: + case SendAccessConstants.EmailOtpValidatorResults.EmailInvalid: + return new GrantValidationResult( + TokenRequestErrors.InvalidGrant, + errorDescription: _sendEmailOtpValidatorErrorDescriptions[error], + new Dictionary + { + { SendAccessConstants.SendAccessError, error } + }); + default: + return new GrantValidationResult( + TokenRequestErrors.InvalidRequest, + errorDescription: error); + } + } + + /// + /// Builds a successful validation result for the Send password send_access grant. + /// + /// Guid of the send being accessed. + /// successful grant validation result + private static GrantValidationResult BuildSuccessResult(Guid sendId, string email) + { + var claims = new List + { + new(Claims.SendAccessClaims.SendId, sendId.ToString()), + new(Claims.SendAccessClaims.Email, email), + new(Claims.Type, IdentityClientType.Send.ToString()) + }; + + return new GrantValidationResult( + subject: sendId.ToString(), + authenticationMethod: CustomGrantTypes.SendAccess, + claims: claims); + } +} diff --git a/src/Identity/IdentityServer/RequestValidators/SendAccess/SendPasswordRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/SendAccess/SendPasswordRequestValidator.cs index 3449b4cb56..4eade01a49 100644 --- a/src/Identity/IdentityServer/RequestValidators/SendAccess/SendPasswordRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/SendAccess/SendPasswordRequestValidator.cs @@ -8,7 +8,7 @@ using Duende.IdentityServer.Validation; namespace Bit.Identity.IdentityServer.RequestValidators.SendAccess; -public class SendPasswordRequestValidator(ISendPasswordHasher sendPasswordHasher) : ISendPasswordRequestValidator +public class SendPasswordRequestValidator(ISendPasswordHasher sendPasswordHasher) : ISendAuthenticationMethodValidator { private readonly ISendPasswordHasher _sendPasswordHasher = sendPasswordHasher; @@ -21,7 +21,7 @@ public class SendPasswordRequestValidator(ISendPasswordHasher sendPasswordHasher { SendAccessConstants.PasswordValidatorResults.RequestPasswordIsRequired, $"{SendAccessConstants.TokenRequest.ClientB64HashedPassword} is required." } }; - public GrantValidationResult ValidateSendPassword(ExtensionGrantValidationContext context, ResourcePassword resourcePassword, Guid sendId) + public Task ValidateRequestAsync(ExtensionGrantValidationContext context, ResourcePassword resourcePassword, Guid sendId) { var request = context.Request.Raw; var clientHashedPassword = request.Get(SendAccessConstants.TokenRequest.ClientB64HashedPassword); @@ -30,13 +30,13 @@ public class SendPasswordRequestValidator(ISendPasswordHasher sendPasswordHasher if (clientHashedPassword == null) { // Request is the wrong shape and doesn't contain a passwordHashB64 field. - return new GrantValidationResult( + return Task.FromResult(new GrantValidationResult( TokenRequestErrors.InvalidRequest, errorDescription: _sendPasswordValidatorErrorDescriptions[SendAccessConstants.PasswordValidatorResults.RequestPasswordIsRequired], new Dictionary { { SendAccessConstants.SendAccessError, SendAccessConstants.PasswordValidatorResults.RequestPasswordIsRequired } - }); + })); } // _sendPasswordHasher.PasswordHashMatches checks for an empty string so no need to do it before we make the call. @@ -46,16 +46,16 @@ public class SendPasswordRequestValidator(ISendPasswordHasher sendPasswordHasher if (!hashMatches) { // Request is the correct shape but the passwordHashB64 doesn't match, hash could be empty. - return new GrantValidationResult( + return Task.FromResult(new GrantValidationResult( TokenRequestErrors.InvalidGrant, errorDescription: _sendPasswordValidatorErrorDescriptions[SendAccessConstants.PasswordValidatorResults.RequestPasswordDoesNotMatch], new Dictionary { { SendAccessConstants.SendAccessError, SendAccessConstants.PasswordValidatorResults.RequestPasswordDoesNotMatch } - }); + })); } - return BuildSendPasswordSuccessResult(sendId); + return Task.FromResult(BuildSendPasswordSuccessResult(sendId)); } /// @@ -67,7 +67,7 @@ public class SendPasswordRequestValidator(ISendPasswordHasher sendPasswordHasher { var claims = new List { - new(Claims.SendId, sendId.ToString()), + new(Claims.SendAccessClaims.SendId, sendId.ToString()), new(Claims.Type, IdentityClientType.Send.ToString()) }; diff --git a/src/Identity/Utilities/ServiceCollectionExtensions.cs b/src/Identity/Utilities/ServiceCollectionExtensions.cs index d4f2ad8045..95c067d884 100644 --- a/src/Identity/Utilities/ServiceCollectionExtensions.cs +++ b/src/Identity/Utilities/ServiceCollectionExtensions.cs @@ -1,6 +1,7 @@ using Bit.Core.Auth.Repositories; using Bit.Core.IdentityServer; using Bit.Core.Settings; +using Bit.Core.Tools.Models.Data; using Bit.Core.Utilities; using Bit.Identity.IdentityServer; using Bit.Identity.IdentityServer.ClientProviders; @@ -26,7 +27,8 @@ public static class ServiceCollectionExtensions services.AddTransient(); services.AddTransient(); services.AddTransient(); - services.AddTransient(); + services.AddTransient, SendPasswordRequestValidator>(); + services.AddTransient, SendEmailOtpRequestValidator>(); var issuerUri = new Uri(globalSettings.BaseServiceUri.InternalIdentity); var identityServerBuilder = services diff --git a/test/Core.Test/Auth/UserFeatures/SendAccess/SendAccessClaimsPrincipalExtensionsTests.cs b/test/Core.Test/Auth/UserFeatures/SendAccess/SendAccessClaimsPrincipalExtensionsTests.cs index 27a0bc1bbc..bf5322d916 100644 --- a/test/Core.Test/Auth/UserFeatures/SendAccess/SendAccessClaimsPrincipalExtensionsTests.cs +++ b/test/Core.Test/Auth/UserFeatures/SendAccess/SendAccessClaimsPrincipalExtensionsTests.cs @@ -12,7 +12,7 @@ public class SendAccessClaimsPrincipalExtensionsTests { // Arrange var guid = Guid.NewGuid(); - var claims = new[] { new Claim(Claims.SendId, guid.ToString()) }; + var claims = new[] { new Claim(Claims.SendAccessClaims.SendId, guid.ToString()) }; var principal = new ClaimsPrincipal(new ClaimsIdentity(claims)); // Act @@ -30,19 +30,19 @@ public class SendAccessClaimsPrincipalExtensionsTests // Act & Assert var ex = Assert.Throws(() => principal.GetSendId()); - Assert.Equal("Send ID claim not found.", ex.Message); + Assert.Equal("send_id claim not found.", ex.Message); } [Fact] public void GetSendId_ThrowsInvalidOperationException_WhenClaimValueIsInvalid() { // Arrange - var claims = new[] { new Claim(Claims.SendId, "not-a-guid") }; + var claims = new[] { new Claim(Claims.SendAccessClaims.SendId, "not-a-guid") }; var principal = new ClaimsPrincipal(new ClaimsIdentity(claims)); // Act & Assert var ex = Assert.Throws(() => principal.GetSendId()); - Assert.Equal("Invalid Send ID claim value.", ex.Message); + Assert.Equal("Invalid send_id claim value.", ex.Message); } [Fact] diff --git a/test/Core.Test/Services/HandlebarsMailServiceTests.cs b/test/Core.Test/Services/HandlebarsMailServiceTests.cs index 849a5130a3..242bcc60f3 100644 --- a/test/Core.Test/Services/HandlebarsMailServiceTests.cs +++ b/test/Core.Test/Services/HandlebarsMailServiceTests.cs @@ -247,11 +247,18 @@ public class HandlebarsMailServiceTests } } - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. [Fact] - public void ServiceExists() + public async Task SendSendEmailOtpEmailAsync_SendsEmail() { - Assert.NotNull(_sut); + // Arrange + var email = "test@example.com"; + var token = "aToken"; + var subject = string.Format("Your Bitwarden Send verification code is {0}", token); + + // Act + await _sut.SendSendEmailOtpEmailAsync(email, token, subject); + + // Assert + await _mailDeliveryService.Received(1).SendEmailAsync(Arg.Any()); } } diff --git a/test/Identity.IntegrationTest/RequestValidation/SendAccessGrantValidatorIntegrationTests.cs b/test/Identity.IntegrationTest/RequestValidation/SendAccessGrantValidatorIntegrationTests.cs index 4b8c267861..3b0cf2c282 100644 --- a/test/Identity.IntegrationTest/RequestValidation/SendAccessGrantValidatorIntegrationTests.cs +++ b/test/Identity.IntegrationTest/RequestValidation/SendAccessGrantValidatorIntegrationTests.cs @@ -213,8 +213,8 @@ public class SendAccessGrantValidatorIntegrationTests(IdentityApplicationFactory services.AddSingleton(sendAuthQuery); // Mock password validator to return success - var passwordValidator = Substitute.For(); - passwordValidator.ValidateSendPassword( + var passwordValidator = Substitute.For>(); + passwordValidator.ValidateRequestAsync( Arg.Any(), Arg.Any(), Arg.Any()) diff --git a/test/Identity.IntegrationTest/RequestValidation/SendEmailOtpReqestValidatorIntegrationTests.cs b/test/Identity.IntegrationTest/RequestValidation/SendEmailOtpReqestValidatorIntegrationTests.cs new file mode 100644 index 0000000000..9d9bc03ef5 --- /dev/null +++ b/test/Identity.IntegrationTest/RequestValidation/SendEmailOtpReqestValidatorIntegrationTests.cs @@ -0,0 +1,256 @@ +using Bit.Core.Auth.Identity.TokenProviders; +using Bit.Core.Enums; +using Bit.Core.IdentityServer; +using Bit.Core.Services; +using Bit.Core.Tools.Models.Data; +using Bit.Core.Tools.SendFeatures.Queries.Interfaces; +using Bit.Core.Utilities; +using Bit.Identity.IdentityServer.Enums; +using Bit.Identity.IdentityServer.RequestValidators.SendAccess; +using Bit.IntegrationTestCommon.Factories; +using Duende.IdentityModel; +using NSubstitute; +using Xunit; + +namespace Bit.Identity.IntegrationTest.RequestValidation; + +public class SendEmailOtpRequestValidatorIntegrationTests : IClassFixture +{ + private readonly IdentityApplicationFactory _factory; + + public SendEmailOtpRequestValidatorIntegrationTests(IdentityApplicationFactory factory) + { + _factory = factory; + } + + [Fact] + public async Task SendAccess_EmailOtpProtectedSend_MissingEmail_ReturnsInvalidRequest() + { + // Arrange + var sendId = Guid.NewGuid(); + var client = _factory.WithWebHostBuilder(builder => + { + builder.ConfigureServices(services => + { + var featureService = Substitute.For(); + featureService.IsEnabled(Arg.Any()).Returns(true); + services.AddSingleton(featureService); + + var sendAuthQuery = Substitute.For(); + sendAuthQuery.GetAuthenticationMethod(sendId) + .Returns(new EmailOtp(["test@example.com"])); + services.AddSingleton(sendAuthQuery); + }); + }).CreateClient(); + + var requestBody = CreateTokenRequestBody(sendId); // No email + + // Act + var response = await client.PostAsync("/connect/token", requestBody); + + // Assert + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains(OidcConstants.TokenErrors.InvalidRequest, content); + Assert.Contains("email is required", content); + } + + [Fact] + public async Task SendAccess_EmailOtpProtectedSend_EmailWithoutOtp_SendsOtpEmail() + { + // Arrange + var sendId = Guid.NewGuid(); + var email = "test@example.com"; + var generatedToken = "123456"; + + var client = _factory.WithWebHostBuilder(builder => + { + builder.ConfigureServices(services => + { + var featureService = Substitute.For(); + featureService.IsEnabled(Arg.Any()).Returns(true); + services.AddSingleton(featureService); + + var sendAuthQuery = Substitute.For(); + sendAuthQuery.GetAuthenticationMethod(sendId) + .Returns(new EmailOtp([email])); + services.AddSingleton(sendAuthQuery); + + // Mock OTP token provider + var otpProvider = Substitute.For>(); + otpProvider.GenerateTokenAsync(Arg.Any(), Arg.Any(), Arg.Any()) + .Returns(generatedToken); + services.AddSingleton(otpProvider); + + // Mock mail service + var mailService = Substitute.For(); + services.AddSingleton(mailService); + }); + }).CreateClient(); + + var requestBody = CreateTokenRequestBody(sendId, sendEmail: email); // Email but no OTP + + // Act + var response = await client.PostAsync("/connect/token", requestBody); + + // Assert + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains(OidcConstants.TokenErrors.InvalidRequest, content); + Assert.Contains("email otp sent", content); + } + + [Fact] + public async Task SendAccess_EmailOtpProtectedSend_ValidOtp_ReturnsAccessToken() + { + // Arrange + var sendId = Guid.NewGuid(); + var email = "test@example.com"; + var otp = "123456"; + + var client = _factory.WithWebHostBuilder(builder => + { + builder.ConfigureServices(services => + { + var featureService = Substitute.For(); + featureService.IsEnabled(Arg.Any()).Returns(true); + services.AddSingleton(featureService); + + var sendAuthQuery = Substitute.For(); + sendAuthQuery.GetAuthenticationMethod(sendId) + .Returns(new EmailOtp(new[] { email })); + services.AddSingleton(sendAuthQuery); + + // Mock OTP token provider to validate successfully + var otpProvider = Substitute.For>(); + otpProvider.ValidateTokenAsync(otp, Arg.Any(), Arg.Any(), Arg.Any()) + .Returns(true); + services.AddSingleton(otpProvider); + + var mailService = Substitute.For(); + services.AddSingleton(mailService); + }); + }).CreateClient(); + + var requestBody = CreateTokenRequestBody(sendId, sendEmail: email, emailOtp: otp); + + // Act + var response = await client.PostAsync("/connect/token", requestBody); + + // Assert + Assert.True(response.IsSuccessStatusCode); + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains(OidcConstants.TokenResponse.AccessToken, content); + Assert.Contains(OidcConstants.TokenResponse.BearerTokenType, content); + } + + [Fact] + public async Task SendAccess_EmailOtpProtectedSend_InvalidOtp_ReturnsInvalidGrant() + { + // Arrange + var sendId = Guid.NewGuid(); + var email = "test@example.com"; + var invalidOtp = "wrong123"; + + var client = _factory.WithWebHostBuilder(builder => + { + builder.ConfigureServices(services => + { + var featureService = Substitute.For(); + featureService.IsEnabled(Arg.Any()).Returns(true); + services.AddSingleton(featureService); + + var sendAuthQuery = Substitute.For(); + sendAuthQuery.GetAuthenticationMethod(sendId) + .Returns(new EmailOtp(new[] { email })); + services.AddSingleton(sendAuthQuery); + + // Mock OTP token provider to validate as false + var otpProvider = Substitute.For>(); + otpProvider.ValidateTokenAsync(invalidOtp, Arg.Any(), Arg.Any(), Arg.Any()) + .Returns(false); + services.AddSingleton(otpProvider); + + var mailService = Substitute.For(); + services.AddSingleton(mailService); + }); + }).CreateClient(); + + var requestBody = CreateTokenRequestBody(sendId, sendEmail: email, emailOtp: invalidOtp); + + // Act + var response = await client.PostAsync("/connect/token", requestBody); + + // Assert + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains(OidcConstants.TokenErrors.InvalidGrant, content); + Assert.Contains("email otp is invalid", content); + } + + [Fact] + public async Task SendAccess_EmailOtpProtectedSend_OtpGenerationFails_ReturnsInvalidRequest() + { + // Arrange + var sendId = Guid.NewGuid(); + var email = "test@example.com"; + + var client = _factory.WithWebHostBuilder(builder => + { + builder.ConfigureServices(services => + { + var featureService = Substitute.For(); + featureService.IsEnabled(Arg.Any()).Returns(true); + services.AddSingleton(featureService); + + var sendAuthQuery = Substitute.For(); + sendAuthQuery.GetAuthenticationMethod(sendId) + .Returns(new EmailOtp(new[] { email })); + services.AddSingleton(sendAuthQuery); + + // Mock OTP token provider to fail generation + var otpProvider = Substitute.For>(); + otpProvider.GenerateTokenAsync(Arg.Any(), Arg.Any(), Arg.Any()) + .Returns((string)null); + services.AddSingleton(otpProvider); + + var mailService = Substitute.For(); + services.AddSingleton(mailService); + }); + }).CreateClient(); + + var requestBody = CreateTokenRequestBody(sendId, sendEmail: email); // Email but no OTP + + // Act + var response = await client.PostAsync("/connect/token", requestBody); + + // Assert + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains(OidcConstants.TokenErrors.InvalidRequest, content); + } + + private static FormUrlEncodedContent CreateTokenRequestBody(Guid sendId, + string sendEmail = null, string emailOtp = null) + { + var sendIdBase64 = CoreHelpers.Base64UrlEncode(sendId.ToByteArray()); + var parameters = new List> + { + new(OidcConstants.TokenRequest.GrantType, CustomGrantTypes.SendAccess), + new(OidcConstants.TokenRequest.ClientId, BitwardenClient.Send ), + new(OidcConstants.TokenRequest.Scope, ApiScopes.ApiSendAccess), + new("deviceType", ((int)DeviceType.FirefoxBrowser).ToString()), + new(SendAccessConstants.TokenRequest.SendId, sendIdBase64) + }; + + if (!string.IsNullOrEmpty(sendEmail)) + { + parameters.Add(new KeyValuePair( + SendAccessConstants.TokenRequest.Email, sendEmail)); + } + + if (!string.IsNullOrEmpty(emailOtp)) + { + parameters.Add(new KeyValuePair( + SendAccessConstants.TokenRequest.Otp, emailOtp)); + } + + return new FormUrlEncodedContent(parameters); + } +} diff --git a/test/Identity.Test/IdentityServer/SendAccessGrantValidatorTests.cs b/test/Identity.Test/IdentityServer/SendAccess/SendAccessGrantValidatorTests.cs similarity index 90% rename from test/Identity.Test/IdentityServer/SendAccessGrantValidatorTests.cs rename to test/Identity.Test/IdentityServer/SendAccess/SendAccessGrantValidatorTests.cs index c3d422c51a..e651709c47 100644 --- a/test/Identity.Test/IdentityServer/SendAccessGrantValidatorTests.cs +++ b/test/Identity.Test/IdentityServer/SendAccess/SendAccessGrantValidatorTests.cs @@ -17,7 +17,7 @@ using Duende.IdentityServer.Validation; using NSubstitute; using Xunit; -namespace Bit.Identity.Test.IdentityServer; +namespace Bit.Identity.Test.IdentityServer.SendAccess; [SutProviderCustomize] public class SendAccessGrantValidatorTests @@ -167,7 +167,7 @@ public class SendAccessGrantValidatorTests // get the claims from the subject var claims = subject.Claims.ToList(); Assert.NotEmpty(claims); - Assert.Contains(claims, c => c.Type == Claims.SendId && c.Value == sendId.ToString()); + Assert.Contains(claims, c => c.Type == Claims.SendAccessClaims.SendId && c.Value == sendId.ToString()); Assert.Contains(claims, c => c.Type == Claims.Type && c.Value == IdentityClientType.Send.ToString()); } @@ -189,8 +189,8 @@ public class SendAccessGrantValidatorTests .GetAuthenticationMethod(sendId) .Returns(resourcePassword); - sutProvider.GetDependency() - .ValidateSendPassword(context, resourcePassword, sendId) + sutProvider.GetDependency>() + .ValidateRequestAsync(context, resourcePassword, sendId) .Returns(expectedResult); // Act @@ -198,15 +198,16 @@ public class SendAccessGrantValidatorTests // Assert Assert.Equal(expectedResult, context.Result); - sutProvider.GetDependency() + await sutProvider.GetDependency>() .Received(1) - .ValidateSendPassword(context, resourcePassword, sendId); + .ValidateRequestAsync(context, resourcePassword, sendId); } [Theory, BitAutoData] - public async Task ValidateAsync_EmailOtpMethod_NotImplemented_ThrowsError( + public async Task ValidateAsync_EmailOtpMethod_CallsEmailOtp( [AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, SutProvider sutProvider, + GrantValidationResult expectedResult, Guid sendId, EmailOtp emailOtp) { @@ -216,15 +217,22 @@ public class SendAccessGrantValidatorTests sendId, tokenRequest); - sutProvider.GetDependency() .GetAuthenticationMethod(sendId) .Returns(emailOtp); + sutProvider.GetDependency>() + .ValidateRequestAsync(context, emailOtp, sendId) + .Returns(expectedResult); + // Act + await sutProvider.Sut.ValidateAsync(context); + // Assert - // Currently the EmailOtp case doesn't set a result, so it should be null - await Assert.ThrowsAsync(async () => await sutProvider.Sut.ValidateAsync(context)); + Assert.Equal(expectedResult, context.Result); + await sutProvider.GetDependency>() + .Received(1) + .ValidateRequestAsync(context, emailOtp, sendId); } [Theory, BitAutoData] @@ -256,7 +264,7 @@ public class SendAccessGrantValidatorTests public void GrantType_ReturnsCorrectType() { // Arrange & Act - var validator = new SendAccessGrantValidator(null!, null!, null!); + var validator = new SendAccessGrantValidator(null!, null!, null!, null!); // Assert Assert.Equal(CustomGrantTypes.SendAccess, ((IExtensionGrantValidator)validator).GrantType); diff --git a/test/Identity.Test/IdentityServer/SendAccess/SendEmailOtpRequestValidatorTests.cs b/test/Identity.Test/IdentityServer/SendAccess/SendEmailOtpRequestValidatorTests.cs new file mode 100644 index 0000000000..2fd21fd4cf --- /dev/null +++ b/test/Identity.Test/IdentityServer/SendAccess/SendEmailOtpRequestValidatorTests.cs @@ -0,0 +1,310 @@ +using System.Collections.Specialized; +using Bit.Core.Auth.Identity.TokenProviders; +using Bit.Core.Enums; +using Bit.Core.Identity; +using Bit.Core.IdentityServer; +using Bit.Core.Services; +using Bit.Core.Tools.Models.Data; +using Bit.Core.Utilities; +using Bit.Identity.IdentityServer.Enums; +using Bit.Identity.IdentityServer.RequestValidators.SendAccess; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Duende.IdentityModel; +using Duende.IdentityServer.Validation; +using NSubstitute; +using Xunit; + +namespace Bit.Identity.Test.IdentityServer.SendAccess; + +[SutProviderCustomize] +public class SendEmailOtpRequestValidatorTests +{ + [Theory, BitAutoData] + public async Task ValidateRequestAsync_MissingEmail_ReturnsInvalidRequest( + SutProvider sutProvider, + [AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + EmailOtp emailOtp, + Guid sendId) + { + // Arrange + tokenRequest.Raw = CreateValidatedTokenRequest(sendId); + var context = new ExtensionGrantValidationContext + { + Request = tokenRequest + }; + + // Act + var result = await sutProvider.Sut.ValidateRequestAsync(context, emailOtp, sendId); + + // Assert + Assert.True(result.IsError); + Assert.Equal(OidcConstants.TokenErrors.InvalidRequest, result.Error); + Assert.Equal("email is required.", result.ErrorDescription); + + // Verify no OTP generation or email sending occurred + await sutProvider.GetDependency>() + .DidNotReceive() + .GenerateTokenAsync(Arg.Any(), Arg.Any(), Arg.Any()); + + await sutProvider.GetDependency() + .DidNotReceive() + .SendSendEmailOtpEmailAsync(Arg.Any(), Arg.Any(), Arg.Any()); + } + + [Theory, BitAutoData] + public async Task ValidateRequestAsync_EmailNotInList_ReturnsInvalidRequest( + SutProvider sutProvider, + [AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + EmailOtp emailOtp, + string email, + Guid sendId) + { + // Arrange + tokenRequest.Raw = CreateValidatedTokenRequest(sendId, email); + var emailOTP = new EmailOtp(["user@test.dev"]); + var context = new ExtensionGrantValidationContext + { + Request = tokenRequest + }; + + // Act + var result = await sutProvider.Sut.ValidateRequestAsync(context, emailOtp, sendId); + + // Assert + Assert.True(result.IsError); + Assert.Equal(OidcConstants.TokenErrors.InvalidGrant, result.Error); + Assert.Equal("email is invalid.", result.ErrorDescription); + + // Verify no OTP generation or email sending occurred + await sutProvider.GetDependency>() + .DidNotReceive() + .GenerateTokenAsync(Arg.Any(), Arg.Any(), Arg.Any()); + + await sutProvider.GetDependency() + .DidNotReceive() + .SendSendEmailOtpEmailAsync(Arg.Any(), Arg.Any(), Arg.Any()); + } + + [Theory, BitAutoData] + public async Task ValidateRequestAsync_EmailWithoutOtp_GeneratesAndSendsOtp( + SutProvider sutProvider, + [AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + EmailOtp emailOtp, + Guid sendId, + string email, + string generatedToken) + { + // Arrange + tokenRequest.Raw = CreateValidatedTokenRequest(sendId, email); + var context = new ExtensionGrantValidationContext + { + Request = tokenRequest + }; + + var expectedUniqueId = string.Format(SendAccessConstants.OtpToken.TokenUniqueIdentifier, sendId, email); + + sutProvider.GetDependency>() + .GenerateTokenAsync( + SendAccessConstants.OtpToken.TokenProviderName, + SendAccessConstants.OtpToken.Purpose, + expectedUniqueId) + .Returns(generatedToken); + + emailOtp = emailOtp with { Emails = [email] }; + + // Act + var result = await sutProvider.Sut.ValidateRequestAsync(context, emailOtp, sendId); + + // Assert + Assert.True(result.IsError); + Assert.Equal(OidcConstants.TokenErrors.InvalidRequest, result.Error); + Assert.Equal("email otp sent.", result.ErrorDescription); + + // Verify OTP generation + await sutProvider.GetDependency>() + .Received(1) + .GenerateTokenAsync( + SendAccessConstants.OtpToken.TokenProviderName, + SendAccessConstants.OtpToken.Purpose, + expectedUniqueId); + + // Verify email sending + await sutProvider.GetDependency() + .Received(1) + .SendSendEmailOtpEmailAsync(email, generatedToken, Arg.Any()); + } + + [Theory, BitAutoData] + public async Task ValidateRequestAsync_OtpGenerationFails_ReturnsGenerationFailedError( + SutProvider sutProvider, + [AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + EmailOtp emailOtp, + Guid sendId, + string email) + { + // Arrange + tokenRequest.Raw = CreateValidatedTokenRequest(sendId, email); + var context = new ExtensionGrantValidationContext + { + Request = tokenRequest + }; + + emailOtp = emailOtp with { Emails = [email] }; + + sutProvider.GetDependency>() + .GenerateTokenAsync(Arg.Any(), Arg.Any(), Arg.Any()) + .Returns((string)null); // Generation fails + + // Act + var result = await sutProvider.Sut.ValidateRequestAsync(context, emailOtp, sendId); + + // Assert + Assert.True(result.IsError); + Assert.Equal(OidcConstants.TokenErrors.InvalidRequest, result.Error); + + // Verify no email was sent + await sutProvider.GetDependency() + .DidNotReceive() + .SendSendEmailOtpEmailAsync(Arg.Any(), Arg.Any(), Arg.Any()); + } + + [Theory, BitAutoData] + public async Task ValidateRequestAsync_ValidOtp_ReturnsSuccess( + SutProvider sutProvider, + [AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + EmailOtp emailOtp, + Guid sendId, + string email, + string otp) + { + // Arrange + tokenRequest.Raw = CreateValidatedTokenRequest(sendId, email, otp); + var context = new ExtensionGrantValidationContext + { + Request = tokenRequest + }; + + emailOtp = emailOtp with { Emails = [email] }; + + var expectedUniqueId = string.Format(SendAccessConstants.OtpToken.TokenUniqueIdentifier, sendId, email); + + sutProvider.GetDependency>() + .ValidateTokenAsync( + otp, + SendAccessConstants.OtpToken.TokenProviderName, + SendAccessConstants.OtpToken.Purpose, + expectedUniqueId) + .Returns(true); + + // Act + var result = await sutProvider.Sut.ValidateRequestAsync(context, emailOtp, sendId); + + // Assert + Assert.False(result.IsError); + var sub = result.Subject; + Assert.Equal(sendId.ToString(), sub.Claims.First(c => c.Type == Claims.SendAccessClaims.SendId).Value); + + // Verify claims + Assert.Contains(sub.Claims, c => c.Type == Claims.SendAccessClaims.SendId && c.Value == sendId.ToString()); + Assert.Contains(sub.Claims, c => c.Type == Claims.SendAccessClaims.Email && c.Value == email); + Assert.Contains(sub.Claims, c => c.Type == Claims.Type && c.Value == IdentityClientType.Send.ToString()); + + // Verify OTP validation was called + await sutProvider.GetDependency>() + .Received(1) + .ValidateTokenAsync(otp, SendAccessConstants.OtpToken.TokenProviderName, SendAccessConstants.OtpToken.Purpose, expectedUniqueId); + + // Verify no email was sent (validation only) + await sutProvider.GetDependency() + .DidNotReceive() + .SendSendEmailOtpEmailAsync(Arg.Any(), Arg.Any(), Arg.Any()); + } + + [Theory, BitAutoData] + public async Task ValidateRequestAsync_InvalidOtp_ReturnsInvalidGrant( + SutProvider sutProvider, + [AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + EmailOtp emailOtp, + Guid sendId, + string email, + string invalidOtp) + { + // Arrange + tokenRequest.Raw = CreateValidatedTokenRequest(sendId, email, invalidOtp); + var context = new ExtensionGrantValidationContext + { + Request = tokenRequest + }; + + emailOtp = emailOtp with { Emails = [email] }; + + var expectedUniqueId = string.Format(SendAccessConstants.OtpToken.TokenUniqueIdentifier, sendId, email); + + sutProvider.GetDependency>() + .ValidateTokenAsync(invalidOtp, + SendAccessConstants.OtpToken.TokenProviderName, + SendAccessConstants.OtpToken.Purpose, + expectedUniqueId) + .Returns(false); + + // Act + var result = await sutProvider.Sut.ValidateRequestAsync(context, emailOtp, sendId); + + // Assert + Assert.True(result.IsError); + Assert.Equal(OidcConstants.TokenErrors.InvalidGrant, result.Error); + Assert.Equal("email otp is invalid.", result.ErrorDescription); + + // Verify OTP validation was attempted + await sutProvider.GetDependency>() + .Received(1) + .ValidateTokenAsync(invalidOtp, + SendAccessConstants.OtpToken.TokenProviderName, + SendAccessConstants.OtpToken.Purpose, + expectedUniqueId); + } + + [Fact] + public void Constructor_WithValidParameters_CreatesInstance() + { + // Arrange + var otpTokenProvider = Substitute.For>(); + var mailService = Substitute.For(); + + // Act + var validator = new SendEmailOtpRequestValidator(otpTokenProvider, mailService); + + // Assert + Assert.NotNull(validator); + } + + private static NameValueCollection CreateValidatedTokenRequest( + Guid sendId, + string sendEmail = null, + string otpCode = null) + { + var sendIdBase64 = CoreHelpers.Base64UrlEncode(sendId.ToByteArray()); + + var rawRequestParameters = new NameValueCollection + { + { OidcConstants.TokenRequest.GrantType, CustomGrantTypes.SendAccess }, + { OidcConstants.TokenRequest.ClientId, BitwardenClient.Send }, + { OidcConstants.TokenRequest.Scope, ApiScopes.ApiSendAccess }, + { "device_type", ((int)DeviceType.FirefoxBrowser).ToString() }, + { SendAccessConstants.TokenRequest.SendId, sendIdBase64 } + }; + + if (sendEmail != null) + { + rawRequestParameters.Add(SendAccessConstants.TokenRequest.Email, sendEmail); + } + + if (otpCode != null && sendEmail != null) + { + rawRequestParameters.Add(SendAccessConstants.TokenRequest.Otp, otpCode); + } + + return rawRequestParameters; + } +} diff --git a/test/Identity.Test/IdentityServer/SendAccess/SendPasswordRequestValidatorTests.cs b/test/Identity.Test/IdentityServer/SendAccess/SendPasswordRequestValidatorTests.cs new file mode 100644 index 0000000000..e2b8b49830 --- /dev/null +++ b/test/Identity.Test/IdentityServer/SendAccess/SendPasswordRequestValidatorTests.cs @@ -0,0 +1,297 @@ +using System.Collections.Specialized; +using Bit.Core.Auth.UserFeatures.SendAccess; +using Bit.Core.Enums; +using Bit.Core.Identity; +using Bit.Core.IdentityServer; +using Bit.Core.KeyManagement.Sends; +using Bit.Core.Tools.Models.Data; +using Bit.Core.Utilities; +using Bit.Identity.IdentityServer.Enums; +using Bit.Identity.IdentityServer.RequestValidators.SendAccess; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Duende.IdentityModel; +using Duende.IdentityServer.Validation; +using NSubstitute; +using Xunit; + +namespace Bit.Identity.Test.IdentityServer.SendAccess; + +[SutProviderCustomize] +public class SendPasswordRequestValidatorTests +{ + [Theory, BitAutoData] + public async Task ValidateSendPassword_MissingPasswordHash_ReturnsInvalidRequest( + SutProvider sutProvider, + [AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + ResourcePassword resourcePassword, + Guid sendId) + { + // Arrange + tokenRequest.Raw = CreateValidatedTokenRequest(sendId); + + var context = new ExtensionGrantValidationContext + { + Request = tokenRequest + }; + + // Act + var result = await sutProvider.Sut.ValidateRequestAsync(context, resourcePassword, sendId); + + // Assert + Assert.True(result.IsError); + Assert.Equal(OidcConstants.TokenErrors.InvalidRequest, result.Error); + Assert.Equal($"{SendAccessConstants.TokenRequest.ClientB64HashedPassword} is required.", result.ErrorDescription); + + // Verify password hasher was not called + sutProvider.GetDependency() + .DidNotReceive() + .PasswordHashMatches(Arg.Any(), Arg.Any()); + } + + [Theory, BitAutoData] + public async Task ValidateSendPassword_PasswordHashMismatch_ReturnsInvalidGrant( + SutProvider sutProvider, + [AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + ResourcePassword resourcePassword, + Guid sendId, + string clientPasswordHash) + { + // Arrange + tokenRequest.Raw = CreateValidatedTokenRequest(sendId, clientPasswordHash); + + var context = new ExtensionGrantValidationContext + { + Request = tokenRequest + }; + + sutProvider.GetDependency() + .PasswordHashMatches(resourcePassword.Hash, clientPasswordHash) + .Returns(false); + + // Act + var result = await sutProvider.Sut.ValidateRequestAsync(context, resourcePassword, sendId); + + // Assert + Assert.True(result.IsError); + Assert.Equal(OidcConstants.TokenErrors.InvalidGrant, result.Error); + Assert.Equal($"{SendAccessConstants.TokenRequest.ClientB64HashedPassword} is invalid.", result.ErrorDescription); + + // Verify password hasher was called with correct parameters + sutProvider.GetDependency() + .Received(1) + .PasswordHashMatches(resourcePassword.Hash, clientPasswordHash); + } + + [Theory, BitAutoData] + public async Task ValidateSendPassword_PasswordHashMatches_ReturnsSuccess( + SutProvider sutProvider, + [AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + ResourcePassword resourcePassword, + Guid sendId, + string clientPasswordHash) + { + // Arrange + tokenRequest.Raw = CreateValidatedTokenRequest(sendId, clientPasswordHash); + + var context = new ExtensionGrantValidationContext + { + Request = tokenRequest + }; + + sutProvider.GetDependency() + .PasswordHashMatches(resourcePassword.Hash, clientPasswordHash) + .Returns(true); + + // Act + var result = await sutProvider.Sut.ValidateRequestAsync(context, resourcePassword, sendId); + + // Assert + Assert.False(result.IsError); + + var sub = result.Subject; + Assert.Equal(sendId, sub.GetSendId()); + + // Verify claims + Assert.Contains(sub.Claims, c => c.Type == Claims.SendAccessClaims.SendId && c.Value == sendId.ToString()); + Assert.Contains(sub.Claims, c => c.Type == Claims.Type && c.Value == IdentityClientType.Send.ToString()); + + // Verify password hasher was called + sutProvider.GetDependency() + .Received(1) + .PasswordHashMatches(resourcePassword.Hash, clientPasswordHash); + } + + [Theory, BitAutoData] + public async Task ValidateSendPassword_EmptyPasswordHash_CallsPasswordHasher( + SutProvider sutProvider, + [AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + ResourcePassword resourcePassword, + Guid sendId) + { + // Arrange + tokenRequest.Raw = CreateValidatedTokenRequest(sendId, string.Empty); + + var context = new ExtensionGrantValidationContext + { + Request = tokenRequest + }; + + sutProvider.GetDependency() + .PasswordHashMatches(resourcePassword.Hash, string.Empty) + .Returns(false); + + // Act + var result = await sutProvider.Sut.ValidateRequestAsync(context, resourcePassword, sendId); + + // Assert + Assert.True(result.IsError); + Assert.Equal(OidcConstants.TokenErrors.InvalidGrant, result.Error); + + // Verify password hasher was called with empty string + sutProvider.GetDependency() + .Received(1) + .PasswordHashMatches(resourcePassword.Hash, string.Empty); + } + + [Theory, BitAutoData] + public async Task ValidateSendPassword_WhitespacePasswordHash_CallsPasswordHasher( + SutProvider sutProvider, + [AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + ResourcePassword resourcePassword, + Guid sendId) + { + // Arrange + var whitespacePassword = " "; + tokenRequest.Raw = CreateValidatedTokenRequest(sendId, whitespacePassword); + + var context = new ExtensionGrantValidationContext + { + Request = tokenRequest + }; + + sutProvider.GetDependency() + .PasswordHashMatches(resourcePassword.Hash, whitespacePassword) + .Returns(false); + + // Act + var result = await sutProvider.Sut.ValidateRequestAsync(context, resourcePassword, sendId); + + // Assert + Assert.True(result.IsError); + + // Verify password hasher was called with whitespace string + sutProvider.GetDependency() + .Received(1) + .PasswordHashMatches(resourcePassword.Hash, whitespacePassword); + } + + [Theory, BitAutoData] + public async Task ValidateSendPassword_MultiplePasswordHashParameters_ReturnsInvalidGrant( + SutProvider sutProvider, + [AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + ResourcePassword resourcePassword, + Guid sendId) + { + // Arrange + var firstPassword = "first-password"; + var secondPassword = "second-password"; + tokenRequest.Raw = CreateValidatedTokenRequest(sendId, firstPassword, secondPassword); + + var context = new ExtensionGrantValidationContext + { + Request = tokenRequest + }; + + sutProvider.GetDependency() + .PasswordHashMatches(resourcePassword.Hash, firstPassword) + .Returns(true); + + // Act + var result = await sutProvider.Sut.ValidateRequestAsync(context, resourcePassword, sendId); + + // Assert + Assert.True(result.IsError); + Assert.Equal(OidcConstants.TokenErrors.InvalidGrant, result.Error); + + // Verify password hasher was called with first value + sutProvider.GetDependency() + .Received(1) + .PasswordHashMatches(resourcePassword.Hash, $"{firstPassword},{secondPassword}"); + } + + [Theory, BitAutoData] + public async Task ValidateSendPassword_SuccessResult_ContainsCorrectClaims( + SutProvider sutProvider, + [AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + ResourcePassword resourcePassword, + Guid sendId, + string clientPasswordHash) + { + // Arrange + tokenRequest.Raw = CreateValidatedTokenRequest(sendId, clientPasswordHash); + + var context = new ExtensionGrantValidationContext + { + Request = tokenRequest + }; + + sutProvider.GetDependency() + .PasswordHashMatches(Arg.Any(), Arg.Any()) + .Returns(true); + + // Act + var result = await sutProvider.Sut.ValidateRequestAsync(context, resourcePassword, sendId); + + // Assert + Assert.False(result.IsError); + var sub = result.Subject; + + var sendIdClaim = sub.Claims.FirstOrDefault(c => c.Type == Claims.SendAccessClaims.SendId); + Assert.NotNull(sendIdClaim); + Assert.Equal(sendId.ToString(), sendIdClaim.Value); + + var typeClaim = sub.Claims.FirstOrDefault(c => c.Type == Claims.Type); + Assert.NotNull(typeClaim); + Assert.Equal(IdentityClientType.Send.ToString(), typeClaim.Value); + } + + [Fact] + public void Constructor_WithValidParameters_CreatesInstance() + { + // Arrange + var sendPasswordHasher = Substitute.For(); + + // Act + var validator = new SendPasswordRequestValidator(sendPasswordHasher); + + // Assert + Assert.NotNull(validator); + } + + private static NameValueCollection CreateValidatedTokenRequest( + Guid sendId, + params string[] passwordHash) + { + var sendIdBase64 = CoreHelpers.Base64UrlEncode(sendId.ToByteArray()); + + var rawRequestParameters = new NameValueCollection + { + { OidcConstants.TokenRequest.GrantType, CustomGrantTypes.SendAccess }, + { OidcConstants.TokenRequest.ClientId, BitwardenClient.Send }, + { OidcConstants.TokenRequest.Scope, ApiScopes.ApiSendAccess }, + { "device_type", ((int)DeviceType.FirefoxBrowser).ToString() }, + { SendAccessConstants.TokenRequest.SendId, sendIdBase64 } + }; + + if (passwordHash != null && passwordHash.Length > 0) + { + foreach (var hash in passwordHash) + { + rawRequestParameters.Add(SendAccessConstants.TokenRequest.ClientB64HashedPassword, hash); + } + } + + return rawRequestParameters; + } +} diff --git a/test/Identity.Test/IdentityServer/SendPasswordRequestValidatorTests.cs b/test/Identity.Test/IdentityServer/SendPasswordRequestValidatorTests.cs index a776a70178..ccee33d8c7 100644 --- a/test/Identity.Test/IdentityServer/SendPasswordRequestValidatorTests.cs +++ b/test/Identity.Test/IdentityServer/SendPasswordRequestValidatorTests.cs @@ -21,7 +21,7 @@ namespace Bit.Identity.Test.IdentityServer; public class SendPasswordRequestValidatorTests { [Theory, BitAutoData] - public void ValidateSendPassword_MissingPasswordHash_ReturnsInvalidRequest( + public async Task ValidateSendPassword_MissingPasswordHash_ReturnsInvalidRequest( SutProvider sutProvider, [AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, ResourcePassword resourcePassword, @@ -36,7 +36,7 @@ public class SendPasswordRequestValidatorTests }; // Act - var result = sutProvider.Sut.ValidateSendPassword(context, resourcePassword, sendId); + var result = await sutProvider.Sut.ValidateRequestAsync(context, resourcePassword, sendId); // Assert Assert.True(result.IsError); @@ -50,7 +50,7 @@ public class SendPasswordRequestValidatorTests } [Theory, BitAutoData] - public void ValidateSendPassword_PasswordHashMismatch_ReturnsInvalidGrant( + public async Task ValidateSendPassword_PasswordHashMismatch_ReturnsInvalidGrant( SutProvider sutProvider, [AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, ResourcePassword resourcePassword, @@ -70,7 +70,7 @@ public class SendPasswordRequestValidatorTests .Returns(false); // Act - var result = sutProvider.Sut.ValidateSendPassword(context, resourcePassword, sendId); + var result = await sutProvider.Sut.ValidateRequestAsync(context, resourcePassword, sendId); // Assert Assert.True(result.IsError); @@ -84,7 +84,7 @@ public class SendPasswordRequestValidatorTests } [Theory, BitAutoData] - public void ValidateSendPassword_PasswordHashMatches_ReturnsSuccess( + public async Task ValidateSendPassword_PasswordHashMatches_ReturnsSuccess( SutProvider sutProvider, [AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, ResourcePassword resourcePassword, @@ -104,7 +104,7 @@ public class SendPasswordRequestValidatorTests .Returns(true); // Act - var result = sutProvider.Sut.ValidateSendPassword(context, resourcePassword, sendId); + var result = await sutProvider.Sut.ValidateRequestAsync(context, resourcePassword, sendId); // Assert Assert.False(result.IsError); @@ -113,7 +113,7 @@ public class SendPasswordRequestValidatorTests Assert.Equal(sendId, sub.GetSendId()); // Verify claims - Assert.Contains(sub.Claims, c => c.Type == Claims.SendId && c.Value == sendId.ToString()); + Assert.Contains(sub.Claims, c => c.Type == Claims.SendAccessClaims.SendId && c.Value == sendId.ToString()); Assert.Contains(sub.Claims, c => c.Type == Claims.Type && c.Value == IdentityClientType.Send.ToString()); // Verify password hasher was called @@ -123,7 +123,7 @@ public class SendPasswordRequestValidatorTests } [Theory, BitAutoData] - public void ValidateSendPassword_EmptyPasswordHash_CallsPasswordHasher( + public async Task ValidateSendPassword_EmptyPasswordHash_CallsPasswordHasher( SutProvider sutProvider, [AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, ResourcePassword resourcePassword, @@ -142,7 +142,7 @@ public class SendPasswordRequestValidatorTests .Returns(false); // Act - var result = sutProvider.Sut.ValidateSendPassword(context, resourcePassword, sendId); + var result = await sutProvider.Sut.ValidateRequestAsync(context, resourcePassword, sendId); // Assert Assert.True(result.IsError); @@ -155,7 +155,7 @@ public class SendPasswordRequestValidatorTests } [Theory, BitAutoData] - public void ValidateSendPassword_WhitespacePasswordHash_CallsPasswordHasher( + public async Task ValidateSendPassword_WhitespacePasswordHash_CallsPasswordHasher( SutProvider sutProvider, [AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, ResourcePassword resourcePassword, @@ -175,7 +175,7 @@ public class SendPasswordRequestValidatorTests .Returns(false); // Act - var result = sutProvider.Sut.ValidateSendPassword(context, resourcePassword, sendId); + var result = await sutProvider.Sut.ValidateRequestAsync(context, resourcePassword, sendId); // Assert Assert.True(result.IsError); @@ -187,7 +187,7 @@ public class SendPasswordRequestValidatorTests } [Theory, BitAutoData] - public void ValidateSendPassword_MultiplePasswordHashParameters_ReturnsInvalidGrant( + public async Task ValidateSendPassword_MultiplePasswordHashParameters_ReturnsInvalidGrant( SutProvider sutProvider, [AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, ResourcePassword resourcePassword, @@ -208,7 +208,7 @@ public class SendPasswordRequestValidatorTests .Returns(true); // Act - var result = sutProvider.Sut.ValidateSendPassword(context, resourcePassword, sendId); + var result = await sutProvider.Sut.ValidateRequestAsync(context, resourcePassword, sendId); // Assert Assert.True(result.IsError); @@ -221,7 +221,7 @@ public class SendPasswordRequestValidatorTests } [Theory, BitAutoData] - public void ValidateSendPassword_SuccessResult_ContainsCorrectClaims( + public async Task ValidateSendPassword_SuccessResult_ContainsCorrectClaims( SutProvider sutProvider, [AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, ResourcePassword resourcePassword, @@ -241,13 +241,13 @@ public class SendPasswordRequestValidatorTests .Returns(true); // Act - var result = sutProvider.Sut.ValidateSendPassword(context, resourcePassword, sendId); + var result = await sutProvider.Sut.ValidateRequestAsync(context, resourcePassword, sendId); // Assert Assert.False(result.IsError); var sub = result.Subject; - var sendIdClaim = sub.Claims.FirstOrDefault(c => c.Type == Claims.SendId); + var sendIdClaim = sub.Claims.FirstOrDefault(c => c.Type == Claims.SendAccessClaims.SendId); Assert.NotNull(sendIdClaim); Assert.Equal(sendId.ToString(), sendIdClaim.Value); From 0bfbfaa17c36373b8cc16ea4a81c8979886b533f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Garc=C3=ADa?= Date: Wed, 3 Sep 2025 11:38:01 +0200 Subject: [PATCH 07/13] Improve Swagger OperationIDs for Tools (#6239) --- .../Controllers/ImportCiphersController.cs | 2 +- src/Api/Tools/Controllers/SendsController.cs | 2 +- .../ImportCiphersControllerTests.cs | 20 +++++++++---------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/Api/Tools/Controllers/ImportCiphersController.cs b/src/Api/Tools/Controllers/ImportCiphersController.cs index 0f29a9aee3..88028420b7 100644 --- a/src/Api/Tools/Controllers/ImportCiphersController.cs +++ b/src/Api/Tools/Controllers/ImportCiphersController.cs @@ -63,7 +63,7 @@ public class ImportCiphersController : Controller } [HttpPost("import-organization")] - public async Task PostImport([FromQuery] string organizationId, + public async Task PostImportOrganization([FromQuery] string organizationId, [FromBody] ImportOrganizationCiphersRequestModel model) { if (!_globalSettings.SelfHosted && diff --git a/src/Api/Tools/Controllers/SendsController.cs b/src/Api/Tools/Controllers/SendsController.cs index 43239b3995..c02e9b0c20 100644 --- a/src/Api/Tools/Controllers/SendsController.cs +++ b/src/Api/Tools/Controllers/SendsController.cs @@ -192,7 +192,7 @@ public class SendsController : Controller } [HttpGet("")] - public async Task> Get() + public async Task> GetAll() { var userId = _userService.GetProperUserId(User).Value; var sends = await _sendRepository.GetManyByUserIdAsync(userId); diff --git a/test/Api.Test/Tools/Controllers/ImportCiphersControllerTests.cs b/test/Api.Test/Tools/Controllers/ImportCiphersControllerTests.cs index 53d9d2a1f8..4908bb6847 100644 --- a/test/Api.Test/Tools/Controllers/ImportCiphersControllerTests.cs +++ b/test/Api.Test/Tools/Controllers/ImportCiphersControllerTests.cs @@ -126,7 +126,7 @@ public class ImportCiphersControllerTests }; // Act - var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.PostImport(Arg.Any(), model)); + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.PostImportOrganization(Arg.Any(), model)); // Assert Assert.Equal("You cannot import this much data at once.", exception.Message); @@ -186,7 +186,7 @@ public class ImportCiphersControllerTests .Returns(existingCollections.Select(c => new Collection { Id = orgIdGuid }).ToList()); // Act - await sutProvider.Sut.PostImport(orgId, request); + await sutProvider.Sut.PostImportOrganization(orgId, request); // Assert await sutProvider.GetDependency() @@ -257,7 +257,7 @@ public class ImportCiphersControllerTests .Returns(existingCollections.Select(c => new Collection { Id = orgIdGuid }).ToList()); // Act - await sutProvider.Sut.PostImport(orgId, request); + await sutProvider.Sut.PostImportOrganization(orgId, request); // Assert await sutProvider.GetDependency() @@ -324,7 +324,7 @@ public class ImportCiphersControllerTests // Act var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.PostImport(orgId, request)); + sutProvider.Sut.PostImportOrganization(orgId, request)); // Assert Assert.IsType(exception); @@ -387,7 +387,7 @@ public class ImportCiphersControllerTests // Act var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.PostImport(orgId, request)); + sutProvider.Sut.PostImportOrganization(orgId, request)); // Assert Assert.IsType(exception); @@ -457,7 +457,7 @@ public class ImportCiphersControllerTests // Act // User imports into collections and creates new collections // User has ImportCiphers and Create ciphers permission - await sutProvider.Sut.PostImport(orgId.ToString(), request); + await sutProvider.Sut.PostImportOrganization(orgId.ToString(), request); // Assert await sutProvider.GetDependency() @@ -535,7 +535,7 @@ public class ImportCiphersControllerTests // User has ImportCiphers permission only and doesn't have Create permission var exception = await Assert.ThrowsAsync(async () => { - await sutProvider.Sut.PostImport(orgId.ToString(), request); + await sutProvider.Sut.PostImportOrganization(orgId.ToString(), request); }); // Assert @@ -610,7 +610,7 @@ public class ImportCiphersControllerTests // Act // User imports/creates a new collection - existing collections not affected // User has create permissions and doesn't need import permissions - await sutProvider.Sut.PostImport(orgId.ToString(), request); + await sutProvider.Sut.PostImportOrganization(orgId.ToString(), request); // Assert await sutProvider.GetDependency() @@ -685,7 +685,7 @@ public class ImportCiphersControllerTests // Act // User import into existing collection // User has ImportCiphers permission only and doesn't need create permission - await sutProvider.Sut.PostImport(orgId.ToString(), request); + await sutProvider.Sut.PostImportOrganization(orgId.ToString(), request); // Assert await sutProvider.GetDependency() @@ -753,7 +753,7 @@ public class ImportCiphersControllerTests // import ciphers only and no collections // User has Create permissions // expected to be successful - await sutProvider.Sut.PostImport(orgId.ToString(), request); + await sutProvider.Sut.PostImportOrganization(orgId.ToString(), request); // Assert await sutProvider.GetDependency() From d627b0a0643650bb39132985c63ea0dfd4d253ed Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Wed, 3 Sep 2025 12:01:39 +0200 Subject: [PATCH 08/13] [deps] Tools: Update aws-sdk-net monorepo (#6272) Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> --- src/Core/Core.csproj | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Core/Core.csproj b/src/Core/Core.csproj index 25e74d8aee..04dd7781bc 100644 --- a/src/Core/Core.csproj +++ b/src/Core/Core.csproj @@ -21,8 +21,8 @@ - - + + From 99058891d0ba39a7a6901799e81a6472b3556d88 Mon Sep 17 00:00:00 2001 From: Patrick-Pimentel-Bitwarden Date: Wed, 3 Sep 2025 09:12:26 -0400 Subject: [PATCH 09/13] Auth/pm 24434/enhance email (#6157) * fix(emails): [PM-24434] Email Enhancement - Added seconds to new device logged in email --- src/Core/Services/Implementations/HandlebarsMailService.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Core/Services/Implementations/HandlebarsMailService.cs b/src/Core/Services/Implementations/HandlebarsMailService.cs index 394b5c5125..8de0e99bd3 100644 --- a/src/Core/Services/Implementations/HandlebarsMailService.cs +++ b/src/Core/Services/Implementations/HandlebarsMailService.cs @@ -559,7 +559,7 @@ public class HandlebarsMailService : IMailService SiteName = _globalSettings.SiteName, DeviceType = deviceType, TheDate = timestamp.ToLongDateString(), - TheTime = timestamp.ToShortTimeString(), + TheTime = timestamp.ToString("hh:mm:ss tt"), TimeZone = _utcTimeZoneDisplay, IpAddress = ip }; From 1dade9d4b868fb73907c0d280fd19bb0191692cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rui=20Tom=C3=A9?= <108268980+r-tome@users.noreply.github.com> Date: Wed, 3 Sep 2025 14:57:53 +0100 Subject: [PATCH 10/13] [PM-24233] Use BulkResourceCreationService in CipherRepository (#6201) * Add constant for CipherRepositoryBulkResourceCreation in FeatureFlagKeys * Add bulk creation methods for Ciphers, Folders, and CollectionCiphers in BulkResourceCreationService - Implemented CreateCiphersAsync, CreateFoldersAsync, CreateCollectionCiphersAsync, and CreateTempCiphersAsync methods for bulk insertion. - Added helper methods to build DataTables for Ciphers, Folders, and CollectionCiphers. - Enhanced error handling for empty collections during bulk operations. * Refactor CipherRepository to utilize BulkResourceCreationService - Introduced IFeatureService to manage feature flag checks for bulk operations. - Updated methods to conditionally use BulkResourceCreationService for creating Ciphers, Folders, and CollectionCiphers based on feature flag status. - Enhanced existing bulk copy logic to maintain functionality while integrating feature flag checks. * Add InlineFeatureService to DatabaseDataAttribute for feature flag management - Introduced EnabledFeatureFlags property to DatabaseDataAttribute for configuring feature flags. - Integrated InlineFeatureService to provide feature flag checks within the service collection. - Enhanced GetData method to utilize feature flags for conditional service registration. * Add tests for bulk creation of Ciphers in CipherRepositoryTests - Implemented tests for bulk creation of Ciphers, Folders, and Collections with feature flag checks. - Added test cases for updating multiple Ciphers to validate bulk update functionality. - Enhanced existing test structure to ensure comprehensive coverage of bulk operations in the CipherRepository. * Refactor BulkResourceCreationService to use dynamic types for DataColumns - Updated DataColumn definitions in BulkResourceCreationService to utilize the actual types of properties from the cipher object instead of hardcoded types. - Simplified the assignment of nullable properties to directly use their values, improving code readability and maintainability. * Update BulkResourceCreationService to use specific types for DataColumns - Changed DataColumn definitions to use specific types (short and string) instead of dynamic types based on cipher properties. - Improved handling of nullable properties when assigning values to DataTable rows, ensuring proper handling of DBNull for null values. * Refactor CipherRepositoryTests for improved clarity and consistency - Renamed test methods to better reflect their purpose and improve readability. - Updated test data to use more descriptive names for users, folders, and collections. - Enhanced test structure with clear Arrange, Act, and Assert sections for better understanding of test flow. - Ensured all tests validate the expected outcomes for bulk operations with feature flag checks. * Update CipherRepositoryBulkResourceCreation feature flag key * Refactor DatabaseDataAttribute usage in CipherRepositoryTests to use array syntax for EnabledFeatureFlags * Update CipherRepositoryTests to use GenerateComb for generating unique IDs * Refactor CipherRepository methods to accept a boolean parameter for enabling bulk resource creation based on feature flags. Update tests to verify functionality with and without the feature flag enabled. * Refactor CipherRepository and related services to support new methods for bulk resource creation without boolean parameters. --- src/Core/Constants.cs | 1 + .../RotateUserAccountkeysCommand.cs | 15 +- .../ImportFeatures/ImportCiphersCommand.cs | 20 +- .../Vault/Repositories/ICipherRepository.cs | 22 ++ .../Services/Implementations/CipherService.cs | 10 +- .../Helpers/BulkResourceCreationService.cs | 190 ++++++++++++++++ .../Vault/Repositories/CipherRepository.cs | 211 ++++++++++++++++++ .../Vault/Repositories/CipherRepository.cs | 41 ++++ .../ImportCiphersAsyncCommandTests.cs | 136 ++++++++++- .../Vault/Services/CipherServiceTests.cs | 53 +++++ .../Repositories/CipherRepositoryTests.cs | 157 +++++++++++++ 11 files changed, 849 insertions(+), 7 deletions(-) diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 393ab15e4c..2993f6a094 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -114,6 +114,7 @@ public static class FeatureFlagKeys public const string SeparateCustomRolePermissions = "pm-19917-separate-custom-role-permissions"; public const string CreateDefaultLocation = "pm-19467-create-default-location"; public const string DirectoryConnectorPreventUserRemoval = "pm-24592-directory-connector-prevent-user-removal"; + public const string CipherRepositoryBulkResourceCreation = "pm-24951-cipher-repository-bulk-resource-creation-service"; /* Auth Team */ public const string TwoFactorExtensionDataPersistence = "pm-9115-two-factor-extension-data-persistence"; diff --git a/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs b/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs index 6967c9bf85..011fc2932f 100644 --- a/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs +++ b/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs @@ -25,6 +25,7 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand private readonly IdentityErrorDescriber _identityErrorDescriber; private readonly IWebAuthnCredentialRepository _credentialRepository; private readonly IPasswordHasher _passwordHasher; + private readonly IFeatureService _featureService; /// /// Instantiates a new @@ -45,7 +46,8 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand IEmergencyAccessRepository emergencyAccessRepository, IOrganizationUserRepository organizationUserRepository, IDeviceRepository deviceRepository, IPasswordHasher passwordHasher, - IPushNotificationService pushService, IdentityErrorDescriber errors, IWebAuthnCredentialRepository credentialRepository) + IPushNotificationService pushService, IdentityErrorDescriber errors, IWebAuthnCredentialRepository credentialRepository, + IFeatureService featureService) { _userService = userService; _userRepository = userRepository; @@ -59,6 +61,7 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand _identityErrorDescriber = errors; _credentialRepository = credentialRepository; _passwordHasher = passwordHasher; + _featureService = featureService; } /// @@ -100,7 +103,15 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand List saveEncryptedDataActions = new(); if (model.Ciphers.Any()) { - saveEncryptedDataActions.Add(_cipherRepository.UpdateForKeyRotation(user.Id, model.Ciphers)); + var useBulkResourceCreationService = _featureService.IsEnabled(FeatureFlagKeys.CipherRepositoryBulkResourceCreation); + if (useBulkResourceCreationService) + { + saveEncryptedDataActions.Add(_cipherRepository.UpdateForKeyRotation_vNext(user.Id, model.Ciphers)); + } + else + { + saveEncryptedDataActions.Add(_cipherRepository.UpdateForKeyRotation(user.Id, model.Ciphers)); + } } if (model.Folders.Any()) diff --git a/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs b/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs index c7f7e3aff7..ce269bc68c 100644 --- a/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs +++ b/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs @@ -108,7 +108,15 @@ public class ImportCiphersCommand : IImportCiphersCommand } // Create it all - await _cipherRepository.CreateAsync(importingUserId, ciphers, newFolders); + var useBulkResourceCreationService = _featureService.IsEnabled(FeatureFlagKeys.CipherRepositoryBulkResourceCreation); + if (useBulkResourceCreationService) + { + await _cipherRepository.CreateAsync_vNext(importingUserId, ciphers, newFolders); + } + else + { + await _cipherRepository.CreateAsync(importingUserId, ciphers, newFolders); + } // push await _pushService.PushSyncVaultAsync(importingUserId); @@ -183,7 +191,15 @@ public class ImportCiphersCommand : IImportCiphersCommand } // Create it all - await _cipherRepository.CreateAsync(ciphers, newCollections, collectionCiphers, newCollectionUsers); + var useBulkResourceCreationService = _featureService.IsEnabled(FeatureFlagKeys.CipherRepositoryBulkResourceCreation); + if (useBulkResourceCreationService) + { + await _cipherRepository.CreateAsync_vNext(ciphers, newCollections, collectionCiphers, newCollectionUsers); + } + else + { + await _cipherRepository.CreateAsync(ciphers, newCollections, collectionCiphers, newCollectionUsers); + } // push await _pushService.PushSyncVaultAsync(importingUserId); diff --git a/src/Core/Vault/Repositories/ICipherRepository.cs b/src/Core/Vault/Repositories/ICipherRepository.cs index 5a04a6651d..60b6e21f1d 100644 --- a/src/Core/Vault/Repositories/ICipherRepository.cs +++ b/src/Core/Vault/Repositories/ICipherRepository.cs @@ -32,12 +32,28 @@ public interface ICipherRepository : IRepository Task DeleteByUserIdAsync(Guid userId); Task DeleteByOrganizationIdAsync(Guid organizationId); Task UpdateCiphersAsync(Guid userId, IEnumerable ciphers); + /// + /// + /// This version uses the bulk resource creation service to create the temp table. + /// + Task UpdateCiphersAsync_vNext(Guid userId, IEnumerable ciphers); /// /// Create ciphers and folders for the specified UserId. Must not be used to create organization owned items. /// Task CreateAsync(Guid userId, IEnumerable ciphers, IEnumerable folders); + /// + /// + /// This version uses the bulk resource creation service to create the temp tables. + /// + Task CreateAsync_vNext(Guid userId, IEnumerable ciphers, IEnumerable folders); Task CreateAsync(IEnumerable ciphers, IEnumerable collections, IEnumerable collectionCiphers, IEnumerable collectionUsers); + /// + /// + /// This version uses the bulk resource creation service to create the temp tables. + /// + Task CreateAsync_vNext(IEnumerable ciphers, IEnumerable collections, + IEnumerable collectionCiphers, IEnumerable collectionUsers); Task SoftDeleteAsync(IEnumerable ids, Guid userId); Task SoftDeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId); Task RestoreAsync(IEnumerable ids, Guid userId); @@ -68,4 +84,10 @@ public interface ICipherRepository : IRepository /// A list of ciphers with updated data UpdateEncryptedDataForKeyRotation UpdateForKeyRotation(Guid userId, IEnumerable ciphers); + /// + /// + /// This version uses the bulk resource creation service to create the temp table. + /// + UpdateEncryptedDataForKeyRotation UpdateForKeyRotation_vNext(Guid userId, + IEnumerable ciphers); } diff --git a/src/Core/Vault/Services/Implementations/CipherService.cs b/src/Core/Vault/Services/Implementations/CipherService.cs index 51ed4b0ce7..2a4cc6c137 100644 --- a/src/Core/Vault/Services/Implementations/CipherService.cs +++ b/src/Core/Vault/Services/Implementations/CipherService.cs @@ -642,7 +642,15 @@ public class CipherService : ICipherService cipherIds.Add(cipher.Id); } - await _cipherRepository.UpdateCiphersAsync(sharingUserId, cipherInfos.Select(c => c.cipher)); + var useBulkResourceCreationService = _featureService.IsEnabled(FeatureFlagKeys.CipherRepositoryBulkResourceCreation); + if (useBulkResourceCreationService) + { + await _cipherRepository.UpdateCiphersAsync_vNext(sharingUserId, cipherInfos.Select(c => c.cipher)); + } + else + { + await _cipherRepository.UpdateCiphersAsync(sharingUserId, cipherInfos.Select(c => c.cipher)); + } await _collectionCipherRepository.UpdateCollectionsForCiphersAsync(cipherIds, sharingUserId, organizationId, collectionIds); diff --git a/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs b/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs index 139960ceba..3610c1c484 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs @@ -1,5 +1,6 @@ using System.Data; using Bit.Core.Entities; +using Bit.Core.Vault.Entities; using Microsoft.Data.SqlClient; namespace Bit.Infrastructure.Dapper.AdminConsole.Helpers; @@ -15,6 +16,38 @@ public static class BulkResourceCreationService await bulkCopy.WriteToServerAsync(dataTable); } + public static async Task CreateCiphersAsync(SqlConnection connection, SqlTransaction transaction, IEnumerable ciphers, string errorMessage = _defaultErrorMessage) + { + using var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction); + bulkCopy.DestinationTableName = "[dbo].[Cipher]"; + var dataTable = BuildCiphersTable(bulkCopy, ciphers, errorMessage); + await bulkCopy.WriteToServerAsync(dataTable); + } + + public static async Task CreateFoldersAsync(SqlConnection connection, SqlTransaction transaction, IEnumerable folders, string errorMessage = _defaultErrorMessage) + { + using var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction); + bulkCopy.DestinationTableName = "[dbo].[Folder]"; + var dataTable = BuildFoldersTable(bulkCopy, folders, errorMessage); + await bulkCopy.WriteToServerAsync(dataTable); + } + + public static async Task CreateCollectionCiphersAsync(SqlConnection connection, SqlTransaction transaction, IEnumerable collectionCiphers, string errorMessage = _defaultErrorMessage) + { + using var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction); + bulkCopy.DestinationTableName = "[dbo].[CollectionCipher]"; + var dataTable = BuildCollectionCiphersTable(bulkCopy, collectionCiphers, errorMessage); + await bulkCopy.WriteToServerAsync(dataTable); + } + + public static async Task CreateTempCiphersAsync(SqlConnection connection, SqlTransaction transaction, IEnumerable ciphers, string errorMessage = _defaultErrorMessage) + { + using var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction); + bulkCopy.DestinationTableName = "#TempCipher"; + var dataTable = BuildCiphersTable(bulkCopy, ciphers, errorMessage); + await bulkCopy.WriteToServerAsync(dataTable); + } + private static DataTable BuildCollectionsUsersTable(SqlBulkCopy bulkCopy, IEnumerable collectionUsers, string errorMessage) { var collectionUser = collectionUsers.FirstOrDefault(); @@ -126,4 +159,161 @@ public static class BulkResourceCreationService return collectionsTable; } + + private static DataTable BuildCiphersTable(SqlBulkCopy bulkCopy, IEnumerable ciphers, string errorMessage) + { + var c = ciphers.FirstOrDefault(); + + if (c == null) + { + throw new ApplicationException(errorMessage); + } + + var ciphersTable = new DataTable("CipherDataTable"); + + var idColumn = new DataColumn(nameof(c.Id), c.Id.GetType()); + ciphersTable.Columns.Add(idColumn); + var userIdColumn = new DataColumn(nameof(c.UserId), typeof(Guid)); + ciphersTable.Columns.Add(userIdColumn); + var organizationId = new DataColumn(nameof(c.OrganizationId), typeof(Guid)); + ciphersTable.Columns.Add(organizationId); + var typeColumn = new DataColumn(nameof(c.Type), typeof(short)); + ciphersTable.Columns.Add(typeColumn); + var dataColumn = new DataColumn(nameof(c.Data), typeof(string)); + ciphersTable.Columns.Add(dataColumn); + var favoritesColumn = new DataColumn(nameof(c.Favorites), typeof(string)); + ciphersTable.Columns.Add(favoritesColumn); + var foldersColumn = new DataColumn(nameof(c.Folders), typeof(string)); + ciphersTable.Columns.Add(foldersColumn); + var attachmentsColumn = new DataColumn(nameof(c.Attachments), typeof(string)); + ciphersTable.Columns.Add(attachmentsColumn); + var creationDateColumn = new DataColumn(nameof(c.CreationDate), c.CreationDate.GetType()); + ciphersTable.Columns.Add(creationDateColumn); + var revisionDateColumn = new DataColumn(nameof(c.RevisionDate), c.RevisionDate.GetType()); + ciphersTable.Columns.Add(revisionDateColumn); + var deletedDateColumn = new DataColumn(nameof(c.DeletedDate), typeof(DateTime)); + ciphersTable.Columns.Add(deletedDateColumn); + var repromptColumn = new DataColumn(nameof(c.Reprompt), typeof(short)); + ciphersTable.Columns.Add(repromptColumn); + var keyColummn = new DataColumn(nameof(c.Key), typeof(string)); + ciphersTable.Columns.Add(keyColummn); + + foreach (DataColumn col in ciphersTable.Columns) + { + bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + } + + var keys = new DataColumn[1]; + keys[0] = idColumn; + ciphersTable.PrimaryKey = keys; + + foreach (var cipher in ciphers) + { + var row = ciphersTable.NewRow(); + + row[idColumn] = cipher.Id; + row[userIdColumn] = cipher.UserId.HasValue ? (object)cipher.UserId.Value : DBNull.Value; + row[organizationId] = cipher.OrganizationId.HasValue ? (object)cipher.OrganizationId.Value : DBNull.Value; + row[typeColumn] = (short)cipher.Type; + row[dataColumn] = cipher.Data; + row[favoritesColumn] = cipher.Favorites; + row[foldersColumn] = cipher.Folders; + row[attachmentsColumn] = cipher.Attachments; + row[creationDateColumn] = cipher.CreationDate; + row[revisionDateColumn] = cipher.RevisionDate; + row[deletedDateColumn] = cipher.DeletedDate.HasValue ? (object)cipher.DeletedDate : DBNull.Value; + row[repromptColumn] = cipher.Reprompt.HasValue ? cipher.Reprompt.Value : DBNull.Value; + row[keyColummn] = cipher.Key; + + ciphersTable.Rows.Add(row); + } + + return ciphersTable; + } + + private static DataTable BuildFoldersTable(SqlBulkCopy bulkCopy, IEnumerable folders, string errorMessage) + { + var f = folders.FirstOrDefault(); + + if (f == null) + { + throw new ApplicationException(errorMessage); + } + + var foldersTable = new DataTable("FolderDataTable"); + + var idColumn = new DataColumn(nameof(f.Id), f.Id.GetType()); + foldersTable.Columns.Add(idColumn); + var userIdColumn = new DataColumn(nameof(f.UserId), f.UserId.GetType()); + foldersTable.Columns.Add(userIdColumn); + var nameColumn = new DataColumn(nameof(f.Name), typeof(string)); + foldersTable.Columns.Add(nameColumn); + var creationDateColumn = new DataColumn(nameof(f.CreationDate), f.CreationDate.GetType()); + foldersTable.Columns.Add(creationDateColumn); + var revisionDateColumn = new DataColumn(nameof(f.RevisionDate), f.RevisionDate.GetType()); + foldersTable.Columns.Add(revisionDateColumn); + + foreach (DataColumn col in foldersTable.Columns) + { + bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + } + + var keys = new DataColumn[1]; + keys[0] = idColumn; + foldersTable.PrimaryKey = keys; + + foreach (var folder in folders) + { + var row = foldersTable.NewRow(); + + row[idColumn] = folder.Id; + row[userIdColumn] = folder.UserId; + row[nameColumn] = folder.Name; + row[creationDateColumn] = folder.CreationDate; + row[revisionDateColumn] = folder.RevisionDate; + + foldersTable.Rows.Add(row); + } + + return foldersTable; + } + + private static DataTable BuildCollectionCiphersTable(SqlBulkCopy bulkCopy, IEnumerable collectionCiphers, string errorMessage) + { + var cc = collectionCiphers.FirstOrDefault(); + + if (cc == null) + { + throw new ApplicationException(errorMessage); + } + + var collectionCiphersTable = new DataTable("CollectionCipherDataTable"); + + var collectionIdColumn = new DataColumn(nameof(cc.CollectionId), cc.CollectionId.GetType()); + collectionCiphersTable.Columns.Add(collectionIdColumn); + var cipherIdColumn = new DataColumn(nameof(cc.CipherId), cc.CipherId.GetType()); + collectionCiphersTable.Columns.Add(cipherIdColumn); + + foreach (DataColumn col in collectionCiphersTable.Columns) + { + bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + } + + var keys = new DataColumn[2]; + keys[0] = collectionIdColumn; + keys[1] = cipherIdColumn; + collectionCiphersTable.PrimaryKey = keys; + + foreach (var collectionCipher in collectionCiphers) + { + var row = collectionCiphersTable.NewRow(); + + row[collectionIdColumn] = collectionCipher.CollectionId; + row[cipherIdColumn] = collectionCipher.CipherId; + + collectionCiphersTable.Rows.Add(row); + } + + return collectionCiphersTable; + } } diff --git a/src/Infrastructure.Dapper/Vault/Repositories/CipherRepository.cs b/src/Infrastructure.Dapper/Vault/Repositories/CipherRepository.cs index 180a90fd41..8c1f04affc 100644 --- a/src/Infrastructure.Dapper/Vault/Repositories/CipherRepository.cs +++ b/src/Infrastructure.Dapper/Vault/Repositories/CipherRepository.cs @@ -10,6 +10,7 @@ using Bit.Core.Tools.Entities; using Bit.Core.Vault.Entities; using Bit.Core.Vault.Models.Data; using Bit.Core.Vault.Repositories; +using Bit.Infrastructure.Dapper.AdminConsole.Helpers; using Bit.Infrastructure.Dapper.Repositories; using Bit.Infrastructure.Dapper.Vault.Helpers; using Dapper; @@ -408,6 +409,52 @@ public class CipherRepository : Repository, ICipherRepository }; } + /// + public UpdateEncryptedDataForKeyRotation UpdateForKeyRotation_vNext( + Guid userId, IEnumerable ciphers) + { + return async (SqlConnection connection, SqlTransaction transaction) => + { + // Create temp table + var sqlCreateTemp = @" + SELECT TOP 0 * + INTO #TempCipher + FROM [dbo].[Cipher]"; + + await using (var cmd = new SqlCommand(sqlCreateTemp, connection, transaction)) + { + cmd.ExecuteNonQuery(); + } + + // Bulk copy data into temp table + await BulkResourceCreationService.CreateTempCiphersAsync(connection, transaction, ciphers); + + // Update cipher table from temp table + var sql = @" + UPDATE + [dbo].[Cipher] + SET + [Data] = TC.[Data], + [Attachments] = TC.[Attachments], + [RevisionDate] = TC.[RevisionDate], + [Key] = TC.[Key] + FROM + [dbo].[Cipher] C + INNER JOIN + #TempCipher TC ON C.Id = TC.Id + WHERE + C.[UserId] = @UserId + + DROP TABLE #TempCipher"; + + await using (var cmd = new SqlCommand(sql, connection, transaction)) + { + cmd.Parameters.Add("@UserId", SqlDbType.UniqueIdentifier).Value = userId; + cmd.ExecuteNonQuery(); + } + }; + } + public async Task UpdateCiphersAsync(Guid userId, IEnumerable ciphers) { if (!ciphers.Any()) @@ -490,6 +537,83 @@ public class CipherRepository : Repository, ICipherRepository } } + public async Task UpdateCiphersAsync_vNext(Guid userId, IEnumerable ciphers) + { + if (!ciphers.Any()) + { + return; + } + + using (var connection = new SqlConnection(ConnectionString)) + { + connection.Open(); + + using (var transaction = connection.BeginTransaction()) + { + try + { + // 1. Create temp tables to bulk copy into. + + var sqlCreateTemp = @" + SELECT TOP 0 * + INTO #TempCipher + FROM [dbo].[Cipher]"; + + using (var cmd = new SqlCommand(sqlCreateTemp, connection, transaction)) + { + cmd.ExecuteNonQuery(); + } + + // 2. Bulk copy into temp tables. + await BulkResourceCreationService.CreateTempCiphersAsync(connection, transaction, ciphers); + + // 3. Insert into real tables from temp tables and clean up. + + // Intentionally not including Favorites, Folders, and CreationDate + // since those are not meant to be bulk updated at this time + var sql = @" + UPDATE + [dbo].[Cipher] + SET + [UserId] = TC.[UserId], + [OrganizationId] = TC.[OrganizationId], + [Type] = TC.[Type], + [Data] = TC.[Data], + [Attachments] = TC.[Attachments], + [RevisionDate] = TC.[RevisionDate], + [DeletedDate] = TC.[DeletedDate], + [Key] = TC.[Key] + FROM + [dbo].[Cipher] C + INNER JOIN + #TempCipher TC ON C.Id = TC.Id + WHERE + C.[UserId] = @UserId + + DROP TABLE #TempCipher"; + + using (var cmd = new SqlCommand(sql, connection, transaction)) + { + cmd.Parameters.Add("@UserId", SqlDbType.UniqueIdentifier).Value = userId; + cmd.ExecuteNonQuery(); + } + + await connection.ExecuteAsync( + $"[{Schema}].[User_BumpAccountRevisionDate]", + new { Id = userId }, + commandType: CommandType.StoredProcedure, transaction: transaction); + + transaction.Commit(); + } + catch + { + transaction.Rollback(); + throw; + } + } + } + } + public async Task CreateAsync(Guid userId, IEnumerable ciphers, IEnumerable folders) { if (!ciphers.Any()) @@ -538,6 +662,44 @@ public class CipherRepository : Repository, ICipherRepository } } + public async Task CreateAsync_vNext(Guid userId, IEnumerable ciphers, IEnumerable folders) + { + if (!ciphers.Any()) + { + return; + } + + using (var connection = new SqlConnection(ConnectionString)) + { + connection.Open(); + + using (var transaction = connection.BeginTransaction()) + { + try + { + if (folders.Any()) + { + await BulkResourceCreationService.CreateFoldersAsync(connection, transaction, folders); + } + + await BulkResourceCreationService.CreateCiphersAsync(connection, transaction, ciphers); + + await connection.ExecuteAsync( + $"[{Schema}].[User_BumpAccountRevisionDate]", + new { Id = userId }, + commandType: CommandType.StoredProcedure, transaction: transaction); + + transaction.Commit(); + } + catch + { + transaction.Rollback(); + throw; + } + } + } + } + public async Task CreateAsync(IEnumerable ciphers, IEnumerable collections, IEnumerable collectionCiphers, IEnumerable collectionUsers) { @@ -607,6 +769,55 @@ public class CipherRepository : Repository, ICipherRepository } } + public async Task CreateAsync_vNext(IEnumerable ciphers, IEnumerable collections, + IEnumerable collectionCiphers, IEnumerable collectionUsers) + { + if (!ciphers.Any()) + { + return; + } + + using (var connection = new SqlConnection(ConnectionString)) + { + connection.Open(); + + using (var transaction = connection.BeginTransaction()) + { + try + { + await BulkResourceCreationService.CreateCiphersAsync(connection, transaction, ciphers); + + if (collections.Any()) + { + await BulkResourceCreationService.CreateCollectionsAsync(connection, transaction, collections); + } + + if (collectionCiphers.Any()) + { + await BulkResourceCreationService.CreateCollectionCiphersAsync(connection, transaction, collectionCiphers); + } + + if (collectionUsers.Any()) + { + await BulkResourceCreationService.CreateCollectionsUsersAsync(connection, transaction, collectionUsers); + } + + await connection.ExecuteAsync( + $"[{Schema}].[User_BumpAccountRevisionDateByOrganizationId]", + new { OrganizationId = ciphers.First().OrganizationId }, + commandType: CommandType.StoredProcedure, transaction: transaction); + + transaction.Commit(); + } + catch + { + transaction.Rollback(); + throw; + } + } + } + } + public async Task SoftDeleteAsync(IEnumerable ids, Guid userId) { using (var connection = new SqlConnection(ConnectionString)) diff --git a/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs b/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs index 3fae537a1e..d595fe7cfe 100644 --- a/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs +++ b/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs @@ -167,6 +167,16 @@ public class CipherRepository : Repository + /// + /// EF does not use the bulk resource creation service, so we need to use the regular create method. + /// + public async Task CreateAsync_vNext(Guid userId, IEnumerable ciphers, + IEnumerable folders) + { + await CreateAsync(userId, ciphers, folders); + } + public async Task CreateAsync(IEnumerable ciphers, IEnumerable collections, IEnumerable collectionCiphers, @@ -205,6 +215,18 @@ public class CipherRepository : Repository + /// + /// EF does not use the bulk resource creation service, so we need to use the regular create method. + /// + public async Task CreateAsync_vNext(IEnumerable ciphers, + IEnumerable collections, + IEnumerable collectionCiphers, + IEnumerable collectionUsers) + { + await CreateAsync(ciphers, collections, collectionCiphers, collectionUsers); + } + public async Task DeleteAsync(IEnumerable ids, Guid userId) { await ToggleCipherStates(ids, userId, CipherStateAction.HardDelete); @@ -907,6 +929,15 @@ public class CipherRepository : Repository + /// + /// EF does not use the bulk resource creation service, so we need to use the regular update method. + /// + public async Task UpdateCiphersAsync_vNext(Guid userId, IEnumerable ciphers) + { + await UpdateCiphersAsync(userId, ciphers); + } + public async Task UpdatePartialAsync(Guid id, Guid userId, Guid? folderId, bool favorite) { using (var scope = ServiceScopeFactory.CreateScope()) @@ -970,6 +1001,16 @@ public class CipherRepository : Repository + /// + /// EF does not use the bulk resource creation service, so we need to use the regular update method. + /// + public UpdateEncryptedDataForKeyRotation UpdateForKeyRotation_vNext( + Guid userId, IEnumerable ciphers) + { + return UpdateForKeyRotation(userId, ciphers); + } + public async Task UpsertAsync(CipherDetails cipher) { if (cipher.Id.Equals(default)) diff --git a/test/Core.Test/Tools/ImportFeatures/ImportCiphersAsyncCommandTests.cs b/test/Core.Test/Tools/ImportFeatures/ImportCiphersAsyncCommandTests.cs index 0cb0deaf52..11f637d207 100644 --- a/test/Core.Test/Tools/ImportFeatures/ImportCiphersAsyncCommandTests.cs +++ b/test/Core.Test/Tools/ImportFeatures/ImportCiphersAsyncCommandTests.cs @@ -47,7 +47,41 @@ public class ImportCiphersAsyncCommandTests await sutProvider.Sut.ImportIntoIndividualVaultAsync(folders, ciphers, folderRelationships, importingUserId); // Assert - await sutProvider.GetDependency().Received(1).CreateAsync(importingUserId, ciphers, Arg.Any>()); + await sutProvider.GetDependency() + .Received(1) + .CreateAsync(importingUserId, ciphers, Arg.Any>()); + await sutProvider.GetDependency().Received(1).PushSyncVaultAsync(importingUserId); + } + + [Theory, BitAutoData] + public async Task ImportIntoIndividualVaultAsync_WithBulkResourceCreationServiceEnabled_Success( + Guid importingUserId, + List ciphers, + SutProvider sutProvider) + { + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.CipherRepositoryBulkResourceCreation) + .Returns(true); + + sutProvider.GetDependency() + .AnyPoliciesApplicableToUserAsync(importingUserId, PolicyType.OrganizationDataOwnership) + .Returns(false); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(importingUserId) + .Returns(new List()); + + var folders = new List { new Folder { UserId = importingUserId } }; + + var folderRelationships = new List>(); + + // Act + await sutProvider.Sut.ImportIntoIndividualVaultAsync(folders, ciphers, folderRelationships, importingUserId); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .CreateAsync_vNext(importingUserId, ciphers, Arg.Any>()); await sutProvider.GetDependency().Received(1).PushSyncVaultAsync(importingUserId); } @@ -77,7 +111,45 @@ public class ImportCiphersAsyncCommandTests await sutProvider.Sut.ImportIntoIndividualVaultAsync(folders, ciphers, folderRelationships, importingUserId); - await sutProvider.GetDependency().Received(1).CreateAsync(importingUserId, ciphers, Arg.Any>()); + await sutProvider.GetDependency() + .Received(1) + .CreateAsync(importingUserId, ciphers, Arg.Any>()); + await sutProvider.GetDependency().Received(1).PushSyncVaultAsync(importingUserId); + } + + [Theory, BitAutoData] + public async Task ImportIntoIndividualVaultAsync_WithBulkResourceCreationServiceEnabled_WithPolicyRequirementsEnabled_WithOrganizationDataOwnershipPolicyDisabled_Success( + Guid importingUserId, + List ciphers, + SutProvider sutProvider) + { + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.CipherRepositoryBulkResourceCreation) + .Returns(true); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PolicyRequirements) + .Returns(true); + + sutProvider.GetDependency() + .GetAsync(importingUserId) + .Returns(new OrganizationDataOwnershipPolicyRequirement( + OrganizationDataOwnershipState.Disabled, + [])); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(importingUserId) + .Returns(new List()); + + var folders = new List { new Folder { UserId = importingUserId } }; + + var folderRelationships = new List>(); + + await sutProvider.Sut.ImportIntoIndividualVaultAsync(folders, ciphers, folderRelationships, importingUserId); + + await sutProvider.GetDependency() + .Received(1) + .CreateAsync_vNext(importingUserId, ciphers, Arg.Any>()); await sutProvider.GetDependency().Received(1).PushSyncVaultAsync(importingUserId); } @@ -187,6 +259,66 @@ public class ImportCiphersAsyncCommandTests await sutProvider.GetDependency().Received(1).PushSyncVaultAsync(importingUserId); } + [Theory, BitAutoData] + public async Task ImportIntoOrganizationalVaultAsync_WithBulkResourceCreationServiceEnabled_Success( + Organization organization, + Guid importingUserId, + OrganizationUser importingOrganizationUser, + List collections, + List ciphers, + SutProvider sutProvider) + { + organization.MaxCollections = null; + importingOrganizationUser.OrganizationId = organization.Id; + + foreach (var collection in collections) + { + collection.OrganizationId = organization.Id; + } + + foreach (var cipher in ciphers) + { + cipher.OrganizationId = organization.Id; + } + + KeyValuePair[] collectionRelationships = { + new(0, 0), + new(1, 1), + new(2, 2) + }; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.CipherRepositoryBulkResourceCreation) + .Returns(true); + + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + + sutProvider.GetDependency() + .GetByOrganizationAsync(organization.Id, importingUserId) + .Returns(importingOrganizationUser); + + // Set up a collection that already exists in the organization + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(organization.Id) + .Returns(new List { collections[0] }); + + await sutProvider.Sut.ImportIntoOrganizationalVaultAsync(collections, ciphers, collectionRelationships, importingUserId); + + await sutProvider.GetDependency().Received(1).CreateAsync_vNext( + ciphers, + Arg.Is>(cols => cols.Count() == collections.Count - 1 && + !cols.Any(c => c.Id == collections[0].Id) && // Check that the collection that already existed in the organization was not added + cols.All(c => collections.Any(x => c.Name == x.Name))), + Arg.Is>(c => c.Count() == ciphers.Count), + Arg.Is>(cus => + cus.Count() == collections.Count - 1 && + !cus.Any(cu => cu.CollectionId == collections[0].Id) && // Check that access was not added for the collection that already existed in the organization + cus.All(cu => cu.OrganizationUserId == importingOrganizationUser.Id && cu.Manage == true))); + await sutProvider.GetDependency().Received(1).PushSyncVaultAsync(importingUserId); + } + [Theory, BitAutoData] public async Task ImportIntoOrganizationalVaultAsync_ThrowsBadRequestException( Organization organization, diff --git a/test/Core.Test/Vault/Services/CipherServiceTests.cs b/test/Core.Test/Vault/Services/CipherServiceTests.cs index 55db5a9143..44c86389e3 100644 --- a/test/Core.Test/Vault/Services/CipherServiceTests.cs +++ b/test/Core.Test/Vault/Services/CipherServiceTests.cs @@ -674,6 +674,32 @@ public class CipherServiceTests Arg.Is>(arg => !arg.Except(ciphers).Any())); } + [Theory] + [BitAutoData("")] + [BitAutoData("Correct Time")] + public async Task ShareManyAsync_CorrectRevisionDate_WithBulkResourceCreationServiceEnabled_Passes(string revisionDateString, + SutProvider sutProvider, IEnumerable ciphers, Organization organization, List collectionIds) + { + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.CipherRepositoryBulkResourceCreation) + .Returns(true); + + sutProvider.GetDependency().GetByIdAsync(organization.Id) + .Returns(new Organization + { + PlanType = PlanType.EnterpriseAnnually, + MaxStorageGb = 100 + }); + + var cipherInfos = ciphers.Select(c => (c, + string.IsNullOrEmpty(revisionDateString) ? null : (DateTime?)c.RevisionDate)); + var sharingUserId = ciphers.First().UserId.Value; + + await sutProvider.Sut.ShareManyAsync(cipherInfos, organization.Id, collectionIds, sharingUserId); + await sutProvider.GetDependency().Received(1).UpdateCiphersAsync_vNext(sharingUserId, + Arg.Is>(arg => !arg.Except(ciphers).Any())); + } + [Theory] [BitAutoData] public async Task RestoreAsync_UpdatesUserCipher(Guid restoringUserId, CipherDetails cipher, SutProvider sutProvider) @@ -1094,6 +1120,33 @@ public class CipherServiceTests Arg.Is>(arg => !arg.Except(ciphers).Any())); } + [Theory, BitAutoData] + public async Task ShareManyAsync_PaidOrgWithAttachment_WithBulkResourceCreationServiceEnabled_Passes(SutProvider sutProvider, + IEnumerable ciphers, Guid organizationId, List collectionIds) + { + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.CipherRepositoryBulkResourceCreation) + .Returns(true); + + sutProvider.GetDependency().GetByIdAsync(organizationId) + .Returns(new Organization + { + PlanType = PlanType.EnterpriseAnnually, + MaxStorageGb = 100 + }); + ciphers.FirstOrDefault().Attachments = + "{\"attachment1\":{\"Size\":\"250\",\"FileName\":\"superCoolFile\"," + + "\"Key\":\"superCoolFile\",\"ContainerName\":\"testContainer\",\"Validated\":false}}"; + + var cipherInfos = ciphers.Select(c => (c, + (DateTime?)c.RevisionDate)); + var sharingUserId = ciphers.First().UserId.Value; + + await sutProvider.Sut.ShareManyAsync(cipherInfos, organizationId, collectionIds, sharingUserId); + await sutProvider.GetDependency().Received(1).UpdateCiphersAsync_vNext(sharingUserId, + Arg.Is>(arg => !arg.Except(ciphers).Any())); + } + private class SaveDetailsAsyncDependencies { public CipherDetails CipherDetails { get; set; } diff --git a/test/Infrastructure.IntegrationTest/Vault/Repositories/CipherRepositoryTests.cs b/test/Infrastructure.IntegrationTest/Vault/Repositories/CipherRepositoryTests.cs index 0a186e43be..2a31398a02 100644 --- a/test/Infrastructure.IntegrationTest/Vault/Repositories/CipherRepositoryTests.cs +++ b/test/Infrastructure.IntegrationTest/Vault/Repositories/CipherRepositoryTests.cs @@ -8,11 +8,13 @@ using Bit.Core.Models.Data; using Bit.Core.NotificationCenter.Entities; using Bit.Core.NotificationCenter.Repositories; using Bit.Core.Repositories; +using Bit.Core.Utilities; using Bit.Core.Vault.Entities; using Bit.Core.Vault.Enums; using Bit.Core.Vault.Models.Data; using Bit.Core.Vault.Repositories; using Xunit; +using CipherType = Bit.Core.Vault.Enums.CipherType; namespace Bit.Infrastructure.IntegrationTest.Repositories; @@ -975,6 +977,161 @@ public class CipherRepositoryTests Assert.Equal("new_attachments", updatedCipher2.Attachments); } + [DatabaseTheory, DatabaseData] + public async Task CreateAsync_vNext_WithFolders_Works( + IUserRepository userRepository, ICipherRepository cipherRepository, IFolderRepository folderRepository) + { + // Arrange + var user = await userRepository.CreateAsync(new User + { + Name = "Test User", + Email = $"{Guid.NewGuid()}@example.com", + ApiKey = "TEST", + SecurityStamp = "stamp", + }); + + var folder1 = new Folder { Id = CoreHelpers.GenerateComb(), UserId = user.Id, Name = "Test Folder 1" }; + var folder2 = new Folder { Id = CoreHelpers.GenerateComb(), UserId = user.Id, Name = "Test Folder 2" }; + var cipher1 = new Cipher { Id = CoreHelpers.GenerateComb(), Type = CipherType.Login, UserId = user.Id, Data = "" }; + var cipher2 = new Cipher { Id = CoreHelpers.GenerateComb(), Type = CipherType.SecureNote, UserId = user.Id, Data = "" }; + + // Act + await cipherRepository.CreateAsync_vNext( + userId: user.Id, + ciphers: [cipher1, cipher2], + folders: [folder1, folder2]); + + // Assert + var readCipher1 = await cipherRepository.GetByIdAsync(cipher1.Id); + var readCipher2 = await cipherRepository.GetByIdAsync(cipher2.Id); + Assert.NotNull(readCipher1); + Assert.NotNull(readCipher2); + + var readFolder1 = await folderRepository.GetByIdAsync(folder1.Id); + var readFolder2 = await folderRepository.GetByIdAsync(folder2.Id); + Assert.NotNull(readFolder1); + Assert.NotNull(readFolder2); + } + + [DatabaseTheory, DatabaseData] + public async Task CreateAsync_vNext_WithCollectionsAndUsers_Works( + IOrganizationRepository orgRepository, + IOrganizationUserRepository orgUserRepository, + ICollectionRepository collectionRepository, + ICollectionCipherRepository collectionCipherRepository, + ICipherRepository cipherRepository, + IUserRepository userRepository) + { + // Arrange + var user = await userRepository.CreateAsync(new User + { + Name = "Test User", + Email = $"{Guid.NewGuid()}@example.com", + ApiKey = "TEST", + SecurityStamp = "stamp", + }); + + var org = await orgRepository.CreateAsync(new Organization + { + Name = "Test Organization", + BillingEmail = user.Email, + Plan = "Test" + }); + + var orgUser = await orgUserRepository.CreateAsync(new OrganizationUser + { + UserId = user.Id, + OrganizationId = org.Id, + Status = OrganizationUserStatusType.Confirmed, + Type = OrganizationUserType.Owner, + }); + + var collection = new Collection { Id = CoreHelpers.GenerateComb(), Name = "Test Collection", OrganizationId = org.Id }; + var cipher = new Cipher { Id = CoreHelpers.GenerateComb(), Type = CipherType.Login, OrganizationId = org.Id, Data = "" }; + var collectionCipher = new CollectionCipher { CollectionId = collection.Id, CipherId = cipher.Id }; + var collectionUser = new CollectionUser + { + CollectionId = collection.Id, + OrganizationUserId = orgUser.Id, + HidePasswords = false, + ReadOnly = false, + Manage = true + }; + + // Act + await cipherRepository.CreateAsync_vNext( + ciphers: [cipher], + collections: [collection], + collectionCiphers: [collectionCipher], + collectionUsers: [collectionUser]); + + // Assert + var orgCiphers = await cipherRepository.GetManyByOrganizationIdAsync(org.Id); + Assert.Contains(orgCiphers, c => c.Id == cipher.Id); + + var collCiphers = await collectionCipherRepository.GetManyByOrganizationIdAsync(org.Id); + Assert.Contains(collCiphers, cc => cc.CipherId == cipher.Id && cc.CollectionId == collection.Id); + + var collectionsInOrg = await collectionRepository.GetManyByOrganizationIdAsync(org.Id); + Assert.Contains(collectionsInOrg, c => c.Id == collection.Id); + + var collectionUsers = await collectionRepository.GetManyUsersByIdAsync(collection.Id); + var foundCollectionUser = collectionUsers.FirstOrDefault(cu => cu.Id == orgUser.Id); + Assert.NotNull(foundCollectionUser); + Assert.True(foundCollectionUser.Manage); + Assert.False(foundCollectionUser.ReadOnly); + Assert.False(foundCollectionUser.HidePasswords); + } + + [DatabaseTheory, DatabaseData] + public async Task UpdateCiphersAsync_vNext_Works( + IUserRepository userRepository, ICipherRepository cipherRepository) + { + // Arrange + var expectedNewType = CipherType.SecureNote; + var expectedNewAttachments = "bulk_new_attachments"; + + var user = await userRepository.CreateAsync(new User + { + Name = "Test User", + Email = $"{Guid.NewGuid()}@example.com", + ApiKey = "TEST", + SecurityStamp = "stamp", + }); + + var c1 = new Cipher { Id = CoreHelpers.GenerateComb(), Type = CipherType.Login, UserId = user.Id, Data = "" }; + var c2 = new Cipher { Id = CoreHelpers.GenerateComb(), Type = CipherType.Login, UserId = user.Id, Data = "" }; + await cipherRepository.CreateAsync( + userId: user.Id, + ciphers: [c1, c2], + folders: []); + + c1.Type = expectedNewType; + c2.Attachments = expectedNewAttachments; + + // Act + await cipherRepository.UpdateCiphersAsync_vNext(user.Id, [c1, c2]); + + // Assert + var updated1 = await cipherRepository.GetByIdAsync(c1.Id); + Assert.NotNull(updated1); + Assert.Equal(c1.Id, updated1.Id); + Assert.Equal(expectedNewType, updated1.Type); + Assert.Equal(c1.UserId, updated1.UserId); + Assert.Equal(c1.Data, updated1.Data); + Assert.Equal(c1.OrganizationId, updated1.OrganizationId); + Assert.Equal(c1.Attachments, updated1.Attachments); + + var updated2 = await cipherRepository.GetByIdAsync(c2.Id); + Assert.NotNull(updated2); + Assert.Equal(c2.Id, updated2.Id); + Assert.Equal(c2.Type, updated2.Type); + Assert.Equal(c2.UserId, updated2.UserId); + Assert.Equal(c2.Data, updated2.Data); + Assert.Equal(c2.OrganizationId, updated2.OrganizationId); + Assert.Equal(expectedNewAttachments, updated2.Attachments); + } + [DatabaseTheory, DatabaseData] public async Task DeleteCipherWithSecurityTaskAsync_Works( IOrganizationRepository organizationRepository, From fa8d65cc1f572fad047e3da17eebb299fa097ef2 Mon Sep 17 00:00:00 2001 From: cyprain-okeke <108260115+cyprain-okeke@users.noreply.github.com> Date: Wed, 3 Sep 2025 20:33:32 +0530 Subject: [PATCH 11/13] [PM 19727] Update InvoiceUpcoming email content (#6168) * changes to implement the email * Refactoring and fix the unit testing * refactor the code and remove used method * Fix the failing test * Update the email templates * remove the extra space here * Refactor the descriptions * Fix the wrong subject header * Add the in the hyperlink rather than just Help center --- .../Implementations/UpcomingInvoiceHandler.cs | 40 +- .../Billing/Extensions/InvoiceExtensions.cs | 76 ++++ .../Handlebars/Layouts/ProviderFull.html.hbs | 211 ++++++++++ .../ProviderInvoiceUpcoming.html.hbs | 89 ++++ .../ProviderInvoiceUpcoming.text.hbs | 41 ++ .../Models/Mail/InvoiceUpcomingViewModel.cs | 5 + src/Core/Services/IMailService.cs | 8 + .../Implementations/HandlebarsMailService.cs | 42 ++ .../NoopImplementations/NoopMailService.cs | 9 + .../Extensions/InvoiceExtensionsTests.cs | 394 ++++++++++++++++++ 10 files changed, 914 insertions(+), 1 deletion(-) create mode 100644 src/Core/Billing/Extensions/InvoiceExtensions.cs create mode 100644 src/Core/MailTemplates/Handlebars/Layouts/ProviderFull.html.hbs create mode 100644 src/Core/MailTemplates/Handlebars/ProviderInvoiceUpcoming.html.hbs create mode 100644 src/Core/MailTemplates/Handlebars/ProviderInvoiceUpcoming.text.hbs create mode 100644 test/Core.Test/Billing/Extensions/InvoiceExtensionsTests.cs diff --git a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs index 9b1d110b5e..9f6fda7d3f 100644 --- a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs +++ b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs @@ -8,6 +8,7 @@ using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Pricing; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; @@ -18,6 +19,7 @@ using Event = Stripe.Event; namespace Bit.Billing.Services.Implementations; public class UpcomingInvoiceHandler( + IGetPaymentMethodQuery getPaymentMethodQuery, ILogger logger, IMailService mailService, IOrganizationRepository organizationRepository, @@ -137,7 +139,7 @@ public class UpcomingInvoiceHandler( await AlignProviderTaxConcernsAsync(provider, subscription, parsedEvent.Id); - await SendUpcomingInvoiceEmailsAsync(new List { provider.BillingEmail }, invoice); + await SendProviderUpcomingInvoiceEmailsAsync(new List { provider.BillingEmail }, invoice, subscription, providerId.Value); } } @@ -158,6 +160,42 @@ public class UpcomingInvoiceHandler( } } + private async Task SendProviderUpcomingInvoiceEmailsAsync(IEnumerable emails, Invoice invoice, Subscription subscription, Guid providerId) + { + var validEmails = emails.Where(e => !string.IsNullOrEmpty(e)); + + var items = invoice.FormatForProvider(subscription); + + if (invoice.NextPaymentAttempt.HasValue && invoice.AmountDue > 0) + { + var provider = await providerRepository.GetByIdAsync(providerId); + if (provider == null) + { + logger.LogWarning("Provider {ProviderId} not found for invoice upcoming email", providerId); + return; + } + + var collectionMethod = subscription.CollectionMethod; + var paymentMethod = await getPaymentMethodQuery.Run(provider); + + var hasPaymentMethod = paymentMethod != null; + var paymentMethodDescription = paymentMethod?.Match( + bankAccount => $"Bank account ending in {bankAccount.Last4}", + card => $"{card.Brand} ending in {card.Last4}", + payPal => $"PayPal account {payPal.Email}" + ); + + await mailService.SendProviderInvoiceUpcoming( + validEmails, + invoice.AmountDue / 100M, + invoice.NextPaymentAttempt.Value, + items, + collectionMethod, + hasPaymentMethod, + paymentMethodDescription); + } + } + private async Task AlignOrganizationTaxConcernsAsync( Organization organization, Subscription subscription, diff --git a/src/Core/Billing/Extensions/InvoiceExtensions.cs b/src/Core/Billing/Extensions/InvoiceExtensions.cs new file mode 100644 index 0000000000..bb9f7588bf --- /dev/null +++ b/src/Core/Billing/Extensions/InvoiceExtensions.cs @@ -0,0 +1,76 @@ +using System.Text.RegularExpressions; +using Stripe; + +namespace Bit.Core.Billing.Extensions; + +public static class InvoiceExtensions +{ + /// + /// Formats invoice line items specifically for provider invoices, standardizing product descriptions + /// and ensuring consistent tax representation. + /// + /// The Stripe invoice containing line items + /// The associated subscription (for future extensibility) + /// A list of formatted invoice item descriptions + public static List FormatForProvider(this Invoice invoice, Subscription subscription) + { + var items = new List(); + + // Return empty list if no line items + if (invoice.Lines == null) + { + return items; + } + + foreach (var line in invoice.Lines.Data ?? new List()) + { + // Skip null lines or lines without description + if (line?.Description == null) + { + continue; + } + + var description = line.Description; + + // Handle Provider Portal and Business Unit Portal service lines + if (description.Contains("Provider Portal") || description.Contains("Business Unit")) + { + var priceMatch = Regex.Match(description, @"\(at \$[\d,]+\.?\d* / month\)"); + var priceInfo = priceMatch.Success ? priceMatch.Value : ""; + + var standardizedDescription = $"{line.Quantity} × Manage service provider {priceInfo}"; + items.Add(standardizedDescription); + } + // Handle tax lines + else if (description.ToLower().Contains("tax")) + { + var priceMatch = Regex.Match(description, @"\(at \$[\d,]+\.?\d* / month\)"); + var priceInfo = priceMatch.Success ? priceMatch.Value : ""; + + // If no price info found in description, calculate from amount + if (string.IsNullOrEmpty(priceInfo) && line.Quantity > 0) + { + var pricePerItem = (line.Amount / 100m) / line.Quantity; + priceInfo = $"(at ${pricePerItem:F2} / month)"; + } + + var taxDescription = $"{line.Quantity} × Tax {priceInfo}"; + items.Add(taxDescription); + } + // Handle other line items as-is + else + { + items.Add(description); + } + } + + // Add fallback tax from invoice-level tax if present and not already included + if (invoice.Tax.HasValue && invoice.Tax.Value > 0) + { + var taxAmount = invoice.Tax.Value / 100m; + items.Add($"1 × Tax (at ${taxAmount:F2} / month)"); + } + + return items; + } +} diff --git a/src/Core/MailTemplates/Handlebars/Layouts/ProviderFull.html.hbs b/src/Core/MailTemplates/Handlebars/Layouts/ProviderFull.html.hbs new file mode 100644 index 0000000000..33e32c2bb0 --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/Layouts/ProviderFull.html.hbs @@ -0,0 +1,211 @@ + + + + + + Bitwarden + + + + + {{! Yahoo center fix }} + + + + +
+ {{! 600px container }} + + + {{! Left column (center fix) }} + + {{! Right column (center fix) }} + +
+ + + + + +
+ Bitwarden +
+ + + + + + +
+ + {{>@partial-block}} + +
+ + + + + + + + + + + +
+
+ + \ No newline at end of file diff --git a/src/Core/MailTemplates/Handlebars/ProviderInvoiceUpcoming.html.hbs b/src/Core/MailTemplates/Handlebars/ProviderInvoiceUpcoming.html.hbs new file mode 100644 index 0000000000..d9061d1ffe --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/ProviderInvoiceUpcoming.html.hbs @@ -0,0 +1,89 @@ +{{#>ProviderFull}} + + + + + {{#unless (eq CollectionMethod "send_invoice")}} + + + + {{/unless}} + {{#if Items}} + {{#unless (eq CollectionMethod "send_invoice")}} + + + + {{/unless}} + {{/if}} + + + + {{#unless (eq CollectionMethod "send_invoice")}} + + + + + {{/unless}} + + + + {{#if (eq CollectionMethod "send_invoice")}} + + + + {{/if}} + {{#unless (eq CollectionMethod "send_invoice")}} + + + + {{/unless}} +
+ {{#if (eq CollectionMethod "send_invoice")}} +
Your subscription will renew soon
+
On {{date DueDate 'MMMM dd, yyyy'}} we'll send you an invoice with a summary of the charges including tax.
+ {{else}} +
Your subscription will renew on {{date DueDate 'MMMM dd, yyyy'}}
+ {{#if HasPaymentMethod}} +
To avoid any interruption in service, please ensure your {{PaymentMethodDescription}} can be charged for the following amount:
+ {{else}} +
To avoid any interruption in service, please add a payment method that can be charged for the following amount:
+ {{/if}} + {{/if}} +
+ {{usd AmountDue}} +
+ Summary Of Charges
+
+ {{#each Items}} +
{{this}}
+ {{/each}} +
+ {{#if (eq CollectionMethod "send_invoice")}} +
To avoid any interruption in service for you or your clients, please pay the invoice by the due date, or contact Bitwarden Customer Support to sign up for auto-pay.
+ {{else}} + + {{/if}} +
+ + + + +
+ Update payment method +
+
+ {{#if (eq CollectionMethod "send_invoice")}} + + + + +
+ Contact Bitwarden Support +
+ {{/if}} +
+ For assistance managing your subscription, please visit the Help Center or contact Bitwarden Customer Support. +
+ For assistance managing your subscription, please visit the Help Center or contact Bitwarden Customer Support. +
+{{/ProviderFull}} \ No newline at end of file diff --git a/src/Core/MailTemplates/Handlebars/ProviderInvoiceUpcoming.text.hbs b/src/Core/MailTemplates/Handlebars/ProviderInvoiceUpcoming.text.hbs new file mode 100644 index 0000000000..c666e287a5 --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/ProviderInvoiceUpcoming.text.hbs @@ -0,0 +1,41 @@ +{{#>BasicTextLayout}} +{{#if (eq CollectionMethod "send_invoice")}} +Your subscription will renew soon + +On {{date DueDate 'MMMM dd, yyyy'}} we'll send you an invoice with a summary of the charges including tax. +{{else}} +Your subscription will renew on {{date DueDate 'MMMM dd, yyyy'}} + + {{#if HasPaymentMethod}} +To avoid any interruption in service, please ensure your {{PaymentMethodDescription}} can be charged for the following amount: + {{else}} +To avoid any interruption in service, please add a payment method that can be charged for the following amount: + {{/if}} + +{{usd AmountDue}} +{{/if}} +{{#if Items}} +{{#unless (eq CollectionMethod "send_invoice")}} + +Summary Of Charges +------------------ +{{#each Items}} +{{this}} +{{/each}} +{{/unless}} +{{/if}} + +{{#if (eq CollectionMethod "send_invoice")}} +To avoid any interruption in service for you or your clients, please pay the invoice by the due date, or contact Bitwarden Customer Support to sign up for auto-pay. + +Contact Bitwarden Support: {{{ContactUrl}}} + +For assistance managing your subscription, please visit the **Help center** (https://bitwarden.com/help/update-billing-info) or **contact Bitwarden Customer Support** (https://bitwarden.com/contact/). +{{else}} + +{{/if}} + +{{#unless (eq CollectionMethod "send_invoice")}} +For assistance managing your subscription, please visit the **Help center** (https://bitwarden.com/help/update-billing-info) or **contact Bitwarden Customer Support** (https://bitwarden.com/contact/). +{{/unless}} +{{/BasicTextLayout}} \ No newline at end of file diff --git a/src/Core/Models/Mail/InvoiceUpcomingViewModel.cs b/src/Core/Models/Mail/InvoiceUpcomingViewModel.cs index 50f8256b3d..b63213b811 100644 --- a/src/Core/Models/Mail/InvoiceUpcomingViewModel.cs +++ b/src/Core/Models/Mail/InvoiceUpcomingViewModel.cs @@ -10,4 +10,9 @@ public class InvoiceUpcomingViewModel : BaseMailModel public List Items { get; set; } public bool MentionInvoices { get; set; } public string UpdateBillingInfoUrl { get; set; } = "https://bitwarden.com/help/update-billing-info/"; + public string CollectionMethod { get; set; } + public bool HasPaymentMethod { get; set; } + public string PaymentMethodDescription { get; set; } + public string HelpUrl { get; set; } = "https://bitwarden.com/help/"; + public string ContactUrl { get; set; } = "https://bitwarden.com/contact/"; } diff --git a/src/Core/Services/IMailService.cs b/src/Core/Services/IMailService.cs index a38328dc9d..6e61c4f8dd 100644 --- a/src/Core/Services/IMailService.cs +++ b/src/Core/Services/IMailService.cs @@ -59,6 +59,14 @@ public interface IMailService DateTime dueDate, List items, bool mentionInvoices); + Task SendProviderInvoiceUpcoming( + IEnumerable emails, + decimal amount, + DateTime dueDate, + List items, + string? collectionMethod, + bool hasPaymentMethod, + string? paymentMethodDescription); Task SendPaymentFailedAsync(string email, decimal amount, bool mentionInvoices); Task SendAddedCreditAsync(string email, decimal amount); Task SendLicenseExpiredAsync(IEnumerable emails, string? organizationName = null); diff --git a/src/Core/Services/Implementations/HandlebarsMailService.cs b/src/Core/Services/Implementations/HandlebarsMailService.cs index 8de0e99bd3..0410bad19e 100644 --- a/src/Core/Services/Implementations/HandlebarsMailService.cs +++ b/src/Core/Services/Implementations/HandlebarsMailService.cs @@ -478,6 +478,33 @@ public class HandlebarsMailService : IMailService await _mailDeliveryService.SendEmailAsync(message); } + public async Task SendProviderInvoiceUpcoming( + IEnumerable emails, + decimal amount, + DateTime dueDate, + List items, + string? collectionMethod = null, + bool hasPaymentMethod = true, + string? paymentMethodDescription = null) + { + var message = CreateDefaultMessage("Your upcoming Bitwarden invoice", emails); + var model = new InvoiceUpcomingViewModel + { + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + AmountDue = amount, + DueDate = dueDate, + Items = items, + MentionInvoices = false, + CollectionMethod = collectionMethod, + HasPaymentMethod = hasPaymentMethod, + PaymentMethodDescription = paymentMethodDescription + }; + await AddMessageContentAsync(message, "ProviderInvoiceUpcoming", model); + message.Category = "ProviderInvoiceUpcoming"; + await _mailDeliveryService.SendEmailAsync(message); + } + public async Task SendPaymentFailedAsync(string email, decimal amount, bool mentionInvoices) { var message = CreateDefaultMessage("Payment Failed", email); @@ -708,6 +735,8 @@ public class HandlebarsMailService : IMailService Handlebars.RegisterTemplate("SecurityTasksHtmlLayout", securityTasksHtmlLayoutSource); var securityTasksTextLayoutSource = await ReadSourceAsync("Layouts.SecurityTasks.text"); Handlebars.RegisterTemplate("SecurityTasksTextLayout", securityTasksTextLayoutSource); + var providerFullHtmlLayoutSource = await ReadSourceAsync("Layouts.ProviderFull.html"); + Handlebars.RegisterTemplate("ProviderFull", providerFullHtmlLayoutSource); Handlebars.RegisterHelper("date", (writer, context, parameters) => { @@ -863,6 +892,19 @@ public class HandlebarsMailService : IMailService writer.WriteSafeString(string.Empty); } }); + + // Equality comparison helper for conditional templates. + Handlebars.RegisterHelper("eq", (context, arguments) => + { + if (arguments.Length != 2) + { + return false; + } + + var value1 = arguments[0]?.ToString(); + var value2 = arguments[1]?.ToString(); + return string.Equals(value1, value2, StringComparison.OrdinalIgnoreCase); + }); } public async Task SendEmergencyAccessInviteEmailAsync(EmergencyAccess emergencyAccess, string name, string token) diff --git a/src/Core/Services/NoopImplementations/NoopMailService.cs b/src/Core/Services/NoopImplementations/NoopMailService.cs index bc73fb5398..7ec05bb1f9 100644 --- a/src/Core/Services/NoopImplementations/NoopMailService.cs +++ b/src/Core/Services/NoopImplementations/NoopMailService.cs @@ -137,6 +137,15 @@ public class NoopMailService : IMailService List items, bool mentionInvoices) => Task.FromResult(0); + public Task SendProviderInvoiceUpcoming( + IEnumerable emails, + decimal amount, + DateTime dueDate, + List items, + string? collectionMethod = null, + bool hasPaymentMethod = true, + string? paymentMethodDescription = null) => Task.FromResult(0); + public Task SendPaymentFailedAsync(string email, decimal amount, bool mentionInvoices) { return Task.FromResult(0); diff --git a/test/Core.Test/Billing/Extensions/InvoiceExtensionsTests.cs b/test/Core.Test/Billing/Extensions/InvoiceExtensionsTests.cs new file mode 100644 index 0000000000..a30e5e896c --- /dev/null +++ b/test/Core.Test/Billing/Extensions/InvoiceExtensionsTests.cs @@ -0,0 +1,394 @@ +using Bit.Core.Billing.Extensions; +using Stripe; +using Xunit; + +namespace Bit.Core.Test.Billing.Extensions; + +public class InvoiceExtensionsTests +{ + private static Invoice CreateInvoiceWithLines(params InvoiceLineItem[] lineItems) + { + return new Invoice + { + Lines = new StripeList + { + Data = lineItems?.ToList() ?? new List() + } + }; + } + + #region FormatForProvider Tests + + [Fact] + public void FormatForProvider_NullLines_ReturnsEmptyList() + { + // Arrange + var invoice = new Invoice + { + Lines = null + }; + var subscription = new Subscription(); + + // Act + var result = invoice.FormatForProvider(subscription); + + // Assert + Assert.NotNull(result); + Assert.Empty(result); + } + + [Fact] + public void FormatForProvider_EmptyLines_ReturnsEmptyList() + { + // Arrange + var invoice = CreateInvoiceWithLines(); + var subscription = new Subscription(); + + // Act + var result = invoice.FormatForProvider(subscription); + + // Assert + Assert.NotNull(result); + Assert.Empty(result); + } + + [Fact] + public void FormatForProvider_NullLineItem_SkipsNullLine() + { + // Arrange + var invoice = CreateInvoiceWithLines(null); + var subscription = new Subscription(); + + // Act + var result = invoice.FormatForProvider(subscription); + + // Assert + Assert.NotNull(result); + Assert.Empty(result); + } + + [Fact] + public void FormatForProvider_LineWithNullDescription_SkipsLine() + { + // Arrange + var invoice = CreateInvoiceWithLines( + new InvoiceLineItem { Description = null, Quantity = 1, Amount = 1000 } + ); + var subscription = new Subscription(); + + // Act + var result = invoice.FormatForProvider(subscription); + + // Assert + Assert.NotNull(result); + Assert.Empty(result); + } + + [Fact] + public void FormatForProvider_ProviderPortalTeams_FormatsCorrectly() + { + // Arrange + var invoice = CreateInvoiceWithLines( + new InvoiceLineItem + { + Description = "Provider Portal - Teams (at $6.00 / month)", + Quantity = 5, + Amount = 3000 + } + ); + var subscription = new Subscription(); + + // Act + var result = invoice.FormatForProvider(subscription); + + // Assert + Assert.Single(result); + Assert.Equal("5 × Manage service provider (at $6.00 / month)", result[0]); + } + + [Fact] + public void FormatForProvider_ProviderPortalEnterprise_FormatsCorrectly() + { + // Arrange + var invoice = CreateInvoiceWithLines( + new InvoiceLineItem + { + Description = "Provider Portal - Enterprise (at $4.00 / month)", + Quantity = 10, + Amount = 4000 + } + ); + var subscription = new Subscription(); + + // Act + var result = invoice.FormatForProvider(subscription); + + // Assert + Assert.Single(result); + Assert.Equal("10 × Manage service provider (at $4.00 / month)", result[0]); + } + + [Fact] + public void FormatForProvider_ProviderPortalWithoutPriceInfo_FormatsWithoutPrice() + { + // Arrange + var invoice = CreateInvoiceWithLines( + new InvoiceLineItem + { + Description = "Provider Portal - Teams", + Quantity = 3, + Amount = 1800 + } + ); + var subscription = new Subscription(); + + // Act + var result = invoice.FormatForProvider(subscription); + + // Assert + Assert.Single(result); + Assert.Equal("3 × Manage service provider ", result[0]); + } + + [Fact] + public void FormatForProvider_BusinessUnitPortalEnterprise_FormatsCorrectly() + { + // Arrange + var invoice = CreateInvoiceWithLines( + new InvoiceLineItem + { + Description = "Business Unit Portal - Enterprise (at $5.00 / month)", + Quantity = 8, + Amount = 4000 + } + ); + var subscription = new Subscription(); + + // Act + var result = invoice.FormatForProvider(subscription); + + // Assert + Assert.Single(result); + Assert.Equal("8 × Manage service provider (at $5.00 / month)", result[0]); + } + + [Fact] + public void FormatForProvider_BusinessUnitPortalGeneric_FormatsCorrectly() + { + // Arrange + var invoice = CreateInvoiceWithLines( + new InvoiceLineItem + { + Description = "Business Unit Portal (at $3.00 / month)", + Quantity = 2, + Amount = 600 + } + ); + var subscription = new Subscription(); + + // Act + var result = invoice.FormatForProvider(subscription); + + // Assert + Assert.Single(result); + Assert.Equal("2 × Manage service provider (at $3.00 / month)", result[0]); + } + + [Fact] + public void FormatForProvider_TaxLineWithPriceInfo_FormatsCorrectly() + { + // Arrange + var invoice = CreateInvoiceWithLines( + new InvoiceLineItem + { + Description = "Tax (at $2.00 / month)", + Quantity = 1, + Amount = 200 + } + ); + var subscription = new Subscription(); + + // Act + var result = invoice.FormatForProvider(subscription); + + // Assert + Assert.Single(result); + Assert.Equal("1 × Tax (at $2.00 / month)", result[0]); + } + + [Fact] + public void FormatForProvider_TaxLineWithoutPriceInfo_CalculatesPrice() + { + // Arrange + var invoice = CreateInvoiceWithLines( + new InvoiceLineItem + { + Description = "Tax", + Quantity = 2, + Amount = 400 // $4.00 total, $2.00 per item + } + ); + var subscription = new Subscription(); + + // Act + var result = invoice.FormatForProvider(subscription); + + // Assert + Assert.Single(result); + Assert.Equal("2 × Tax (at $2.00 / month)", result[0]); + } + + [Fact] + public void FormatForProvider_TaxLineWithZeroQuantity_DoesNotCalculatePrice() + { + // Arrange + var invoice = CreateInvoiceWithLines( + new InvoiceLineItem + { + Description = "Tax", + Quantity = 0, + Amount = 200 + } + ); + var subscription = new Subscription(); + + // Act + var result = invoice.FormatForProvider(subscription); + + // Assert + Assert.Single(result); + Assert.Equal("0 × Tax ", result[0]); + } + + [Fact] + public void FormatForProvider_OtherLineItem_ReturnsAsIs() + { + // Arrange + var invoice = CreateInvoiceWithLines( + new InvoiceLineItem + { + Description = "Some other service", + Quantity = 1, + Amount = 1000 + } + ); + var subscription = new Subscription(); + + // Act + var result = invoice.FormatForProvider(subscription); + + // Assert + Assert.Single(result); + Assert.Equal("Some other service", result[0]); + } + + [Fact] + public void FormatForProvider_InvoiceLevelTax_AddsToResult() + { + // Arrange + var invoice = CreateInvoiceWithLines( + new InvoiceLineItem + { + Description = "Provider Portal - Teams", + Quantity = 1, + Amount = 600 + } + ); + invoice.Tax = 120; // $1.20 in cents + var subscription = new Subscription(); + + // Act + var result = invoice.FormatForProvider(subscription); + + // Assert + Assert.Equal(2, result.Count); + Assert.Equal("1 × Manage service provider ", result[0]); + Assert.Equal("1 × Tax (at $1.20 / month)", result[1]); + } + + [Fact] + public void FormatForProvider_NoInvoiceLevelTax_DoesNotAddTax() + { + // Arrange + var invoice = CreateInvoiceWithLines( + new InvoiceLineItem + { + Description = "Provider Portal - Teams", + Quantity = 1, + Amount = 600 + } + ); + invoice.Tax = null; + var subscription = new Subscription(); + + // Act + var result = invoice.FormatForProvider(subscription); + + // Assert + Assert.Single(result); + Assert.Equal("1 × Manage service provider ", result[0]); + } + + [Fact] + public void FormatForProvider_ZeroInvoiceLevelTax_DoesNotAddTax() + { + // Arrange + var invoice = CreateInvoiceWithLines( + new InvoiceLineItem + { + Description = "Provider Portal - Teams", + Quantity = 1, + Amount = 600 + } + ); + invoice.Tax = 0; + var subscription = new Subscription(); + + // Act + var result = invoice.FormatForProvider(subscription); + + // Assert + Assert.Single(result); + Assert.Equal("1 × Manage service provider ", result[0]); + } + + [Fact] + public void FormatForProvider_ComplexScenario_HandlesAllLineTypes() + { + // Arrange + var lineItems = new StripeList(); + lineItems.Data = new List + { + new InvoiceLineItem + { + Description = "Provider Portal - Teams (at $6.00 / month)", Quantity = 5, Amount = 3000 + }, + new InvoiceLineItem + { + Description = "Provider Portal - Enterprise (at $4.00 / month)", Quantity = 10, Amount = 4000 + }, + new InvoiceLineItem { Description = "Tax", Quantity = 1, Amount = 800 }, + new InvoiceLineItem { Description = "Custom Service", Quantity = 2, Amount = 2000 } + }; + + var invoice = new Invoice + { + Lines = lineItems, + Tax = 200 // Additional $2.00 tax + }; + var subscription = new Subscription(); + + // Act + var result = invoice.FormatForProvider(subscription); + + // Assert + Assert.Equal(5, result.Count); + Assert.Equal("5 × Manage service provider (at $6.00 / month)", result[0]); + Assert.Equal("10 × Manage service provider (at $4.00 / month)", result[1]); + Assert.Equal("1 × Tax (at $8.00 / month)", result[2]); + Assert.Equal("Custom Service", result[3]); + Assert.Equal("1 × Tax (at $2.00 / month)", result[4]); + } + + #endregion +} From ef8c7f656d87a7b0e3307489f92c571719d01012 Mon Sep 17 00:00:00 2001 From: Kyle Denney <4227399+kdenney@users.noreply.github.com> Date: Wed, 3 Sep 2025 10:03:49 -0500 Subject: [PATCH 12/13] [PM-24350] fix tax calculation (#6251) --- .../Services/ProviderBillingService.cs | 5 +- .../OrganizationCreateRequestModel.cs | 3 +- .../Request/Accounts/PremiumRequestModel.cs | 3 +- .../Accounts/TaxInfoUpdateRequestModel.cs | 3 +- .../Implementations/StripeEventService.cs | 2 +- .../Implementations/UpcomingInvoiceHandler.cs | 4 +- .../Billing/Extensions/BillingExtensions.cs | 13 + .../Extensions/ServiceCollectionExtensions.cs | 3 - .../Services/OrganizationBillingService.cs | 6 +- .../Commands/UpdateBillingAddressCommand.cs | 2 +- .../Implementations/SubscriberService.cs | 10 +- .../Tax/Commands/PreviewTaxAmountCommand.cs | 14 +- .../Tax/Services/IAutomaticTaxFactory.cs | 11 - .../Tax/Services/IAutomaticTaxStrategy.cs | 33 -- .../Implementations/AutomaticTaxFactory.cs | 50 -- .../BusinessUseAutomaticTaxStrategy.cs | 96 ---- .../PersonalUseAutomaticTaxStrategy.cs | 64 --- src/Core/Constants.cs | 13 + src/Core/Models/Business/TaxInfo.cs | 2 +- .../Implementations/StripePaymentService.cs | 24 +- .../Commands/PreviewTaxAmountCommandTests.cs | 267 +++++++++- .../Tax/Services/AutomaticTaxFactoryTests.cs | 105 ---- .../BusinessUseAutomaticTaxStrategyTests.cs | 492 ------------------ .../Tax/Services/FakeAutomaticTaxStrategy.cs | 35 -- .../PersonalUseAutomaticTaxStrategyTests.cs | 217 -------- .../Services/StripePaymentServiceTests.cs | 358 ++++++++++++- 26 files changed, 663 insertions(+), 1172 deletions(-) delete mode 100644 src/Core/Billing/Tax/Services/IAutomaticTaxFactory.cs delete mode 100644 src/Core/Billing/Tax/Services/IAutomaticTaxStrategy.cs delete mode 100644 src/Core/Billing/Tax/Services/Implementations/AutomaticTaxFactory.cs delete mode 100644 src/Core/Billing/Tax/Services/Implementations/BusinessUseAutomaticTaxStrategy.cs delete mode 100644 src/Core/Billing/Tax/Services/Implementations/PersonalUseAutomaticTaxStrategy.cs delete mode 100644 test/Core.Test/Billing/Tax/Services/AutomaticTaxFactoryTests.cs delete mode 100644 test/Core.Test/Billing/Tax/Services/BusinessUseAutomaticTaxStrategyTests.cs delete mode 100644 test/Core.Test/Billing/Tax/Services/FakeAutomaticTaxStrategy.cs delete mode 100644 test/Core.Test/Billing/Tax/Services/PersonalUseAutomaticTaxStrategyTests.cs diff --git a/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs b/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs index 8c0b2c8275..5169d6cfd1 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs @@ -3,6 +3,7 @@ using System.Globalization; using Bit.Commercial.Core.Billing.Providers.Models; +using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; @@ -282,7 +283,7 @@ public class ProviderBillingService( ] }; - if (providerCustomer.Address is not { Country: "US" }) + if (providerCustomer.Address is not { Country: Constants.CountryAbbreviations.UnitedStates }) { customerCreateOptions.TaxExempt = StripeConstants.TaxExempt.Reverse; } @@ -525,7 +526,7 @@ public class ProviderBillingService( } }; - if (taxInfo.BillingAddressCountry is not "US") + if (taxInfo.BillingAddressCountry is not Constants.CountryAbbreviations.UnitedStates) { options.TaxExempt = StripeConstants.TaxExempt.Reverse; } diff --git a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationCreateRequestModel.cs b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationCreateRequestModel.cs index 10f938adfe..7754c44c8c 100644 --- a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationCreateRequestModel.cs +++ b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationCreateRequestModel.cs @@ -3,6 +3,7 @@ using System.ComponentModel.DataAnnotations; using System.Text.Json.Serialization; +using Bit.Core; using Bit.Core.Billing.Enums; using Bit.Core.Entities; using Bit.Core.Enums; @@ -139,7 +140,7 @@ public class OrganizationCreateRequestModel : IValidatableObject new string[] { nameof(BillingAddressCountry) }); } - if (PlanType != PlanType.Free && BillingAddressCountry == "US" && + if (PlanType != PlanType.Free && BillingAddressCountry == Constants.CountryAbbreviations.UnitedStates && string.IsNullOrWhiteSpace(BillingAddressPostalCode)) { yield return new ValidationResult("Zip / postal code is required.", diff --git a/src/Api/Models/Request/Accounts/PremiumRequestModel.cs b/src/Api/Models/Request/Accounts/PremiumRequestModel.cs index 4e9882d67c..8e9aac8cc2 100644 --- a/src/Api/Models/Request/Accounts/PremiumRequestModel.cs +++ b/src/Api/Models/Request/Accounts/PremiumRequestModel.cs @@ -2,6 +2,7 @@ #nullable disable using System.ComponentModel.DataAnnotations; +using Bit.Core; using Bit.Core.Settings; using Enums = Bit.Core.Enums; @@ -35,7 +36,7 @@ public class PremiumRequestModel : IValidatableObject { yield return new ValidationResult("Payment token or license is required."); } - if (Country == "US" && string.IsNullOrWhiteSpace(PostalCode)) + if (Country == Constants.CountryAbbreviations.UnitedStates && string.IsNullOrWhiteSpace(PostalCode)) { yield return new ValidationResult("Zip / postal code is required.", new string[] { nameof(PostalCode) }); diff --git a/src/Api/Models/Request/Accounts/TaxInfoUpdateRequestModel.cs b/src/Api/Models/Request/Accounts/TaxInfoUpdateRequestModel.cs index 5f58453a6d..d3e3f5ec55 100644 --- a/src/Api/Models/Request/Accounts/TaxInfoUpdateRequestModel.cs +++ b/src/Api/Models/Request/Accounts/TaxInfoUpdateRequestModel.cs @@ -2,6 +2,7 @@ #nullable disable using System.ComponentModel.DataAnnotations; +using Bit.Core; namespace Bit.Api.Models.Request.Accounts; @@ -13,7 +14,7 @@ public class TaxInfoUpdateRequestModel : IValidatableObject public virtual IEnumerable Validate(ValidationContext validationContext) { - if (Country == "US" && string.IsNullOrWhiteSpace(PostalCode)) + if (Country == Constants.CountryAbbreviations.UnitedStates && string.IsNullOrWhiteSpace(PostalCode)) { yield return new ValidationResult("Zip / postal code is required.", new string[] { nameof(PostalCode) }); diff --git a/src/Billing/Services/Implementations/StripeEventService.cs b/src/Billing/Services/Implementations/StripeEventService.cs index 7e2984e423..7eef357e14 100644 --- a/src/Billing/Services/Implementations/StripeEventService.cs +++ b/src/Billing/Services/Implementations/StripeEventService.cs @@ -218,7 +218,7 @@ public class StripeEventService : IStripeEventService private static string GetCustomerRegion(IDictionary customerMetadata) { - const string defaultRegion = "US"; + const string defaultRegion = Core.Constants.CountryAbbreviations.UnitedStates; if (customerMetadata is null) { diff --git a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs index 9f6fda7d3f..e5675f7c0a 100644 --- a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs +++ b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs @@ -203,7 +203,7 @@ public class UpcomingInvoiceHandler( { var nonUSBusinessUse = organization.PlanType.GetProductTier() != ProductTierType.Families && - subscription.Customer.Address.Country != "US"; + subscription.Customer.Address.Country != Core.Constants.CountryAbbreviations.UnitedStates; if (nonUSBusinessUse && subscription.Customer.TaxExempt != StripeConstants.TaxExempt.Reverse) { @@ -248,7 +248,7 @@ public class UpcomingInvoiceHandler( Subscription subscription, string eventId) { - if (subscription.Customer.Address.Country != "US" && + if (subscription.Customer.Address.Country != Core.Constants.CountryAbbreviations.UnitedStates && subscription.Customer.TaxExempt != StripeConstants.TaxExempt.Reverse) { try diff --git a/src/Core/Billing/Extensions/BillingExtensions.cs b/src/Core/Billing/Extensions/BillingExtensions.cs index 55db9dde18..7f81bfd33f 100644 --- a/src/Core/Billing/Extensions/BillingExtensions.cs +++ b/src/Core/Billing/Extensions/BillingExtensions.cs @@ -22,6 +22,19 @@ public static class BillingExtensions _ => throw new BillingException($"PlanType {planType} could not be matched to a ProductTierType") }; + public static bool IsBusinessProductTierType(this PlanType planType) + => IsBusinessProductTierType(planType.GetProductTier()); + + public static bool IsBusinessProductTierType(this ProductTierType productTierType) + => productTierType switch + { + ProductTierType.Free => false, + ProductTierType.Families => false, + ProductTierType.Enterprise => true, + ProductTierType.Teams => true, + ProductTierType.TeamsStarter => true + }; + public static bool IsBillable(this Provider provider) => provider is { diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs index 39ee3ec1ec..147e96105a 100644 --- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs +++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs @@ -25,9 +25,6 @@ public static class ServiceCollectionExtensions services.AddTransient(); services.AddTransient(); services.AddTransient(); - services.AddKeyedTransient(AutomaticTaxFactory.PersonalUse); - services.AddKeyedTransient(AutomaticTaxFactory.BusinessUse); - services.AddTransient(); services.AddLicenseServices(); services.AddPricingClient(); services.AddTransient(); diff --git a/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs b/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs index 0e42803aaf..446f9563f9 100644 --- a/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs +++ b/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs @@ -275,7 +275,7 @@ public class OrganizationBillingService( if (planType.GetProductTier() is not ProductTierType.Free and not ProductTierType.Families && - customerSetup.TaxInformation.Country != "US") + customerSetup.TaxInformation.Country != Core.Constants.CountryAbbreviations.UnitedStates) { customerCreateOptions.TaxExempt = StripeConstants.TaxExempt.Reverse; } @@ -514,14 +514,14 @@ public class OrganizationBillingService( customer = customer switch { - { Address.Country: not "US", TaxExempt: not StripeConstants.TaxExempt.Reverse } => await + { Address.Country: not Core.Constants.CountryAbbreviations.UnitedStates, TaxExempt: not StripeConstants.TaxExempt.Reverse } => await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions { Expand = expansions, TaxExempt = StripeConstants.TaxExempt.Reverse }), - { Address.Country: "US", TaxExempt: StripeConstants.TaxExempt.Reverse } => await + { Address.Country: Core.Constants.CountryAbbreviations.UnitedStates, TaxExempt: StripeConstants.TaxExempt.Reverse } => await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions { diff --git a/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs b/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs index fdf519523a..f4eca40cae 100644 --- a/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs +++ b/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs @@ -84,7 +84,7 @@ public class UpdateBillingAddressCommand( State = billingAddress.State }, Expand = ["subscriptions", "tax_ids"], - TaxExempt = billingAddress.Country != "US" + TaxExempt = billingAddress.Country != Core.Constants.CountryAbbreviations.UnitedStates ? StripeConstants.TaxExempt.Reverse : StripeConstants.TaxExempt.None }); diff --git a/src/Core/Billing/Services/Implementations/SubscriberService.cs b/src/Core/Billing/Services/Implementations/SubscriberService.cs index 63a9352020..84d41f829c 100644 --- a/src/Core/Billing/Services/Implementations/SubscriberService.cs +++ b/src/Core/Billing/Services/Implementations/SubscriberService.cs @@ -801,15 +801,13 @@ public class SubscriberService( _ => false }; - - if (isBusinessUseSubscriber) { switch (customer) { case { - Address.Country: not "US", + Address.Country: not Core.Constants.CountryAbbreviations.UnitedStates, TaxExempt: not TaxExempt.Reverse }: await stripeAdapter.CustomerUpdateAsync(customer.Id, @@ -817,7 +815,7 @@ public class SubscriberService( break; case { - Address.Country: "US", + Address.Country: Core.Constants.CountryAbbreviations.UnitedStates, TaxExempt: TaxExempt.Reverse }: await stripeAdapter.CustomerUpdateAsync(customer.Id, @@ -840,8 +838,8 @@ public class SubscriberService( { User => true, Organization organization => organization.PlanType.GetProductTier() == ProductTierType.Families || - customer.Address.Country == "US" || (customer.TaxIds?.Any() ?? false), - Provider => customer.Address.Country == "US" || (customer.TaxIds?.Any() ?? false), + customer.Address.Country == Core.Constants.CountryAbbreviations.UnitedStates || (customer.TaxIds?.Any() ?? false), + Provider => customer.Address.Country == Core.Constants.CountryAbbreviations.UnitedStates || (customer.TaxIds?.Any() ?? false), _ => false }; diff --git a/src/Core/Billing/Tax/Commands/PreviewTaxAmountCommand.cs b/src/Core/Billing/Tax/Commands/PreviewTaxAmountCommand.cs index 6e061293c7..94d3724d73 100644 --- a/src/Core/Billing/Tax/Commands/PreviewTaxAmountCommand.cs +++ b/src/Core/Billing/Tax/Commands/PreviewTaxAmountCommand.cs @@ -95,17 +95,11 @@ public class PreviewTaxAmountCommand( } } - if (planType.GetProductTier() == ProductTierType.Families) + options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true }; + if (parameters.PlanType.IsBusinessProductTierType() && + parameters.TaxInformation.Country != Core.Constants.CountryAbbreviations.UnitedStates) { - options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true }; - } - else - { - options.AutomaticTax = new InvoiceAutomaticTaxOptions - { - Enabled = options.CustomerDetails.Address.Country == "US" || - options.CustomerDetails.TaxIds is [_, ..] - }; + options.CustomerDetails.TaxExempt = StripeConstants.TaxExempt.Reverse; } var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); diff --git a/src/Core/Billing/Tax/Services/IAutomaticTaxFactory.cs b/src/Core/Billing/Tax/Services/IAutomaticTaxFactory.cs deleted file mode 100644 index c0a31efb3c..0000000000 --- a/src/Core/Billing/Tax/Services/IAutomaticTaxFactory.cs +++ /dev/null @@ -1,11 +0,0 @@ -using Bit.Core.Billing.Tax.Models; - -namespace Bit.Core.Billing.Tax.Services; - -/// -/// Responsible for defining the correct automatic tax strategy for either personal use of business use. -/// -public interface IAutomaticTaxFactory -{ - Task CreateAsync(AutomaticTaxFactoryParameters parameters); -} diff --git a/src/Core/Billing/Tax/Services/IAutomaticTaxStrategy.cs b/src/Core/Billing/Tax/Services/IAutomaticTaxStrategy.cs deleted file mode 100644 index 557bb1d30c..0000000000 --- a/src/Core/Billing/Tax/Services/IAutomaticTaxStrategy.cs +++ /dev/null @@ -1,33 +0,0 @@ -#nullable enable -using Stripe; - -namespace Bit.Core.Billing.Tax.Services; - -public interface IAutomaticTaxStrategy -{ - /// - /// - /// - /// - /// - /// Returns if changes are to be applied to the subscription, returns null - /// otherwise. - /// - SubscriptionUpdateOptions? GetUpdateOptions(Subscription subscription); - - /// - /// Modifies an existing object with the automatic tax flag set correctly. - /// - /// - /// - void SetCreateOptions(SubscriptionCreateOptions options, Customer customer); - - /// - /// Modifies an existing object with the automatic tax flag set correctly. - /// - /// - /// - void SetUpdateOptions(SubscriptionUpdateOptions options, Subscription subscription); - - void SetInvoiceCreatePreviewOptions(InvoiceCreatePreviewOptions options); -} diff --git a/src/Core/Billing/Tax/Services/Implementations/AutomaticTaxFactory.cs b/src/Core/Billing/Tax/Services/Implementations/AutomaticTaxFactory.cs deleted file mode 100644 index 6086a16b79..0000000000 --- a/src/Core/Billing/Tax/Services/Implementations/AutomaticTaxFactory.cs +++ /dev/null @@ -1,50 +0,0 @@ -#nullable enable -using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Pricing; -using Bit.Core.Billing.Tax.Models; -using Bit.Core.Entities; -using Bit.Core.Services; - -namespace Bit.Core.Billing.Tax.Services.Implementations; - -public class AutomaticTaxFactory( - IFeatureService featureService, - IPricingClient pricingClient) : IAutomaticTaxFactory -{ - public const string BusinessUse = "business-use"; - public const string PersonalUse = "personal-use"; - - private readonly Lazy>> _personalUsePlansTask = new(async () => - { - var plans = await Task.WhenAll( - pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019), - pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually)); - - return plans.Select(plan => plan.PasswordManager.StripePlanId); - }); - - public async Task CreateAsync(AutomaticTaxFactoryParameters parameters) - { - if (parameters.Subscriber is User) - { - return new PersonalUseAutomaticTaxStrategy(featureService); - } - - if (parameters.PlanType.HasValue) - { - var plan = await pricingClient.GetPlanOrThrow(parameters.PlanType.Value); - return plan.CanBeUsedByBusiness - ? new BusinessUseAutomaticTaxStrategy(featureService) - : new PersonalUseAutomaticTaxStrategy(featureService); - } - - var personalUsePlans = await _personalUsePlansTask.Value; - - if (parameters.Prices != null && parameters.Prices.Any(x => personalUsePlans.Any(y => y == x))) - { - return new PersonalUseAutomaticTaxStrategy(featureService); - } - - return new BusinessUseAutomaticTaxStrategy(featureService); - } -} diff --git a/src/Core/Billing/Tax/Services/Implementations/BusinessUseAutomaticTaxStrategy.cs b/src/Core/Billing/Tax/Services/Implementations/BusinessUseAutomaticTaxStrategy.cs deleted file mode 100644 index 6affc57354..0000000000 --- a/src/Core/Billing/Tax/Services/Implementations/BusinessUseAutomaticTaxStrategy.cs +++ /dev/null @@ -1,96 +0,0 @@ -#nullable enable -using Bit.Core.Billing.Extensions; -using Bit.Core.Services; -using Stripe; - -namespace Bit.Core.Billing.Tax.Services.Implementations; - -public class BusinessUseAutomaticTaxStrategy(IFeatureService featureService) : IAutomaticTaxStrategy -{ - public SubscriptionUpdateOptions? GetUpdateOptions(Subscription subscription) - { - if (!featureService.IsEnabled(FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - { - return null; - } - - var shouldBeEnabled = ShouldBeEnabled(subscription.Customer); - if (subscription.AutomaticTax.Enabled == shouldBeEnabled) - { - return null; - } - - var options = new SubscriptionUpdateOptions - { - AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = shouldBeEnabled - }, - DefaultTaxRates = [] - }; - - return options; - } - - public void SetCreateOptions(SubscriptionCreateOptions options, Customer customer) - { - options.AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = ShouldBeEnabled(customer) - }; - } - - public void SetUpdateOptions(SubscriptionUpdateOptions options, Subscription subscription) - { - if (!featureService.IsEnabled(FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - { - return; - } - - var shouldBeEnabled = ShouldBeEnabled(subscription.Customer); - - if (subscription.AutomaticTax.Enabled == shouldBeEnabled) - { - return; - } - - options.AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = shouldBeEnabled - }; - options.DefaultTaxRates = []; - } - - public void SetInvoiceCreatePreviewOptions(InvoiceCreatePreviewOptions options) - { - options.AutomaticTax ??= new InvoiceAutomaticTaxOptions(); - - if (options.CustomerDetails.Address.Country == "US") - { - options.AutomaticTax.Enabled = true; - return; - } - - options.AutomaticTax.Enabled = options.CustomerDetails.TaxIds != null && options.CustomerDetails.TaxIds.Any(); - } - - private bool ShouldBeEnabled(Customer customer) - { - if (!customer.HasRecognizedTaxLocation()) - { - return false; - } - - if (customer.Address.Country == "US") - { - return true; - } - - if (customer.TaxIds == null) - { - throw new ArgumentNullException(nameof(customer.TaxIds), "`customer.tax_ids` must be expanded."); - } - - return customer.TaxIds.Any(); - } -} diff --git a/src/Core/Billing/Tax/Services/Implementations/PersonalUseAutomaticTaxStrategy.cs b/src/Core/Billing/Tax/Services/Implementations/PersonalUseAutomaticTaxStrategy.cs deleted file mode 100644 index 615222259e..0000000000 --- a/src/Core/Billing/Tax/Services/Implementations/PersonalUseAutomaticTaxStrategy.cs +++ /dev/null @@ -1,64 +0,0 @@ -#nullable enable -using Bit.Core.Billing.Extensions; -using Bit.Core.Services; -using Stripe; - -namespace Bit.Core.Billing.Tax.Services.Implementations; - -public class PersonalUseAutomaticTaxStrategy(IFeatureService featureService) : IAutomaticTaxStrategy -{ - public void SetCreateOptions(SubscriptionCreateOptions options, Customer customer) - { - options.AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = ShouldBeEnabled(customer) - }; - } - - public void SetUpdateOptions(SubscriptionUpdateOptions options, Subscription subscription) - { - if (!featureService.IsEnabled(FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - { - return; - } - options.AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = ShouldBeEnabled(subscription.Customer) - }; - options.DefaultTaxRates = []; - } - - public SubscriptionUpdateOptions? GetUpdateOptions(Subscription subscription) - { - if (!featureService.IsEnabled(FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - { - return null; - } - - if (subscription.AutomaticTax.Enabled == ShouldBeEnabled(subscription.Customer)) - { - return null; - } - - var options = new SubscriptionUpdateOptions - { - AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = ShouldBeEnabled(subscription.Customer), - }, - DefaultTaxRates = [] - }; - - return options; - } - - public void SetInvoiceCreatePreviewOptions(InvoiceCreatePreviewOptions options) - { - options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true }; - } - - private static bool ShouldBeEnabled(Customer customer) - { - return customer.HasRecognizedTaxLocation(); - } -} diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 2993f6a094..9ddbf5c600 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -52,6 +52,19 @@ public static class Constants /// regardless of whether there is a proration or not. ///
public const string AlwaysInvoice = "always_invoice"; + + /// + /// Used primarily to determine whether a customer's business is inside or outside the United States + /// for billing purposes. + /// + public static class CountryAbbreviations + { + /// + /// Abbreviation for The United States. + /// This value must match what Stripe uses for the `Country` field value for the United States. + /// + public const string UnitedStates = "US"; + } } public static class AuthConstants diff --git a/src/Core/Models/Business/TaxInfo.cs b/src/Core/Models/Business/TaxInfo.cs index 4daa9a268a..4f95bb393d 100644 --- a/src/Core/Models/Business/TaxInfo.cs +++ b/src/Core/Models/Business/TaxInfo.cs @@ -13,5 +13,5 @@ public class TaxInfo public string BillingAddressCity { get; set; } public string BillingAddressState { get; set; } public string BillingAddressPostalCode { get; set; } - public string BillingAddressCountry { get; set; } = "US"; + public string BillingAddressCountry { get; set; } = Constants.CountryAbbreviations.UnitedStates; } diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Services/Implementations/StripePaymentService.cs index 440fb5c546..ec45944bd2 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Services/Implementations/StripePaymentService.cs @@ -9,11 +9,9 @@ using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Pricing; -using Bit.Core.Billing.Tax.Models; using Bit.Core.Billing.Tax.Requests; using Bit.Core.Billing.Tax.Responses; using Bit.Core.Billing.Tax.Services; -using Bit.Core.Billing.Tax.Services.Implementations; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -21,7 +19,6 @@ using Bit.Core.Models.BitStripe; using Bit.Core.Models.Business; using Bit.Core.Repositories; using Bit.Core.Settings; -using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Stripe; using PaymentMethod = Stripe.PaymentMethod; @@ -41,8 +38,6 @@ public class StripePaymentService : IPaymentService private readonly IFeatureService _featureService; private readonly ITaxService _taxService; private readonly IPricingClient _pricingClient; - private readonly IAutomaticTaxFactory _automaticTaxFactory; - private readonly IAutomaticTaxStrategy _personalUseTaxStrategy; public StripePaymentService( ITransactionRepository transactionRepository, @@ -52,9 +47,7 @@ public class StripePaymentService : IPaymentService IGlobalSettings globalSettings, IFeatureService featureService, ITaxService taxService, - IPricingClient pricingClient, - IAutomaticTaxFactory automaticTaxFactory, - [FromKeyedServices(AutomaticTaxFactory.PersonalUse)] IAutomaticTaxStrategy personalUseTaxStrategy) + IPricingClient pricingClient) { _transactionRepository = transactionRepository; _logger = logger; @@ -64,8 +57,6 @@ public class StripePaymentService : IPaymentService _featureService = featureService; _taxService = taxService; _pricingClient = pricingClient; - _automaticTaxFactory = automaticTaxFactory; - _personalUseTaxStrategy = personalUseTaxStrategy; } private async Task ChangeOrganizationSponsorship( @@ -137,7 +128,7 @@ public class StripePaymentService : IPaymentService { if (sub.Customer is { - Address.Country: not "US", + Address.Country: not Constants.CountryAbbreviations.UnitedStates, TaxExempt: not StripeConstants.TaxExempt.Reverse }) { @@ -987,8 +978,6 @@ public class StripePaymentService : IPaymentService } } - _personalUseTaxStrategy.SetInvoiceCreatePreviewOptions(options); - try { var invoice = await _stripeAdapter.InvoiceCreatePreviewAsync(options); @@ -1152,9 +1141,12 @@ public class StripePaymentService : IPaymentService } } - var automaticTaxFactoryParameters = new AutomaticTaxFactoryParameters(parameters.PasswordManager.Plan); - var automaticTaxStrategy = await _automaticTaxFactory.CreateAsync(automaticTaxFactoryParameters); - automaticTaxStrategy.SetInvoiceCreatePreviewOptions(options); + options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true }; + if (parameters.PasswordManager.Plan.IsBusinessProductTierType() && + parameters.TaxInformation.Country != Constants.CountryAbbreviations.UnitedStates) + { + options.CustomerDetails.TaxExempt = StripeConstants.TaxExempt.Reverse; + } try { diff --git a/test/Core.Test/Billing/Tax/Commands/PreviewTaxAmountCommandTests.cs b/test/Core.Test/Billing/Tax/Commands/PreviewTaxAmountCommandTests.cs index ee5625d522..1de180cea1 100644 --- a/test/Core.Test/Billing/Tax/Commands/PreviewTaxAmountCommandTests.cs +++ b/test/Core.Test/Billing/Tax/Commands/PreviewTaxAmountCommandTests.cs @@ -181,7 +181,7 @@ public class PreviewTaxAmountCommandTests options.SubscriptionDetails.Items.Count == 1 && options.SubscriptionDetails.Items[0].Price == plan.PasswordManager.StripeSeatPlanId && options.SubscriptionDetails.Items[0].Quantity == 1 && - options.AutomaticTax.Enabled == false + options.AutomaticTax.Enabled == true )) .Returns(expectedInvoice); @@ -273,4 +273,269 @@ public class PreviewTaxAmountCommandTests var badRequest = result.AsT1; Assert.Equal("We couldn't find a corresponding tax ID type for the tax ID you provided. Please try again or contact support for assistance.", badRequest.Response); } + + [Fact] + public async Task Run_USBased_PersonalUse_SetsAutomaticTaxEnabled() + { + // Arrange + var parameters = new OrganizationTrialParameters + { + PlanType = PlanType.FamiliesAnnually, + ProductType = ProductType.PasswordManager, + TaxInformation = new TaxInformationDTO + { + Country = "US", + PostalCode = "12345" + } + }; + + var plan = StaticStore.GetPlan(parameters.PlanType); + + _pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan); + + var expectedInvoice = new Invoice { Tax = 1000 }; // $10.00 in cents + _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()) + .Returns(expectedInvoice); + + // Act + var result = await _command.Run(parameters); + + // Assert + await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + options.AutomaticTax.Enabled == true + )); + Assert.True(result.IsT0); + } + + [Fact] + public async Task Run_USBased_BusinessUse_SetsAutomaticTaxEnabled() + { + // Arrange + var parameters = new OrganizationTrialParameters + { + PlanType = PlanType.EnterpriseAnnually, + ProductType = ProductType.PasswordManager, + TaxInformation = new TaxInformationDTO + { + Country = "US", + PostalCode = "12345" + } + }; + + var plan = StaticStore.GetPlan(parameters.PlanType); + + _pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan); + + var expectedInvoice = new Invoice { Tax = 1000 }; // $10.00 in cents + _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()) + .Returns(expectedInvoice); + + // Act + var result = await _command.Run(parameters); + + // Assert + await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + options.AutomaticTax.Enabled == true + )); + Assert.True(result.IsT0); + } + + [Fact] + public async Task Run_NonUSBased_PersonalUse_SetsAutomaticTaxEnabled() + { + // Arrange + var parameters = new OrganizationTrialParameters + { + PlanType = PlanType.FamiliesAnnually, + ProductType = ProductType.PasswordManager, + TaxInformation = new TaxInformationDTO + { + Country = "CA", + PostalCode = "12345" + } + }; + + var plan = StaticStore.GetPlan(parameters.PlanType); + + _pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan); + + var expectedInvoice = new Invoice { Tax = 1000 }; // $10.00 in cents + _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()) + .Returns(expectedInvoice); + + // Act + var result = await _command.Run(parameters); + + // Assert + await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + options.AutomaticTax.Enabled == true + )); + Assert.True(result.IsT0); + } + + [Fact] + public async Task Run_NonUSBased_BusinessUse_SetsAutomaticTaxEnabled() + { + // Arrange + var parameters = new OrganizationTrialParameters + { + PlanType = PlanType.EnterpriseAnnually, + ProductType = ProductType.PasswordManager, + TaxInformation = new TaxInformationDTO + { + Country = "CA", + PostalCode = "12345" + } + }; + + var plan = StaticStore.GetPlan(parameters.PlanType); + + _pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan); + + var expectedInvoice = new Invoice { Tax = 1000 }; // $10.00 in cents + _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()) + .Returns(expectedInvoice); + + // Act + var result = await _command.Run(parameters); + + // Assert + await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + options.AutomaticTax.Enabled == true + )); + Assert.True(result.IsT0); + } + + [Fact] + public async Task Run_USBased_PersonalUse_DoesNotSetTaxExempt() + { + // Arrange + var parameters = new OrganizationTrialParameters + { + PlanType = PlanType.FamiliesAnnually, + ProductType = ProductType.PasswordManager, + TaxInformation = new TaxInformationDTO + { + Country = "US", + PostalCode = "12345" + } + }; + + var plan = StaticStore.GetPlan(parameters.PlanType); + + _pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan); + + var expectedInvoice = new Invoice { Tax = 1000 }; // $10.00 in cents + _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()) + .Returns(expectedInvoice); + + // Act + var result = await _command.Run(parameters); + + // Assert + await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + options.CustomerDetails.TaxExempt == null + )); + Assert.True(result.IsT0); + } + + [Fact] + public async Task Run_USBased_BusinessUse_DoesNotSetTaxExempt() + { + // Arrange + var parameters = new OrganizationTrialParameters + { + PlanType = PlanType.EnterpriseAnnually, + ProductType = ProductType.PasswordManager, + TaxInformation = new TaxInformationDTO + { + Country = "US", + PostalCode = "12345" + } + }; + + var plan = StaticStore.GetPlan(parameters.PlanType); + + _pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan); + + var expectedInvoice = new Invoice { Tax = 1000 }; // $10.00 in cents + _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()) + .Returns(expectedInvoice); + + // Act + var result = await _command.Run(parameters); + + // Assert + await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + options.CustomerDetails.TaxExempt == null + )); + Assert.True(result.IsT0); + } + + [Fact] + public async Task Run_NonUSBased_PersonalUse_DoesNotSetTaxExempt() + { + // Arrange + var parameters = new OrganizationTrialParameters + { + PlanType = PlanType.FamiliesAnnually, + ProductType = ProductType.PasswordManager, + TaxInformation = new TaxInformationDTO + { + Country = "CA", + PostalCode = "12345" + } + }; + + var plan = StaticStore.GetPlan(parameters.PlanType); + + _pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan); + + var expectedInvoice = new Invoice { Tax = 1000 }; // $10.00 in cents + _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()) + .Returns(expectedInvoice); + + // Act + var result = await _command.Run(parameters); + + // Assert + await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + options.CustomerDetails.TaxExempt == null + )); + Assert.True(result.IsT0); + + } + + [Fact] + public async Task Run_NonUSBased_BusinessUse_SetsTaxExemptReverse() + { + // Arrange + var parameters = new OrganizationTrialParameters + { + PlanType = PlanType.EnterpriseAnnually, + ProductType = ProductType.PasswordManager, + TaxInformation = new TaxInformationDTO + { + Country = "CA", + PostalCode = "12345" + } + }; + + var plan = StaticStore.GetPlan(parameters.PlanType); + + _pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan); + + var expectedInvoice = new Invoice { Tax = 1000 }; // $10.00 in cents + _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()) + .Returns(expectedInvoice); + + // Act + var result = await _command.Run(parameters); + + // Assert + await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + options.CustomerDetails.TaxExempt == StripeConstants.TaxExempt.Reverse + )); + Assert.True(result.IsT0); + } } diff --git a/test/Core.Test/Billing/Tax/Services/AutomaticTaxFactoryTests.cs b/test/Core.Test/Billing/Tax/Services/AutomaticTaxFactoryTests.cs deleted file mode 100644 index d9d2679bca..0000000000 --- a/test/Core.Test/Billing/Tax/Services/AutomaticTaxFactoryTests.cs +++ /dev/null @@ -1,105 +0,0 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models.StaticStore.Plans; -using Bit.Core.Billing.Pricing; -using Bit.Core.Billing.Tax.Models; -using Bit.Core.Billing.Tax.Services.Implementations; -using Bit.Core.Entities; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using NSubstitute; -using Xunit; - -namespace Bit.Core.Test.Billing.Tax.Services; - -[SutProviderCustomize] -public class AutomaticTaxFactoryTests -{ - [BitAutoData] - [Theory] - public async Task CreateAsync_ReturnsPersonalUseStrategy_WhenSubscriberIsUser(SutProvider sut) - { - var parameters = new AutomaticTaxFactoryParameters(new User(), []); - - var actual = await sut.Sut.CreateAsync(parameters); - - Assert.IsType(actual); - } - - [BitAutoData] - [Theory] - public async Task CreateAsync_ReturnsPersonalUseStrategy_WhenSubscriberIsOrganizationWithFamiliesAnnuallyPrice( - SutProvider sut) - { - var familiesPlan = new FamiliesPlan(); - var parameters = new AutomaticTaxFactoryParameters(new Organization(), [familiesPlan.PasswordManager.StripePlanId]); - - sut.GetDependency() - .GetPlanOrThrow(Arg.Is(p => p == PlanType.FamiliesAnnually)) - .Returns(new FamiliesPlan()); - - sut.GetDependency() - .GetPlanOrThrow(Arg.Is(p => p == PlanType.FamiliesAnnually2019)) - .Returns(new Families2019Plan()); - - var actual = await sut.Sut.CreateAsync(parameters); - - Assert.IsType(actual); - } - - [Theory] - [BitAutoData] - public async Task CreateAsync_ReturnsBusinessUseStrategy_WhenSubscriberIsOrganizationWithBusinessUsePrice( - EnterpriseAnnually plan, - SutProvider sut) - { - var parameters = new AutomaticTaxFactoryParameters(new Organization(), [plan.PasswordManager.StripePlanId]); - - sut.GetDependency() - .GetPlanOrThrow(Arg.Is(p => p == PlanType.FamiliesAnnually)) - .Returns(new FamiliesPlan()); - - sut.GetDependency() - .GetPlanOrThrow(Arg.Is(p => p == PlanType.FamiliesAnnually2019)) - .Returns(new Families2019Plan()); - - var actual = await sut.Sut.CreateAsync(parameters); - - Assert.IsType(actual); - } - - [Theory] - [BitAutoData] - public async Task CreateAsync_ReturnsPersonalUseStrategy_WhenPlanIsMeantForPersonalUse(SutProvider sut) - { - var parameters = new AutomaticTaxFactoryParameters(PlanType.FamiliesAnnually); - sut.GetDependency() - .GetPlanOrThrow(Arg.Is(p => p == parameters.PlanType.Value)) - .Returns(new FamiliesPlan()); - - var actual = await sut.Sut.CreateAsync(parameters); - - Assert.IsType(actual); - } - - [Theory] - [BitAutoData] - public async Task CreateAsync_ReturnsBusinessUseStrategy_WhenPlanIsMeantForBusinessUse(SutProvider sut) - { - var parameters = new AutomaticTaxFactoryParameters(PlanType.EnterpriseAnnually); - sut.GetDependency() - .GetPlanOrThrow(Arg.Is(p => p == parameters.PlanType.Value)) - .Returns(new EnterprisePlan(true)); - - var actual = await sut.Sut.CreateAsync(parameters); - - Assert.IsType(actual); - } - - public record EnterpriseAnnually : EnterprisePlan - { - public EnterpriseAnnually() : base(true) - { - } - } -} diff --git a/test/Core.Test/Billing/Tax/Services/BusinessUseAutomaticTaxStrategyTests.cs b/test/Core.Test/Billing/Tax/Services/BusinessUseAutomaticTaxStrategyTests.cs deleted file mode 100644 index dc10d222f1..0000000000 --- a/test/Core.Test/Billing/Tax/Services/BusinessUseAutomaticTaxStrategyTests.cs +++ /dev/null @@ -1,492 +0,0 @@ -using Bit.Core.Billing.Constants; -using Bit.Core.Billing.Tax.Services.Implementations; -using Bit.Core.Services; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using NSubstitute; -using Stripe; -using Xunit; - -namespace Bit.Core.Test.Billing.Tax.Services; - -[SutProviderCustomize] -public class BusinessUseAutomaticTaxStrategyTests -{ - [Theory] - [BitAutoData] - public void GetUpdateOptions_ReturnsNull_WhenFeatureFlagAllowingToUpdateSubscriptionsIsDisabled( - SutProvider sutProvider) - { - var subscription = new Subscription(); - - sutProvider.GetDependency() - .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - .Returns(false); - - var actual = sutProvider.Sut.GetUpdateOptions(subscription); - - Assert.Null(actual); - } - - [Theory] - [BitAutoData] - public void GetUpdateOptions_ReturnsNull_WhenSubscriptionDoesNotNeedUpdating( - SutProvider sutProvider) - { - var subscription = new Subscription - { - AutomaticTax = new SubscriptionAutomaticTax - { - Enabled = true - }, - Customer = new Customer - { - Address = new Address - { - Country = "US", - }, - Tax = new CustomerTax - { - AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported - } - } - }; - - sutProvider.GetDependency() - .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - .Returns(true); - - var actual = sutProvider.Sut.GetUpdateOptions(subscription); - - Assert.Null(actual); - } - - [Theory] - [BitAutoData] - public void GetUpdateOptions_SetsAutomaticTaxToFalse_WhenTaxLocationIsUnrecognizedOrInvalid( - SutProvider sutProvider) - { - var subscription = new Subscription - { - AutomaticTax = new SubscriptionAutomaticTax - { - Enabled = true - }, - Customer = new Customer - { - Tax = new CustomerTax - { - AutomaticTax = StripeConstants.AutomaticTaxStatus.UnrecognizedLocation - } - } - }; - - sutProvider.GetDependency() - .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - .Returns(true); - - var actual = sutProvider.Sut.GetUpdateOptions(subscription); - - Assert.NotNull(actual); - Assert.False(actual.AutomaticTax.Enabled); - } - - [Theory] - [BitAutoData] - public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForAmericanCustomers( - SutProvider sutProvider) - { - var subscription = new Subscription - { - AutomaticTax = new SubscriptionAutomaticTax - { - Enabled = false - }, - Customer = new Customer - { - Address = new Address - { - Country = "US", - }, - Tax = new CustomerTax - { - AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported - } - } - }; - - sutProvider.GetDependency() - .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - .Returns(true); - - var actual = sutProvider.Sut.GetUpdateOptions(subscription); - - Assert.NotNull(actual); - Assert.True(actual.AutomaticTax.Enabled); - } - - [Theory] - [BitAutoData] - public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithTaxIds( - SutProvider sutProvider) - { - var subscription = new Subscription - { - AutomaticTax = new SubscriptionAutomaticTax - { - Enabled = false - }, - Customer = new Customer - { - Address = new Address - { - Country = "ES", - }, - Tax = new CustomerTax - { - AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported - }, - TaxIds = new StripeList - { - Data = new List - { - new() - { - Country = "ES", - Type = "eu_vat", - Value = "ESZ8880999Z" - } - } - } - } - }; - - sutProvider.GetDependency() - .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - .Returns(true); - - var actual = sutProvider.Sut.GetUpdateOptions(subscription); - - Assert.NotNull(actual); - Assert.True(actual.AutomaticTax.Enabled); - } - - [Theory] - [BitAutoData] - public void GetUpdateOptions_ThrowsArgumentNullException_WhenTaxIdsIsNull( - SutProvider sutProvider) - { - var subscription = new Subscription - { - AutomaticTax = new SubscriptionAutomaticTax - { - Enabled = true - }, - Customer = new Customer - { - Address = new Address - { - Country = "ES", - }, - Tax = new CustomerTax - { - AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported - }, - TaxIds = null - } - }; - - sutProvider.GetDependency() - .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - .Returns(true); - - Assert.Throws(() => sutProvider.Sut.GetUpdateOptions(subscription)); - } - - [Theory] - [BitAutoData] - public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithoutTaxIds( - SutProvider sutProvider) - { - - var subscription = new Subscription - { - AutomaticTax = new SubscriptionAutomaticTax - { - Enabled = true - }, - Customer = new Customer - { - Address = new Address - { - Country = "ES", - }, - Tax = new CustomerTax - { - AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported - }, - TaxIds = new StripeList - { - Data = new List() - } - } - }; - - sutProvider.GetDependency() - .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - .Returns(true); - - var actual = sutProvider.Sut.GetUpdateOptions(subscription); - - Assert.NotNull(actual); - Assert.False(actual.AutomaticTax.Enabled); - } - - [Theory] - [BitAutoData] - public void SetUpdateOptions_SetsNothing_WhenFeatureFlagAllowingToUpdateSubscriptionsIsDisabled( - SutProvider sutProvider) - { - var options = new SubscriptionUpdateOptions(); - - var subscription = new Subscription - { - Customer = new Customer - { - Address = new() - { - Country = "US" - } - } - }; - - sutProvider.GetDependency() - .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - .Returns(false); - - sutProvider.Sut.SetUpdateOptions(options, subscription); - - Assert.Null(options.AutomaticTax); - } - - [Theory] - [BitAutoData] - public void SetUpdateOptions_SetsNothing_WhenSubscriptionDoesNotNeedUpdating( - SutProvider sutProvider) - { - var options = new SubscriptionUpdateOptions(); - - var subscription = new Subscription - { - AutomaticTax = new SubscriptionAutomaticTax - { - Enabled = true - }, - Customer = new Customer - { - Address = new Address - { - Country = "US", - }, - Tax = new CustomerTax - { - AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported - } - } - }; - - sutProvider.GetDependency() - .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - .Returns(true); - - sutProvider.Sut.SetUpdateOptions(options, subscription); - - Assert.Null(options.AutomaticTax); - } - - [Theory] - [BitAutoData] - public void SetUpdateOptions_SetsAutomaticTaxToFalse_WhenTaxLocationIsUnrecognizedOrInvalid( - SutProvider sutProvider) - { - var options = new SubscriptionUpdateOptions(); - - var subscription = new Subscription - { - AutomaticTax = new SubscriptionAutomaticTax - { - Enabled = true - }, - Customer = new Customer - { - Tax = new CustomerTax - { - AutomaticTax = StripeConstants.AutomaticTaxStatus.UnrecognizedLocation - } - } - }; - - sutProvider.GetDependency() - .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - .Returns(true); - - sutProvider.Sut.SetUpdateOptions(options, subscription); - - Assert.False(options.AutomaticTax.Enabled); - } - - [Theory] - [BitAutoData] - public void SetUpdateOptions_SetsAutomaticTaxToTrue_ForAmericanCustomers( - SutProvider sutProvider) - { - var options = new SubscriptionUpdateOptions(); - - var subscription = new Subscription - { - AutomaticTax = new SubscriptionAutomaticTax - { - Enabled = false - }, - Customer = new Customer - { - Address = new Address - { - Country = "US", - }, - Tax = new CustomerTax - { - AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported - } - } - }; - - sutProvider.GetDependency() - .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - .Returns(true); - - sutProvider.Sut.SetUpdateOptions(options, subscription); - - Assert.True(options.AutomaticTax!.Enabled); - } - - [Theory] - [BitAutoData] - public void SetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithTaxIds( - SutProvider sutProvider) - { - var options = new SubscriptionUpdateOptions(); - - var subscription = new Subscription - { - AutomaticTax = new SubscriptionAutomaticTax - { - Enabled = false - }, - Customer = new Customer - { - Address = new Address - { - Country = "ES", - }, - Tax = new CustomerTax - { - AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported - }, - TaxIds = new StripeList - { - Data = new List - { - new() - { - Country = "ES", - Type = "eu_vat", - Value = "ESZ8880999Z" - } - } - } - } - }; - - sutProvider.GetDependency() - .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - .Returns(true); - - sutProvider.Sut.SetUpdateOptions(options, subscription); - - Assert.True(options.AutomaticTax!.Enabled); - } - - [Theory] - [BitAutoData] - public void SetUpdateOptions_ThrowsArgumentNullException_WhenTaxIdsIsNull( - SutProvider sutProvider) - { - var options = new SubscriptionUpdateOptions(); - - var subscription = new Subscription - { - AutomaticTax = new SubscriptionAutomaticTax - { - Enabled = true - }, - Customer = new Customer - { - Address = new Address - { - Country = "ES", - }, - Tax = new CustomerTax - { - AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported - }, - TaxIds = null - } - }; - - sutProvider.GetDependency() - .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - .Returns(true); - - Assert.Throws(() => sutProvider.Sut.SetUpdateOptions(options, subscription)); - } - - [Theory] - [BitAutoData] - public void SetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithoutTaxIds( - SutProvider sutProvider) - { - var options = new SubscriptionUpdateOptions(); - - var subscription = new Subscription - { - AutomaticTax = new SubscriptionAutomaticTax - { - Enabled = true - }, - Customer = new Customer - { - Address = new Address - { - Country = "ES", - }, - Tax = new CustomerTax - { - AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported - }, - TaxIds = new StripeList - { - Data = new List() - } - } - }; - - sutProvider.GetDependency() - .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - .Returns(true); - - sutProvider.Sut.SetUpdateOptions(options, subscription); - - Assert.False(options.AutomaticTax!.Enabled); - } -} diff --git a/test/Core.Test/Billing/Tax/Services/FakeAutomaticTaxStrategy.cs b/test/Core.Test/Billing/Tax/Services/FakeAutomaticTaxStrategy.cs deleted file mode 100644 index 2f3cbc98ee..0000000000 --- a/test/Core.Test/Billing/Tax/Services/FakeAutomaticTaxStrategy.cs +++ /dev/null @@ -1,35 +0,0 @@ -using Bit.Core.Billing.Tax.Services; -using Stripe; - -namespace Bit.Core.Test.Billing.Tax.Services; - -/// -/// Whether the subscription options will have automatic tax enabled or not. -/// -public class FakeAutomaticTaxStrategy( - bool isAutomaticTaxEnabled) : IAutomaticTaxStrategy -{ - public SubscriptionUpdateOptions? GetUpdateOptions(Subscription subscription) - { - return new SubscriptionUpdateOptions - { - AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = isAutomaticTaxEnabled } - }; - } - - public void SetCreateOptions(SubscriptionCreateOptions options, Customer customer) - { - options.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = isAutomaticTaxEnabled }; - } - - public void SetUpdateOptions(SubscriptionUpdateOptions options, Subscription subscription) - { - options.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = isAutomaticTaxEnabled }; - } - - public void SetInvoiceCreatePreviewOptions(InvoiceCreatePreviewOptions options) - { - options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = isAutomaticTaxEnabled }; - - } -} diff --git a/test/Core.Test/Billing/Tax/Services/PersonalUseAutomaticTaxStrategyTests.cs b/test/Core.Test/Billing/Tax/Services/PersonalUseAutomaticTaxStrategyTests.cs deleted file mode 100644 index 30614b94ba..0000000000 --- a/test/Core.Test/Billing/Tax/Services/PersonalUseAutomaticTaxStrategyTests.cs +++ /dev/null @@ -1,217 +0,0 @@ -using Bit.Core.Billing.Constants; -using Bit.Core.Billing.Tax.Services.Implementations; -using Bit.Core.Services; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using NSubstitute; -using Stripe; -using Xunit; - -namespace Bit.Core.Test.Billing.Tax.Services; - -[SutProviderCustomize] -public class PersonalUseAutomaticTaxStrategyTests -{ - [Theory] - [BitAutoData] - public void GetUpdateOptions_ReturnsNull_WhenFeatureFlagAllowingToUpdateSubscriptionsIsDisabled( - SutProvider sutProvider) - { - var subscription = new Subscription(); - - sutProvider.GetDependency() - .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - .Returns(false); - - var actual = sutProvider.Sut.GetUpdateOptions(subscription); - - Assert.Null(actual); - } - - [Theory] - [BitAutoData] - public void GetUpdateOptions_ReturnsNull_WhenSubscriptionDoesNotNeedUpdating( - SutProvider sutProvider) - { - var subscription = new Subscription - { - AutomaticTax = new SubscriptionAutomaticTax - { - Enabled = true - }, - Customer = new Customer - { - Address = new Address - { - Country = "US", - }, - Tax = new CustomerTax - { - AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported - } - } - }; - - sutProvider.GetDependency() - .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - .Returns(true); - - var actual = sutProvider.Sut.GetUpdateOptions(subscription); - - Assert.Null(actual); - } - - [Theory] - [BitAutoData] - public void GetUpdateOptions_SetsAutomaticTaxToFalse_WhenTaxLocationIsUnrecognizedOrInvalid( - SutProvider sutProvider) - { - var subscription = new Subscription - { - AutomaticTax = new SubscriptionAutomaticTax - { - Enabled = true - }, - Customer = new Customer - { - Tax = new CustomerTax - { - AutomaticTax = StripeConstants.AutomaticTaxStatus.UnrecognizedLocation - } - } - }; - - sutProvider.GetDependency() - .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - .Returns(true); - - var actual = sutProvider.Sut.GetUpdateOptions(subscription); - - Assert.NotNull(actual); - Assert.False(actual.AutomaticTax.Enabled); - } - - [Theory] - [BitAutoData("CA")] - [BitAutoData("ES")] - [BitAutoData("US")] - public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForAllCountries( - string country, SutProvider sutProvider) - { - var subscription = new Subscription - { - AutomaticTax = new SubscriptionAutomaticTax - { - Enabled = false - }, - Customer = new Customer - { - Address = new Address - { - Country = country - }, - Tax = new CustomerTax - { - AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported - } - } - }; - - sutProvider.GetDependency() - .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - .Returns(true); - - var actual = sutProvider.Sut.GetUpdateOptions(subscription); - - Assert.NotNull(actual); - Assert.True(actual.AutomaticTax.Enabled); - } - - [Theory] - [BitAutoData("CA")] - [BitAutoData("ES")] - [BitAutoData("US")] - public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithTaxIds( - string country, SutProvider sutProvider) - { - var subscription = new Subscription - { - AutomaticTax = new SubscriptionAutomaticTax - { - Enabled = false - }, - Customer = new Customer - { - Address = new Address - { - Country = country, - }, - Tax = new CustomerTax - { - AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported - }, - TaxIds = new StripeList - { - Data = new List - { - new() - { - Country = "ES", - Type = "eu_vat", - Value = "ESZ8880999Z" - } - } - } - } - }; - - sutProvider.GetDependency() - .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - .Returns(true); - - var actual = sutProvider.Sut.GetUpdateOptions(subscription); - - Assert.NotNull(actual); - Assert.True(actual.AutomaticTax.Enabled); - } - - [Theory] - [BitAutoData("CA")] - [BitAutoData("ES")] - [BitAutoData("US")] - public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithoutTaxIds( - string country, SutProvider sutProvider) - { - var subscription = new Subscription - { - AutomaticTax = new SubscriptionAutomaticTax - { - Enabled = false - }, - Customer = new Customer - { - Address = new Address - { - Country = country - }, - Tax = new CustomerTax - { - AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported - }, - TaxIds = new StripeList - { - Data = new List() - } - } - }; - - sutProvider.GetDependency() - .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) - .Returns(true); - - var actual = sutProvider.Sut.GetUpdateOptions(subscription); - - Assert.NotNull(actual); - Assert.True(actual.AutomaticTax.Enabled); - } -} diff --git a/test/Core.Test/Services/StripePaymentServiceTests.cs b/test/Core.Test/Services/StripePaymentServiceTests.cs index 7d8a059d76..609437b8d1 100644 --- a/test/Core.Test/Services/StripePaymentServiceTests.cs +++ b/test/Core.Test/Services/StripePaymentServiceTests.cs @@ -1,12 +1,10 @@ -using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Enums; using Bit.Core.Billing.Models.StaticStore.Plans; using Bit.Core.Billing.Pricing; -using Bit.Core.Billing.Tax.Models; using Bit.Core.Billing.Tax.Requests; -using Bit.Core.Billing.Tax.Services; using Bit.Core.Enums; using Bit.Core.Services; -using Bit.Core.Test.Billing.Tax.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -23,10 +21,6 @@ public class StripePaymentServiceTests public async Task PreviewInvoiceAsync_ForOrganization_CalculatesSalesTaxCorrectlyForFamiliesWithoutAdditionalStorage( SutProvider sutProvider) { - sutProvider.GetDependency() - .CreateAsync(Arg.Is(p => p.PlanType == PlanType.FamiliesAnnually)) - .Returns(new FakeAutomaticTaxStrategy(true)); - var familiesPlan = new FamiliesPlan(); sutProvider.GetDependency() .GetPlanOrThrow(Arg.Is(p => p == PlanType.FamiliesAnnually)) @@ -74,10 +68,6 @@ public class StripePaymentServiceTests public async Task PreviewInvoiceAsync_ForOrganization_CalculatesSalesTaxCorrectlyForFamiliesWithAdditionalStorage( SutProvider sutProvider) { - sutProvider.GetDependency() - .CreateAsync(Arg.Is(p => p.PlanType == PlanType.FamiliesAnnually)) - .Returns(new FakeAutomaticTaxStrategy(true)); - var familiesPlan = new FamiliesPlan(); sutProvider.GetDependency() .GetPlanOrThrow(Arg.Is(p => p == PlanType.FamiliesAnnually)) @@ -125,10 +115,6 @@ public class StripePaymentServiceTests public async Task PreviewInvoiceAsync_ForOrganization_CalculatesSalesTaxCorrectlyForFamiliesForEnterpriseWithoutAdditionalStorage( SutProvider sutProvider) { - sutProvider.GetDependency() - .CreateAsync(Arg.Is(p => p.PlanType == PlanType.FamiliesAnnually)) - .Returns(new FakeAutomaticTaxStrategy(true)); - var familiesPlan = new FamiliesPlan(); sutProvider.GetDependency() .GetPlanOrThrow(Arg.Is(p => p == PlanType.FamiliesAnnually)) @@ -177,10 +163,6 @@ public class StripePaymentServiceTests public async Task PreviewInvoiceAsync_ForOrganization_CalculatesSalesTaxCorrectlyForFamiliesForEnterpriseWithAdditionalStorage( SutProvider sutProvider) { - sutProvider.GetDependency() - .CreateAsync(Arg.Is(p => p.PlanType == PlanType.FamiliesAnnually)) - .Returns(new FakeAutomaticTaxStrategy(true)); - var familiesPlan = new FamiliesPlan(); sutProvider.GetDependency() .GetPlanOrThrow(Arg.Is(p => p == PlanType.FamiliesAnnually)) @@ -223,4 +205,340 @@ public class StripePaymentServiceTests Assert.Equal(4.08M, actual.TotalAmount); Assert.Equal(4M, actual.TaxableBaseAmount); } + + [Theory] + [BitAutoData] + public async Task PreviewInvoiceAsync_USBased_PersonalUse_SetsAutomaticTaxEnabled(SutProvider sutProvider) + { + // Arrange + var familiesPlan = new FamiliesPlan(); + sutProvider.GetDependency() + .GetPlanOrThrow(Arg.Is(p => p == PlanType.FamiliesAnnually)) + .Returns(familiesPlan); + + var parameters = new PreviewOrganizationInvoiceRequestBody + { + PasswordManager = new OrganizationPasswordManagerRequestModel + { + Plan = PlanType.FamiliesAnnually + }, + TaxInformation = new TaxInformationRequestModel + { + Country = "US", + PostalCode = "12345" + } + }; + + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter + .InvoiceCreatePreviewAsync(Arg.Any()) + .Returns(new Invoice + { + TotalExcludingTax = 400, + Tax = 8, + Total = 408 + }); + + // Act + await sutProvider.Sut.PreviewInvoiceAsync(parameters, null, null); + + // Assert + await stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + options.AutomaticTax.Enabled == true + )); + } + + [Theory] + [BitAutoData] + public async Task PreviewInvoiceAsync_USBased_BusinessUse_SetsAutomaticTaxEnabled(SutProvider sutProvider) + { + // Arrange + var plan = new EnterprisePlan(true); + sutProvider.GetDependency() + .GetPlanOrThrow(Arg.Is(p => p == PlanType.EnterpriseAnnually)) + .Returns(plan); + + var parameters = new PreviewOrganizationInvoiceRequestBody + { + PasswordManager = new OrganizationPasswordManagerRequestModel + { + Plan = PlanType.EnterpriseAnnually + }, + TaxInformation = new TaxInformationRequestModel + { + Country = "US", + PostalCode = "12345" + } + }; + + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter + .InvoiceCreatePreviewAsync(Arg.Any()) + .Returns(new Invoice + { + TotalExcludingTax = 400, + Tax = 8, + Total = 408 + }); + + // Act + await sutProvider.Sut.PreviewInvoiceAsync(parameters, null, null); + + // Assert + await stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + options.AutomaticTax.Enabled == true + )); + } + + [Theory] + [BitAutoData] + public async Task PreviewInvoiceAsync_NonUSBased_PersonalUse_SetsAutomaticTaxEnabled(SutProvider sutProvider) + { + // Arrange + var familiesPlan = new FamiliesPlan(); + sutProvider.GetDependency() + .GetPlanOrThrow(Arg.Is(p => p == PlanType.FamiliesAnnually)) + .Returns(familiesPlan); + + var parameters = new PreviewOrganizationInvoiceRequestBody + { + PasswordManager = new OrganizationPasswordManagerRequestModel + { + Plan = PlanType.FamiliesAnnually + }, + TaxInformation = new TaxInformationRequestModel + { + Country = "FR", + PostalCode = "12345" + } + }; + + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter + .InvoiceCreatePreviewAsync(Arg.Any()) + .Returns(new Invoice + { + TotalExcludingTax = 400, + Tax = 8, + Total = 408 + }); + + // Act + await sutProvider.Sut.PreviewInvoiceAsync(parameters, null, null); + + // Assert + await stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + options.AutomaticTax.Enabled == true + )); + } + + [Theory] + [BitAutoData] + public async Task PreviewInvoiceAsync_NonUSBased_BusinessUse_SetsAutomaticTaxEnabled(SutProvider sutProvider) + { + // Arrange + var plan = new EnterprisePlan(true); + sutProvider.GetDependency() + .GetPlanOrThrow(Arg.Is(p => p == PlanType.EnterpriseAnnually)) + .Returns(plan); + + var parameters = new PreviewOrganizationInvoiceRequestBody + { + PasswordManager = new OrganizationPasswordManagerRequestModel + { + Plan = PlanType.EnterpriseAnnually + }, + TaxInformation = new TaxInformationRequestModel + { + Country = "FR", + PostalCode = "12345" + } + }; + + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter + .InvoiceCreatePreviewAsync(Arg.Any()) + .Returns(new Invoice + { + TotalExcludingTax = 400, + Tax = 8, + Total = 408 + }); + + // Act + await sutProvider.Sut.PreviewInvoiceAsync(parameters, null, null); + + // Assert + await stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + options.AutomaticTax.Enabled == true + )); + } + + [Theory] + [BitAutoData] + public async Task PreviewInvoiceAsync_USBased_PersonalUse_DoesNotSetTaxExempt(SutProvider sutProvider) + { + // Arrange + var familiesPlan = new FamiliesPlan(); + sutProvider.GetDependency() + .GetPlanOrThrow(Arg.Is(p => p == PlanType.FamiliesAnnually)) + .Returns(familiesPlan); + + var parameters = new PreviewOrganizationInvoiceRequestBody + { + PasswordManager = new OrganizationPasswordManagerRequestModel + { + Plan = PlanType.FamiliesAnnually + }, + TaxInformation = new TaxInformationRequestModel + { + Country = "US", + PostalCode = "12345" + } + }; + + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter + .InvoiceCreatePreviewAsync(Arg.Any()) + .Returns(new Invoice + { + TotalExcludingTax = 400, + Tax = 8, + Total = 408 + }); + + // Act + await sutProvider.Sut.PreviewInvoiceAsync(parameters, null, null); + + // Assert + await stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + options.CustomerDetails.TaxExempt == null + )); + } + + [Theory] + [BitAutoData] + public async Task PreviewInvoiceAsync_USBased_BusinessUse_DoesNotSetTaxExempt(SutProvider sutProvider) + { + // Arrange + var plan = new EnterprisePlan(true); + sutProvider.GetDependency() + .GetPlanOrThrow(Arg.Is(p => p == PlanType.EnterpriseAnnually)) + .Returns(plan); + + var parameters = new PreviewOrganizationInvoiceRequestBody + { + PasswordManager = new OrganizationPasswordManagerRequestModel + { + Plan = PlanType.EnterpriseAnnually + }, + TaxInformation = new TaxInformationRequestModel + { + Country = "US", + PostalCode = "12345" + } + }; + + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter + .InvoiceCreatePreviewAsync(Arg.Any()) + .Returns(new Invoice + { + TotalExcludingTax = 400, + Tax = 8, + Total = 408 + }); + + // Act + await sutProvider.Sut.PreviewInvoiceAsync(parameters, null, null); + + // Assert + await stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + options.CustomerDetails.TaxExempt == null + )); + } + + [Theory] + [BitAutoData] + public async Task PreviewInvoiceAsync_NonUSBased_PersonalUse_DoesNotSetTaxExempt(SutProvider sutProvider) + { + // Arrange + var familiesPlan = new FamiliesPlan(); + sutProvider.GetDependency() + .GetPlanOrThrow(Arg.Is(p => p == PlanType.FamiliesAnnually)) + .Returns(familiesPlan); + + var parameters = new PreviewOrganizationInvoiceRequestBody + { + PasswordManager = new OrganizationPasswordManagerRequestModel + { + Plan = PlanType.FamiliesAnnually + }, + TaxInformation = new TaxInformationRequestModel + { + Country = "FR", + PostalCode = "12345" + } + }; + + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter + .InvoiceCreatePreviewAsync(Arg.Any()) + .Returns(new Invoice + { + TotalExcludingTax = 400, + Tax = 8, + Total = 408 + }); + + // Act + await sutProvider.Sut.PreviewInvoiceAsync(parameters, null, null); + + // Assert + await stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + options.CustomerDetails.TaxExempt == null + )); + } + + [Theory] + [BitAutoData] + public async Task PreviewInvoiceAsync_NonUSBased_BusinessUse_SetsTaxExemptReverse(SutProvider sutProvider) + { + // Arrange + var plan = new EnterprisePlan(true); + sutProvider.GetDependency() + .GetPlanOrThrow(Arg.Is(p => p == PlanType.EnterpriseAnnually)) + .Returns(plan); + + var parameters = new PreviewOrganizationInvoiceRequestBody + { + PasswordManager = new OrganizationPasswordManagerRequestModel + { + Plan = PlanType.EnterpriseAnnually + }, + TaxInformation = new TaxInformationRequestModel + { + Country = "FR", + PostalCode = "12345" + } + }; + + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter + .InvoiceCreatePreviewAsync(Arg.Any()) + .Returns(new Invoice + { + TotalExcludingTax = 400, + Tax = 8, + Total = 408 + }); + + // Act + await sutProvider.Sut.PreviewInvoiceAsync(parameters, null, null); + + // Assert + await stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + options.CustomerDetails.TaxExempt == StripeConstants.TaxExempt.Reverse + )); + } } From 3731c7c40c3e7da515328318e2b64983cd96d4f6 Mon Sep 17 00:00:00 2001 From: Graham Walker Date: Wed, 3 Sep 2025 10:39:12 -0500 Subject: [PATCH 13/13] PM-24436 Add logging to backend for Member Access Report (#6159) * pm-24436 inital commit * PM-24436 updating logsto bypass event filter --- src/Api/Dirt/Controllers/ReportsController.cs | 32 ++++++++----------- .../ReportFeatures/MemberAccessReportQuery.cs | 21 +++++++++++- 2 files changed, 34 insertions(+), 19 deletions(-) diff --git a/src/Api/Dirt/Controllers/ReportsController.cs b/src/Api/Dirt/Controllers/ReportsController.cs index e7c7e4a9bf..d643d68661 100644 --- a/src/Api/Dirt/Controllers/ReportsController.cs +++ b/src/Api/Dirt/Controllers/ReportsController.cs @@ -1,6 +1,7 @@ using Bit.Api.Dirt.Models; using Bit.Api.Dirt.Models.Response; using Bit.Api.Tools.Models.Response; +using Bit.Core; using Bit.Core.Context; using Bit.Core.Dirt.Entities; using Bit.Core.Dirt.Reports.Models.Data; @@ -26,6 +27,7 @@ public class ReportsController : Controller private readonly IAddOrganizationReportCommand _addOrganizationReportCommand; private readonly IDropOrganizationReportCommand _dropOrganizationReportCommand; private readonly IGetOrganizationReportQuery _getOrganizationReportQuery; + private readonly ILogger _logger; public ReportsController( ICurrentContext currentContext, @@ -36,7 +38,8 @@ public class ReportsController : Controller IDropPasswordHealthReportApplicationCommand dropPwdHealthReportAppCommand, IGetOrganizationReportQuery getOrganizationReportQuery, IAddOrganizationReportCommand addOrganizationReportCommand, - IDropOrganizationReportCommand dropOrganizationReportCommand + IDropOrganizationReportCommand dropOrganizationReportCommand, + ILogger logger ) { _currentContext = currentContext; @@ -48,6 +51,7 @@ public class ReportsController : Controller _getOrganizationReportQuery = getOrganizationReportQuery; _addOrganizationReportCommand = addOrganizationReportCommand; _dropOrganizationReportCommand = dropOrganizationReportCommand; + _logger = logger; } /// @@ -86,32 +90,24 @@ public class ReportsController : Controller { if (!await _currentContext.AccessReports(orgId)) { + _logger.LogInformation(Constants.BypassFiltersEventId, + "AccessReports Check - UserId: {userId} OrgId: {orgId} DeviceType: {deviceType}", + _currentContext.UserId, orgId, _currentContext.DeviceType); throw new NotFoundException(); } - var accessDetails = await GetMemberAccessDetails(new MemberAccessReportRequest { OrganizationId = orgId }); + _logger.LogInformation(Constants.BypassFiltersEventId, + "MemberAccessReportQuery starts - UserId: {userId} OrgId: {orgId} DeviceType: {deviceType}", + _currentContext.UserId, orgId, _currentContext.DeviceType); + + var accessDetails = await _memberAccessReportQuery + .GetMemberAccessReportsAsync(new MemberAccessReportRequest { OrganizationId = orgId }); var responses = accessDetails.Select(x => new MemberAccessDetailReportResponseModel(x)); return responses; } - /// - /// Contains the organization member info, the cipher ids associated with the member, - /// and details on their collections, groups, and permissions - /// - /// Request parameters - /// - /// List of a user's permissions at a group and collection level as well as the number of ciphers - /// associated with that group/collection - /// - private async Task> GetMemberAccessDetails( - MemberAccessReportRequest request) - { - var accessDetails = await _memberAccessReportQuery.GetMemberAccessReportsAsync(request); - return accessDetails; - } - /// /// Gets the risk insights report details from the risk insights query. Associates a user to their cipher ids /// diff --git a/src/Core/Dirt/Reports/ReportFeatures/MemberAccessReportQuery.cs b/src/Core/Dirt/Reports/ReportFeatures/MemberAccessReportQuery.cs index 33acd73d14..83d074454d 100644 --- a/src/Core/Dirt/Reports/ReportFeatures/MemberAccessReportQuery.cs +++ b/src/Core/Dirt/Reports/ReportFeatures/MemberAccessReportQuery.cs @@ -7,25 +7,40 @@ using Bit.Core.Dirt.Reports.ReportFeatures.OrganizationReportMembers.Interfaces; using Bit.Core.Dirt.Reports.ReportFeatures.Requests; using Bit.Core.Dirt.Reports.Repositories; using Bit.Core.Services; +using Microsoft.Extensions.Logging; namespace Bit.Core.Dirt.Reports.ReportFeatures; public class MemberAccessReportQuery( IOrganizationMemberBaseDetailRepository organizationMemberBaseDetailRepository, ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, - IApplicationCacheService applicationCacheService) : IMemberAccessReportQuery + IApplicationCacheService applicationCacheService, + ILogger logger) : IMemberAccessReportQuery { public async Task> GetMemberAccessReportsAsync( MemberAccessReportRequest request) { + logger.LogInformation(Constants.BypassFiltersEventId, "Starting MemberAccessReport generation for OrganizationId: {OrganizationId}", request.OrganizationId); + var baseDetails = await organizationMemberBaseDetailRepository.GetOrganizationMemberBaseDetailsByOrganizationId( request.OrganizationId); + logger.LogInformation(Constants.BypassFiltersEventId, "Retrieved {BaseDetailsCount} base details for OrganizationId: {OrganizationId}", + baseDetails.Count(), request.OrganizationId); + var orgUsers = baseDetails.Select(x => x.UserGuid.GetValueOrDefault()).Distinct(); + var orgUsersCount = orgUsers.Count(); + logger.LogInformation(Constants.BypassFiltersEventId, "Found {UniqueUsersCount} unique users for OrganizationId: {OrganizationId}", + orgUsersCount, request.OrganizationId); + var orgUsersTwoFactorEnabled = await twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(orgUsers); + logger.LogInformation(Constants.BypassFiltersEventId, "Retrieved two-factor status for {UsersCount} users for OrganizationId: {OrganizationId}", + orgUsersTwoFactorEnabled.Count(), request.OrganizationId); var orgAbility = await applicationCacheService.GetOrganizationAbilityAsync(request.OrganizationId); + logger.LogInformation(Constants.BypassFiltersEventId, "Retrieved organization ability (UseResetPassword: {UseResetPassword}) for OrganizationId: {OrganizationId}", + orgAbility?.UseResetPassword, request.OrganizationId); var accessDetails = baseDetails .GroupBy(b => new @@ -62,6 +77,10 @@ public class MemberAccessReportQuery( CipherIds = g.Select(c => c.CipherId) }); + var accessDetailsCount = accessDetails.Count(); + logger.LogInformation(Constants.BypassFiltersEventId, "Completed MemberAccessReport generation for OrganizationId: {OrganizationId}. Generated {AccessDetailsCount} access detail records", + request.OrganizationId, accessDetailsCount); + return accessDetails; } }