1
0
mirror of https://github.com/bitwarden/server synced 2026-01-27 23:03:31 +00:00

Merge branch 'main' into billing/pm-30908/correct-premium-subscription-status-handling

This commit is contained in:
cyprain-okeke
2026-01-21 12:49:03 +01:00
committed by GitHub
19 changed files with 519 additions and 477 deletions

View File

@@ -93,7 +93,7 @@ public class RestoreOrganizationUserCommand(
.twoFactorIsEnabled;
}
if (organization.PlanType == PlanType.Free)
if (organization.PlanType == PlanType.Free && organizationUser.UserId.HasValue)
{
await CheckUserForOtherFreeOrganizationOwnershipAsync(organizationUser);
}

View File

@@ -2,6 +2,7 @@
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Subscriptions.Models;
using Bit.Core.Entities;
using Bit.Core.Services;
using Bit.Core.Utilities;
@@ -29,6 +30,7 @@ public interface IUpdatePremiumStorageCommand
}
public class UpdatePremiumStorageCommand(
IBraintreeService braintreeService,
IStripeAdapter stripeAdapter,
IUserService userService,
IPricingClient pricingClient,
@@ -49,7 +51,10 @@ public class UpdatePremiumStorageCommand(
// Fetch all premium plans and the user's subscription to find which plan they're on
var premiumPlans = await pricingClient.ListPremiumPlans();
var subscription = await stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId);
var subscription = await stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, new SubscriptionGetOptions
{
Expand = ["customer"]
});
// Find the password manager subscription item (seat, not storage) and match it to a plan
var passwordManagerItem = subscription.Items.Data.FirstOrDefault(i =>
@@ -127,13 +132,41 @@ public class UpdatePremiumStorageCommand(
});
}
var subscriptionUpdateOptions = new SubscriptionUpdateOptions
{
Items = subscriptionItemOptions,
ProrationBehavior = ProrationBehavior.AlwaysInvoice
};
var usingPayPal = subscription.Customer.Metadata.ContainsKey(MetadataKeys.BraintreeCustomerId);
await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, subscriptionUpdateOptions);
if (usingPayPal)
{
var options = new SubscriptionUpdateOptions
{
Items = subscriptionItemOptions,
ProrationBehavior = ProrationBehavior.CreateProrations
};
await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, options);
var draftInvoice = await stripeAdapter.CreateInvoiceAsync(new InvoiceCreateOptions
{
Customer = subscription.CustomerId,
Subscription = subscription.Id,
AutoAdvance = false,
CollectionMethod = CollectionMethod.ChargeAutomatically
});
var finalizedInvoice = await stripeAdapter.FinalizeInvoiceAsync(draftInvoice.Id,
new InvoiceFinalizeOptions { AutoAdvance = false, Expand = ["customer"] });
await braintreeService.PayInvoice(new UserId(user.Id), finalizedInvoice);
}
else
{
var options = new SubscriptionUpdateOptions
{
Items = subscriptionItemOptions,
ProrationBehavior = ProrationBehavior.AlwaysInvoice
};
await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, options);
}
// Update the user's max storage
user.MaxStorageGb = maxStorageGb;

View File

@@ -24,6 +24,7 @@ public interface IStripeAdapter
Task<Subscription> CancelSubscriptionAsync(string id, SubscriptionCancelOptions options = null);
Task<Invoice> GetInvoiceAsync(string id, InvoiceGetOptions options);
Task<List<Invoice>> ListInvoicesAsync(StripeInvoiceListOptions options);
Task<Invoice> CreateInvoiceAsync(InvoiceCreateOptions options);
Task<Invoice> CreateInvoicePreviewAsync(InvoiceCreatePreviewOptions options);
Task<List<Invoice>> SearchInvoiceAsync(InvoiceSearchOptions options);
Task<Invoice> UpdateInvoiceAsync(string id, InvoiceUpdateOptions options);

View File

@@ -116,6 +116,9 @@ public class StripeAdapter : IStripeAdapter
return invoices;
}
public Task<Invoice> CreateInvoiceAsync(InvoiceCreateOptions options) =>
_invoiceService.CreateAsync(options);
public Task<Invoice> CreateInvoicePreviewAsync(InvoiceCreatePreviewOptions options) =>
_invoiceService.CreatePreviewAsync(options);

View File

@@ -162,7 +162,6 @@ public static class FeatureFlagKeys
public const string MjmlWelcomeEmailTemplates = "pm-21741-mjml-welcome-email";
public const string OrganizationConfirmationEmail = "pm-28402-update-confirmed-to-org-email-template";
public const string MarketingInitiatedPremiumFlow = "pm-26140-marketing-initiated-premium-flow";
public const string RedirectOnSsoRequired = "pm-1632-redirect-on-sso-required";
public const string PrefetchPasswordPrelogin = "pm-23801-prefetch-password-prelogin";
public const string PM27086_UpdateAuthenticationApisForInputPassword = "pm-27086-update-authentication-apis-for-input-password";

View File

@@ -19,7 +19,7 @@
</mj-text>
<mj-text font-size="16px" line-height="24px" padding="10px 15px 15px 15px">
As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal.
This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax.
This year's renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax.
</mj-text>
<mj-text font-size="16px" line-height="24px" padding="10px 15px">
Questions? Contact

View File

@@ -18,7 +18,7 @@
</mj-text>
<mj-text font-size="16px" line-height="24px" padding="10px 15px 15px 15px">
As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal.
This renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually.
This year's renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually.
</mj-text>
<mj-text font-size="16px" line-height="24px" padding="10px 15px">
Questions? Contact

View File

