1
0
mirror of https://github.com/bitwarden/server synced 2026-02-28 10:23:24 +00:00

updated proration preview and upgrade to be consistent

also using the correct proration behavior and making the upgrade flow start a trial
This commit is contained in:
Kyle Denney
2026-01-21 16:43:06 -06:00
parent db456eb4b5
commit d330fd2082
9 changed files with 259 additions and 147 deletions

View File

@@ -36,7 +36,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
user.Premium = false;
// Act
var result = await _command.Run(user, ProductTierType.Teams, billingAddress);
var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert
Assert.True(result.IsT1);
@@ -52,7 +52,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
user.GatewaySubscriptionId = null;
// Act
var result = await _command.Run(user, ProductTierType.Teams, billingAddress);
var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert
Assert.True(result.IsT1);
@@ -136,7 +136,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
.Returns(invoice);
// Act
var result = await _command.Run(user, ProductTierType.Teams, billingAddress);
var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert
Assert.True(result.IsT0);
@@ -212,7 +212,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
.Returns(invoice);
// Act
var result = await _command.Run(user, ProductTierType.Teams, billingAddress);
var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert
Assert.True(result.IsT0);
@@ -279,7 +279,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
.Returns(invoice);
// Act
await _command.Run(user, ProductTierType.Teams, billingAddress);
await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert - Verify that the subscription item quantity is always 1
await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(
@@ -289,82 +289,6 @@ public class PreviewPremiumUpgradeProrationCommandTests
item.Quantity == 1)));
}
[Theory]
[InlineData(ProductTierType.Families, PlanType.FamiliesAnnually)]
[InlineData(ProductTierType.Teams, PlanType.TeamsAnnually)]
[InlineData(ProductTierType.Enterprise, PlanType.EnterpriseAnnually)]
public async Task Run_ProductTierTypeConversion_MapsToCorrectPlanType(
ProductTierType productTierType,
PlanType expectedPlanType)
{
// Arrange
var user = new User
{
Premium = true,
GatewaySubscriptionId = "sub_123",
GatewayCustomerId = "cus_123"
};
var billingAddress = new BillingAddress
{
Country = "US",
PostalCode = "12345"
};
var premiumPlan = new PremiumPlan
{
Name = "Premium",
Available = true,
LegacyYear = null,
Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "premium-annually",
Price = 10m,
Provided = 1
},
Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "storage-gb-annually",
Price = 4m,
Provided = 1
}
};
var premiumPlans = new List<PremiumPlan> { premiumPlan };
var currentSubscription = new Subscription
{
Id = "sub_123",
Customer = new Customer { Id = "cus_123", Discount = null },
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new() { Id = "si_premium", Price = new Price { Id = "premium-annually" } }
}
}
};
var targetPlan = new TeamsPlan(isAnnual: true);
var invoice = new Invoice
{
Total = 5000,
TotalTaxes = new List<InvoiceTotalTax> { new() { Amount = 500 } }
};
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_pricingClient.GetPlanOrThrow(expectedPlanType).Returns(targetPlan);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>())
.Returns(currentSubscription);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(invoice);
// Act
await _command.Run(user, productTierType, billingAddress);
// Assert - Verify that the correct PlanType was used
await _pricingClient.Received(1).GetPlanOrThrow(expectedPlanType);
}
[Theory, BitAutoData]
public async Task Run_ValidUpgrade_DeletesPremiumSubscriptionItems(User user, BillingAddress billingAddress)
{
@@ -423,7 +347,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
.Returns(invoice);
// Act
await _command.Run(user, ProductTierType.Teams, billingAddress);
await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert - Verify both password manager and storage items are marked as deleted
await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(
@@ -491,7 +415,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
.Returns(invoice);
// Act
await _command.Run(user, ProductTierType.Families, billingAddress);
await _command.Run(user, PlanType.FamiliesAnnually, billingAddress);
// Assert - Verify non-seat-based plan uses StripePlanId with quantity 1
await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(
@@ -560,7 +484,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
.Returns(invoice);
// Act
await _command.Run(user, ProductTierType.Teams, billingAddress);
await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert - Verify all invoice preview options are correct
await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(
@@ -570,24 +494,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
options.Subscription == "sub_123" &&
options.CustomerDetails.Address.Country == "US" &&
options.CustomerDetails.Address.PostalCode == "12345" &&
options.SubscriptionDetails.ProrationBehavior == "always_invoice"));
}
[Theory, BitAutoData]
public async Task Run_TeamsStarterTierType_ReturnsBadRequest(User user, BillingAddress billingAddress)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
// Act
var result = await _command.Run(user, ProductTierType.TeamsStarter, billingAddress);
// Assert
Assert.True(result.IsT1);
var badRequest = result.AsT1;
Assert.Equal("Cannot upgrade Premium subscription to TeamsStarter plan.", badRequest.Response);
options.SubscriptionDetails.ProrationBehavior == "create_prorations"));
}
[Theory, BitAutoData]
@@ -648,7 +555,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
.Returns(invoice);
// Act
await _command.Run(user, ProductTierType.Teams, billingAddress);
await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert - Verify seat-based plan uses StripeSeatPlanId with quantity 1
await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(

View File

@@ -28,7 +28,8 @@ public class UpgradePremiumToOrganizationCommandTests
string? stripePlanId = null,
string? stripeSeatPlanId = null,
string? stripePremiumAccessPlanId = null,
string? stripeStoragePlanId = null)
string? stripeStoragePlanId = null,
int? trialPeriodDays = null)
{
Type = planType;
ProductTier = ProductTierType.Teams;
@@ -37,7 +38,7 @@ public class UpgradePremiumToOrganizationCommandTests
NameLocalizationKey = "";
DescriptionLocalizationKey = "";
CanBeUsedByBusiness = true;
TrialPeriodDays = null;
TrialPeriodDays = trialPeriodDays;
HasSelfHost = false;
HasPolicies = false;
HasGroups = false;
@@ -86,10 +87,9 @@ public class UpgradePremiumToOrganizationCommandTests
string? stripePlanId = null,
string? stripeSeatPlanId = null,
string? stripePremiumAccessPlanId = null,
string? stripeStoragePlanId = null)
{
return new TestPlan(planType, stripePlanId, stripeSeatPlanId, stripePremiumAccessPlanId, stripeStoragePlanId);
}
string? stripeStoragePlanId = null,
int? trialPeriodDays = null) =>
new TestPlan(planType, stripePlanId, stripeSeatPlanId, stripePremiumAccessPlanId, stripeStoragePlanId, trialPeriodDays);
private static PremiumPlan CreateTestPremiumPlan(
string seatPriceId = "premium-annually",
@@ -643,4 +643,75 @@ public class UpgradePremiumToOrganizationCommandTests
var badRequest = result.AsT1;
Assert.Equal("Premium subscription item not found.", badRequest.Response);
}
[Theory, BitAutoData]
public async Task Run_PlanWithTrialPeriod_SetsTrialEnd(User user)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
var currentPeriodEnd = DateTime.UtcNow.AddMonths(1);
var mockSubscription = new Subscription
{
Id = "sub_123",
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new SubscriptionItem
{
Id = "si_premium",
Price = new Price { Id = "premium-annually" },
CurrentPeriodEnd = currentPeriodEnd
}
}
},
Metadata = new Dictionary<string, string>()
};
var mockPremiumPlans = CreateTestPremiumPlansList();
// Create a plan with a trial period
var mockPlan = CreateTestPlan(
PlanType.TeamsAnnually,
stripeSeatPlanId: "teams-seat-annually",
trialPeriodDays: 7
);
// Capture the subscription update options to verify TrialEnd is set
SubscriptionUpdateOptions capturedOptions = null;
_stripeAdapter.GetSubscriptionAsync("sub_123")
.Returns(mockSubscription);
_pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan);
_stripeAdapter.UpdateSubscriptionAsync(Arg.Any<string>(), Arg.Do<SubscriptionUpdateOptions>(opts => capturedOptions = opts))
.Returns(Task.FromResult(mockSubscription));
_organizationRepository.CreateAsync(Arg.Any<Organization>()).Returns(callInfo => Task.FromResult(callInfo.Arg<Organization>()));
_organizationApiKeyRepository.CreateAsync(Arg.Any<OrganizationApiKey>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationApiKey>()));
_organizationUserRepository.CreateAsync(Arg.Any<OrganizationUser>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationUser>()));
_applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any<Organization>()).Returns(Task.CompletedTask);
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
var testStartTime = DateTime.UtcNow;
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
var testEndTime = DateTime.UtcNow;
// Assert
Assert.True(result.IsT0);
await _stripeAdapter.Received(1).UpdateSubscriptionAsync("sub_123", Arg.Any<SubscriptionUpdateOptions>());
Assert.NotNull(capturedOptions);
Assert.NotNull(capturedOptions.TrialEnd);
// TrialEnd is AnyOf<DateTime?, SubscriptionTrialEnd> - verify it's a DateTime
var trialEndDateTime = capturedOptions.TrialEnd.Value as DateTime?;
Assert.NotNull(trialEndDateTime);
Assert.True(trialEndDateTime.Value >= testStartTime.AddDays(7));
Assert.True(trialEndDateTime.Value <= testEndTime.AddDays(7));
}
}