From d330fd2082fc5b38bf8f8ee8551ac45ac79a2c92 Mon Sep 17 00:00:00 2001
From: Kyle Denney <4227399+kdenney@users.noreply.github.com>
Date: Wed, 21 Jan 2026 16:43:06 -0600
Subject: [PATCH] updated proration preview and upgrade to be consistent
also using the correct proration behavior and making the upgrade flow start a trial
---
.../Controllers/PreviewInvoiceController.cs | 4 +-
.../UpgradePremiumToOrganizationRequest.cs | 34 +++---
.../PreviewPremiumUpgradeProrationRequest.cs | 25 +++-
.../PreviewPremiumUpgradeProrationCommand.cs | 22 +---
.../UpgradePremiumToOrganizationCommand.cs | 9 +-
...viewPremiumUpgradeProrationRequestTests.cs | 56 +++++++++
...pgradePremiumToOrganizationRequestTests.cs | 60 ++++++++++
...viewPremiumUpgradeProrationCommandTests.cs | 113 ++----------------
...pgradePremiumToOrganizationCommandTests.cs | 83 ++++++++++++-
9 files changed, 259 insertions(+), 147 deletions(-)
create mode 100644 test/Api.Test/Billing/Models/Requests/PreviewPremiumUpgradeProrationRequestTests.cs
create mode 100644 test/Api.Test/Billing/Models/Requests/UpgradePremiumToOrganizationRequestTests.cs
diff --git a/src/Api/Billing/Controllers/PreviewInvoiceController.cs b/src/Api/Billing/Controllers/PreviewInvoiceController.cs
index d55e7299d6..b2e2dcbad9 100644
--- a/src/Api/Billing/Controllers/PreviewInvoiceController.cs
+++ b/src/Api/Billing/Controllers/PreviewInvoiceController.cs
@@ -63,11 +63,11 @@ public class PreviewInvoiceController(
[BindNever] User user,
[FromBody] PreviewPremiumUpgradeProrationRequest request)
{
- var (tierType, billingAddress) = request.ToDomain();
+ var (planType, billingAddress) = request.ToDomain();
var result = await previewPremiumUpgradeProrationCommand.Run(
user,
- tierType,
+ planType,
billingAddress);
return Handle(result.Map(pair => new { pair.Tax, pair.Total, pair.Credit }));
diff --git a/src/Api/Billing/Models/Requests/Premium/UpgradePremiumToOrganizationRequest.cs b/src/Api/Billing/Models/Requests/Premium/UpgradePremiumToOrganizationRequest.cs
index 14375efc78..7e7bf3ef3a 100644
--- a/src/Api/Billing/Models/Requests/Premium/UpgradePremiumToOrganizationRequest.cs
+++ b/src/Api/Billing/Models/Requests/Premium/UpgradePremiumToOrganizationRequest.cs
@@ -1,5 +1,6 @@
using System.ComponentModel.DataAnnotations;
using System.Text.Json.Serialization;
+using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Core.Billing.Enums;
namespace Bit.Api.Billing.Models.Requests.Premium;
@@ -14,24 +15,29 @@ public class UpgradePremiumToOrganizationRequest
[Required]
[JsonConverter(typeof(JsonStringEnumConverter))]
- public ProductTierType Tier { get; set; }
+ public required ProductTierType TargetProductTierType { get; set; }
[Required]
- [JsonConverter(typeof(JsonStringEnumConverter))]
- public PlanCadenceType Cadence { get; set; }
+ public required MinimalBillingAddressRequest BillingAddress { get; set; }
- private PlanType PlanType =>
- Tier switch
+ private PlanType PlanType
+ {
+ get
{
- ProductTierType.Families => PlanType.FamiliesAnnually,
- ProductTierType.Teams => Cadence == PlanCadenceType.Monthly
- ? PlanType.TeamsMonthly
- : PlanType.TeamsAnnually,
- ProductTierType.Enterprise => Cadence == PlanCadenceType.Monthly
- ? PlanType.EnterpriseMonthly
- : PlanType.EnterpriseAnnually,
- _ => throw new InvalidOperationException("Cannot upgrade to an Organization subscription that isn't Families, Teams or Enterprise.")
- };
+ if (TargetProductTierType is not (ProductTierType.Families or ProductTierType.Teams or ProductTierType.Enterprise))
+ {
+ throw new InvalidOperationException($"Cannot upgrade Premium subscription to {TargetProductTierType} plan.");
+ }
+
+ return TargetProductTierType switch
+ {
+ ProductTierType.Families => PlanType.FamiliesAnnually,
+ ProductTierType.Teams => PlanType.TeamsAnnually,
+ ProductTierType.Enterprise => PlanType.EnterpriseAnnually,
+ _ => throw new InvalidOperationException($"Unexpected ProductTierType: {TargetProductTierType}")
+ };
+ }
+ }
public (string OrganizationName, string Key, PlanType PlanType) ToDomain() => (OrganizationName, Key, PlanType);
}
diff --git a/src/Api/Billing/Models/Requests/PreviewInvoice/PreviewPremiumUpgradeProrationRequest.cs b/src/Api/Billing/Models/Requests/PreviewInvoice/PreviewPremiumUpgradeProrationRequest.cs
index e3c109a155..68d7a8d002 100644
--- a/src/Api/Billing/Models/Requests/PreviewInvoice/PreviewPremiumUpgradeProrationRequest.cs
+++ b/src/Api/Billing/Models/Requests/PreviewInvoice/PreviewPremiumUpgradeProrationRequest.cs
@@ -1,4 +1,5 @@
using System.ComponentModel.DataAnnotations;
+using System.Text.Json.Serialization;
using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Payment.Models;
@@ -8,11 +9,31 @@ namespace Bit.Api.Billing.Models.Requests.PreviewInvoice;
public record PreviewPremiumUpgradeProrationRequest
{
[Required]
+ [JsonConverter(typeof(JsonStringEnumConverter))]
public required ProductTierType TargetProductTierType { get; set; }
[Required]
public required MinimalBillingAddressRequest BillingAddress { get; set; }
- public (ProductTierType, BillingAddress) ToDomain() =>
- (TargetProductTierType, BillingAddress.ToDomain());
+ private PlanType PlanType
+ {
+ get
+ {
+ if (TargetProductTierType is not (ProductTierType.Families or ProductTierType.Teams or ProductTierType.Enterprise))
+ {
+ throw new InvalidOperationException($"Cannot upgrade Premium subscription to {TargetProductTierType} plan.");
+ }
+
+ return TargetProductTierType switch
+ {
+ ProductTierType.Families => PlanType.FamiliesAnnually,
+ ProductTierType.Teams => PlanType.TeamsAnnually,
+ ProductTierType.Enterprise => PlanType.EnterpriseAnnually,
+ _ => throw new InvalidOperationException($"Unexpected ProductTierType: {TargetProductTierType}")
+ };
+ }
+ }
+
+ public (PlanType, BillingAddress) ToDomain() =>
+ (PlanType, BillingAddress.ToDomain());
}
diff --git a/src/Core/Billing/Premium/Commands/PreviewPremiumUpgradeProrationCommand.cs b/src/Core/Billing/Premium/Commands/PreviewPremiumUpgradeProrationCommand.cs
index db5b907f8c..4c864677df 100644
--- a/src/Core/Billing/Premium/Commands/PreviewPremiumUpgradeProrationCommand.cs
+++ b/src/Core/Billing/Premium/Commands/PreviewPremiumUpgradeProrationCommand.cs
@@ -20,12 +20,12 @@ public interface IPreviewPremiumUpgradeProrationCommand
/// Calculates the tax, total cost, and proration credit for upgrading a Premium subscription to an Organization plan.
///
/// The user with an active Premium subscription.
- /// The target organization tier (Families, Teams, or Enterprise).
+ /// The target organization plan type.
/// The billing address for tax calculation.
/// A tuple containing the tax amount, total cost, and proration credit from unused Premium time.
Task> Run(
User user,
- ProductTierType targetProductTierType,
+ PlanType targetPlanType,
BillingAddress billingAddress);
}
@@ -38,7 +38,7 @@ public class PreviewPremiumUpgradeProrationCommand(
{
public Task> Run(
User user,
- ProductTierType targetProductTierType,
+ PlanType targetPlanType,
BillingAddress billingAddress) => HandleAsync<(decimal, decimal, decimal)>(async () =>
{
if (user is not { Premium: true, GatewaySubscriptionId: not null and not "" })
@@ -46,20 +46,6 @@ public class PreviewPremiumUpgradeProrationCommand(
return new BadRequest("User does not have an active Premium subscription.");
}
- if (targetProductTierType is not (ProductTierType.Families or ProductTierType.Teams or ProductTierType.Enterprise))
- {
- return new BadRequest($"Cannot upgrade Premium subscription to {targetProductTierType} plan.");
- }
-
- // Convert ProductTierType to PlanType (for premium upgrade, the only choice is annual plans so we can assume that cadence)
- var targetPlanType = targetProductTierType switch
- {
- ProductTierType.Families => PlanType.FamiliesAnnually,
- ProductTierType.Teams => PlanType.TeamsAnnually,
- ProductTierType.Enterprise => PlanType.EnterpriseAnnually,
- _ => throw new InvalidOperationException($"Unexpected ProductTierType: {targetProductTierType}")
- };
-
// Hardcode seats to 1 for upgrade flow
const int seats = 1;
@@ -128,7 +114,7 @@ public class PreviewPremiumUpgradeProrationCommand(
SubscriptionDetails = new InvoiceSubscriptionDetailsOptions
{
Items = subscriptionItems,
- ProrationBehavior = StripeConstants.ProrationBehavior.AlwaysInvoice
+ ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations
}
};
diff --git a/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs b/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs
index 81bc5c9e2c..50f90808e2 100644
--- a/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs
+++ b/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs
@@ -74,7 +74,7 @@ public class UpgradePremiumToOrganizationCommand(
if (passwordManagerItem == null)
{
- return new BadRequest("Premium subscription item not found.");
+ return new BadRequest("Premium subscription password manager item not found.");
}
var usersPremiumPlan = premiumPlans.First(p => p.Seat.StripePriceId == passwordManagerItem.Price.Id);
@@ -133,7 +133,7 @@ public class UpgradePremiumToOrganizationCommand(
var subscriptionUpdateOptions = new SubscriptionUpdateOptions
{
Items = subscriptionItemOptions,
- ProrationBehavior = StripeConstants.ProrationBehavior.None,
+ ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations,
Metadata = new Dictionary
{
[StripeConstants.MetadataKeys.OrganizationId] = organizationId.ToString(),
@@ -145,6 +145,11 @@ public class UpgradePremiumToOrganizationCommand(
}
};
+ if (targetPlan.TrialPeriodDays.HasValue)
+ {
+ subscriptionUpdateOptions.TrialEnd = DateTime.UtcNow.AddDays((double)targetPlan.TrialPeriodDays);
+ }
+
// Create the Organization entity
var organization = new Organization
{
diff --git a/test/Api.Test/Billing/Models/Requests/PreviewPremiumUpgradeProrationRequestTests.cs b/test/Api.Test/Billing/Models/Requests/PreviewPremiumUpgradeProrationRequestTests.cs
new file mode 100644
index 0000000000..5ed4182a5d
--- /dev/null
+++ b/test/Api.Test/Billing/Models/Requests/PreviewPremiumUpgradeProrationRequestTests.cs
@@ -0,0 +1,56 @@
+using Bit.Api.Billing.Models.Requests.Payment;
+using Bit.Api.Billing.Models.Requests.PreviewInvoice;
+using Bit.Core.Billing.Enums;
+using Xunit;
+
+namespace Bit.Api.Test.Billing.Models.Requests;
+
+public class PreviewPremiumUpgradeProrationRequestTests
+{
+ [Theory]
+ [InlineData(ProductTierType.Families, PlanType.FamiliesAnnually)]
+ [InlineData(ProductTierType.Teams, PlanType.TeamsAnnually)]
+ [InlineData(ProductTierType.Enterprise, PlanType.EnterpriseAnnually)]
+ public void ToDomain_ValidTierTypes_ReturnsPlanType(ProductTierType tierType, PlanType expectedPlanType)
+ {
+ // Arrange
+ var sut = new PreviewPremiumUpgradeProrationRequest
+ {
+ TargetProductTierType = tierType,
+ BillingAddress = new MinimalBillingAddressRequest
+ {
+ Country = "US",
+ PostalCode = "12345"
+ }
+ };
+
+ // Act
+ var (planType, billingAddress) = sut.ToDomain();
+
+ // Assert
+ Assert.Equal(expectedPlanType, planType);
+ Assert.Equal("US", billingAddress.Country);
+ Assert.Equal("12345", billingAddress.PostalCode);
+ }
+
+ [Theory]
+ [InlineData(ProductTierType.Free)]
+ [InlineData(ProductTierType.TeamsStarter)]
+ public void ToDomain_InvalidTierTypes_ThrowsInvalidOperationException(ProductTierType tierType)
+ {
+ // Arrange
+ var sut = new PreviewPremiumUpgradeProrationRequest
+ {
+ TargetProductTierType = tierType,
+ BillingAddress = new MinimalBillingAddressRequest
+ {
+ Country = "US",
+ PostalCode = "12345"
+ }
+ };
+
+ // Act & Assert
+ var exception = Assert.Throws(() => sut.ToDomain());
+ Assert.Contains($"Cannot upgrade Premium subscription to {tierType} plan", exception.Message);
+ }
+}
diff --git a/test/Api.Test/Billing/Models/Requests/UpgradePremiumToOrganizationRequestTests.cs b/test/Api.Test/Billing/Models/Requests/UpgradePremiumToOrganizationRequestTests.cs
new file mode 100644
index 0000000000..6a8e328bf3
--- /dev/null
+++ b/test/Api.Test/Billing/Models/Requests/UpgradePremiumToOrganizationRequestTests.cs
@@ -0,0 +1,60 @@
+using Bit.Api.Billing.Models.Requests.Payment;
+using Bit.Api.Billing.Models.Requests.Premium;
+using Bit.Core.Billing.Enums;
+using Xunit;
+
+namespace Bit.Api.Test.Billing.Models.Requests;
+
+public class UpgradePremiumToOrganizationRequestTests
+{
+ [Theory]
+ [InlineData(ProductTierType.Families, PlanType.FamiliesAnnually)]
+ [InlineData(ProductTierType.Teams, PlanType.TeamsAnnually)]
+ [InlineData(ProductTierType.Enterprise, PlanType.EnterpriseAnnually)]
+ public void ToDomain_ValidTierTypes_ReturnsPlanType(ProductTierType tierType, PlanType expectedPlanType)
+ {
+ // Arrange
+ var sut = new UpgradePremiumToOrganizationRequest
+ {
+ OrganizationName = "Test Organization",
+ Key = "encrypted-key",
+ TargetProductTierType = tierType,
+ BillingAddress = new MinimalBillingAddressRequest
+ {
+ Country = "US",
+ PostalCode = "12345"
+ }
+ };
+
+ // Act
+ var (organizationName, key, planType) = sut.ToDomain();
+
+ // Assert
+ Assert.Equal("Test Organization", organizationName);
+ Assert.Equal("encrypted-key", key);
+ Assert.Equal(expectedPlanType, planType);
+ }
+
+ [Theory]
+ [InlineData(ProductTierType.Free)]
+ [InlineData(ProductTierType.TeamsStarter)]
+ public void ToDomain_InvalidTierTypes_ThrowsInvalidOperationException(ProductTierType tierType)
+ {
+ // Arrange
+ var sut = new UpgradePremiumToOrganizationRequest
+ {
+ OrganizationName = "Test Organization",
+ Key = "encrypted-key",
+ TargetProductTierType = tierType,
+ BillingAddress = new MinimalBillingAddressRequest
+ {
+ Country = "US",
+ PostalCode = "12345"
+ }
+ };
+
+ // Act & Assert
+ var exception = Assert.Throws(() => sut.ToDomain());
+ Assert.Contains($"Cannot upgrade Premium subscription to {tierType} plan", exception.Message);
+ }
+}
diff --git a/test/Core.Test/Billing/Premium/Commands/PreviewPremiumUpgradeProrationCommandTests.cs b/test/Core.Test/Billing/Premium/Commands/PreviewPremiumUpgradeProrationCommandTests.cs
index a681797a82..475c458361 100644
--- a/test/Core.Test/Billing/Premium/Commands/PreviewPremiumUpgradeProrationCommandTests.cs
+++ b/test/Core.Test/Billing/Premium/Commands/PreviewPremiumUpgradeProrationCommandTests.cs
@@ -36,7 +36,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
user.Premium = false;
// Act
- var result = await _command.Run(user, ProductTierType.Teams, billingAddress);
+ var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert
Assert.True(result.IsT1);
@@ -52,7 +52,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
user.GatewaySubscriptionId = null;
// Act
- var result = await _command.Run(user, ProductTierType.Teams, billingAddress);
+ var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert
Assert.True(result.IsT1);
@@ -136,7 +136,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
.Returns(invoice);
// Act
- var result = await _command.Run(user, ProductTierType.Teams, billingAddress);
+ var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert
Assert.True(result.IsT0);
@@ -212,7 +212,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
.Returns(invoice);
// Act
- var result = await _command.Run(user, ProductTierType.Teams, billingAddress);
+ var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert
Assert.True(result.IsT0);
@@ -279,7 +279,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
.Returns(invoice);
// Act
- await _command.Run(user, ProductTierType.Teams, billingAddress);
+ await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert - Verify that the subscription item quantity is always 1
await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(
@@ -289,82 +289,6 @@ public class PreviewPremiumUpgradeProrationCommandTests
item.Quantity == 1)));
}
- [Theory]
- [InlineData(ProductTierType.Families, PlanType.FamiliesAnnually)]
- [InlineData(ProductTierType.Teams, PlanType.TeamsAnnually)]
- [InlineData(ProductTierType.Enterprise, PlanType.EnterpriseAnnually)]
- public async Task Run_ProductTierTypeConversion_MapsToCorrectPlanType(
- ProductTierType productTierType,
- PlanType expectedPlanType)
- {
- // Arrange
- var user = new User
- {
- Premium = true,
- GatewaySubscriptionId = "sub_123",
- GatewayCustomerId = "cus_123"
- };
- var billingAddress = new BillingAddress
- {
- Country = "US",
- PostalCode = "12345"
- };
-
- var premiumPlan = new PremiumPlan
- {
- Name = "Premium",
- Available = true,
- LegacyYear = null,
- Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable
- {
- StripePriceId = "premium-annually",
- Price = 10m,
- Provided = 1
- },
- Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable
- {
- StripePriceId = "storage-gb-annually",
- Price = 4m,
- Provided = 1
- }
- };
- var premiumPlans = new List { premiumPlan };
-
- var currentSubscription = new Subscription
- {
- Id = "sub_123",
- Customer = new Customer { Id = "cus_123", Discount = null },
- Items = new StripeList
- {
- Data = new List
- {
- new() { Id = "si_premium", Price = new Price { Id = "premium-annually" } }
- }
- }
- };
-
- var targetPlan = new TeamsPlan(isAnnual: true);
-
- var invoice = new Invoice
- {
- Total = 5000,
- TotalTaxes = new List { new() { Amount = 500 } }
- };
-
- _pricingClient.ListPremiumPlans().Returns(premiumPlans);
- _pricingClient.GetPlanOrThrow(expectedPlanType).Returns(targetPlan);
- _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any())
- .Returns(currentSubscription);
- _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any())
- .Returns(invoice);
-
- // Act
- await _command.Run(user, productTierType, billingAddress);
-
- // Assert - Verify that the correct PlanType was used
- await _pricingClient.Received(1).GetPlanOrThrow(expectedPlanType);
- }
-
[Theory, BitAutoData]
public async Task Run_ValidUpgrade_DeletesPremiumSubscriptionItems(User user, BillingAddress billingAddress)
{
@@ -423,7 +347,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
.Returns(invoice);
// Act
- await _command.Run(user, ProductTierType.Teams, billingAddress);
+ await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert - Verify both password manager and storage items are marked as deleted
await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(
@@ -491,7 +415,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
.Returns(invoice);
// Act
- await _command.Run(user, ProductTierType.Families, billingAddress);
+ await _command.Run(user, PlanType.FamiliesAnnually, billingAddress);
// Assert - Verify non-seat-based plan uses StripePlanId with quantity 1
await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(
@@ -560,7 +484,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
.Returns(invoice);
// Act
- await _command.Run(user, ProductTierType.Teams, billingAddress);
+ await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert - Verify all invoice preview options are correct
await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(
@@ -570,24 +494,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
options.Subscription == "sub_123" &&
options.CustomerDetails.Address.Country == "US" &&
options.CustomerDetails.Address.PostalCode == "12345" &&
- options.SubscriptionDetails.ProrationBehavior == "always_invoice"));
- }
-
- [Theory, BitAutoData]
- public async Task Run_TeamsStarterTierType_ReturnsBadRequest(User user, BillingAddress billingAddress)
- {
- // Arrange
- user.Premium = true;
- user.GatewaySubscriptionId = "sub_123";
- user.GatewayCustomerId = "cus_123";
-
- // Act
- var result = await _command.Run(user, ProductTierType.TeamsStarter, billingAddress);
-
- // Assert
- Assert.True(result.IsT1);
- var badRequest = result.AsT1;
- Assert.Equal("Cannot upgrade Premium subscription to TeamsStarter plan.", badRequest.Response);
+ options.SubscriptionDetails.ProrationBehavior == "create_prorations"));
}
[Theory, BitAutoData]
@@ -648,7 +555,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
.Returns(invoice);
// Act
- await _command.Run(user, ProductTierType.Teams, billingAddress);
+ await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert - Verify seat-based plan uses StripeSeatPlanId with quantity 1
await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(
diff --git a/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs b/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs
index e686d04009..4ead5c12da 100644
--- a/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs
+++ b/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs
@@ -28,7 +28,8 @@ public class UpgradePremiumToOrganizationCommandTests
string? stripePlanId = null,
string? stripeSeatPlanId = null,
string? stripePremiumAccessPlanId = null,
- string? stripeStoragePlanId = null)
+ string? stripeStoragePlanId = null,
+ int? trialPeriodDays = null)
{
Type = planType;
ProductTier = ProductTierType.Teams;
@@ -37,7 +38,7 @@ public class UpgradePremiumToOrganizationCommandTests
NameLocalizationKey = "";
DescriptionLocalizationKey = "";
CanBeUsedByBusiness = true;
- TrialPeriodDays = null;
+ TrialPeriodDays = trialPeriodDays;
HasSelfHost = false;
HasPolicies = false;
HasGroups = false;
@@ -86,10 +87,9 @@ public class UpgradePremiumToOrganizationCommandTests
string? stripePlanId = null,
string? stripeSeatPlanId = null,
string? stripePremiumAccessPlanId = null,
- string? stripeStoragePlanId = null)
- {
- return new TestPlan(planType, stripePlanId, stripeSeatPlanId, stripePremiumAccessPlanId, stripeStoragePlanId);
- }
+ string? stripeStoragePlanId = null,
+ int? trialPeriodDays = null) =>
+ new TestPlan(planType, stripePlanId, stripeSeatPlanId, stripePremiumAccessPlanId, stripeStoragePlanId, trialPeriodDays);
private static PremiumPlan CreateTestPremiumPlan(
string seatPriceId = "premium-annually",
@@ -643,4 +643,75 @@ public class UpgradePremiumToOrganizationCommandTests
var badRequest = result.AsT1;
Assert.Equal("Premium subscription item not found.", badRequest.Response);
}
+
+ [Theory, BitAutoData]
+ public async Task Run_PlanWithTrialPeriod_SetsTrialEnd(User user)
+ {
+ // Arrange
+ user.Premium = true;
+ user.GatewaySubscriptionId = "sub_123";
+ user.GatewayCustomerId = "cus_123";
+
+ var currentPeriodEnd = DateTime.UtcNow.AddMonths(1);
+ var mockSubscription = new Subscription
+ {
+ Id = "sub_123",
+ Items = new StripeList
+ {
+ Data = new List
+ {
+ new SubscriptionItem
+ {
+ Id = "si_premium",
+ Price = new Price { Id = "premium-annually" },
+ CurrentPeriodEnd = currentPeriodEnd
+ }
+ }
+ },
+ Metadata = new Dictionary()
+ };
+
+ var mockPremiumPlans = CreateTestPremiumPlansList();
+
+ // Create a plan with a trial period
+ var mockPlan = CreateTestPlan(
+ PlanType.TeamsAnnually,
+ stripeSeatPlanId: "teams-seat-annually",
+ trialPeriodDays: 7
+ );
+
+ // Capture the subscription update options to verify TrialEnd is set
+ SubscriptionUpdateOptions capturedOptions = null;
+
+ _stripeAdapter.GetSubscriptionAsync("sub_123")
+ .Returns(mockSubscription);
+ _pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
+ _pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan);
+ _stripeAdapter.UpdateSubscriptionAsync(Arg.Any(), Arg.Do(opts => capturedOptions = opts))
+ .Returns(Task.FromResult(mockSubscription));
+ _organizationRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg()));
+ _organizationApiKeyRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg()));
+ _organizationUserRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg()));
+ _applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any()).Returns(Task.CompletedTask);
+ _userService.SaveUserAsync(user).Returns(Task.CompletedTask);
+
+ // Act
+ var testStartTime = DateTime.UtcNow;
+ var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
+ var testEndTime = DateTime.UtcNow;
+
+ // Assert
+ Assert.True(result.IsT0);
+
+ await _stripeAdapter.Received(1).UpdateSubscriptionAsync("sub_123", Arg.Any());
+
+ Assert.NotNull(capturedOptions);
+ Assert.NotNull(capturedOptions.TrialEnd);
+
+ // TrialEnd is AnyOf - verify it's a DateTime
+ var trialEndDateTime = capturedOptions.TrialEnd.Value as DateTime?;
+ Assert.NotNull(trialEndDateTime);
+ Assert.True(trialEndDateTime.Value >= testStartTime.AddDays(7));
+ Assert.True(trialEndDateTime.Value <= testEndTime.AddDays(7));
+ }
}