@@ -203,7 +203,7 @@
<td align="left" style="font-size:0px;padding:10px 15px 15px 15px;word-break:break-word;">
<div style="font-family:'Helvetica Neue', Helvetica, Arial, sans-serif;font-size:16px;line-height:24px;text-align:left;color:#1B2029;">As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal.
This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax.</div>
This year's renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax.</div>
</td>
</tr>
@@ -271,12 +271,12 @@
<tbody>
<tr>
<td style="direction:ltr;font-size:0px;padding:0px 20px 10px 20px;text-align:center;">
<!--[if mso | IE]><table role="presentation" border="0" cellpadding="0" cellspacing="0"><tr><td class="" width="660px" ><table align="center" border="0" cellpadding="0" cellspacing="0" class="" role="presentation" style="width:620px;" width="620" bgcolor="#f6f6f6" ><tr><td style="line-height:0px;font-size:0px;mso-line-height-rule:exactly;"><![endif]-->
<!--[if mso | IE]><table role="presentation" border="0" cellpadding="0" cellspacing="0"><tr><td class="" width="660px" ><table align="center" border="0" cellpadding="0" cellspacing="0" class="" role="presentation" style="width:620px;" width="620" bgcolor="#F3F6F9" ><tr><td style="line-height:0px;font-size:0px;mso-line-height-rule:exactly;"><![endif]-->
<div style="background:#f6f6f6;background-color:#f6f6f6;margin:0px auto;border-radius:0px 0px 4px 4px;max-width:620px;">
<div style="background:#F3F6F9;background-color:#F3F6F9;margin:0px auto;border-radius:0px 0px 4px 4px;max-width:620px;">
<table align="center" border="0" cellpadding="0" cellspacing="0" role="presentation" style="background:#f6f6f6;background-color:#f6f6f6;width:100%;border-radius:0px 0px 4px 4px;">
<table align="center" border="0" cellpadding="0" cellspacing="0" role="presentation" style="background:#F3F6F9;background-color:#F3F6F9;width:100%;border-radius:0px 0px 4px 4px;">
<tbody>
<tr>
<td style="direction:ltr;font-size:0px;padding:5px 10px 10px 10px;text-align:center;">
@@ -364,8 +364,8 @@
<table align="center" border="0" cellpadding="0" cellspacing="0" role="presentation" style="width:100%;">
<tbody>
<tr>
<td style="direction:ltr;font-size:0px;padding:20px 0;text-align:center;">
<!--[if mso | IE]><table role="presentation" border="0" cellpadding="0" cellspacing="0"><tr><td class="" style="vertical-align:top;width:660px;" ><![endif]-->
<td style="direction:ltr;font-size:0px;padding:5px 20px 10px 20px;text-align:center;">
<!--[if mso | IE]><table role="presentation" border="0" cellpadding="0" cellspacing="0"><tr><td class="" style="vertical-align:top;width:620px;" ><![endif]-->
<div class="mj-column-per-100 mj-outlook-group-fix" style="font-size:0px;text-align:left;direction:ltr;display:inline-block;vertical-align:top;width:100%;">
@@ -381,13 +381,13 @@
<tbody>
<tr>
<td style="padding:10px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:30px;">
<td style="padding:8px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:24px;">
<tbody>
<tr>
<td style="font-size:0;height:30px;vertical-align:middle;width:30px;">
<td style="font-size:0;height:24px;vertical-align:middle;width:24px;">
<a href="https://x.com/bitwarden" target="_blank">
<img alt height="30" src="https://assets.bitwarden.com/email/v1/mail-x.png" style="border-radius:3px;display:block;" width="30">
<img alt height="24" src="https://assets.bitwarden.com/email/v1/social-icons-x-twitter.png" style="border-radius:3px;display:block;" width="24">
</a>
</td>
</tr>
@@ -404,13 +404,13 @@
<tbody>
<tr>
<td style="padding:10px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:30px;">
<td style="padding:8px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:24px;">
<tbody>
<tr>
<td style="font-size:0;height:30px;vertical-align:middle;width:30px;">
<td style="font-size:0;height:24px;vertical-align:middle;width:24px;">
<a href="https://www.reddit.com/r/Bitwarden/" target="_blank">
<img alt height="30" src="https://assets.bitwarden.com/email/v1/mail-reddit.png" style="border-radius:3px;display:block;" width="30">
<img alt height="24" src="https://assets.bitwarden.com/email/v1/social-icons-reddit.png" style="border-radius:3px;display:block;" width="24">
</a>
</td>
</tr>
@@ -427,13 +427,13 @@
<tbody>
<tr>
<td style="padding:10px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:30px;">
<td style="padding:8px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:24px;">
<tbody>
<tr>
<td style="font-size:0;height:30px;vertical-align:middle;width:30px;">
<td style="font-size:0;height:24px;vertical-align:middle;width:24px;">
<a href="https://community.bitwarden.com/" target="_blank">
<img alt height="30" src="https://assets.bitwarden.com/email/v1/mail-discourse.png" style="border-radius:3px;display:block;" width="30">
<img alt height="24" src="https://assets.bitwarden.com/email/v1/social-icons-discourse.png" style="border-radius:3px;display:block;" width="24">
</a>
</td>
</tr>
@@ -450,13 +450,13 @@
<tbody>
<tr>
<td style="padding:10px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:30px;">
<td style="padding:8px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:24px;">
<tbody>
<tr>
<td style="font-size:0;height:30px;vertical-align:middle;width:30px;">
<td style="font-size:0;height:24px;vertical-align:middle;width:24px;">
<a href="https://github.com/bitwarden" target="_blank">
<img alt height="30" src="https://assets.bitwarden.com/email/v1/mail-github.png" style="border-radius:3px;display:block;" width="30">
<img alt height="24" src="https://assets.bitwarden.com/email/v1/social-icons-github.png" style="border-radius:3px;display:block;" width="24">
</a>
</td>
</tr>
@@ -473,13 +473,13 @@
<tbody>
<tr>
<td style="padding:10px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:30px;">
<td style="padding:8px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:24px;">
<tbody>
<tr>
<td style="font-size:0;height:30px;vertical-align:middle;width:30px;">
<td style="font-size:0;height:24px;vertical-align:middle;width:24px;">
<a href="https://www.youtube.com/channel/UCId9a_jQqvJre0_dE2lE_Rw" target="_blank">
<img alt height="30" src="https://assets.bitwarden.com/email/v1/mail-youtube.png" style="border-radius:3px;display:block;" width="30">
<img alt height="24" src="https://assets.bitwarden.com/email/v1/social-icons-youtube.png" style="border-radius:3px;display:block;" width="24">
</a>
</td>
</tr>
@@ -496,13 +496,13 @@
<tbody>
<tr>
<td style="padding:10px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:30px;">
<td style="padding:8px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:24px;">
<tbody>
<tr>
<td style="font-size:0;height:30px;vertical-align:middle;width:30px;">
<td style="font-size:0;height:24px;vertical-align:middle;width:24px;">
<a href="https://www.linkedin.com/company/bitwarden1/" target="_blank">
<img alt height="30" src="https://assets.bitwarden.com/email/v1/mail-linkedin.png" style="border-radius:3px;display:block;" width="30">
<img alt height="24" src="https://assets.bitwarden.com/email/v1/social-icons-linkedin.png" style="border-radius:3px;display:block;" width="24">
</a>
</td>
</tr>
@@ -519,13 +519,13 @@
<tbody>
<tr>
<td style="padding:10px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:30px;">
<td style="padding:8px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:24px;">
<tbody>
<tr>
<td style="font-size:0;height:30px;vertical-align:middle;width:30px;">
<td style="font-size:0;height:24px;vertical-align:middle;width:24px;">
<a href="https://www.facebook.com/bitwarden/" target="_blank">
<img alt height="30" src="https://assets.bitwarden.com/email/v1/mail-facebook.png" style="border-radius:3px;display:block;" width="30">
<img alt height="24" src="https://assets.bitwarden.com/email/v1/social-icons-facebook.png" style="border-radius:3px;display:block;" width="24">
</a>
</td>
</tr>
@@ -546,15 +546,15 @@
<tr>
<td align="center" style="font-size:0px;padding:10px 25px;word-break:break-word;">
<div style="font-family:'Helvetica Neue', Helvetica, Arial, sans-serif;font-size:12px;line-height:16px;text-align:center;color:#5A6D91;"><p style="margin-bottom: 5px">
<div style="font-family:'Helvetica Neue', Helvetica, Arial, sans-serif;font-size:12px;line-height:16px;text-align:center;color:#5A6D91;"><p style="margin-bottom: 5px; margin-top: 5px">
© 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa
Barbara, CA, USA
</p>
<p style="margin-top: 5px">
Always confirm you are on a trusted Bitwarden domain before logging
in:<br>
<a href="https://bitwarden.com/">bitwarden.com</a> |
<a href="https://bitwarden.com/help/emails-from-bitwarden/">Learn why we include this</a>
<a href="https://bitwarden.com/" style="text-decoration:none;color:#175ddc; font-weight:400">bitwarden.com</a> |
<a href="https://bitwarden.com/help/emails-from-bitwarden/" style="text-decoration:none; color:#175ddc; font-weight:400">Learn why we include this</a>
</p></div>
</td>

View File

@@ -2,6 +2,6 @@
at {{BaseAnnualRenewalPrice}} + tax.
As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal.
This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax.
This year's renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax.
Questions? Contact support@bitwarden.com

View File

