diff --git a/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs b/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs index 6c56d6db3a..241e595333 100644 --- a/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs +++ b/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs @@ -132,8 +132,8 @@ public class AccountBillingVNextController( [BindNever] User user, [FromBody] UpgradePremiumToOrganizationRequest request) { - var (organizationName, key, planType) = request.ToDomain(); - var result = await upgradePremiumToOrganizationCommand.Run(user, organizationName, key, planType); + var (organizationName, key, planType, billingAddress) = request.ToDomain(); + var result = await upgradePremiumToOrganizationCommand.Run(user, organizationName, key, planType, billingAddress); return Handle(result); } } diff --git a/src/Api/Billing/Models/Requests/Premium/UpgradePremiumToOrganizationRequest.cs b/src/Api/Billing/Models/Requests/Premium/UpgradePremiumToOrganizationRequest.cs index 7e7bf3ef3a..00b1da4bba 100644 --- a/src/Api/Billing/Models/Requests/Premium/UpgradePremiumToOrganizationRequest.cs +++ b/src/Api/Billing/Models/Requests/Premium/UpgradePremiumToOrganizationRequest.cs @@ -39,5 +39,6 @@ public class UpgradePremiumToOrganizationRequest } } - public (string OrganizationName, string Key, PlanType PlanType) ToDomain() => (OrganizationName, Key, PlanType); + public (string OrganizationName, string Key, PlanType PlanType, Core.Billing.Payment.Models.BillingAddress BillingAddress) ToDomain() => + (OrganizationName, Key, PlanType, BillingAddress.ToDomain()); } diff --git a/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs b/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs index 50f90808e2..d3e2eb899f 100644 --- a/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs +++ b/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs @@ -28,12 +28,14 @@ public interface IUpgradePremiumToOrganizationCommand /// The name for the new organization. /// The encrypted organization key for the owner. /// The target organization plan type to upgrade to. + /// The billing address for tax calculation. /// A billing command result indicating success or failure with appropriate error details. Task> Run( User user, string organizationName, string key, - PlanType targetPlanType); + PlanType targetPlanType, + Payment.Models.BillingAddress billingAddress); } public class UpgradePremiumToOrganizationCommand( @@ -51,7 +53,8 @@ public class UpgradePremiumToOrganizationCommand( User user, string organizationName, string key, - PlanType targetPlanType) => HandleAsync(async () => + PlanType targetPlanType, + Payment.Models.BillingAddress billingAddress) => HandleAsync(async () => { // Validate that the user has an active Premium subscription if (user is not { Premium: true, GatewaySubscriptionId: not null and not "" }) @@ -134,6 +137,7 @@ public class UpgradePremiumToOrganizationCommand( { Items = subscriptionItemOptions, ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations, + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }, Metadata = new Dictionary { [StripeConstants.MetadataKeys.OrganizationId] = organizationId.ToString(), @@ -187,6 +191,16 @@ public class UpgradePremiumToOrganizationCommand( GatewaySubscriptionId = currentSubscription.Id }; + // Update customer billing address for tax calculation + await stripeAdapter.UpdateCustomerAsync(user.GatewayCustomerId, new CustomerUpdateOptions + { + Address = new AddressOptions + { + Country = billingAddress.Country, + PostalCode = billingAddress.PostalCode + } + }); + // Update the subscription in Stripe await stripeAdapter.UpdateSubscriptionAsync(currentSubscription.Id, subscriptionUpdateOptions); diff --git a/test/Api.Test/Billing/Models/Requests/UpgradePremiumToOrganizationRequestTests.cs b/test/Api.Test/Billing/Models/Requests/UpgradePremiumToOrganizationRequestTests.cs index 6a8e328bf3..2d3bdb7b14 100644 --- a/test/Api.Test/Billing/Models/Requests/UpgradePremiumToOrganizationRequestTests.cs +++ b/test/Api.Test/Billing/Models/Requests/UpgradePremiumToOrganizationRequestTests.cs @@ -27,12 +27,14 @@ public class UpgradePremiumToOrganizationRequestTests }; // Act - var (organizationName, key, planType) = sut.ToDomain(); + var (organizationName, key, planType, billingAddress) = sut.ToDomain(); // Assert Assert.Equal("Test Organization", organizationName); Assert.Equal("encrypted-key", key); Assert.Equal(expectedPlanType, planType); + Assert.Equal("US", billingAddress.Country); + Assert.Equal("12345", billingAddress.PostalCode); } [Theory] diff --git a/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs b/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs index 4ead5c12da..702e449746 100644 --- a/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs +++ b/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs @@ -151,6 +151,9 @@ public class UpgradePremiumToOrganizationCommandTests _applicationCacheService); } + private static Core.Billing.Payment.Models.BillingAddress CreateTestBillingAddress() => + new() { Country = "US", PostalCode = "12345" }; + [Theory, BitAutoData] public async Task Run_UserNotPremium_ReturnsBadRequest(User user) { @@ -158,7 +161,7 @@ public class UpgradePremiumToOrganizationCommandTests user.Premium = false; // Act - var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually); + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); // Assert Assert.True(result.IsT1); @@ -174,7 +177,7 @@ public class UpgradePremiumToOrganizationCommandTests user.GatewaySubscriptionId = null; // Act - var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually); + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); // Assert Assert.True(result.IsT1); @@ -190,7 +193,7 @@ public class UpgradePremiumToOrganizationCommandTests user.GatewaySubscriptionId = ""; // Act - var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually); + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); // Assert Assert.True(result.IsT1); @@ -245,7 +248,7 @@ public class UpgradePremiumToOrganizationCommandTests _userService.SaveUserAsync(user).Returns(Task.CompletedTask); // Act - var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually); + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); // Assert Assert.True(result.IsT0); @@ -320,7 +323,7 @@ public class UpgradePremiumToOrganizationCommandTests _userService.SaveUserAsync(user).Returns(Task.CompletedTask); // Act - var result = await _command.Run(user, "My Families Org", "encrypted-key", PlanType.FamiliesAnnually); + var result = await _command.Run(user, "My Families Org", "encrypted-key", PlanType.FamiliesAnnually, CreateTestBillingAddress()); // Assert Assert.True(result.IsT0); @@ -383,7 +386,7 @@ public class UpgradePremiumToOrganizationCommandTests _userService.SaveUserAsync(user).Returns(Task.CompletedTask); // Act - var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually); + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); // Assert Assert.True(result.IsT0); @@ -453,7 +456,7 @@ public class UpgradePremiumToOrganizationCommandTests _userService.SaveUserAsync(user).Returns(Task.CompletedTask); // Act - var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually); + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); // Assert Assert.True(result.IsT0); @@ -520,7 +523,7 @@ public class UpgradePremiumToOrganizationCommandTests _userService.SaveUserAsync(user).Returns(Task.CompletedTask); // Act - var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually); + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); // Assert Assert.True(result.IsT0); @@ -589,7 +592,7 @@ public class UpgradePremiumToOrganizationCommandTests _userService.SaveUserAsync(user).Returns(Task.CompletedTask); // Act - var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually); + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); // Assert Assert.True(result.IsT0); @@ -636,12 +639,12 @@ public class UpgradePremiumToOrganizationCommandTests _pricingClient.ListPremiumPlans().Returns(mockPremiumPlans); // Act - var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually); + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); // Assert Assert.True(result.IsT1); var badRequest = result.AsT1; - Assert.Equal("Premium subscription item not found.", badRequest.Response); + Assert.Equal("Premium subscription password manager item not found.", badRequest.Response); } [Theory, BitAutoData] @@ -697,7 +700,7 @@ public class UpgradePremiumToOrganizationCommandTests // Act var testStartTime = DateTime.UtcNow; - var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually); + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); var testEndTime = DateTime.UtcNow; // Assert @@ -714,4 +717,110 @@ public class UpgradePremiumToOrganizationCommandTests Assert.True(trialEndDateTime.Value >= testStartTime.AddDays(7)); Assert.True(trialEndDateTime.Value <= testEndTime.AddDays(7)); } + + [Theory, BitAutoData] + public async Task Run_UpdatesCustomerBillingAddress(User user) + { + // Arrange + user.Premium = true; + user.GatewaySubscriptionId = "sub_123"; + user.GatewayCustomerId = "cus_123"; + + var mockSubscription = new Subscription + { + Id = "sub_123", + Items = new StripeList + { + Data = new List + { + new SubscriptionItem + { + Id = "si_premium", + Price = new Price { Id = "premium-annually" } + } + } + }, + Metadata = new Dictionary() + }; + + var mockPremiumPlans = CreateTestPremiumPlansList(); + var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually"); + + _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription); + _pricingClient.ListPremiumPlans().Returns(mockPremiumPlans); + _pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan); + _stripeAdapter.UpdateSubscriptionAsync(Arg.Any(), Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(new Customer())); + _organizationRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationApiKeyRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationUserRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any()).Returns(Task.CompletedTask); + _userService.SaveUserAsync(user).Returns(Task.CompletedTask); + + var billingAddress = new Core.Billing.Payment.Models.BillingAddress { Country = "US", PostalCode = "12345" }; + + // Act + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, billingAddress); + + // Assert + Assert.True(result.IsT0); + + await _stripeAdapter.Received(1).UpdateCustomerAsync( + "cus_123", + Arg.Is(opts => + opts.Address.Country == "US" && + opts.Address.PostalCode == "12345")); + } + + [Theory, BitAutoData] + public async Task Run_EnablesAutomaticTaxOnSubscription(User user) + { + // Arrange + user.Premium = true; + user.GatewaySubscriptionId = "sub_123"; + user.GatewayCustomerId = "cus_123"; + + var mockSubscription = new Subscription + { + Id = "sub_123", + Items = new StripeList + { + Data = new List + { + new SubscriptionItem + { + Id = "si_premium", + Price = new Price { Id = "premium-annually" } + } + } + }, + Metadata = new Dictionary() + }; + + var mockPremiumPlans = CreateTestPremiumPlansList(); + var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually"); + + _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription); + _pricingClient.ListPremiumPlans().Returns(mockPremiumPlans); + _pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan); + _stripeAdapter.UpdateSubscriptionAsync(Arg.Any(), Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(new Customer())); + _organizationRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationApiKeyRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationUserRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any()).Returns(Task.CompletedTask); + _userService.SaveUserAsync(user).Returns(Task.CompletedTask); + + // Act + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); + + // Assert + Assert.True(result.IsT0); + + await _stripeAdapter.Received(1).UpdateSubscriptionAsync( + "sub_123", + Arg.Is(opts => + opts.AutomaticTax != null && + opts.AutomaticTax.Enabled == true)); + } }