1
0
mirror of https://github.com/bitwarden/server synced 2026-01-27 14:53:21 +00:00

updated proration preview and upgrade to be consistent

also using the correct proration behavior and making the upgrade flow start a trial
This commit is contained in:
Kyle Denney
2026-01-21 16:43:06 -06:00
parent db456eb4b5
commit d330fd2082
9 changed files with 259 additions and 147 deletions

View File

@@ -63,11 +63,11 @@ public class PreviewInvoiceController(
[BindNever] User user,
[FromBody] PreviewPremiumUpgradeProrationRequest request)
{
var (tierType, billingAddress) = request.ToDomain();
var (planType, billingAddress) = request.ToDomain();
var result = await previewPremiumUpgradeProrationCommand.Run(
user,
tierType,
planType,
billingAddress);
return Handle(result.Map(pair => new { pair.Tax, pair.Total, pair.Credit }));

View File

@@ -1,5 +1,6 @@
using System.ComponentModel.DataAnnotations;
using System.Text.Json.Serialization;
using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Core.Billing.Enums;
namespace Bit.Api.Billing.Models.Requests.Premium;
@@ -14,24 +15,29 @@ public class UpgradePremiumToOrganizationRequest
[Required]
[JsonConverter(typeof(JsonStringEnumConverter))]
public ProductTierType Tier { get; set; }
public required ProductTierType TargetProductTierType { get; set; }
[Required]
[JsonConverter(typeof(JsonStringEnumConverter))]
public PlanCadenceType Cadence { get; set; }
public required MinimalBillingAddressRequest BillingAddress { get; set; }
private PlanType PlanType =>
Tier switch
private PlanType PlanType
{
get
{
ProductTierType.Families => PlanType.FamiliesAnnually,
ProductTierType.Teams => Cadence == PlanCadenceType.Monthly
? PlanType.TeamsMonthly
: PlanType.TeamsAnnually,
ProductTierType.Enterprise => Cadence == PlanCadenceType.Monthly
? PlanType.EnterpriseMonthly
: PlanType.EnterpriseAnnually,
_ => throw new InvalidOperationException("Cannot upgrade to an Organization subscription that isn't Families, Teams or Enterprise.")
};
if (TargetProductTierType is not (ProductTierType.Families or ProductTierType.Teams or ProductTierType.Enterprise))
{
throw new InvalidOperationException($"Cannot upgrade Premium subscription to {TargetProductTierType} plan.");
}
return TargetProductTierType switch
{
ProductTierType.Families => PlanType.FamiliesAnnually,
ProductTierType.Teams => PlanType.TeamsAnnually,
ProductTierType.Enterprise => PlanType.EnterpriseAnnually,
_ => throw new InvalidOperationException($"Unexpected ProductTierType: {TargetProductTierType}")
};
}
}
public (string OrganizationName, string Key, PlanType PlanType) ToDomain() => (OrganizationName, Key, PlanType);
}

View File

@@ -1,4 +1,5 @@
using System.ComponentModel.DataAnnotations;
using System.Text.Json.Serialization;
using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Payment.Models;
@@ -8,11 +9,31 @@ namespace Bit.Api.Billing.Models.Requests.PreviewInvoice;
public record PreviewPremiumUpgradeProrationRequest
{
[Required]
[JsonConverter(typeof(JsonStringEnumConverter))]
public required ProductTierType TargetProductTierType { get; set; }
[Required]
public required MinimalBillingAddressRequest BillingAddress { get; set; }
public (ProductTierType, BillingAddress) ToDomain() =>
(TargetProductTierType, BillingAddress.ToDomain());
private PlanType PlanType
{
get
{
if (TargetProductTierType is not (ProductTierType.Families or ProductTierType.Teams or ProductTierType.Enterprise))
{
throw new InvalidOperationException($"Cannot upgrade Premium subscription to {TargetProductTierType} plan.");
}
return TargetProductTierType switch
{
ProductTierType.Families => PlanType.FamiliesAnnually,
ProductTierType.Teams => PlanType.TeamsAnnually,
ProductTierType.Enterprise => PlanType.EnterpriseAnnually,
_ => throw new InvalidOperationException($"Unexpected ProductTierType: {TargetProductTierType}")
};
}
}
public (PlanType, BillingAddress) ToDomain() =>
(PlanType, BillingAddress.ToDomain());
}