@@ -202,7 +202,7 @@
<td align="left" style="font-size:0px;padding:10px 15px 15px 15px;word-break:break-word;">
<div style="font-family:'Helvetica Neue', Helvetica, Arial, sans-serif;font-size:16px;line-height:24px;text-align:left;color:#1B2029;">As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal.
This renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually.</div>
This year's renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually.</div>
</td>
</tr>
@@ -270,12 +270,12 @@
<tbody>
<tr>
<td style="direction:ltr;font-size:0px;padding:0px 20px 10px 20px;text-align:center;">
<!--[if mso | IE]><table role="presentation" border="0" cellpadding="0" cellspacing="0"><tr><td class="" width="660px" ><table align="center" border="0" cellpadding="0" cellspacing="0" class="" role="presentation" style="width:620px;" width="620" bgcolor="#f6f6f6" ><tr><td style="line-height:0px;font-size:0px;mso-line-height-rule:exactly;"><![endif]-->
<!--[if mso | IE]><table role="presentation" border="0" cellpadding="0" cellspacing="0"><tr><td class="" width="660px" ><table align="center" border="0" cellpadding="0" cellspacing="0" class="" role="presentation" style="width:620px;" width="620" bgcolor="#F3F6F9" ><tr><td style="line-height:0px;font-size:0px;mso-line-height-rule:exactly;"><![endif]-->
<div style="background:#f6f6f6;background-color:#f6f6f6;margin:0px auto;border-radius:0px 0px 4px 4px;max-width:620px;">
<div style="background:#F3F6F9;background-color:#F3F6F9;margin:0px auto;border-radius:0px 0px 4px 4px;max-width:620px;">
<table align="center" border="0" cellpadding="0" cellspacing="0" role="presentation" style="background:#f6f6f6;background-color:#f6f6f6;width:100%;border-radius:0px 0px 4px 4px;">
<table align="center" border="0" cellpadding="0" cellspacing="0" role="presentation" style="background:#F3F6F9;background-color:#F3F6F9;width:100%;border-radius:0px 0px 4px 4px;">
<tbody>
<tr>
<td style="direction:ltr;font-size:0px;padding:5px 10px 10px 10px;text-align:center;">
@@ -363,8 +363,8 @@
<table align="center" border="0" cellpadding="0" cellspacing="0" role="presentation" style="width:100%;">
<tbody>
<tr>
<td style="direction:ltr;font-size:0px;padding:20px 0;text-align:center;">
<!--[if mso | IE]><table role="presentation" border="0" cellpadding="0" cellspacing="0"><tr><td class="" style="vertical-align:top;width:660px;" ><![endif]-->
<td style="direction:ltr;font-size:0px;padding:5px 20px 10px 20px;text-align:center;">
<!--[if mso | IE]><table role="presentation" border="0" cellpadding="0" cellspacing="0"><tr><td class="" style="vertical-align:top;width:620px;" ><![endif]-->
<div class="mj-column-per-100 mj-outlook-group-fix" style="font-size:0px;text-align:left;direction:ltr;display:inline-block;vertical-align:top;width:100%;">
@@ -380,13 +380,13 @@
<tbody>
<tr>
<td style="padding:10px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:30px;">
<td style="padding:8px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:24px;">
<tbody>
<tr>
<td style="font-size:0;height:30px;vertical-align:middle;width:30px;">
<td style="font-size:0;height:24px;vertical-align:middle;width:24px;">
<a href="https://x.com/bitwarden" target="_blank">
<img alt height="30" src="https://assets.bitwarden.com/email/v1/mail-x.png" style="border-radius:3px;display:block;" width="30">
<img alt height="24" src="https://assets.bitwarden.com/email/v1/social-icons-x-twitter.png" style="border-radius:3px;display:block;" width="24">
</a>
</td>
</tr>
@@ -403,13 +403,13 @@
<tbody>
<tr>
<td style="padding:10px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:30px;">
<td style="padding:8px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:24px;">
<tbody>
<tr>
<td style="font-size:0;height:30px;vertical-align:middle;width:30px;">
<td style="font-size:0;height:24px;vertical-align:middle;width:24px;">
<a href="https://www.reddit.com/r/Bitwarden/" target="_blank">
<img alt height="30" src="https://assets.bitwarden.com/email/v1/mail-reddit.png" style="border-radius:3px;display:block;" width="30">
<img alt height="24" src="https://assets.bitwarden.com/email/v1/social-icons-reddit.png" style="border-radius:3px;display:block;" width="24">
</a>
</td>
</tr>
@@ -426,13 +426,13 @@
<tbody>
<tr>
<td style="padding:10px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:30px;">
<td style="padding:8px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:24px;">
<tbody>
<tr>
<td style="font-size:0;height:30px;vertical-align:middle;width:30px;">
<td style="font-size:0;height:24px;vertical-align:middle;width:24px;">
<a href="https://community.bitwarden.com/" target="_blank">
<img alt height="30" src="https://assets.bitwarden.com/email/v1/mail-discourse.png" style="border-radius:3px;display:block;" width="30">
<img alt height="24" src="https://assets.bitwarden.com/email/v1/social-icons-discourse.png" style="border-radius:3px;display:block;" width="24">
</a>
</td>
</tr>
@@ -449,13 +449,13 @@
<tbody>
<tr>
<td style="padding:10px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:30px;">
<td style="padding:8px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:24px;">
<tbody>
<tr>
<td style="font-size:0;height:30px;vertical-align:middle;width:30px;">
<td style="font-size:0;height:24px;vertical-align:middle;width:24px;">
<a href="https://github.com/bitwarden" target="_blank">
<img alt height="30" src="https://assets.bitwarden.com/email/v1/mail-github.png" style="border-radius:3px;display:block;" width="30">
<img alt height="24" src="https://assets.bitwarden.com/email/v1/social-icons-github.png" style="border-radius:3px;display:block;" width="24">
</a>
</td>
</tr>
@@ -472,13 +472,13 @@
<tbody>
<tr>
<td style="padding:10px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:30px;">
<td style="padding:8px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:24px;">
<tbody>
<tr>
<td style="font-size:0;height:30px;vertical-align:middle;width:30px;">
<td style="font-size:0;height:24px;vertical-align:middle;width:24px;">
<a href="https://www.youtube.com/channel/UCId9a_jQqvJre0_dE2lE_Rw" target="_blank">
<img alt height="30" src="https://assets.bitwarden.com/email/v1/mail-youtube.png" style="border-radius:3px;display:block;" width="30">
<img alt height="24" src="https://assets.bitwarden.com/email/v1/social-icons-youtube.png" style="border-radius:3px;display:block;" width="24">
</a>
</td>
</tr>
@@ -495,13 +495,13 @@
<tbody>
<tr>
<td style="padding:10px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:30px;">
<td style="padding:8px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:24px;">
<tbody>
<tr>
<td style="font-size:0;height:30px;vertical-align:middle;width:30px;">
<td style="font-size:0;height:24px;vertical-align:middle;width:24px;">
<a href="https://www.linkedin.com/company/bitwarden1/" target="_blank">
<img alt height="30" src="https://assets.bitwarden.com/email/v1/mail-linkedin.png" style="border-radius:3px;display:block;" width="30">
<img alt height="24" src="https://assets.bitwarden.com/email/v1/social-icons-linkedin.png" style="border-radius:3px;display:block;" width="24">
</a>
</td>
</tr>
@@ -518,13 +518,13 @@
<tbody>
<tr>
<td style="padding:10px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:30px;">
<td style="padding:8px;vertical-align:middle;">
<table border="0" cellpadding="0" cellspacing="0" role="presentation" style="border-radius:3px;width:24px;">
<tbody>
<tr>
<td style="font-size:0;height:30px;vertical-align:middle;width:30px;">
<td style="font-size:0;height:24px;vertical-align:middle;width:24px;">
<a href="https://www.facebook.com/bitwarden/" target="_blank">
<img alt height="30" src="https://assets.bitwarden.com/email/v1/mail-facebook.png" style="border-radius:3px;display:block;" width="30">
<img alt height="24" src="https://assets.bitwarden.com/email/v1/social-icons-facebook.png" style="border-radius:3px;display:block;" width="24">
</a>
</td>
</tr>
@@ -545,15 +545,15 @@
<tr>
<td align="center" style="font-size:0px;padding:10px 25px;word-break:break-word;">
<div style="font-family:'Helvetica Neue', Helvetica, Arial, sans-serif;font-size:12px;line-height:16px;text-align:center;color:#5A6D91;"><p style="margin-bottom: 5px">
<div style="font-family:'Helvetica Neue', Helvetica, Arial, sans-serif;font-size:12px;line-height:16px;text-align:center;color:#5A6D91;"><p style="margin-bottom: 5px; margin-top: 5px">
© 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa
Barbara, CA, USA
</p>
<p style="margin-top: 5px">
Always confirm you are on a trusted Bitwarden domain before logging
in:<br>
<a href="https://bitwarden.com/">bitwarden.com</a> |
<a href="https://bitwarden.com/help/emails-from-bitwarden/">Learn why we include this</a>
<a href="https://bitwarden.com/" style="text-decoration:none;color:#175ddc; font-weight:400">bitwarden.com</a> |
<a href="https://bitwarden.com/help/emails-from-bitwarden/" style="text-decoration:none; color:#175ddc; font-weight:400">Learn why we include this</a>
</p></div>
</td>

