diff --git a/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs b/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs
index 6c56d6db3a..241e595333 100644
--- a/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs
+++ b/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs
@@ -132,8 +132,8 @@ public class AccountBillingVNextController(
[BindNever] User user,
[FromBody] UpgradePremiumToOrganizationRequest request)
{
- var (organizationName, key, planType) = request.ToDomain();
- var result = await upgradePremiumToOrganizationCommand.Run(user, organizationName, key, planType);
+ var (organizationName, key, planType, billingAddress) = request.ToDomain();
+ var result = await upgradePremiumToOrganizationCommand.Run(user, organizationName, key, planType, billingAddress);
return Handle(result);
}
}
diff --git a/src/Api/Billing/Models/Requests/Premium/UpgradePremiumToOrganizationRequest.cs b/src/Api/Billing/Models/Requests/Premium/UpgradePremiumToOrganizationRequest.cs
index 7e7bf3ef3a..00b1da4bba 100644
--- a/src/Api/Billing/Models/Requests/Premium/UpgradePremiumToOrganizationRequest.cs
+++ b/src/Api/Billing/Models/Requests/Premium/UpgradePremiumToOrganizationRequest.cs
@@ -39,5 +39,6 @@ public class UpgradePremiumToOrganizationRequest
}
}
- public (string OrganizationName, string Key, PlanType PlanType) ToDomain() => (OrganizationName, Key, PlanType);
+ public (string OrganizationName, string Key, PlanType PlanType, Core.Billing.Payment.Models.BillingAddress BillingAddress) ToDomain() =>
+ (OrganizationName, Key, PlanType, BillingAddress.ToDomain());
}
diff --git a/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs b/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs
index 50f90808e2..d3e2eb899f 100644
--- a/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs
+++ b/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs
@@ -28,12 +28,14 @@ public interface IUpgradePremiumToOrganizationCommand
/// The name for the new organization.
/// The encrypted organization key for the owner.
/// The target organization plan type to upgrade to.
+ /// The billing address for tax calculation.
/// A billing command result indicating success or failure with appropriate error details.
Task> Run(
User user,
string organizationName,
string key,
- PlanType targetPlanType);
+ PlanType targetPlanType,
+ Payment.Models.BillingAddress billingAddress);
}
public class UpgradePremiumToOrganizationCommand(
@@ -51,7 +53,8 @@ public class UpgradePremiumToOrganizationCommand(
User user,
string organizationName,
string key,
- PlanType targetPlanType) => HandleAsync(async () =>
+ PlanType targetPlanType,
+ Payment.Models.BillingAddress billingAddress) => HandleAsync(async () =>
{
// Validate that the user has an active Premium subscription
if (user is not { Premium: true, GatewaySubscriptionId: not null and not "" })
@@ -134,6 +137,7 @@ public class UpgradePremiumToOrganizationCommand(
{
Items = subscriptionItemOptions,
ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations,
+ AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true },
Metadata = new Dictionary
{
[StripeConstants.MetadataKeys.OrganizationId] = organizationId.ToString(),
@@ -187,6 +191,16 @@ public class UpgradePremiumToOrganizationCommand(
GatewaySubscriptionId = currentSubscription.Id
};
+ // Update customer billing address for tax calculation
+ await stripeAdapter.UpdateCustomerAsync(user.GatewayCustomerId, new CustomerUpdateOptions
+ {
+ Address = new AddressOptions
+ {
+ Country = billingAddress.Country,
+ PostalCode = billingAddress.PostalCode
+ }
+ });
+
// Update the subscription in Stripe
await stripeAdapter.UpdateSubscriptionAsync(currentSubscription.Id, subscriptionUpdateOptions);
diff --git a/test/Api.Test/Billing/Models/Requests/UpgradePremiumToOrganizationRequestTests.cs b/test/Api.Test/Billing/Models/Requests/UpgradePremiumToOrganizationRequestTests.cs
index 6a8e328bf3..2d3bdb7b14 100644
--- a/test/Api.Test/Billing/Models/Requests/UpgradePremiumToOrganizationRequestTests.cs
+++ b/test/Api.Test/Billing/Models/Requests/UpgradePremiumToOrganizationRequestTests.cs
@@ -27,12 +27,14 @@ public class UpgradePremiumToOrganizationRequestTests
};
// Act
- var (organizationName, key, planType) = sut.ToDomain();
+ var (organizationName, key, planType, billingAddress) = sut.ToDomain();
// Assert
Assert.Equal("Test Organization", organizationName);
Assert.Equal("encrypted-key", key);
Assert.Equal(expectedPlanType, planType);
+ Assert.Equal("US", billingAddress.Country);
+ Assert.Equal("12345", billingAddress.PostalCode);
}
[Theory]
diff --git a/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs b/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs
index 4ead5c12da..702e449746 100644
--- a/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs
+++ b/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs
@@ -151,6 +151,9 @@ public class UpgradePremiumToOrganizationCommandTests
_applicationCacheService);
}
+ private static Core.Billing.Payment.Models.BillingAddress CreateTestBillingAddress() =>
+ new() { Country = "US", PostalCode = "12345" };
+
[Theory, BitAutoData]
public async Task Run_UserNotPremium_ReturnsBadRequest(User user)
{
@@ -158,7 +161,7 @@ public class UpgradePremiumToOrganizationCommandTests
user.Premium = false;
// Act
- var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
+ var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT1);
@@ -174,7 +177,7 @@ public class UpgradePremiumToOrganizationCommandTests
user.GatewaySubscriptionId = null;
// Act
- var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
+ var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT1);
@@ -190,7 +193,7 @@ public class UpgradePremiumToOrganizationCommandTests
user.GatewaySubscriptionId = "";
// Act
- var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
+ var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT1);
@@ -245,7 +248,7 @@ public class UpgradePremiumToOrganizationCommandTests
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
- var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
+ var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT0);
@@ -320,7 +323,7 @@ public class UpgradePremiumToOrganizationCommandTests
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
- var result = await _command.Run(user, "My Families Org", "encrypted-key", PlanType.FamiliesAnnually);
+ var result = await _command.Run(user, "My Families Org", "encrypted-key", PlanType.FamiliesAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT0);
@@ -383,7 +386,7 @@ public class UpgradePremiumToOrganizationCommandTests
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
- var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
+ var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT0);
@@ -453,7 +456,7 @@ public class UpgradePremiumToOrganizationCommandTests
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
- var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
+ var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT0);
@@ -520,7 +523,7 @@ public class UpgradePremiumToOrganizationCommandTests
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
- var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
+ var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT0);
@@ -589,7 +592,7 @@ public class UpgradePremiumToOrganizationCommandTests
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
- var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
+ var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT0);
@@ -636,12 +639,12 @@ public class UpgradePremiumToOrganizationCommandTests
_pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
// Act
- var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
+ var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT1);
var badRequest = result.AsT1;
- Assert.Equal("Premium subscription item not found.", badRequest.Response);
+ Assert.Equal("Premium subscription password manager item not found.", badRequest.Response);
}
[Theory, BitAutoData]
@@ -697,7 +700,7 @@ public class UpgradePremiumToOrganizationCommandTests
// Act
var testStartTime = DateTime.UtcNow;
- var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
+ var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
var testEndTime = DateTime.UtcNow;
// Assert
@@ -714,4 +717,110 @@ public class UpgradePremiumToOrganizationCommandTests
Assert.True(trialEndDateTime.Value >= testStartTime.AddDays(7));
Assert.True(trialEndDateTime.Value <= testEndTime.AddDays(7));
}
+
+ [Theory, BitAutoData]
+ public async Task Run_UpdatesCustomerBillingAddress(User user)
+ {
+ // Arrange
+ user.Premium = true;
+ user.GatewaySubscriptionId = "sub_123";
+ user.GatewayCustomerId = "cus_123";
+
+ var mockSubscription = new Subscription
+ {
+ Id = "sub_123",
+ Items = new StripeList
+ {
+ Data = new List
+ {
+ new SubscriptionItem
+ {
+ Id = "si_premium",
+ Price = new Price { Id = "premium-annually" }
+ }
+ }
+ },
+ Metadata = new Dictionary()
+ };
+
+ var mockPremiumPlans = CreateTestPremiumPlansList();
+ var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually");
+
+ _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription);
+ _pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
+ _pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan);
+ _stripeAdapter.UpdateSubscriptionAsync(Arg.Any(), Arg.Any()).Returns(mockSubscription);
+ _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(new Customer()));
+ _organizationRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg()));
+ _organizationApiKeyRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg()));
+ _organizationUserRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg()));
+ _applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any()).Returns(Task.CompletedTask);
+ _userService.SaveUserAsync(user).Returns(Task.CompletedTask);
+
+ var billingAddress = new Core.Billing.Payment.Models.BillingAddress { Country = "US", PostalCode = "12345" };
+
+ // Act
+ var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, billingAddress);
+
+ // Assert
+ Assert.True(result.IsT0);
+
+ await _stripeAdapter.Received(1).UpdateCustomerAsync(
+ "cus_123",
+ Arg.Is(opts =>
+ opts.Address.Country == "US" &&
+ opts.Address.PostalCode == "12345"));
+ }
+
+ [Theory, BitAutoData]
+ public async Task Run_EnablesAutomaticTaxOnSubscription(User user)
+ {
+ // Arrange
+ user.Premium = true;
+ user.GatewaySubscriptionId = "sub_123";
+ user.GatewayCustomerId = "cus_123";
+
+ var mockSubscription = new Subscription
+ {
+ Id = "sub_123",
+ Items = new StripeList
+ {
+ Data = new List
+ {
+ new SubscriptionItem
+ {
+ Id = "si_premium",
+ Price = new Price { Id = "premium-annually" }
+ }
+ }
+ },
+ Metadata = new Dictionary()
+ };
+
+ var mockPremiumPlans = CreateTestPremiumPlansList();
+ var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually");
+
+ _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription);
+ _pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
+ _pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan);
+ _stripeAdapter.UpdateSubscriptionAsync(Arg.Any(), Arg.Any()).Returns(mockSubscription);
+ _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(new Customer()));
+ _organizationRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg()));
+ _organizationApiKeyRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg()));
+ _organizationUserRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg()));
+ _applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any()).Returns(Task.CompletedTask);
+ _userService.SaveUserAsync(user).Returns(Task.CompletedTask);
+
+ // Act
+ var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
+
+ // Assert
+ Assert.True(result.IsT0);
+
+ await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
+ "sub_123",
+ Arg.Is(opts =>
+ opts.AutomaticTax != null &&
+ opts.AutomaticTax.Enabled == true));
+ }
}