View File

@@ -20,12 +20,12 @@ public interface IPreviewPremiumUpgradeProrationCommand
/// Calculates the tax, total cost, and proration credit for upgrading a Premium subscription to an Organization plan.
/// </summary>
/// <param name="user">The user with an active Premium subscription.</param>
/// <param name="targetProductTierType">The target organization tier (Families, Teams, or Enterprise).</param>
/// <param name="targetPlanType">The target organization plan type.</param>
/// <param name="billingAddress">The billing address for tax calculation.</param>
/// <returns>A tuple containing the tax amount, total cost, and proration credit from unused Premium time.</returns>
Task<BillingCommandResult<(decimal Tax, decimal Total, decimal Credit)>> Run(
User user,
ProductTierType targetProductTierType,
PlanType targetPlanType,
BillingAddress billingAddress);
}
@@ -38,7 +38,7 @@ public class PreviewPremiumUpgradeProrationCommand(
{
public Task<BillingCommandResult<(decimal Tax, decimal Total, decimal Credit)>> Run(
User user,
ProductTierType targetProductTierType,
PlanType targetPlanType,
BillingAddress billingAddress) => HandleAsync<(decimal, decimal, decimal)>(async () =>
{
if (user is not { Premium: true, GatewaySubscriptionId: not null and not "" })
@@ -46,20 +46,6 @@ public class PreviewPremiumUpgradeProrationCommand(
return new BadRequest("User does not have an active Premium subscription.");
}
if (targetProductTierType is not (ProductTierType.Families or ProductTierType.Teams or ProductTierType.Enterprise))
{
return new BadRequest($"Cannot upgrade Premium subscription to {targetProductTierType} plan.");
}
// Convert ProductTierType to PlanType (for premium upgrade, the only choice is annual plans so we can assume that cadence)
var targetPlanType = targetProductTierType switch
{
ProductTierType.Families => PlanType.FamiliesAnnually,
ProductTierType.Teams => PlanType.TeamsAnnually,
ProductTierType.Enterprise => PlanType.EnterpriseAnnually,
_ => throw new InvalidOperationException($"Unexpected ProductTierType: {targetProductTierType}")
};
// Hardcode seats to 1 for upgrade flow
const int seats = 1;
@@ -128,7 +114,7 @@ public class PreviewPremiumUpgradeProrationCommand(
SubscriptionDetails = new InvoiceSubscriptionDetailsOptions
{
Items = subscriptionItems,
ProrationBehavior = StripeConstants.ProrationBehavior.AlwaysInvoice
ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations
}
};

View File

@@ -74,7 +74,7 @@ public class UpgradePremiumToOrganizationCommand(
if (passwordManagerItem == null)
{
return new BadRequest("Premium subscription item not found.");
return new BadRequest("Premium subscription password manager item not found.");
}
var usersPremiumPlan = premiumPlans.First(p => p.Seat.StripePriceId == passwordManagerItem.Price.Id);
@@ -133,7 +133,7 @@ public class UpgradePremiumToOrganizationCommand(
var subscriptionUpdateOptions = new SubscriptionUpdateOptions
{
Items = subscriptionItemOptions,
ProrationBehavior = StripeConstants.ProrationBehavior.None,
ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations,
Metadata = new Dictionary<string, string>
{
[StripeConstants.MetadataKeys.OrganizationId] = organizationId.ToString(),
@@ -145,6 +145,11 @@ public class UpgradePremiumToOrganizationCommand(
}
};
if (targetPlan.TrialPeriodDays.HasValue)
{
subscriptionUpdateOptions.TrialEnd = DateTime.UtcNow.AddDays((double)targetPlan.TrialPeriodDays);
}
// Create the Organization entity
var organization = new Organization
{

View File

@@ -0,0 +1,56 @@
using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Api.Billing.Models.Requests.PreviewInvoice;
using Bit.Core.Billing.Enums;
using Xunit;
namespace Bit.Api.Test.Billing.Models.Requests;
public class PreviewPremiumUpgradeProrationRequestTests
{
[Theory]
[InlineData(ProductTierType.Families, PlanType.FamiliesAnnually)]
[InlineData(ProductTierType.Teams, PlanType.TeamsAnnually)]
[InlineData(ProductTierType.Enterprise, PlanType.EnterpriseAnnually)]
public void ToDomain_ValidTierTypes_ReturnsPlanType(ProductTierType tierType, PlanType expectedPlanType)
{
// Arrange
var sut = new PreviewPremiumUpgradeProrationRequest
{
TargetProductTierType = tierType,
BillingAddress = new MinimalBillingAddressRequest
{
Country = "US",
PostalCode = "12345"
}
};
// Act
var (planType, billingAddress) = sut.ToDomain();
// Assert
Assert.Equal(expectedPlanType, planType);
Assert.Equal("US", billingAddress.Country);
Assert.Equal("12345", billingAddress.PostalCode);
}
[Theory]
[InlineData(ProductTierType.Free)]
[InlineData(ProductTierType.TeamsStarter)]
public void ToDomain_InvalidTierTypes_ThrowsInvalidOperationException(ProductTierType tierType)
{
// Arrange
var sut = new PreviewPremiumUpgradeProrationRequest
{
TargetProductTierType = tierType,
BillingAddress = new MinimalBillingAddressRequest
{
Country = "US",
PostalCode = "12345"
}
};
// Act & Assert
var exception = Assert.Throws<InvalidOperationException>(() => sut.ToDomain());
Assert.Contains($"Cannot upgrade Premium subscription to {tierType} plan", exception.Message);
}
}

View File

@@ -0,0 +1,60 @@
using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Api.Billing.Models.Requests.Premium;
using Bit.Core.Billing.Enums;
using Xunit;
namespace Bit.Api.Test.Billing.Models.Requests;
public class UpgradePremiumToOrganizationRequestTests
{
[Theory]
[InlineData(ProductTierType.Families, PlanType.FamiliesAnnually)]
[InlineData(ProductTierType.Teams, PlanType.TeamsAnnually)]
[InlineData(ProductTierType.Enterprise, PlanType.EnterpriseAnnually)]
public void ToDomain_ValidTierTypes_ReturnsPlanType(ProductTierType tierType, PlanType expectedPlanType)
{
// Arrange
var sut = new UpgradePremiumToOrganizationRequest
{
OrganizationName = "Test Organization",
Key = "encrypted-key",
TargetProductTierType = tierType,
BillingAddress = new MinimalBillingAddressRequest
{
Country = "US",
PostalCode = "12345"
}
};
// Act
var (organizationName, key, planType) = sut.ToDomain();
// Assert
Assert.Equal("Test Organization", organizationName);
Assert.Equal("encrypted-key", key);
Assert.Equal(expectedPlanType, planType);
}
[Theory]
[InlineData(ProductTierType.Free)]
[InlineData(ProductTierType.TeamsStarter)]
public void ToDomain_InvalidTierTypes_ThrowsInvalidOperationException(ProductTierType tierType)
{
// Arrange
var sut = new UpgradePremiumToOrganizationRequest
{
OrganizationName = "Test Organization",
Key = "encrypted-key",
TargetProductTierType = tierType,
BillingAddress = new MinimalBillingAddressRequest
{
Country = "US",
PostalCode = "12345"
}
};
// Act & Assert
var exception = Assert.Throws<InvalidOperationException>(() => sut.ToDomain());
Assert.Contains($"Cannot upgrade Premium subscription to {tierType} plan", exception.Message);
}
}

View File

@@ -36,7 +36,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
user.Premium = false;
// Act
var result = await _command.Run(user, ProductTierType.Teams, billingAddress);
var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert
Assert.True(result.IsT1);
@@ -52,7 +52,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
user.GatewaySubscriptionId = null;
// Act
var result = await _command.Run(user, ProductTierType.Teams, billingAddress);
var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert
Assert.True(result.IsT1);
@@ -136,7 +136,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
.Returns(invoice);
// Act
var result = await _command.Run(user, ProductTierType.Teams, billingAddress);
var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert
Assert.True(result.IsT0);
@@ -212,7 +212,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
.Returns(invoice);
// Act
var result = await _command.Run(user, ProductTierType.Teams, billingAddress);
var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert
Assert.True(result.IsT0);
@@ -279,7 +279,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
.Returns(invoice);
// Act
await _command.Run(user, ProductTierType.Teams, billingAddress);
await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert - Verify that the subscription item quantity is always 1
await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(
@@ -289,82 +289,6 @@ public class PreviewPremiumUpgradeProrationCommandTests
item.Quantity == 1)));
}
[Theory]
[InlineData(ProductTierType.Families, PlanType.FamiliesAnnually)]
[InlineData(ProductTierType.Teams, PlanType.TeamsAnnually)]
[InlineData(ProductTierType.Enterprise, PlanType.EnterpriseAnnually)]
public async Task Run_ProductTierTypeConversion_MapsToCorrectPlanType(
ProductTierType productTierType,
PlanType expectedPlanType)
{
// Arrange
var user = new User
{
Premium = true,
GatewaySubscriptionId = "sub_123",
GatewayCustomerId = "cus_123"
};
var billingAddress = new BillingAddress
{
Country = "US",
PostalCode = "12345"
};
var premiumPlan = new PremiumPlan
{
Name = "Premium",
Available = true,
LegacyYear = null,
Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "premium-annually",
Price = 10m,
Provided = 1
},
Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "storage-gb-annually",
Price = 4m,
Provided = 1
}
};
var premiumPlans = new List<PremiumPlan> { premiumPlan };
var currentSubscription = new Subscription
{
Id = "sub_123",
Customer = new Customer { Id = "cus_123", Discount = null },
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new() { Id = "si_premium", Price = new Price { Id = "premium-annually" } }
}
}
};
var targetPlan = new TeamsPlan(isAnnual: true);
var invoice = new Invoice
{
Total = 5000,
TotalTaxes = new List<InvoiceTotalTax> { new() { Amount = 500 } }
};
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_pricingClient.GetPlanOrThrow(expectedPlanType).Returns(targetPlan);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>())
.Returns(currentSubscription);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(invoice);
// Act
await _command.Run(user, productTierType, billingAddress);
// Assert - Verify that the correct PlanType was used
await _pricingClient.Received(1).GetPlanOrThrow(expectedPlanType);
}
[Theory, BitAutoData]
public async Task Run_ValidUpgrade_DeletesPremiumSubscriptionItems(User user, BillingAddress billingAddress)
{
@@ -423,7 +347,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
.Returns(invoice);
// Act
await _command.Run(user, ProductTierType.Teams, billingAddress);
await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert - Verify both password manager and storage items are marked as deleted
await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(
@@ -491,7 +415,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
.Returns(invoice);
// Act
await _command.Run(user, ProductTierType.Families, billingAddress);
await _command.Run(user, PlanType.FamiliesAnnually, billingAddress);
// Assert - Verify non-seat-based plan uses StripePlanId with quantity 1
await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(
@@ -560,7 +484,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
.Returns(invoice);
// Act
await _command.Run(user, ProductTierType.Teams, billingAddress);
await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert - Verify all invoice preview options are correct
await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(
@@ -570,24 +494,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
options.Subscription == "sub_123" &&
options.CustomerDetails.Address.Country == "US" &&
options.CustomerDetails.Address.PostalCode == "12345" &&
options.SubscriptionDetails.ProrationBehavior == "always_invoice"));
}
[Theory, BitAutoData]
public async Task Run_TeamsStarterTierType_ReturnsBadRequest(User user, BillingAddress billingAddress)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
// Act
var result = await _command.Run(user, ProductTierType.TeamsStarter, billingAddress);
// Assert
Assert.True(result.IsT1);
var badRequest = result.AsT1;
Assert.Equal("Cannot upgrade Premium subscription to TeamsStarter plan.", badRequest.Response);
options.SubscriptionDetails.ProrationBehavior == "create_prorations"));
}
[Theory, BitAutoData]
@@ -648,7 +555,7 @@ public class PreviewPremiumUpgradeProrationCommandTests
.Returns(invoice);
// Act
await _command.Run(user, ProductTierType.Teams, billingAddress);
await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert - Verify seat-based plan uses StripeSeatPlanId with quantity 1
await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(

View File

@@ -28,7 +28,8 @@ public class UpgradePremiumToOrganizationCommandTests
string? stripePlanId = null,
string? stripeSeatPlanId = null,
string? stripePremiumAccessPlanId = null,
string? stripeStoragePlanId = null)
string? stripeStoragePlanId = null,
int? trialPeriodDays = null)
{
Type = planType;
ProductTier = ProductTierType.Teams;
@@ -37,7 +38,7 @@ public class UpgradePremiumToOrganizationCommandTests
NameLocalizationKey = "";
DescriptionLocalizationKey = "";
CanBeUsedByBusiness = true;
TrialPeriodDays = null;
TrialPeriodDays = trialPeriodDays;
HasSelfHost = false;
HasPolicies = false;
HasGroups = false;
@@ -86,10 +87,9 @@ public class UpgradePremiumToOrganizationCommandTests
string? stripePlanId = null,
string? stripeSeatPlanId = null,
string? stripePremiumAccessPlanId = null,
string? stripeStoragePlanId = null)
{
return new TestPlan(planType, stripePlanId, stripeSeatPlanId, stripePremiumAccessPlanId, stripeStoragePlanId);
}
string? stripeStoragePlanId = null,
int? trialPeriodDays = null) =>
new TestPlan(planType, stripePlanId, stripeSeatPlanId, stripePremiumAccessPlanId, stripeStoragePlanId, trialPeriodDays);
private static PremiumPlan CreateTestPremiumPlan(
string seatPriceId = "premium-annually",
@@ -643,4 +643,75 @@ public class UpgradePremiumToOrganizationCommandTests
var badRequest = result.AsT1;
Assert.Equal("Premium subscription item not found.", badRequest.Response);
}
[Theory, BitAutoData]
public async Task Run_PlanWithTrialPeriod_SetsTrialEnd(User user)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
var currentPeriodEnd = DateTime.UtcNow.AddMonths(1);
var mockSubscription = new Subscription
{
Id = "sub_123",
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new SubscriptionItem
{
Id = "si_premium",
Price = new Price { Id = "premium-annually" },
CurrentPeriodEnd = currentPeriodEnd
}
}
},
Metadata = new Dictionary<string, string>()
};
var mockPremiumPlans = CreateTestPremiumPlansList();
// Create a plan with a trial period
var mockPlan = CreateTestPlan(
PlanType.TeamsAnnually,
stripeSeatPlanId: "teams-seat-annually",
trialPeriodDays: 7
);
// Capture the subscription update options to verify TrialEnd is set
SubscriptionUpdateOptions capturedOptions = null;
_stripeAdapter.GetSubscriptionAsync("sub_123")
.Returns(mockSubscription);
_pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan);
_stripeAdapter.UpdateSubscriptionAsync(Arg.Any<string>(), Arg.Do<SubscriptionUpdateOptions>(opts => capturedOptions = opts))
.Returns(Task.FromResult(mockSubscription));
_organizationRepository.CreateAsync(Arg.Any<Organization>()).Returns(callInfo => Task.FromResult(callInfo.Arg<Organization>()));
_organizationApiKeyRepository.CreateAsync(Arg.Any<OrganizationApiKey>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationApiKey>()));
_organizationUserRepository.CreateAsync(Arg.Any<OrganizationUser>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationUser>()));
_applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any<Organization>()).Returns(Task.CompletedTask);
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
var testStartTime = DateTime.UtcNow;
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
var testEndTime = DateTime.UtcNow;
// Assert
Assert.True(result.IsT0);
await _stripeAdapter.Received(1).UpdateSubscriptionAsync("sub_123", Arg.Any<SubscriptionUpdateOptions>());
Assert.NotNull(capturedOptions);
Assert.NotNull(capturedOptions.TrialEnd);
// TrialEnd is AnyOf<DateTime?, SubscriptionTrialEnd> - verify it's a DateTime
var trialEndDateTime = capturedOptions.TrialEnd.Value as DateTime?;
Assert.NotNull(trialEndDateTime);
Assert.True(trialEndDateTime.Value >= testStartTime.AddDays(7));
Assert.True(trialEndDateTime.Value <= testEndTime.AddDays(7));
}
}