View File

@@ -1,6 +1,6 @@
Your Bitwarden Premium subscription renews in 15 days. The price is updating to {{BaseMonthlyRenewalPrice}}/month, billed annually.
As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal.
This renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually.
This year's renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually.
Questions? Contact support@bitwarden.com

View File

@@ -4,7 +4,6 @@
using System.Security.Claims;
using Bit.Core;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Auth.Entities;
@@ -233,56 +232,14 @@ public abstract class BaseRequestValidator<T> where T : class
private async Task<bool> ValidateSsoAsync(T context, ValidatedTokenRequest request,
CustomValidatorRequestContext validatorContext)
{
// TODO: Clean up Feature Flag: Remove this if block: PM-28281
if (!_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired))
var ssoValid = await _ssoRequestValidator.ValidateAsync(validatorContext.User, request, validatorContext);
if (ssoValid)
{
validatorContext.SsoRequired = await RequireSsoLoginAsync(validatorContext.User, request.GrantType);
if (!validatorContext.SsoRequired)
{
return true;
}
// Users without SSO requirement requesting 2FA recovery will be fast-forwarded through login and are
// presented with their 2FA management area as a reminder to re-evaluate their 2FA posture after recovery and
// review their new recovery token if desired.
// SSO users cannot be assumed to be authenticated, and must prove authentication with their IdP after recovery.
// As described in validation order determination, if TwoFactorRequired, the 2FA validation scheme will have been
// evaluated, and recovery will have been performed if requested.
// We will send a descriptive message in these cases so clients can give the appropriate feedback and redirect
// to /login.
if (validatorContext.TwoFactorRequired &&
validatorContext.TwoFactorRecoveryRequested)
{
SetSsoResult(context,
new Dictionary<string, object>
{
{
"ErrorModel",
new ErrorResponseModel(
"Two-factor recovery has been performed. SSO authentication is required.")
}
});
return false;
}
SetSsoResult(context,
new Dictionary<string, object>
{
{ "ErrorModel", new ErrorResponseModel("SSO authentication is required.") }
});
return false;
return true;
}
else
{
var ssoValid = await _ssoRequestValidator.ValidateAsync(validatorContext.User, request, validatorContext);
if (ssoValid)
{
return true;
}
SetValidationErrorResult(context, validatorContext);
return ssoValid;
}
SetValidationErrorResult(context, validatorContext);
return ssoValid;
}
/// <summary>
@@ -521,9 +478,6 @@ public abstract class BaseRequestValidator<T> where T : class
[Obsolete("Consider using SetValidationErrorResult instead.")]
protected abstract void SetTwoFactorResult(T context, Dictionary<string, object> customResponse);
[Obsolete("Consider using SetValidationErrorResult instead.")]
protected abstract void SetSsoResult(T context, Dictionary<string, object> customResponse);
[Obsolete("Consider using SetValidationErrorResult instead.")]
protected abstract void SetErrorResult(T context, Dictionary<string, object> customResponse);
@@ -540,41 +494,6 @@ public abstract class BaseRequestValidator<T> where T : class
protected abstract ClaimsPrincipal GetSubject(T context);
/// <summary>
/// Check if the user is required to authenticate via SSO. If the user requires SSO, but they are
/// logging in using an API Key (client_credentials) then they are allowed to bypass the SSO requirement.
/// If the GrantType is authorization_code or client_credentials we know the user is trying to login
/// using the SSO flow so they are allowed to continue.
/// </summary>
/// <param name="user">user trying to login</param>
/// <param name="grantType">magic string identifying the grant type requested</param>
/// <returns>true if sso required; false if not required or already in process</returns>
[Obsolete(
"This method is deprecated and will be removed in future versions, PM-28281. Please use the SsoRequestValidator scheme instead.")]
private async Task<bool> RequireSsoLoginAsync(User user, string grantType)
{
if (grantType == "authorization_code" || grantType == "client_credentials")
{
// Already using SSO to authenticate, or logging-in via api key to skip SSO requirement
// allow to authenticate successfully
return false;
}
// Check if user belongs to any organization with an active SSO policy
var ssoRequired = _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements)
? (await PolicyRequirementQuery.GetAsync<RequireSsoPolicyRequirement>(user.Id))
.SsoRequired
: await PolicyService.AnyPoliciesApplicableToUserAsync(
user.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed);
if (ssoRequired)
{
return true;
}
// Default - SSO is not required
return false;
}
private async Task ResetFailedAuthDetailsAsync(User user)
{
// Early escape if db hit not necessary

View File

@@ -194,17 +194,6 @@ public class CustomTokenRequestValidator : BaseRequestValidator<CustomTokenReque
context.Result.CustomResponse = customResponse;
}
[Obsolete("Consider using SetGrantValidationErrorResult instead.")]
protected override void SetSsoResult(CustomTokenRequestValidationContext context,
Dictionary<string, object> customResponse)
{
Debug.Assert(context.Result is not null);
context.Result.Error = "invalid_grant";
context.Result.ErrorDescription = "Sso authentication required.";
context.Result.IsError = true;
context.Result.CustomResponse = customResponse;
}
[Obsolete("Consider using SetGrantValidationErrorResult instead.")]
protected override void SetErrorResult(CustomTokenRequestValidationContext context,
Dictionary<string, object> customResponse)

View File

@@ -152,14 +152,6 @@ public class ResourceOwnerPasswordValidator : BaseRequestValidator<ResourceOwner
customResponse);
}
[Obsolete("Consider using SetGrantValidationErrorResult instead.")]
protected override void SetSsoResult(ResourceOwnerPasswordValidationContext context,
Dictionary<string, object> customResponse)
{
context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, "Sso authentication required.",
customResponse);
}
[Obsolete("Consider using SetGrantValidationErrorResult instead.")]
protected override void SetErrorResult(ResourceOwnerPasswordValidationContext context,
Dictionary<string, object> customResponse)

View File

@@ -142,14 +142,6 @@ public class WebAuthnGrantValidator : BaseRequestValidator<ExtensionGrantValidat
customResponse);
}
[Obsolete("Consider using SetValidationErrorResult instead.")]
protected override void SetSsoResult(ExtensionGrantValidationContext context,
Dictionary<string, object> customResponse)
{
context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, "Sso authentication required.",
customResponse);
}
[Obsolete("Consider using SetValidationErrorResult instead.")]
protected override void SetErrorResult(ExtensionGrantValidationContext context, Dictionary<string, object> customResponse)
{

View File

@@ -715,6 +715,39 @@ public class RestoreOrganizationUserCommandTests
Arg.Is<OrganizationUserStatusType>(x => x != OrganizationUserStatusType.Revoked));
}
[Theory, BitAutoData]
public async Task RestoreUser_InvitedUserInFreeOrganization_Success(
Organization organization,
[OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner,
[OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser,
SutProvider<RestoreOrganizationUserCommand> sutProvider)
{
organization.PlanType = PlanType.Free;
organizationUser.UserId = null;
organizationUser.Key = null;
organizationUser.Status = OrganizationUserStatusType.Revoked;
RestoreUser_Setup(organization, owner, organizationUser, sutProvider);
sutProvider.GetDependency<IOrganizationRepository>()
.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id).Returns(new OrganizationSeatCounts
{
Sponsored = 0,
Users = 1
});
await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id);
await sutProvider.GetDependency<IOrganizationUserRepository>()
.Received(1)
.RestoreAsync(organizationUser.Id, OrganizationUserStatusType.Invited);
await sutProvider.GetDependency<IEventService>()
.Received(1)
.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Restored);
await sutProvider.GetDependency<IPushNotificationService>()
.DidNotReceiveWithAnyArgs()
.PushSyncOrgKeysAsync(Arg.Any<Guid>());
}
[Theory, BitAutoData]
public async Task RestoreUsers_Success(Organization organization,
[OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner,

View File

@@ -1,6 +1,7 @@
using Bit.Core.Billing.Premium.Commands;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Subscriptions.Models;
using Bit.Core.Entities;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture.Attributes;
@@ -8,6 +9,7 @@ using Microsoft.Extensions.Logging;
using NSubstitute;
using Stripe;
using Xunit;
using static Bit.Core.Billing.Constants.StripeConstants;
using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan;
using PremiumPurchasable = Bit.Core.Billing.Pricing.Premium.Purchasable;
@@ -15,6 +17,7 @@ namespace Bit.Core.Test.Billing.Premium.Commands;
public class UpdatePremiumStorageCommandTests
{
private readonly IBraintreeService _braintreeService = Substitute.For<IBraintreeService>();
private readonly IStripeAdapter _stripeAdapter = Substitute.For<IStripeAdapter>();
private readonly IUserService _userService = Substitute.For<IUserService>();
private readonly IPricingClient _pricingClient = Substitute.For<IPricingClient>();
@@ -33,13 +36,14 @@ public class UpdatePremiumStorageCommandTests
_pricingClient.ListPremiumPlans().Returns([premiumPlan]);
_command = new UpdatePremiumStorageCommand(
_braintreeService,
_stripeAdapter,
_userService,
_pricingClient,
Substitute.For<ILogger<UpdatePremiumStorageCommand>>());
}
private Subscription CreateMockSubscription(string subscriptionId, int? storageQuantity = null)
private Subscription CreateMockSubscription(string subscriptionId, int? storageQuantity = null, bool isPayPal = false)
{
var items = new List<SubscriptionItem>
{
@@ -63,9 +67,17 @@ public class UpdatePremiumStorageCommandTests
});
}
var customer = new Customer
{
Id = "cus_123",
Metadata = isPayPal ? new Dictionary<string, string> { { MetadataKeys.BraintreeCustomerId, "braintree_123" } } : new Dictionary<string, string>()
};
return new Subscription
{
Id = subscriptionId,
CustomerId = "cus_123",
Customer = customer,
Items = new StripeList<SubscriptionItem>
{
Data = items
@@ -97,7 +109,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 4);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, -5);
@@ -117,7 +129,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 4);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 100);
@@ -154,7 +166,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 9);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 0);
@@ -176,7 +188,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 4);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 4);
@@ -185,7 +197,7 @@ public class UpdatePremiumStorageCommandTests
Assert.True(result.IsT0);
// Verify subscription was fetched but NOT updated
await _stripeAdapter.Received(1).GetSubscriptionAsync("sub_123");
await _stripeAdapter.Received(1).GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>());
await _stripeAdapter.DidNotReceive().UpdateSubscriptionAsync(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>());
await _userService.DidNotReceive().SaveUserAsync(Arg.Any<User>());
}
@@ -200,7 +212,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 4);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 9);
@@ -233,7 +245,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123");
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 9);
@@ -262,7 +274,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 9);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 2);
@@ -291,7 +303,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 9);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 0);
@@ -320,7 +332,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 4);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 99);
@@ -335,4 +347,200 @@ public class UpdatePremiumStorageCommandTests
await _userService.Received(1).SaveUserAsync(Arg.Is<User>(u => u.MaxStorageGb == 100));
}
[Theory, BitAutoData]
public async Task Run_IncreaseStorage_PayPal_Success(User user)
{
// Arrange
user.Premium = true;
user.MaxStorageGb = 5;
user.Storage = 2L * 1024 * 1024 * 1024;
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 4, isPayPal: true);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
var draftInvoice = new Invoice { Id = "in_draft" };
_stripeAdapter.CreateInvoiceAsync(Arg.Any<InvoiceCreateOptions>()).Returns(draftInvoice);
var finalizedInvoice = new Invoice
{
Id = "in_finalized",
Customer = new Customer { Id = "cus_123" }
};
_stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any<InvoiceFinalizeOptions>()).Returns(finalizedInvoice);
// Act
var result = await _command.Run(user, 9);
// Assert
Assert.True(result.IsT0);
// Verify subscription was updated with CreateProrations
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Items.Count == 1 &&
opts.Items[0].Id == "si_storage" &&
opts.Items[0].Quantity == 9 &&
opts.ProrationBehavior == "create_prorations"));
// Verify draft invoice was created
await _stripeAdapter.Received(1).CreateInvoiceAsync(
Arg.Is<InvoiceCreateOptions>(opts =>
opts.Customer == "cus_123" &&
opts.Subscription == "sub_123" &&
opts.AutoAdvance == false &&
opts.CollectionMethod == "charge_automatically"));
// Verify invoice was finalized
await _stripeAdapter.Received(1).FinalizeInvoiceAsync(
"in_draft",
Arg.Is<InvoiceFinalizeOptions>(opts =>
opts.AutoAdvance == false &&
opts.Expand.Contains("customer")));
// Verify Braintree payment was processed
await _braintreeService.Received(1).PayInvoice(Arg.Any<SubscriberId>(), finalizedInvoice);
// Verify user was saved
await _userService.Received(1).SaveUserAsync(Arg.Is<User>(u =>
u.Id == user.Id &&
u.MaxStorageGb == 10));
}
[Theory, BitAutoData]
public async Task Run_AddStorageFromZero_PayPal_Success(User user)
{
// Arrange
user.Premium = true;
user.MaxStorageGb = 1;
user.Storage = 500L * 1024 * 1024;
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", isPayPal: true);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
var draftInvoice = new Invoice { Id = "in_draft" };
_stripeAdapter.CreateInvoiceAsync(Arg.Any<InvoiceCreateOptions>()).Returns(draftInvoice);
var finalizedInvoice = new Invoice
{
Id = "in_finalized",
Customer = new Customer { Id = "cus_123" }
};
_stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any<InvoiceFinalizeOptions>()).Returns(finalizedInvoice);
// Act
var result = await _command.Run(user, 9);
// Assert
Assert.True(result.IsT0);
// Verify subscription was updated with new storage item
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Items.Count == 1 &&
opts.Items[0].Price == "price_storage" &&
opts.Items[0].Quantity == 9 &&
opts.ProrationBehavior == "create_prorations"));
// Verify invoice creation and payment flow
await _stripeAdapter.Received(1).CreateInvoiceAsync(Arg.Any<InvoiceCreateOptions>());
await _stripeAdapter.Received(1).FinalizeInvoiceAsync("in_draft", Arg.Any<InvoiceFinalizeOptions>());
await _braintreeService.Received(1).PayInvoice(Arg.Any<SubscriberId>(), finalizedInvoice);
await _userService.Received(1).SaveUserAsync(Arg.Is<User>(u => u.MaxStorageGb == 10));
}
[Theory, BitAutoData]
public async Task Run_DecreaseStorage_PayPal_Success(User user)
{
// Arrange
user.Premium = true;
user.MaxStorageGb = 10;
user.Storage = 2L * 1024 * 1024 * 1024;
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 9, isPayPal: true);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
var draftInvoice = new Invoice { Id = "in_draft" };
_stripeAdapter.CreateInvoiceAsync(Arg.Any<InvoiceCreateOptions>()).Returns(draftInvoice);
var finalizedInvoice = new Invoice
{
Id = "in_finalized",
Customer = new Customer { Id = "cus_123" }
};
_stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any<InvoiceFinalizeOptions>()).Returns(finalizedInvoice);
// Act
var result = await _command.Run(user, 2);
// Assert
Assert.True(result.IsT0);
// Verify subscription was updated
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Items.Count == 1 &&
opts.Items[0].Id == "si_storage" &&
opts.Items[0].Quantity == 2 &&
opts.ProrationBehavior == "create_prorations"));
// Verify invoice creation and payment flow
await _stripeAdapter.Received(1).CreateInvoiceAsync(Arg.Any<InvoiceCreateOptions>());
await _stripeAdapter.Received(1).FinalizeInvoiceAsync("in_draft", Arg.Any<InvoiceFinalizeOptions>());
await _braintreeService.Received(1).PayInvoice(Arg.Any<SubscriberId>(), finalizedInvoice);
await _userService.Received(1).SaveUserAsync(Arg.Is<User>(u => u.MaxStorageGb == 3));
}
[Theory, BitAutoData]
public async Task Run_RemoveAllAdditionalStorage_PayPal_Success(User user)
{
// Arrange
user.Premium = true;
user.MaxStorageGb = 10;
user.Storage = 500L * 1024 * 1024;
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 9, isPayPal: true);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
var draftInvoice = new Invoice { Id = "in_draft" };
_stripeAdapter.CreateInvoiceAsync(Arg.Any<InvoiceCreateOptions>()).Returns(draftInvoice);
var finalizedInvoice = new Invoice
{
Id = "in_finalized",
Customer = new Customer { Id = "cus_123" }
};
_stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any<InvoiceFinalizeOptions>()).Returns(finalizedInvoice);
// Act
var result = await _command.Run(user, 0);
// Assert
Assert.True(result.IsT0);
// Verify subscription item was deleted
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Items.Count == 1 &&
opts.Items[0].Id == "si_storage" &&
opts.Items[0].Deleted == true &&
opts.ProrationBehavior == "create_prorations"));
// Verify invoice creation and payment flow
await _stripeAdapter.Received(1).CreateInvoiceAsync(Arg.Any<InvoiceCreateOptions>());
await _stripeAdapter.Received(1).FinalizeInvoiceAsync("in_draft", Arg.Any<InvoiceFinalizeOptions>());
await _braintreeService.Received(1).PayInvoice(Arg.Any<SubscriberId>(), finalizedInvoice);
await _userService.Received(1).SaveUserAsync(Arg.Is<User>(u => u.MaxStorageGb == 1));
}
}

View File

@@ -18,6 +18,7 @@ using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Identity.IdentityServer;
using Bit.Identity.IdentityServer.RequestValidationConstants;
using Bit.Identity.IdentityServer.RequestValidators;
using Bit.Identity.Test.Wrappers;
using Bit.Test.Common.AutoFixture.Attributes;
@@ -130,7 +131,7 @@ public class BaseRequestValidatorTests
var logs = _logger.Collector.GetSnapshot(true);
Assert.Contains(logs,
l => l.Level == LogLevel.Warning && l.Message == "Failed login attempt. Is2FARequest: False IpAddress: ");
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel];
Assert.Equal("Username or password is incorrect. Try again.", errorResponse.Message);
}
@@ -161,7 +162,11 @@ public class BaseRequestValidatorTests
.ValidateRequestDeviceAsync(tokenRequest, requestContext)
.Returns(Task.FromResult(false));
// 5 -> not legacy user
// 5 -> SSO not required
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// 6 -> not legacy user
_userService.IsLegacyUser(Arg.Any<string>())
.Returns(false);
@@ -203,6 +208,11 @@ public class BaseRequestValidatorTests
_userService.IsLegacyUser(Arg.Any<string>())
.Returns(false);
// 6 -> SSO validation passes
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// 7 -> setup user account keys
_userAccountKeysQuery.Run(Arg.Any<User>()).Returns(new UserAccountKeysData
{
PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData(
@@ -262,6 +272,11 @@ public class BaseRequestValidatorTests
_userService.IsLegacyUser(Arg.Any<string>())
.Returns(false);
// 6 -> SSO validation passes
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// 7 -> setup user account keys
_userAccountKeysQuery.Run(Arg.Any<User>()).Returns(new UserAccountKeysData
{
PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData(
@@ -326,6 +341,9 @@ public class BaseRequestValidatorTests
{ "TwoFactorProviders2", new Dictionary<string, object> { { "Email", null } } }
}));
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
@@ -368,6 +386,10 @@ public class BaseRequestValidatorTests
.VerifyTwoFactorAsync(user, null, TwoFactorProviderType.Email, "invalid_token")
.Returns(Task.FromResult(false));
// 5 -> set up SSO required verification to succeed
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
@@ -396,21 +418,25 @@ public class BaseRequestValidatorTests
// 1 -> initial validation passes
_sut.isValid = true;
// 2 -> set up 2FA as required
// 2 -> set up SSO required verification to succeed
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// 3 -> set up 2FA as required
_twoFactorAuthenticationValidator
.RequiresTwoFactorAsync(Arg.Any<User>(), tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(true, null)));
// 3 -> provide invalid remember token (remember token expired)
// 4 -> provide invalid remember token (remember token expired)
tokenRequest.Raw["TwoFactorToken"] = "expired_remember_token";
tokenRequest.Raw["TwoFactorProvider"] = "5"; // Remember provider
// 4 -> set up remember token verification to fail
// 5 -> set up remember token verification to fail
_twoFactorAuthenticationValidator
.VerifyTwoFactorAsync(user, null, TwoFactorProviderType.Remember, "expired_remember_token")
.Returns(Task.FromResult(false));
// 5 -> set up dummy BuildTwoFactorResultAsync
// 6 -> set up dummy BuildTwoFactorResultAsync
var twoFactorResultDict = new Dictionary<string, object>
{
{ "TwoFactorProviders", new[] { "0", "1" } },
@@ -446,6 +472,19 @@ public class BaseRequestValidatorTests
GrantValidationResult grantResult)
{
// Arrange
// SsoRequestValidator sets custom response
requestContext.ValidationErrorResult = new ValidationResult
{
IsError = true,
Error = SsoConstants.RequestErrors.SsoRequired,
ErrorDescription = SsoConstants.RequestErrors.SsoRequiredDescription
};
requestContext.CustomResponse = new Dictionary<string, object>
{
{ CustomResponseConstants.ResponseKeys.ErrorModel, new ErrorResponseModel(SsoConstants.RequestErrors.SsoRequiredDescription) },
};
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
@@ -454,13 +493,17 @@ public class BaseRequestValidatorTests
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(true));
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(false));
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.True(context.GrantResult.IsError);
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
Assert.Equal("SSO authentication is required.", errorResponse.Message);
Assert.NotNull(context.GrantResult.CustomResponse);
var errorResponse = (ErrorResponseModel)context.CustomValidatorRequestContext.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel];
Assert.Equal(SsoConstants.RequestErrors.SsoRequiredDescription, errorResponse.Message);
}
// Test grantTypes with RequireSsoPolicyRequirement when feature flag is enabled
@@ -477,6 +520,20 @@ public class BaseRequestValidatorTests
{
// Arrange
_featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true);
// SsoRequestValidator sets custom response with organization identifier
requestContext.ValidationErrorResult = new ValidationResult
{
IsError = true,
Error = SsoConstants.RequestErrors.SsoRequired,
ErrorDescription = SsoConstants.RequestErrors.SsoRequiredDescription
};
requestContext.CustomResponse = new Dictionary<string, object>
{
{ CustomResponseConstants.ResponseKeys.ErrorModel, new ErrorResponseModel(SsoConstants.RequestErrors.SsoRequiredDescription) },
{ CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier, "test-org-identifier" }
};
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
@@ -485,6 +542,10 @@ public class BaseRequestValidatorTests
var requirement = new RequireSsoPolicyRequirement { SsoRequired = true };
_policyRequirementQuery.GetAsync<RequireSsoPolicyRequirement>(Arg.Any<Guid>()).Returns(requirement);
// Mock the SSO validator to return false
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(false));
// Act
await _sut.ValidateAsync(context);
@@ -492,8 +553,9 @@ public class BaseRequestValidatorTests
await _policyService.DidNotReceive().AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed);
Assert.True(context.GrantResult.IsError);
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
Assert.Equal("SSO authentication is required.", errorResponse.Message);
Assert.NotNull(context.GrantResult.CustomResponse);
var errorResponse = (ErrorResponseModel)context.CustomValidatorRequestContext.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel];
Assert.Equal(SsoConstants.RequestErrors.SsoRequiredDescription, errorResponse.Message);
}
[Theory]
@@ -519,6 +581,10 @@ public class BaseRequestValidatorTests
var requirement = new RequireSsoPolicyRequirement { SsoRequired = false };
_policyRequirementQuery.GetAsync<RequireSsoPolicyRequirement>(Arg.Any<Guid>()).Returns(requirement);
// SSO validation passes
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
_twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
@@ -561,6 +627,11 @@ public class BaseRequestValidatorTests
_policyService.AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(false));
// SSO validation passes
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
_twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
@@ -603,6 +674,10 @@ public class BaseRequestValidatorTests
context.ValidatedTokenRequest.GrantType = grantType;
// SSO validation passes
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
_twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
@@ -652,13 +727,15 @@ public class BaseRequestValidatorTests
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
.Returns(Task.FromResult(true));
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.True(context.GrantResult.IsError);
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel];
var expectedMessage =
"Legacy encryption without a userkey is no longer supported. To recover your account, please contact support";
Assert.Equal(expectedMessage, errorResponse.Message);
@@ -694,6 +771,10 @@ public class BaseRequestValidatorTests
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
// SSO validation passes
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
_twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
@@ -760,6 +841,8 @@ public class BaseRequestValidatorTests
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
.Returns(Task.FromResult(true));
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
@@ -833,6 +916,8 @@ public class BaseRequestValidatorTests
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
.Returns(Task.FromResult(true));
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
@@ -877,6 +962,8 @@ public class BaseRequestValidatorTests
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
.Returns(Task.FromResult(true));
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
@@ -921,6 +1008,8 @@ public class BaseRequestValidatorTests
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
.Returns(Task.FromResult(true));
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
@@ -950,6 +1039,19 @@ public class BaseRequestValidatorTests
GrantValidationResult grantResult)
{
// Arrange
// SsoRequestValidator sets custom response
requestContext.ValidationErrorResult = new ValidationResult
{
IsError = true,
Error = SsoConstants.RequestErrors.SsoRequired,
ErrorDescription = SsoConstants.RequestErrors.SsoRequiredDescription
};
requestContext.CustomResponse = new Dictionary<string, object>
{
{ CustomResponseConstants.ResponseKeys.ErrorModel, new ErrorResponseModel(SsoConstants.RequestErrors.SsoRequiredDescription) },
};
var context = CreateContext(tokenRequest, requestContext, grantResult);
var user = requestContext.User;
@@ -984,12 +1086,12 @@ public class BaseRequestValidatorTests
// Assert
Assert.True(context.GrantResult.IsError, "Authentication should fail - SSO required after recovery");
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
Assert.NotNull(context.GrantResult.CustomResponse);
var errorResponse = (ErrorResponseModel)context.CustomValidatorRequestContext.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel];
// Recovery succeeds, then SSO blocks with descriptive message
Assert.Equal(
"Two-factor recovery has been performed. SSO authentication is required.",
SsoConstants.RequestErrors.SsoRequiredDescription,
errorResponse.Message);
// Verify recovery was marked
@@ -1050,7 +1152,7 @@ public class BaseRequestValidatorTests
// Assert
Assert.True(context.GrantResult.IsError, "Authentication should fail - invalid recovery code");
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel];
// 2FA is checked first (due to recovery code request), fails with 2FA error
Assert.Equal(
@@ -1132,7 +1234,11 @@ public class BaseRequestValidatorTests
_userService.IsLegacyUser(Arg.Any<string>())
.Returns(false);
// 8. Setup user account keys for successful login response
// 8. SSO is not required
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// 9. Setup user account keys for successful login response
_userAccountKeysQuery.Run(Arg.Any<User>()).Returns(new UserAccountKeysData
{
PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData(
@@ -1161,179 +1267,18 @@ public class BaseRequestValidatorTests
}
/// <summary>
/// Tests that when RedirectOnSsoRequired is DISABLED, the legacy SSO validation path is used.
/// This validates the deprecated RequireSsoLoginAsync method is called and SSO requirement
/// is checked using the old PolicyService.AnyPoliciesApplicableToUserAsync approach.
/// Tests that when SSO validation returns a custom response, (e.g., with organization identifier),
/// that custom response is properly propagated to the result.
/// </summary>
[Theory]
[BitAutoData]
public async Task ValidateAsync_RedirectOnSsoRequired_Disabled_UsesLegacySsoValidation(
public async Task ValidateAsync_SsoRequired_PropagatesCustomResponse(
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext]
CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(false);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
tokenRequest.GrantType = OidcConstants.GrantTypes.Password;
// SSO is required via legacy path
_policyService.AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.True(context.GrantResult.IsError);
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
Assert.Equal("SSO authentication is required.", errorResponse.Message);
// Verify legacy path was used
await _policyService.Received(1).AnyPoliciesApplicableToUserAsync(
requestContext.User.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed);
// Verify new SsoRequestValidator was NOT called
await _ssoRequestValidator.DidNotReceive().ValidateAsync(
Arg.Any<User>(), Arg.Any<ValidatedTokenRequest>(), Arg.Any<CustomValidatorRequestContext>());
}
/// <summary>
/// Tests that when RedirectOnSsoRequired is ENABLED, the new ISsoRequestValidator is used
/// instead of the legacy RequireSsoLoginAsync method.
/// </summary>
[Theory]
[BitAutoData]
public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_UsesNewSsoRequestValidator(
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext]
CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
tokenRequest.GrantType = OidcConstants.GrantTypes.Password;
// Configure SsoRequestValidator to indicate SSO is required
_ssoRequestValidator.ValidateAsync(
Arg.Any<User>(),
Arg.Any<ValidatedTokenRequest>(),
Arg.Any<CustomValidatorRequestContext>())
.Returns(Task.FromResult(false)); // false = SSO required
// Set up the ValidationErrorResult that SsoRequestValidator would set
requestContext.ValidationErrorResult = new ValidationResult
{
IsError = true,
Error = "sso_required",
ErrorDescription = "SSO authentication is required."
};
requestContext.CustomResponse = new Dictionary<string, object>
{
{ "ErrorModel", new ErrorResponseModel("SSO authentication is required.") }
};
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.True(context.GrantResult.IsError);
// Verify new SsoRequestValidator was called
await _ssoRequestValidator.Received(1).ValidateAsync(
requestContext.User,
tokenRequest,
requestContext);
// Verify legacy path was NOT used
await _policyService.DidNotReceive().AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), Arg.Any<PolicyType>(), Arg.Any<OrganizationUserStatusType>());
}
/// <summary>
/// Tests that when RedirectOnSsoRequired is ENABLED and SSO is NOT required,
/// authentication continues successfully through the new validation path.
/// </summary>
[Theory]
[BitAutoData]
public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_SsoNotRequired_SuccessfulLogin(
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext]
CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
tokenRequest.GrantType = OidcConstants.GrantTypes.Password;
tokenRequest.ClientId = "web";
// SsoRequestValidator returns true (SSO not required)
_ssoRequestValidator.ValidateAsync(
Arg.Any<User>(),
Arg.Any<ValidatedTokenRequest>(),
Arg.Any<CustomValidatorRequestContext>())
.Returns(Task.FromResult(true));
// No 2FA required
_twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
// Device validation passes
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// User is not legacy
_userService.IsLegacyUser(Arg.Any<string>()).Returns(false);
_userAccountKeysQuery.Run(Arg.Any<User>()).Returns(new UserAccountKeysData
{
PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData(
"test-private-key",
"test-public-key"
)
});
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.False(context.GrantResult.IsError);
await _eventService.Received(1).LogUserEventAsync(requestContext.User.Id, EventType.User_LoggedIn);
// Verify new validator was used
await _ssoRequestValidator.Received(1).ValidateAsync(
requestContext.User,
tokenRequest,
requestContext);
}
/// <summary>
/// Tests that when RedirectOnSsoRequired is ENABLED and SSO validation returns a custom response
/// (e.g., with organization identifier), that custom response is properly propagated to the result.
/// </summary>
[Theory]
[BitAutoData]
public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_PropagatesCustomResponse(
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext]
CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true);
_sut.isValid = true;
tokenRequest.GrantType = OidcConstants.GrantTypes.Password;
@@ -1342,13 +1287,13 @@ public class BaseRequestValidatorTests
requestContext.ValidationErrorResult = new ValidationResult
{
IsError = true,
Error = "sso_required",
ErrorDescription = "SSO authentication is required."
Error = SsoConstants.RequestErrors.SsoRequired,
ErrorDescription = SsoConstants.RequestErrors.SsoRequiredDescription
};
requestContext.CustomResponse = new Dictionary<string, object>
{
{ "ErrorModel", new ErrorResponseModel("SSO authentication is required.") },
{ "SsoOrganizationIdentifier", "test-org-identifier" }
{ CustomResponseConstants.ResponseKeys.ErrorModel, new ErrorResponseModel(SsoConstants.RequestErrors.SsoRequiredDescription) },
{ CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier, "test-org-identifier" }
};
var context = CreateContext(tokenRequest, requestContext, grantResult);
@@ -1365,77 +1310,24 @@ public class BaseRequestValidatorTests
// Assert
Assert.True(context.GrantResult.IsError);
Assert.NotNull(context.GrantResult.CustomResponse);
Assert.Contains("SsoOrganizationIdentifier", context.CustomValidatorRequestContext.CustomResponse);
Assert.Contains(CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier, context.CustomValidatorRequestContext.CustomResponse);
Assert.Equal("test-org-identifier",
context.CustomValidatorRequestContext.CustomResponse["SsoOrganizationIdentifier"]);
context.CustomValidatorRequestContext.CustomResponse[CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier]);
}
/// <summary>
/// Tests that when RedirectOnSsoRequired is DISABLED and a user with 2FA recovery completes recovery,
/// but SSO is required, the legacy error message is returned (without the recovery-specific message).
/// </summary>
[Theory]
[BitAutoData]
public async Task ValidateAsync_RedirectOnSsoRequired_Disabled_RecoveryWithSso_LegacyMessage(
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext]
CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(false);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
// Recovery code scenario
tokenRequest.Raw["TwoFactorProvider"] = ((int)TwoFactorProviderType.RecoveryCode).ToString();
tokenRequest.Raw["TwoFactorToken"] = "valid-recovery-code";
// 2FA with recovery
_twoFactorAuthenticationValidator
.RequiresTwoFactorAsync(requestContext.User, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(true, null)));
_twoFactorAuthenticationValidator
.VerifyTwoFactorAsync(requestContext.User, null, TwoFactorProviderType.RecoveryCode, "valid-recovery-code")
.Returns(Task.FromResult(true));
// SSO is required (legacy check)
_policyService.AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.True(context.GrantResult.IsError);
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
// Legacy behavior: recovery-specific message IS shown even without RedirectOnSsoRequired
Assert.Equal("Two-factor recovery has been performed. SSO authentication is required.", errorResponse.Message);
// But legacy validation path was used
await _policyService.Received(1).AnyPoliciesApplicableToUserAsync(
requestContext.User.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed);
}
/// <summary>
/// Tests that when RedirectOnSsoRequired is ENABLED and recovery code is used for SSO-required user,
/// Tests that when a recovery code is used for SSO-required user,
/// the SsoRequestValidator provides the recovery-specific error message.
/// </summary>
[Theory]
[BitAutoData]
public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_RecoveryWithSso_NewValidatorMessage(
public async Task ValidateAsync_RecoveryWithSso_CorrectValidatorMessage(
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext]
CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
@@ -1457,14 +1349,14 @@ public class BaseRequestValidatorTests
requestContext.ValidationErrorResult = new ValidationResult
{
IsError = true,
Error = "sso_required",
ErrorDescription = "Two-factor recovery has been performed. SSO authentication is required."
Error = SsoConstants.RequestErrors.SsoRequired,
ErrorDescription = SsoConstants.RequestErrors.SsoTwoFactorRecoveryDescription
};
requestContext.CustomResponse = new Dictionary<string, object>
{
{
"ErrorModel",
new ErrorResponseModel("Two-factor recovery has been performed. SSO authentication is required.")
CustomResponseConstants.ResponseKeys.ErrorModel,
new ErrorResponseModel(SsoConstants.RequestErrors.SsoTwoFactorRecoveryDescription)
}
};
@@ -1479,18 +1371,8 @@ public class BaseRequestValidatorTests
// Assert
Assert.True(context.GrantResult.IsError);
var errorResponse = (ErrorResponseModel)context.CustomValidatorRequestContext.CustomResponse["ErrorModel"];
Assert.Equal("Two-factor recovery has been performed. SSO authentication is required.", errorResponse.Message);
// Verify new validator was used
await _ssoRequestValidator.Received(1).ValidateAsync(
requestContext.User,
tokenRequest,
Arg.Is<CustomValidatorRequestContext>(ctx => ctx.TwoFactorRecoveryRequested));
// Verify legacy path was NOT used
await _policyService.DidNotReceive().AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), Arg.Any<PolicyType>(), Arg.Any<OrganizationUserStatusType>());
var errorResponse = (ErrorResponseModel)context.CustomValidatorRequestContext.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel];
Assert.Equal(SsoConstants.RequestErrors.SsoTwoFactorRecoveryDescription, errorResponse.Message);
}
private BaseRequestValidationContextFake CreateContext(

View File

@@ -111,15 +111,6 @@ IBaseRequestValidatorTestWrapper
context.GrantResult = new GrantValidationResult(TokenRequestErrors.InvalidGrant, customResponse: customResponse);
}
[Obsolete]
protected override void SetSsoResult(
BaseRequestValidationContextFake context,
Dictionary<string, object> customResponse)
{
context.GrantResult = new GrantValidationResult(
TokenRequestErrors.InvalidGrant, "Sso authentication required.", customResponse);
}
protected override Task SetSuccessResult(
BaseRequestValidationContextFake context,
User user,