1
0
mirror of https://github.com/bitwarden/server synced 2026-02-21 03:43:44 +00:00

[PM-31040] Replace ISetupIntentCache with customer-based approach (#6954)

* docs(billing): add design document for replacing SetupIntent cache

* docs(billing): add implementation plan for replacing SetupIntent cache

* feat(db): add gateway lookup stored procedures for Organization, Provider, and User

* feat(db): add gateway lookup indexes to Organization, Provider, and User table definitions

* chore(db): add SQL Server migration for gateway lookup indexes and stored procedures

* feat(repos): add gateway lookup methods to IOrganizationRepository and Dapper implementation

* feat(repos): add gateway lookup methods to IProviderRepository and Dapper implementation

* feat(repos): add gateway lookup methods to IUserRepository and Dapper implementation

* feat(repos): add EF OrganizationRepository gateway lookup methods and index configuration

* feat(repos): add EF ProviderRepository gateway lookup methods and index configuration

* feat(repos): add EF UserRepository gateway lookup methods and index configuration

* chore(db): add EF migrations for gateway lookup indexes

* refactor(billing): update SetupIntentSucceededHandler to use repository instead of cache

* refactor(billing): simplify StripeEventService by expanding customer on SetupIntent

* refactor(billing): query Stripe for SetupIntents by customer ID in GetPaymentMethodQuery

* refactor(billing): query Stripe for SetupIntents by customer ID in HasPaymentMethodQuery

* refactor(billing): update OrganizationBillingService to set customer on SetupIntent

* refactor(billing): update ProviderBillingService to set customer on SetupIntent and query by customer

* refactor(billing): update UpdatePaymentMethodCommand to set customer on SetupIntent

* refactor(billing): remove bank account support from CreatePremiumCloudHostedSubscriptionCommand

* refactor(billing): remove OrganizationBillingService.UpdatePaymentMethod dead code

* refactor(billing): remove ProviderBillingService.UpdatePaymentMethod

* refactor(billing): remove PremiumUserBillingService.UpdatePaymentMethod and UserService.ReplacePaymentMethodAsync

* refactor(billing): remove SubscriberService.UpdatePaymentSource and related dead code

* refactor(billing): update SubscriberService.GetPaymentSourceAsync to query Stripe by customer ID

Add Task 15a to plan - this was a missed requirement for updating
GetPaymentSourceAsync which still used the cache.

* refactor(billing): complete removal of PremiumUserBillingService.Finalize and UserService.SignUpPremiumAsync

* refactor(billing): remove ISetupIntentCache and SetupIntentDistributedCache

* chore: remove temporary planning documents

* chore: run dotnet format

* fix(billing): add MaxLength(50) to Provider gateway ID properties

* chore(db): add EF migrations for Provider gateway column lengths

* chore: run dotnet format

* chore: rename SQL migration for chronological order
This commit is contained in:
Alex Morask
2026-02-18 13:20:25 -06:00
committed by GitHub
parent 2ce98277b4
commit cfd5bedae0
69 changed files with 22548 additions and 1892 deletions

View File

@@ -1,7 +1,6 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Services;
using Bit.Core.Repositories;
using OneOf;
@@ -11,10 +10,10 @@ using Event = Stripe.Event;
namespace Bit.Billing.Services.Implementations;
public class SetupIntentSucceededHandler(
ILogger<SetupIntentSucceededHandler> logger,
IOrganizationRepository organizationRepository,
IProviderRepository providerRepository,
IPushNotificationAdapter pushNotificationAdapter,
ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter,
IStripeEventService stripeEventService) : ISetupIntentSucceededHandler
{
@@ -27,23 +26,29 @@ public class SetupIntentSucceededHandler(
if (setupIntent is not
{
CustomerId: not null,
PaymentMethod.UsBankAccount: not null
})
{
logger.LogWarning("SetupIntent {SetupIntentId} has no customer ID or is not a US bank account", setupIntent.Id);
return;
}
var subscriberId = await setupIntentCache.GetSubscriberIdForSetupIntent(setupIntent.Id);
if (subscriberId == null)
var organization = await organizationRepository.GetByGatewayCustomerIdAsync(setupIntent.CustomerId);
if (organization != null)
{
await SetPaymentMethodAsync(organization, setupIntent.PaymentMethod);
return;
}
var organization = await organizationRepository.GetByIdAsync(subscriberId.Value);
var provider = await providerRepository.GetByIdAsync(subscriberId.Value);
var provider = await providerRepository.GetByGatewayCustomerIdAsync(setupIntent.CustomerId);
if (provider != null)
{
await SetPaymentMethodAsync(provider, setupIntent.PaymentMethod);
return;
}
OneOf<Organization, Provider> entity = organization != null ? organization : provider!;
await SetPaymentMethodAsync(entity, setupIntent.PaymentMethod);
logger.LogError("No organization or provider found for customer {CustomerId}", setupIntent.CustomerId);
}
private async Task SetPaymentMethodAsync(

View File

@@ -1,7 +1,4 @@
using Bit.Billing.Constants;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Caches;
using Bit.Core.Repositories;
using Bit.Core.Settings;
using Stripe;
@@ -9,10 +6,6 @@ namespace Bit.Billing.Services.Implementations;
public class StripeEventService(
GlobalSettings globalSettings,
ILogger<StripeEventService> logger,
IOrganizationRepository organizationRepository,
IProviderRepository providerRepository,
ISetupIntentCache setupIntentCache,
IStripeFacade stripeFacade)
: IStripeEventService
{
@@ -117,7 +110,7 @@ public class StripeEventService(
(await GetCustomer(stripeEvent, true)).Metadata,
HandledStripeWebhook.SetupIntentSucceeded =>
await GetCustomerMetadataFromSetupIntentSucceededEvent(stripeEvent),
(await GetSetupIntent(stripeEvent, true, customerExpansion)).Customer?.Metadata,
_ => null
};
@@ -144,43 +137,6 @@ public class StripeEventService(
return customer?.Metadata;
}
async Task<Dictionary<string, string>?> GetCustomerMetadataFromSetupIntentSucceededEvent(Event localStripeEvent)
{
var setupIntent = await GetSetupIntent(localStripeEvent);
logger.LogInformation("Extracted Setup Intent ({SetupIntentId}) from Stripe 'setup_intent.succeeded' event", setupIntent.Id);
var subscriberId = await setupIntentCache.GetSubscriberIdForSetupIntent(setupIntent.Id);
logger.LogInformation("Retrieved subscriber ID ({SubscriberId}) from cache for Setup Intent ({SetupIntentId})", subscriberId, setupIntent.Id);
if (subscriberId == null)
{
logger.LogError("Cached subscriber ID for Setup Intent ({SetupIntentId}) is null", setupIntent.Id);
return null;
}
var organization = await organizationRepository.GetByIdAsync(subscriberId.Value);
logger.LogInformation("Retrieved organization ({OrganizationId}) via subscriber ID for Setup Intent ({SetupIntentId})", organization?.Id, setupIntent.Id);
if (organization is { GatewayCustomerId: not null })
{
var organizationCustomer = await stripeFacade.GetCustomer(organization.GatewayCustomerId);
logger.LogInformation("Retrieved customer ({CustomerId}) via organization ID for Setup Intent ({SetupIntentId})", organization.Id, setupIntent.Id);
return organizationCustomer.Metadata;
}
var provider = await providerRepository.GetByIdAsync(subscriberId.Value);
logger.LogInformation("Retrieved provider ({ProviderId}) via subscriber ID for Setup Intent ({SetupIntentId})", provider?.Id, setupIntent.Id);
if (provider is not { GatewayCustomerId: not null })
{
return null;
}
var providerCustomer = await stripeFacade.GetCustomer(provider.GatewayCustomerId);
logger.LogInformation("Retrieved customer ({CustomerId}) via provider ID for Setup Intent ({SetupIntentId})", provider.Id, setupIntent.Id);
return providerCustomer.Metadata;
}
}
private static T Extract<T>(Event stripeEvent)

View File

@@ -1,4 +1,5 @@
using System.Net;
using System.ComponentModel.DataAnnotations;
using System.Net;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.Entities;
using Bit.Core.Enums;
@@ -33,7 +34,9 @@ public class Provider : ITableObject<Guid>, ISubscriber
public DateTime CreationDate { get; internal set; } = DateTime.UtcNow;
public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow;
public GatewayType? Gateway { get; set; }
[MaxLength(50)]
public string? GatewayCustomerId { get; set; }
[MaxLength(50)]
public string? GatewaySubscriptionId { get; set; }
public string? DiscountId { get; set; }

View File

@@ -9,6 +9,8 @@ namespace Bit.Core.Repositories;
public interface IOrganizationRepository : IRepository<Organization, Guid>
{
Task<Organization?> GetByGatewayCustomerIdAsync(string gatewayCustomerId);
Task<Organization?> GetByGatewaySubscriptionIdAsync(string gatewaySubscriptionId);
Task<Organization?> GetByIdentifierAsync(string identifier);
Task<ICollection<Organization>> GetManyByEnabledAsync();
Task<ICollection<Organization>> GetManyByUserIdAsync(Guid userId);

View File

@@ -8,6 +8,8 @@ namespace Bit.Core.AdminConsole.Repositories;
public interface IProviderRepository : IRepository<Provider, Guid>
{
Task<Provider?> GetByGatewayCustomerIdAsync(string gatewayCustomerId);
Task<Provider?> GetByGatewaySubscriptionIdAsync(string gatewaySubscriptionId);
Task<Provider?> GetByOrganizationIdAsync(Guid organizationId);
Task<ICollection<Provider>> SearchAsync(string name, string userEmail, int skip, int take);
Task<ICollection<ProviderAbility>> GetManyAbilitiesAsync();

View File

@@ -1,9 +0,0 @@
namespace Bit.Core.Billing.Caches;
public interface ISetupIntentCache
{
Task<string?> GetSetupIntentIdForSubscriber(Guid subscriberId);
Task<Guid?> GetSubscriberIdForSetupIntent(string setupIntentId);
Task RemoveSetupIntentForSubscriber(Guid subscriberId);
Task Set(Guid subscriberId, string setupIntentId);
}

View File

@@ -1,50 +0,0 @@
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
namespace Bit.Core.Billing.Caches.Implementations;
public class SetupIntentDistributedCache(
[FromKeyedServices("persistent")]
IDistributedCache distributedCache,
ILogger<SetupIntentDistributedCache> logger) : ISetupIntentCache
{
public async Task<string?> GetSetupIntentIdForSubscriber(Guid subscriberId)
{
var cacheKey = GetCacheKeyBySubscriberId(subscriberId);
return await distributedCache.GetStringAsync(cacheKey);
}
public async Task<Guid?> GetSubscriberIdForSetupIntent(string setupIntentId)
{
var cacheKey = GetCacheKeyBySetupIntentId(setupIntentId);
var value = await distributedCache.GetStringAsync(cacheKey);
if (!string.IsNullOrEmpty(value) && Guid.TryParse(value, out var subscriberId))
{
return subscriberId;
}
logger.LogError("Subscriber ID value ({Value}) cached for Setup Intent ({SetupIntentId}) is null or not a valid Guid", value, setupIntentId);
return null;
}
public async Task RemoveSetupIntentForSubscriber(Guid subscriberId)
{
var cacheKey = GetCacheKeyBySubscriberId(subscriberId);
await distributedCache.RemoveAsync(cacheKey);
}
public async Task Set(Guid subscriberId, string setupIntentId)
{
var bySubscriberIdCacheKey = GetCacheKeyBySubscriberId(subscriberId);
var bySetupIntentIdCacheKey = GetCacheKeyBySetupIntentId(setupIntentId);
await Task.WhenAll(
distributedCache.SetStringAsync(bySubscriberIdCacheKey, setupIntentId),
distributedCache.SetStringAsync(bySetupIntentIdCacheKey, subscriberId.ToString()));
}
private static string GetCacheKeyBySetupIntentId(string setupIntentId) =>
$"subscriber_id_for_setup_intent_id_{setupIntentId}";
private static string GetCacheKeyBySubscriberId(Guid subscriberId) =>
$"setup_intent_id_for_subscriber_id_{subscriberId}";
}

View File

@@ -1,6 +1,4 @@
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Caches.Implementations;
using Bit.Core.Billing.Licenses;
using Bit.Core.Billing.Licenses;
using Bit.Core.Billing.Licenses.Extensions;
using Bit.Core.Billing.Organizations.Commands;
using Bit.Core.Billing.Organizations.Queries;
@@ -29,7 +27,6 @@ public static class ServiceCollectionExtensions
services.AddSingleton<ITaxService, TaxService>();
services.AddTransient<IOrganizationBillingService, OrganizationBillingService>();
services.AddTransient<IPremiumUserBillingService, PremiumUserBillingService>();
services.AddTransient<ISetupIntentCache, SetupIntentDistributedCache>();
services.AddTransient<ISubscriberService, SubscriberService>();
services.AddLicenseServices();
services.AddLicenseOperations();

View File

@@ -1,8 +1,6 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Organizations.Models;
using Bit.Core.Billing.Tax.Models;
namespace Bit.Core.Billing.Organizations.Services;
@@ -26,19 +24,6 @@ public interface IOrganizationBillingService
/// </example>
Task Finalize(OrganizationSale sale);
/// <summary>
/// Updates the provided <paramref name="organization"/>'s payment source and tax information.
/// If the <paramref name="organization"/> does not have a Stripe <see cref="Stripe.Customer"/>, this method will create one using the provided
/// <paramref name="tokenizedPaymentSource"/> and <paramref name="taxInformation"/>.
/// </summary>
/// <param name="organization">The <paramref name="organization"/> to update the payment source and tax information for.</param>
/// <param name="tokenizedPaymentSource">The tokenized payment source (ex. Credit Card) to attach to the <paramref name="organization"/>.</param>
/// <param name="taxInformation">The <paramref name="organization"/>'s updated tax information.</param>
Task UpdatePaymentMethod(
Organization organization,
TokenizedPaymentSource tokenizedPaymentSource,
TaxInformation taxInformation);
/// <summary>
/// Updates the subscription with new plan frequencies and changes the collection method to charge_automatically if a valid payment method exists.
/// Validates that the customer has a payment method attached before switching to automatic charging.

View File

@@ -1,15 +1,12 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Organizations.Models;
using Bit.Core.Billing.Payment.Queries;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
@@ -31,7 +28,6 @@ public class OrganizationBillingService(
ILogger<OrganizationBillingService> logger,
IOrganizationRepository organizationRepository,
IPricingClient pricingClient,
ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter,
ISubscriberService subscriberService,
ITaxService taxService) : IOrganizationBillingService
@@ -54,33 +50,6 @@ public class OrganizationBillingService(
}
}
public async Task UpdatePaymentMethod(
Organization organization,
TokenizedPaymentSource tokenizedPaymentSource,
TaxInformation taxInformation)
{
if (string.IsNullOrEmpty(organization.GatewayCustomerId))
{
var customer = await CreateCustomerAsync(organization,
new CustomerSetup
{
TokenizedPaymentSource = tokenizedPaymentSource,
TaxInformation = taxInformation
});
organization.Gateway = GatewayType.Stripe;
organization.GatewayCustomerId = customer.Id;
await organizationRepository.ReplaceAsync(organization);
}
else
{
await subscriberService.UpdatePaymentSource(organization, tokenizedPaymentSource);
await subscriberService.UpdateTaxInformation(organization, taxInformation);
await UpdateMissingPaymentMethodBehaviourAsync(organization);
}
}
public async Task UpdateSubscriptionPlanFrequency(
Organization organization, PlanType newPlanType)
{
@@ -203,6 +172,7 @@ public class OrganizationBillingService(
};
var braintreeCustomerId = "";
var setupIntentId = "";
if (customerSetup.IsBillable)
{
@@ -296,7 +266,7 @@ public class OrganizationBillingService(
throw new BillingException();
}
await setupIntentCache.Set(organization.Id, setupIntent.Id);
setupIntentId = setupIntent.Id;
break;
}
case PaymentMethodType.Card:
@@ -323,6 +293,12 @@ public class OrganizationBillingService(
{
var customer = await stripeAdapter.CreateCustomerAsync(customerCreateOptions);
if (!string.IsNullOrEmpty(setupIntentId))
{
await stripeAdapter.UpdateSetupIntentAsync(setupIntentId,
new SetupIntentUpdateOptions { Customer = customer.Id });
}
organization.Gateway = GatewayType.Stripe;
organization.GatewayCustomerId = customer.Id;
await organizationRepository.ReplaceAsync(organization);
@@ -356,11 +332,6 @@ public class OrganizationBillingService(
// ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault
switch (customerSetup.TokenizedPaymentSource!.Type)
{
case PaymentMethodType.BankAccount:
{
await setupIntentCache.RemoveSetupIntentForSubscriber(organization.Id);
break;
}
case PaymentMethodType.PayPal when !string.IsNullOrEmpty(braintreeCustomerId):
{
await braintreeGateway.Customer.DeleteAsync(braintreeCustomerId);
@@ -519,24 +490,6 @@ public class OrganizationBillingService(
return customer;
}
private async Task UpdateMissingPaymentMethodBehaviourAsync(Organization organization)
{
var subscription = await subscriberService.GetSubscriptionOrThrow(organization);
if (subscription.TrialSettings?.EndBehavior?.MissingPaymentMethod == StripeConstants.MissingPaymentMethodBehaviorOptions.Cancel)
{
var options = new SubscriptionUpdateOptions
{
TrialSettings = new SubscriptionTrialSettingsOptions
{
EndBehavior = new SubscriptionTrialSettingsEndBehaviorOptions
{
MissingPaymentMethod = StripeConstants.MissingPaymentMethodBehaviorOptions.CreateInvoice
}
}
};
await stripeAdapter.UpdateSubscriptionAsync(organization.GatewaySubscriptionId, options);
}
}
#endregion
}

View File

@@ -1,5 +1,4 @@
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Services;
@@ -28,7 +27,6 @@ public class UpdatePaymentMethodCommand(
IBraintreeService braintreeService,
IGlobalSettings globalSettings,
ILogger<UpdatePaymentMethodCommand> logger,
ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter,
ISubscriberService subscriberService) : BaseBillingCommand<UpdatePaymentMethodCommand>(logger), IUpdatePaymentMethodCommand
{
@@ -95,9 +93,10 @@ public class UpdatePaymentMethodCommand(
var setupIntent = setupIntents.First();
await setupIntentCache.Set(subscriber.Id, setupIntent.Id);
await stripeAdapter.UpdateSetupIntentAsync(setupIntent.Id,
new SetupIntentUpdateOptions { Customer = customer.Id });
_logger.LogInformation("{Command}: Successfully cached Setup Intent ({SetupIntentId}) for subscriber ({SubscriberID})", CommandName, setupIntent.Id, subscriber.Id);
_logger.LogInformation("{Command}: Successfully linked Setup Intent ({SetupIntentId}) to customer ({CustomerId}) for subscriber ({SubscriberID})", CommandName, setupIntent.Id, customer.Id, subscriber.Id);
await UnlinkBraintreeCustomerAsync(customer);

View File

@@ -1,5 +1,4 @@
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
@@ -16,7 +15,6 @@ public interface IGetPaymentMethodQuery
public class GetPaymentMethodQuery(
IBraintreeService braintreeService,
ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter,
ISubscriberService subscriberService) : IGetPaymentMethodQuery
{
@@ -39,19 +37,17 @@ public class GetPaymentMethodQuery(
}
// Then check for a bank account pending verification
var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(subscriber.Id);
if (!string.IsNullOrEmpty(setupIntentId))
var setupIntents = await stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions
{
var setupIntent = await stripeAdapter.GetSetupIntentAsync(setupIntentId, new SetupIntentGetOptions
{
Expand = ["payment_method"]
});
Customer = customer.Id,
Expand = ["data.payment_method"]
});
if (setupIntent.IsUnverifiedBankAccount())
{
return MaskedPaymentMethod.From(setupIntent);
}
var unverifiedBankAccount = setupIntents?.FirstOrDefault(si => si.IsUnverifiedBankAccount());
if (unverifiedBankAccount != null)
{
return MaskedPaymentMethod.From(unverifiedBankAccount);
}
// Then check the default payment method

View File

@@ -1,5 +1,4 @@
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
@@ -15,21 +14,20 @@ public interface IHasPaymentMethodQuery
}
public class HasPaymentMethodQuery(
ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter,
ISubscriberService subscriberService) : IHasPaymentMethodQuery
{
public async Task<bool> Run(ISubscriber subscriber)
{
var hasUnverifiedBankAccount = await HasUnverifiedBankAccountAsync(subscriber);
var customer = await subscriberService.GetCustomer(subscriber);
if (customer == null)
{
return hasUnverifiedBankAccount;
return false;
}
var hasUnverifiedBankAccount = await HasUnverifiedBankAccountAsync(customer.Id);
return
!string.IsNullOrEmpty(customer.InvoiceSettings.DefaultPaymentMethodId) ||
!string.IsNullOrEmpty(customer.DefaultSourceId) ||
@@ -37,21 +35,14 @@ public class HasPaymentMethodQuery(
customer.Metadata.ContainsKey(MetadataKeys.BraintreeCustomerId);
}
private async Task<bool> HasUnverifiedBankAccountAsync(
ISubscriber subscriber)
private async Task<bool> HasUnverifiedBankAccountAsync(string customerId)
{
var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(subscriber.Id);
if (string.IsNullOrEmpty(setupIntentId))
var setupIntents = await stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions
{
return false;
}
var setupIntent = await stripeAdapter.GetSetupIntentAsync(setupIntentId, new SetupIntentGetOptions
{
Expand = ["payment_method"]
Customer = customerId,
Expand = ["data.payment_method"]
});
return setupIntent.IsUnverifiedBankAccount();
return setupIntents?.Any(si => si.IsUnverifiedBankAccount()) ?? false;
}
}

View File

@@ -1,5 +1,4 @@
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Payment.Commands;
@@ -52,7 +51,6 @@ public class CreatePremiumCloudHostedSubscriptionCommand(
IBraintreeGateway braintreeGateway,
IBraintreeService braintreeService,
IGlobalSettings globalSettings,
ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter,
ISubscriberService subscriberService,
IUserService userService,
@@ -218,21 +216,6 @@ public class CreatePremiumCloudHostedSubscriptionCommand(
var tokenizedPaymentMethod = paymentMethod.AsTokenized;
switch (tokenizedPaymentMethod.Type)
{
case TokenizablePaymentMethodType.BankAccount:
{
var setupIntent =
(await stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions { PaymentMethod = tokenizedPaymentMethod.Token }))
.FirstOrDefault();
if (setupIntent == null)
{
_logger.LogError("Cannot create customer for user ({UserID}) without a setup intent for their bank account", user.Id);
throw new BillingException();
}
await setupIntentCache.Set(user.Id, setupIntent.Id);
break;
}
case TokenizablePaymentMethodType.Card:
{
customerCreateOptions.PaymentMethod = tokenizedPaymentMethod.Token;
@@ -267,11 +250,6 @@ public class CreatePremiumCloudHostedSubscriptionCommand(
// ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault
switch (tokenizedPaymentMethod.Type)
{
case TokenizablePaymentMethodType.BankAccount:
{
await setupIntentCache.RemoveSetupIntentForSubscriber(user.Id);
break;
}
case TokenizablePaymentMethodType.PayPal when !string.IsNullOrEmpty(braintreeCustomerId):
{
await braintreeGateway.Customer.DeleteAsync(braintreeCustomerId);

View File

@@ -4,11 +4,9 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Providers.Entities;
using Bit.Core.Billing.Providers.Models;
using Bit.Core.Billing.Tax.Models;
using Stripe;
namespace Bit.Core.Billing.Providers.Services;
@@ -101,17 +99,6 @@ public interface IProviderBillingService
Task<Subscription> SetupSubscription(
Provider provider);
/// <summary>
/// Updates the <paramref name="provider"/>'s payment source and tax information and then sets their subscription's collection_method to be "charge_automatically".
/// </summary>
/// <param name="provider">The <paramref name="provider"/> to update the payment source and tax information for.</param>
/// <param name="tokenizedPaymentSource">The tokenized payment source (ex. Credit Card) to attach to the <paramref name="provider"/>.</param>
/// <param name="taxInformation">The <paramref name="provider"/>'s updated tax information.</param>
Task UpdatePaymentMethod(
Provider provider,
TokenizedPaymentSource tokenizedPaymentSource,
TaxInformation taxInformation);
Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command);
/// <summary>

View File

@@ -1,39 +1,8 @@
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Entities;
using Bit.Core.Entities;
namespace Bit.Core.Billing.Services;
public interface IPremiumUserBillingService
{
Task Credit(User user, decimal amount);
/// <summary>
/// <para>Establishes the Stripe entities necessary for a Bitwarden <see cref="User"/> using the provided <paramref name="sale"/>.</para>
/// <para>
/// The method first checks to see if the
/// provided <see cref="PremiumUserSale.User"/> already has a Stripe <see cref="Stripe.Customer"/> using the <see cref="User.GatewayCustomerId"/>.
/// If it doesn't, the method creates one using the <paramref name="sale"/>'s <see cref="PremiumUserSale.CustomerSetup"/>. The method then creates a Stripe <see cref="Stripe.Subscription"/>
/// for the created or existing customer while appending the provided <paramref name="sale"/>'s <see cref="PremiumUserSale.Storage"/>.
/// </para>
/// </summary>
/// <param name="sale">The data required to establish the Stripe entities responsible for billing the premium user.</param>
/// <example>
/// <code>
/// var sale = PremiumUserSale.From(
/// user,
/// paymentMethodType,
/// paymentMethodToken,
/// taxInfo,
/// storage);
/// await premiumUserBillingService.Finalize(sale);
/// </code>
/// </example>
Task Finalize(PremiumUserSale sale);
Task UpdatePaymentMethod(
User user,
TokenizedPaymentSource tokenizedPaymentSource,
TaxInformation taxInformation);
}

View File

@@ -47,6 +47,7 @@ public interface IStripeAdapter
Task<List<SetupIntent>> ListSetupIntentsAsync(SetupIntentListOptions options);
Task CancelSetupIntentAsync(string id, SetupIntentCancelOptions options = null);
Task<SetupIntent> GetSetupIntentAsync(string id, SetupIntentGetOptions options = null);
Task<SetupIntent> UpdateSetupIntentAsync(string id, SetupIntentUpdateOptions options = null);
Task<Price> GetPriceAsync(string id, PriceGetOptions options = null);
Task<Coupon> GetCouponAsync(string couponId, CouponGetOptions options = null);
Task<List<Product>> ListProductsAsync(ProductListOptions options = null);

View File

@@ -4,7 +4,6 @@
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Stripe;
namespace Bit.Core.Billing.Services;
@@ -105,18 +104,6 @@ public interface ISubscriberService
/// <param name="subscriber">The subscriber to remove the saved payment source for.</param>
Task RemovePaymentSource(ISubscriber subscriber);
/// <summary>
/// Updates the payment source for the provided <paramref name="subscriber"/> using the <paramref name="tokenizedPaymentSource"/>.
/// The following types are supported: [<see cref="PaymentMethodType.Card"/>, <see cref="PaymentMethodType.BankAccount"/>, <see cref="PaymentMethodType.PayPal"/>].
/// For each type, updating the payment source will attempt to establish a new payment source using the token in the <see cref="TokenizedPaymentSource"/>. Then, it will
/// remove the exising payment source(s) linked to the subscriber's customer.
/// </summary>
/// <param name="subscriber">The subscriber to update the payment method for.</param>
/// <param name="tokenizedPaymentSource">A DTO representing a tokenized payment method.</param>
Task UpdatePaymentSource(
ISubscriber subscriber,
TokenizedPaymentSource tokenizedPaymentSource);
/// <summary>
/// Updates the tax information for the provided <paramref name="subscriber"/>.
/// </summary>

View File

@@ -1,37 +1,16 @@
// FIXME: Update this file to be null safe and then delete the line below
#nullable disable
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Entities;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Core.Settings;
using Braintree;
using Microsoft.Extensions.Logging;
using Stripe;
using Customer = Stripe.Customer;
using Subscription = Stripe.Subscription;
namespace Bit.Core.Billing.Services.Implementations;
using static Utilities;
public class PremiumUserBillingService(
IBraintreeGateway braintreeGateway,
IGlobalSettings globalSettings,
ILogger<PremiumUserBillingService> logger,
ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter,
ISubscriberService subscriberService,
IUserRepository userRepository,
IPricingClient pricingClient) : IPremiumUserBillingService
IUserRepository userRepository) : IPremiumUserBillingService
{
public async Task Credit(User user, decimal amount)
{
@@ -83,309 +62,4 @@ public class PremiumUserBillingService(
await stripeAdapter.UpdateCustomerAsync(customer.Id, options);
}
}
public async Task Finalize(PremiumUserSale sale)
{
var (user, customerSetup, storage) = sale;
List<string> expand = ["tax"];
var customer = string.IsNullOrEmpty(user.GatewayCustomerId)
? await CreateCustomerAsync(user, customerSetup)
: await subscriberService.GetCustomerOrThrow(user, new CustomerGetOptions { Expand = expand });
/*
* If the customer was previously set up with credit, which does not require a billing location,
* we need to update the customer on the fly before we start the subscription.
*/
customer = await ReconcileBillingLocationAsync(customer, customerSetup.TaxInformation);
var premiumPlan = await pricingClient.GetAvailablePremiumPlan();
var subscription = await CreateSubscriptionAsync(user.Id, customer, premiumPlan, storage);
switch (customerSetup.TokenizedPaymentSource)
{
case { Type: PaymentMethodType.PayPal }
when subscription.Status == StripeConstants.SubscriptionStatus.Incomplete:
case { Type: not PaymentMethodType.PayPal }
when subscription.Status == StripeConstants.SubscriptionStatus.Active:
{
user.Premium = true;
user.PremiumExpirationDate = subscription.GetCurrentPeriodEnd();
break;
}
}
user.Gateway = GatewayType.Stripe;
user.GatewayCustomerId = customer.Id;
user.GatewaySubscriptionId = subscription.Id;
user.MaxStorageGb = (short)(premiumPlan.Storage.Provided + (storage ?? 0));
await userRepository.ReplaceAsync(user);
}
public async Task UpdatePaymentMethod(
User user,
TokenizedPaymentSource tokenizedPaymentSource,
TaxInformation taxInformation)
{
if (string.IsNullOrEmpty(user.GatewayCustomerId))
{
var customer = await CreateCustomerAsync(user,
new CustomerSetup { TokenizedPaymentSource = tokenizedPaymentSource, TaxInformation = taxInformation });
user.Gateway = GatewayType.Stripe;
user.GatewayCustomerId = customer.Id;
await userRepository.ReplaceAsync(user);
}
else
{
await subscriberService.UpdatePaymentSource(user, tokenizedPaymentSource);
await subscriberService.UpdateTaxInformation(user, taxInformation);
}
}
private async Task<Customer> CreateCustomerAsync(
User user,
CustomerSetup customerSetup)
{
/*
* Creating a Customer via the adding of a payment method or the purchasing of a subscription requires
* an actual payment source. The only time this is not the case is when the Customer is created when the
* User purchases credit.
*/
if (customerSetup.TokenizedPaymentSource is not
{
Type: PaymentMethodType.BankAccount or PaymentMethodType.Card or PaymentMethodType.PayPal,
Token: not null and not ""
})
{
logger.LogError(
"Cannot create customer for user ({UserID}) without a valid payment source", user.Id);
throw new BillingException();
}
if (customerSetup.TaxInformation is not { Country: not null and not "", PostalCode: not null and not "" })
{
logger.LogError(
"Cannot create customer for user ({UserID}) without valid tax information", user.Id);
throw new BillingException();
}
var subscriberName = user.SubscriberName();
var customerCreateOptions = new CustomerCreateOptions
{
Address = new AddressOptions
{
Line1 = customerSetup.TaxInformation.Line1,
Line2 = customerSetup.TaxInformation.Line2,
City = customerSetup.TaxInformation.City,
PostalCode = customerSetup.TaxInformation.PostalCode,
State = customerSetup.TaxInformation.State,
Country = customerSetup.TaxInformation.Country
},
Description = user.Name,
Email = user.Email,
Expand = ["tax"],
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
CustomFields =
[
new CustomerInvoiceSettingsCustomFieldOptions
{
Name = user.SubscriberType(),
Value = subscriberName.Length <= 30
? subscriberName
: subscriberName[..30]
}
]
},
Metadata = new Dictionary<string, string>
{
["region"] = globalSettings.BaseServiceUri.CloudRegion,
["userId"] = user.Id.ToString()
},
Tax = new CustomerTaxOptions
{
ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately
}
};
var (paymentMethodType, paymentMethodToken) = customerSetup.TokenizedPaymentSource;
var braintreeCustomerId = "";
// ReSharper disable once SwitchStatementHandlesSomeKnownEnumValuesWithDefault
switch (paymentMethodType)
{
case PaymentMethodType.BankAccount:
{
var setupIntent =
(await stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions { PaymentMethod = paymentMethodToken }))
.FirstOrDefault();
if (setupIntent == null)
{
logger.LogError("Cannot create customer for user ({UserID}) without a setup intent for their bank account", user.Id);
throw new BillingException();
}
await setupIntentCache.Set(user.Id, setupIntent.Id);
break;
}
case PaymentMethodType.Card:
{
customerCreateOptions.PaymentMethod = paymentMethodToken;
customerCreateOptions.InvoiceSettings.DefaultPaymentMethod = paymentMethodToken;
break;
}
case PaymentMethodType.PayPal:
{
braintreeCustomerId = await subscriberService.CreateBraintreeCustomer(user, paymentMethodToken);
customerCreateOptions.Metadata[BraintreeCustomerIdKey] = braintreeCustomerId;
break;
}
default:
{
logger.LogError("Cannot create customer for user ({UserID}) using payment method type ({PaymentMethodType}) as it is not supported", user.Id, paymentMethodType.ToString());
throw new BillingException();
}
}
try
{
return await stripeAdapter.CreateCustomerAsync(customerCreateOptions);
}
catch (StripeException stripeException) when (stripeException.StripeError?.Code ==
StripeConstants.ErrorCodes.CustomerTaxLocationInvalid)
{
await Revert();
throw new BadRequestException(
"Your location wasn't recognized. Please ensure your country and postal code are valid.");
}
catch (StripeException stripeException) when (stripeException.StripeError?.Code ==
StripeConstants.ErrorCodes.TaxIdInvalid)
{
await Revert();
throw new BadRequestException(
"Your tax ID wasn't recognized for your selected country. Please ensure your country and tax ID are valid.");
}
catch
{
await Revert();
throw;
}
async Task Revert()
{
// ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault
switch (customerSetup.TokenizedPaymentSource!.Type)
{
case PaymentMethodType.BankAccount:
{
await setupIntentCache.RemoveSetupIntentForSubscriber(user.Id);
break;
}
case PaymentMethodType.PayPal when !string.IsNullOrEmpty(braintreeCustomerId):
{
await braintreeGateway.Customer.DeleteAsync(braintreeCustomerId);
break;
}
}
}
}
private async Task<Subscription> CreateSubscriptionAsync(
Guid userId,
Customer customer,
Pricing.Premium.Plan premiumPlan,
int? storage)
{
var subscriptionItemOptionsList = new List<SubscriptionItemOptions>
{
new ()
{
Price = premiumPlan.Seat.StripePriceId,
Quantity = 1
}
};
if (storage is > 0)
{
subscriptionItemOptionsList.Add(new SubscriptionItemOptions
{
Price = premiumPlan.Storage.StripePriceId,
Quantity = storage
});
}
var usingPayPal = customer.Metadata?.ContainsKey(BraintreeCustomerIdKey) ?? false;
var subscriptionCreateOptions = new SubscriptionCreateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
},
CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically,
Customer = customer.Id,
Items = subscriptionItemOptionsList,
Metadata = new Dictionary<string, string>
{
["userId"] = userId.ToString()
},
PaymentBehavior = usingPayPal
? StripeConstants.PaymentBehavior.DefaultIncomplete
: null,
OffSession = true
};
var subscription = await stripeAdapter.CreateSubscriptionAsync(subscriptionCreateOptions);
if (usingPayPal)
{
await stripeAdapter.UpdateInvoiceAsync(subscription.LatestInvoiceId, new InvoiceUpdateOptions
{
AutoAdvance = false
});
}
return subscription;
}
private async Task<Customer> ReconcileBillingLocationAsync(
Customer customer,
TaxInformation taxInformation)
{
if (customer is { Address: { Country: not null and not "", PostalCode: not null and not "" } })
{
return customer;
}
var options = new CustomerUpdateOptions
{
Address = new AddressOptions
{
Line1 = taxInformation.Line1,
Line2 = taxInformation.Line2,
City = taxInformation.City,
PostalCode = taxInformation.PostalCode,
State = taxInformation.State,
Country = taxInformation.Country
},
Expand = ["tax"],
Tax = new CustomerTaxOptions
{
ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately
}
};
return await stripeAdapter.UpdateCustomerAsync(customer.Id, options);
}
}

View File

@@ -199,6 +199,9 @@ public class StripeAdapter : IStripeAdapter
public Task<SetupIntent> GetSetupIntentAsync(string id, SetupIntentGetOptions options = null) =>
_setupIntentService.GetAsync(id, options);
public Task<SetupIntent> UpdateSetupIntentAsync(string id, SetupIntentUpdateOptions options = null) =>
_setupIntentService.UpdateAsync(id, options);
/*******************
** MISCELLANEOUS **
*******************/

View File

@@ -4,7 +4,6 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
@@ -35,7 +34,6 @@ public class SubscriberService(
ILogger<SubscriberService> logger,
IOrganizationRepository organizationRepository,
IProviderRepository providerRepository,
ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter,
ITaxService taxService,
IUserRepository userRepository) : ISubscriberService
@@ -338,7 +336,7 @@ public class SubscriberService(
Expand = ["default_source", "invoice_settings.default_payment_method"]
});
return await GetPaymentSourceAsync(subscriber.Id, customer);
return await GetPaymentSourceAsync(customer);
}
public async Task<Subscription> GetSubscription(
@@ -507,130 +505,6 @@ public class SubscriberService(
}
}
public async Task UpdatePaymentSource(
ISubscriber subscriber,
TokenizedPaymentSource tokenizedPaymentSource)
{
ArgumentNullException.ThrowIfNull(subscriber);
ArgumentNullException.ThrowIfNull(tokenizedPaymentSource);
var customerGetOptions = new CustomerGetOptions { Expand = ["tax", "tax_ids"] };
var customer = await GetCustomerOrThrow(subscriber, customerGetOptions);
var (type, token) = tokenizedPaymentSource;
if (string.IsNullOrEmpty(token))
{
logger.LogError("Updated payment method for ({SubscriberID}) must contain a token", subscriber.Id);
throw new BillingException();
}
// ReSharper disable once SwitchStatementHandlesSomeKnownEnumValuesWithDefault
switch (type)
{
case PaymentMethodType.BankAccount:
{
var getSetupIntentsForUpdatedPaymentMethod = stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions
{
PaymentMethod = token
});
// Find the setup intent for the incoming payment method token.
var setupIntentsForUpdatedPaymentMethod = await getSetupIntentsForUpdatedPaymentMethod;
if (setupIntentsForUpdatedPaymentMethod.Count != 1)
{
logger.LogError("There were more than 1 setup intents for subscriber's ({SubscriberID}) updated payment method", subscriber.Id);
throw new BillingException();
}
var matchingSetupIntent = setupIntentsForUpdatedPaymentMethod.First();
// Store the incoming payment method's setup intent ID in the cache for the subscriber so it can be verified later.
await setupIntentCache.Set(subscriber.Id, matchingSetupIntent.Id);
// Remove the customer's other attached Stripe payment methods.
var postProcessing = new List<Task>
{
RemoveStripePaymentMethodsAsync(customer),
RemoveBraintreeCustomerIdAsync(customer)
};
await Task.WhenAll(postProcessing);
break;
}
case PaymentMethodType.Card:
{
// Remove the customer's other attached Stripe payment methods.
await RemoveStripePaymentMethodsAsync(customer);
// Attach the incoming payment method.
await stripeAdapter.AttachPaymentMethodAsync(token,
new PaymentMethodAttachOptions { Customer = subscriber.GatewayCustomerId });
var metadata = customer.Metadata;
if (metadata.TryGetValue(BraintreeCustomerIdKey, out var value))
{
metadata[BraintreeCustomerIdOldKey] = value;
metadata[BraintreeCustomerIdKey] = null;
}
// Set the customer's default payment method in Stripe and remove their Braintree customer ID.
await stripeAdapter.UpdateCustomerAsync(subscriber.GatewayCustomerId, new CustomerUpdateOptions
{
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
DefaultPaymentMethod = token
},
Metadata = metadata
});
break;
}
case PaymentMethodType.PayPal:
{
string braintreeCustomerId;
if (customer.Metadata != null)
{
var hasBraintreeCustomerId = customer.Metadata.TryGetValue(BraintreeCustomerIdKey, out braintreeCustomerId);
if (hasBraintreeCustomerId)
{
var braintreeCustomer = await braintreeGateway.Customer.FindAsync(braintreeCustomerId);
if (braintreeCustomer == null)
{
logger.LogError("Failed to retrieve Braintree customer ({BraintreeCustomerId}) when updating payment method for subscriber ({SubscriberID})", braintreeCustomerId, subscriber.Id);
throw new BillingException();
}
await ReplaceBraintreePaymentMethodAsync(braintreeCustomer, token);
return;
}
}
braintreeCustomerId = await CreateBraintreeCustomer(subscriber, token);
await AddBraintreeCustomerIdAsync(customer, braintreeCustomerId);
break;
}
default:
{
logger.LogError("Cannot update subscriber's ({SubscriberID}) payment method to type ({PaymentMethodType}) as it is not supported", subscriber.Id, type.ToString());
throw new BillingException();
}
}
}
public async Task UpdateTaxInformation(
ISubscriber subscriber,
TaxInformation taxInformation)
@@ -819,23 +693,7 @@ public class SubscriberService(
#region Shared Utilities
private async Task AddBraintreeCustomerIdAsync(
Customer customer,
string braintreeCustomerId)
{
var metadata = customer.Metadata ?? new Dictionary<string, string>();
metadata[BraintreeCustomerIdKey] = braintreeCustomerId;
await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions
{
Metadata = metadata
});
}
private async Task<PaymentSource> GetPaymentSourceAsync(
Guid subscriberId,
Customer customer)
private async Task<PaymentSource> GetPaymentSourceAsync(Customer customer)
{
if (customer.Metadata != null)
{
@@ -858,108 +716,17 @@ public class SubscriberService(
/*
* attachedPaymentMethodDTO being null represents a case where we could be looking for the SetupIntent for an unverified "us_bank_account".
* We store the ID of this SetupIntent in the cache when we originally update the payment method.
* Query Stripe for SetupIntents associated with this customer.
*/
var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(subscriberId);
if (string.IsNullOrEmpty(setupIntentId))
var setupIntents = await stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions
{
return null;
}
var setupIntent = await stripeAdapter.GetSetupIntentAsync(setupIntentId, new SetupIntentGetOptions
{
Expand = ["payment_method"]
Customer = customer.Id,
Expand = ["data.payment_method"]
});
return PaymentSource.From(setupIntent);
}
var unverifiedBankAccount = setupIntents?.FirstOrDefault(si => si.IsUnverifiedBankAccount());
private async Task RemoveBraintreeCustomerIdAsync(
Customer customer)
{
var metadata = customer.Metadata ?? new Dictionary<string, string>();
if (metadata.TryGetValue(BraintreeCustomerIdKey, out var value))
{
metadata[BraintreeCustomerIdOldKey] = value;
metadata[BraintreeCustomerIdKey] = null;
await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions
{
Metadata = metadata
});
}
}
private async Task RemoveStripePaymentMethodsAsync(
Customer customer)
{
if (customer.Sources != null && customer.Sources.Any())
{
foreach (var source in customer.Sources)
{
switch (source)
{
case BankAccount:
await stripeAdapter.DeleteBankAccountAsync(customer.Id, source.Id);
break;
case Card:
await stripeAdapter.DeleteCardAsync(customer.Id, source.Id);
break;
}
}
}
var paymentMethods = await stripeAdapter.ListCustomerPaymentMethodsAsync(customer.Id);
await Task.WhenAll(paymentMethods.Select(pm => stripeAdapter.DetachPaymentMethodAsync(pm.Id)));
}
private async Task ReplaceBraintreePaymentMethodAsync(
Braintree.Customer customer,
string defaultPaymentMethodToken)
{
var existingDefaultPaymentMethod = customer.DefaultPaymentMethod;
var createPaymentMethodResult = await braintreeGateway.PaymentMethod.CreateAsync(new PaymentMethodRequest
{
CustomerId = customer.Id,
PaymentMethodNonce = defaultPaymentMethodToken
});
if (!createPaymentMethodResult.IsSuccess())
{
logger.LogError("Failed to replace payment method for Braintree customer ({ID}) - Creation of new payment method failed | Error: {Error}", customer.Id, createPaymentMethodResult.Message);
throw new BillingException();
}
var updateCustomerResult = await braintreeGateway.Customer.UpdateAsync(
customer.Id,
new CustomerRequest { DefaultPaymentMethodToken = createPaymentMethodResult.Target.Token });
if (!updateCustomerResult.IsSuccess())
{
logger.LogError("Failed to replace payment method for Braintree customer ({ID}) - Customer update failed | Error: {Error}",
customer.Id, updateCustomerResult.Message);
await braintreeGateway.PaymentMethod.DeleteAsync(createPaymentMethodResult.Target.Token);
throw new BillingException();
}
if (existingDefaultPaymentMethod != null)
{
var deletePaymentMethodResult = await braintreeGateway.PaymentMethod.DeleteAsync(existingDefaultPaymentMethod.Token);
if (!deletePaymentMethodResult.IsSuccess())
{
logger.LogWarning(
"Failed to delete replaced payment method for Braintree customer ({ID}) - outdated payment method still exists | Error: {Error}",
customer.Id, deletePaymentMethodResult.Message);
}
}
return unverifiedBankAccount != null ? PaymentSource.From(unverifiedBankAccount) : null;
}
#endregion

View File

@@ -10,6 +10,8 @@ namespace Bit.Core.Repositories;
public interface IUserRepository : IRepository<User, Guid>
{
Task<User?> GetByGatewayCustomerIdAsync(string gatewayCustomerId);
Task<User?> GetByGatewaySubscriptionIdAsync(string gatewaySubscriptionId);
Task<User?> GetByEmailAsync(string email);
Task<IEnumerable<User>> GetManyByEmailsAsync(IEnumerable<string> emails);
Task<User?> GetBySsoUserAsync(string externalId, Guid? organizationId);

View File

@@ -41,12 +41,8 @@ public interface IUserService
Task<IdentityResult> DeleteAsync(User user);
Task<IdentityResult> DeleteAsync(User user, string token);
Task SendDeleteConfirmationAsync(string email);
Task<Tuple<bool, string>> SignUpPremiumAsync(User user, string paymentToken,
PaymentMethodType paymentMethodType, short additionalStorageGb, UserLicense license,
TaxInfo taxInfo);
Task UpdateLicenseAsync(User user, UserLicense license);
Task<string> AdjustStorageAsync(User user, short storageAdjustmentGb);
Task ReplacePaymentMethodAsync(User user, string paymentToken, PaymentMethodType paymentMethodType, TaxInfo taxInfo);
Task CancelPremiumAsync(User user, bool? endOfPeriod = null);
Task ReinstatePremiumAsync(User user);
Task EnablePremiumAsync(Guid userId, DateTime? expirationDate);

View File

@@ -15,13 +15,10 @@ using Bit.Core.Auth.Enums;
using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces;
using Bit.Core.Billing.Licenses;
using Bit.Core.Billing.Licenses.Extensions;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Business;
using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Premium.Queries;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Context;
using Bit.Core.Entities;
using Bit.Core.Enums;
@@ -68,7 +65,6 @@ public class UserService : UserManager<User>, IUserService
private readonly IProviderUserRepository _providerUserRepository;
private readonly IStripeSyncService _stripeSyncService;
private readonly IFeatureService _featureService;
private readonly IPremiumUserBillingService _premiumUserBillingService;
private readonly IRevokeNonCompliantOrganizationUserCommand _revokeNonCompliantOrganizationUserCommand;
private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery;
private readonly IDistributedCache _distributedCache;
@@ -105,7 +101,6 @@ public class UserService : UserManager<User>, IUserService
IProviderUserRepository providerUserRepository,
IStripeSyncService stripeSyncService,
IFeatureService featureService,
IPremiumUserBillingService premiumUserBillingService,
IRevokeNonCompliantOrganizationUserCommand revokeNonCompliantOrganizationUserCommand,
ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery,
IDistributedCache distributedCache,
@@ -146,7 +141,6 @@ public class UserService : UserManager<User>, IUserService
_providerUserRepository = providerUserRepository;
_stripeSyncService = stripeSyncService;
_featureService = featureService;
_premiumUserBillingService = premiumUserBillingService;
_revokeNonCompliantOrganizationUserCommand = revokeNonCompliantOrganizationUserCommand;
_twoFactorIsEnabledQuery = twoFactorIsEnabledQuery;
_distributedCache = distributedCache;
@@ -742,78 +736,6 @@ public class UserService : UserManager<User>, IUserService
return true;
}
public async Task<Tuple<bool, string>> SignUpPremiumAsync(User user, string paymentToken,
PaymentMethodType paymentMethodType, short additionalStorageGb, UserLicense license,
TaxInfo taxInfo)
{
if (user.Premium)
{
throw new BadRequestException("Already a premium user.");
}
if (additionalStorageGb < 0)
{
throw new BadRequestException("You can't subtract storage!");
}
string paymentIntentClientSecret = null;
IStripePaymentService paymentService = null;
if (_globalSettings.SelfHosted)
{
if (license == null || !_licenseService.VerifyLicense(license))
{
throw new BadRequestException("Invalid license.");
}
var claimsPrincipal = _licenseService.GetClaimsPrincipalFromLicense(license);
if (!license.CanUse(user, claimsPrincipal, out var exceptionMessage))
{
throw new BadRequestException(exceptionMessage);
}
var dir = $"{_globalSettings.LicenseDirectory}/user";
Directory.CreateDirectory(dir);
using var fs = File.OpenWrite(Path.Combine(dir, $"{user.Id}.json"));
await JsonSerializer.SerializeAsync(fs, license, JsonHelpers.Indented);
}
else
{
var sale = PremiumUserSale.From(user, paymentMethodType, paymentToken, taxInfo, additionalStorageGb);
await _premiumUserBillingService.Finalize(sale);
}
user.Premium = true;
user.RevisionDate = DateTime.UtcNow;
if (_globalSettings.SelfHosted)
{
user.MaxStorageGb = Constants.SelfHostedMaxStorageGb;
user.LicenseKey = license.LicenseKey;
user.PremiumExpirationDate = license.Expires;
}
else
{
user.LicenseKey = CoreHelpers.SecureRandomString(20);
}
try
{
await SaveUserAsync(user);
await _pushService.PushSyncVaultAsync(user.Id);
}
catch when (!_globalSettings.SelfHosted)
{
await paymentService.CancelAndRecoverChargesAsync(user);
throw;
}
return new Tuple<bool, string>(string.IsNullOrWhiteSpace(paymentIntentClientSecret),
paymentIntentClientSecret);
}
public async Task UpdateLicenseAsync(User user, UserLicense license)
{
if (!_globalSettings.SelfHosted)
@@ -883,20 +805,6 @@ public class UserService : UserManager<User>, IUserService
return secret;
}
public async Task ReplacePaymentMethodAsync(User user, string paymentToken, PaymentMethodType paymentMethodType, TaxInfo taxInfo)
{
if (paymentToken.StartsWith("btok_"))
{
throw new BadRequestException("Invalid token.");
}
var tokenizedPaymentSource = new TokenizedPaymentSource(paymentMethodType, paymentToken);
var taxInformation = TaxInformation.From(taxInfo);
await _premiumUserBillingService.UpdatePaymentMethod(user, tokenizedPaymentSource, taxInformation);
await SaveUserAsync(user);
}
public async Task CancelPremiumAsync(User user, bool? endOfPeriod = null)
{
var eop = endOfPeriod.GetValueOrDefault(true);

View File

@@ -445,42 +445,42 @@ The persistent cache is accessed via keyed service injection and is optimized fo
The persistent `IDistributedCache` service is appropriate for workflow state that spans multiple requests and needs automatic TTL cleanup.
```csharp
public class SetupIntentDistributedCache(
[FromKeyedServices("persistent")] IDistributedCache distributedCache) : ISetupIntentCache
public class PaymentWorkflowCache(
[FromKeyedServices("persistent")] IDistributedCache distributedCache) : IPaymentWorkflowCache
{
public async Task Set(Guid subscriberId, string setupIntentId)
public async Task SetPaymentSessionAsync(Guid userId, string sessionId)
{
// Bidirectional mapping for payment flow
var bySubscriberIdCacheKey = $"setup_intent_id_for_subscriber_id_{subscriberId}";
var bySetupIntentIdCacheKey = $"subscriber_id_for_setup_intent_id_{setupIntentId}";
var byUserIdCacheKey = $"payment_session_for_user_{userId}";
var bySessionIdCacheKey = $"user_for_payment_session_{sessionId}";
// Note: No explicit TTL set here. Cosmos DB uses container-level TTL for automatic cleanup.
// In cloud, Cosmos TTL handles expiration. In self-hosted, the cache backend manages TTL.
await Task.WhenAll(
distributedCache.SetStringAsync(bySubscriberIdCacheKey, setupIntentId),
distributedCache.SetStringAsync(bySetupIntentIdCacheKey, subscriberId.ToString()));
distributedCache.SetStringAsync(byUserIdCacheKey, sessionId),
distributedCache.SetStringAsync(bySessionIdCacheKey, userId.ToString()));
}
public async Task<string?> GetSetupIntentIdForSubscriber(Guid subscriberId)
public async Task<string?> GetPaymentSessionForUserAsync(Guid userId)
{
var cacheKey = $"setup_intent_id_for_subscriber_id_{subscriberId}";
var cacheKey = $"payment_session_for_user_{userId}";
return await distributedCache.GetStringAsync(cacheKey);
}
public async Task<Guid?> GetSubscriberIdForSetupIntent(string setupIntentId)
public async Task<Guid?> GetUserForPaymentSessionAsync(string sessionId)
{
var cacheKey = $"subscriber_id_for_setup_intent_id_{setupIntentId}";
var cacheKey = $"user_for_payment_session_{sessionId}";
var value = await distributedCache.GetStringAsync(cacheKey);
if (string.IsNullOrEmpty(value) || !Guid.TryParse(value, out var subscriberId))
if (string.IsNullOrEmpty(value) || !Guid.TryParse(value, out var userId))
{
return null;
}
return subscriberId;
return userId;
}
public async Task RemoveSetupIntentForSubscriber(Guid subscriberId)
public async Task RemovePaymentSessionForUserAsync(Guid userId)
{
var cacheKey = $"setup_intent_id_for_subscriber_id_{subscriberId}";
var cacheKey = $"payment_session_for_user_{userId}";
await distributedCache.RemoveAsync(cacheKey);
}
}

View File

@@ -27,6 +27,32 @@ public class OrganizationRepository : Repository<Organization, Guid>, IOrganizat
_logger = logger;
}
public async Task<Organization?> GetByGatewayCustomerIdAsync(string gatewayCustomerId)
{
using (var connection = new SqlConnection(ConnectionString))
{
var results = await connection.QueryAsync<Organization>(
"[dbo].[Organization_ReadByGatewayCustomerId]",
new { GatewayCustomerId = gatewayCustomerId },
commandType: CommandType.StoredProcedure);
return results.FirstOrDefault();
}
}
public async Task<Organization?> GetByGatewaySubscriptionIdAsync(string gatewaySubscriptionId)
{
using (var connection = new SqlConnection(ConnectionString))
{
var results = await connection.QueryAsync<Organization>(
"[dbo].[Organization_ReadByGatewaySubscriptionId]",
new { GatewaySubscriptionId = gatewaySubscriptionId },
commandType: CommandType.StoredProcedure);
return results.FirstOrDefault();
}
}
public async Task<Organization?> GetByIdentifierAsync(string identifier)
{
using (var connection = new SqlConnection(ConnectionString))

View File

@@ -21,6 +21,32 @@ public class ProviderRepository : Repository<Provider, Guid>, IProviderRepositor
: base(connectionString, readOnlyConnectionString)
{ }
public async Task<Provider?> GetByGatewayCustomerIdAsync(string gatewayCustomerId)
{
using (var connection = new SqlConnection(ConnectionString))
{
var results = await connection.QueryAsync<Provider>(
"[dbo].[Provider_ReadByGatewayCustomerId]",
new { GatewayCustomerId = gatewayCustomerId },
commandType: CommandType.StoredProcedure);
return results.FirstOrDefault();
}
}
public async Task<Provider?> GetByGatewaySubscriptionIdAsync(string gatewaySubscriptionId)
{
using (var connection = new SqlConnection(ConnectionString))
{
var results = await connection.QueryAsync<Provider>(
"[dbo].[Provider_ReadByGatewaySubscriptionId]",
new { GatewaySubscriptionId = gatewaySubscriptionId },
commandType: CommandType.StoredProcedure);
return results.FirstOrDefault();
}
}
public async Task<Provider?> GetByOrganizationIdAsync(Guid organizationId)
{
using (var connection = new SqlConnection(ConnectionString))

View File

@@ -35,6 +35,34 @@ public class UserRepository : Repository<User, Guid>, IUserRepository
return user;
}
public async Task<User?> GetByGatewayCustomerIdAsync(string gatewayCustomerId)
{
using (var connection = new SqlConnection(ConnectionString))
{
var results = await connection.QueryAsync<User>(
"[dbo].[User_ReadByGatewayCustomerId]",
new { GatewayCustomerId = gatewayCustomerId },
commandType: CommandType.StoredProcedure);
UnprotectData(results);
return results.FirstOrDefault();
}
}
public async Task<User?> GetByGatewaySubscriptionIdAsync(string gatewaySubscriptionId)
{
using (var connection = new SqlConnection(ConnectionString))
{
var results = await connection.QueryAsync<User>(
"[dbo].[User_ReadByGatewaySubscriptionId]",
new { GatewaySubscriptionId = gatewaySubscriptionId },
commandType: CommandType.StoredProcedure);
UnprotectData(results);
return results.FirstOrDefault();
}
}
public async Task<User?> GetByEmailAsync(string email)
{
using (var connection = new SqlConnection(ConnectionString))

View File

@@ -20,6 +20,9 @@ public class OrganizationEntityTypeConfiguration : IEntityTypeConfiguration<Orga
builder.HasIndex(o => new { o.Id, o.Enabled }),
o => new { o.UseTotp, o.UsersGetPremium });
builder.HasIndex(o => o.GatewayCustomerId);
builder.HasIndex(o => o.GatewaySubscriptionId);
builder.ToTable(nameof(Organization));
}
}

View File

@@ -0,0 +1,20 @@
using Bit.Infrastructure.EntityFramework.AdminConsole.Models.Provider;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Metadata.Builders;
namespace Bit.Infrastructure.EntityFramework.AdminConsole.Configurations;
public class ProviderEntityTypeConfiguration : IEntityTypeConfiguration<Provider>
{
public void Configure(EntityTypeBuilder<Provider> builder)
{
builder
.Property(p => p.Id)
.ValueGeneratedNever();
builder.HasIndex(p => p.GatewayCustomerId);
builder.HasIndex(p => p.GatewaySubscriptionId);
builder.ToTable(nameof(Provider));
}
}

View File

@@ -31,6 +31,30 @@ public class OrganizationRepository : Repository<Core.AdminConsole.Entities.Orga
_logger = logger;
}
public async Task<Core.AdminConsole.Entities.Organization> GetByGatewayCustomerIdAsync(string gatewayCustomerId)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var organization = await GetDbSet(dbContext)
.Where(e => e.GatewayCustomerId == gatewayCustomerId)
.FirstOrDefaultAsync();
return organization;
}
}
public async Task<Core.AdminConsole.Entities.Organization> GetByGatewaySubscriptionIdAsync(string gatewaySubscriptionId)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var organization = await GetDbSet(dbContext)
.Where(e => e.GatewaySubscriptionId == gatewaySubscriptionId)
.FirstOrDefaultAsync();
return organization;
}
}
public async Task<Core.AdminConsole.Entities.Organization> GetByIdentifierAsync(string identifier)
{
using (var scope = ServiceScopeFactory.CreateScope())

View File

@@ -29,6 +29,30 @@ public class ProviderRepository : Repository<Provider, Models.Provider.Provider,
await base.DeleteAsync(provider);
}
public async Task<Provider> GetByGatewayCustomerIdAsync(string gatewayCustomerId)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var provider = await GetDbSet(dbContext)
.Where(e => e.GatewayCustomerId == gatewayCustomerId)
.FirstOrDefaultAsync();
return Mapper.Map<Provider>(provider);
}
}
public async Task<Provider> GetByGatewaySubscriptionIdAsync(string gatewaySubscriptionId)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var provider = await GetDbSet(dbContext)
.Where(e => e.GatewaySubscriptionId == gatewaySubscriptionId)
.FirstOrDefaultAsync();
return Mapper.Map<Provider>(provider);
}
}
public async Task<Provider> GetByOrganizationIdAsync(Guid organizationId)
{
using (var scope = ServiceScopeFactory.CreateScope())

View File

@@ -21,6 +21,9 @@ public class UserEntityTypeConfiguration : IEntityTypeConfiguration<User>
.HasIndex(u => new { u.Premium, u.PremiumExpirationDate, u.RenewalReminderDate })
.IsClustered(false);
builder.HasIndex(u => u.GatewayCustomerId);
builder.HasIndex(u => u.GatewaySubscriptionId);
builder.ToTable(nameof(User));
}
}

View File

@@ -20,6 +20,28 @@ public class UserRepository : Repository<Core.Entities.User, User, Guid>, IUserR
: base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Users)
{ }
public async Task<Core.Entities.User?> GetByGatewayCustomerIdAsync(string gatewayCustomerId)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var entity = await GetDbSet(dbContext)
.FirstOrDefaultAsync(e => e.GatewayCustomerId == gatewayCustomerId);
return Mapper.Map<Core.Entities.User>(entity);
}
}
public async Task<Core.Entities.User?> GetByGatewaySubscriptionIdAsync(string gatewaySubscriptionId)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var entity = await GetDbSet(dbContext)
.FirstOrDefaultAsync(e => e.GatewaySubscriptionId == gatewaySubscriptionId);
return Mapper.Map<Core.Entities.User>(entity);
}
}
public async Task<Core.Entities.User?> GetByEmailAsync(string email)
{
using (var scope = ServiceScopeFactory.CreateScope())

View File

@@ -0,0 +1,13 @@
CREATE PROCEDURE [dbo].[Organization_ReadByGatewayCustomerId]
@GatewayCustomerId VARCHAR(50)
AS
BEGIN
SET NOCOUNT ON
SELECT
*
FROM
[dbo].[OrganizationView]
WHERE
[GatewayCustomerId] = @GatewayCustomerId
END

View File

@@ -0,0 +1,13 @@
CREATE PROCEDURE [dbo].[Organization_ReadByGatewaySubscriptionId]
@GatewaySubscriptionId VARCHAR(50)
AS
BEGIN
SET NOCOUNT ON
SELECT
*
FROM
[dbo].[OrganizationView]
WHERE
[GatewaySubscriptionId] = @GatewaySubscriptionId
END

View File

@@ -0,0 +1,13 @@
CREATE PROCEDURE [dbo].[Provider_ReadByGatewayCustomerId]
@GatewayCustomerId VARCHAR(50)
AS
BEGIN
SET NOCOUNT ON
SELECT
*
FROM
[dbo].[ProviderView]
WHERE
[GatewayCustomerId] = @GatewayCustomerId
END

View File

@@ -0,0 +1,13 @@
CREATE PROCEDURE [dbo].[Provider_ReadByGatewaySubscriptionId]
@GatewaySubscriptionId VARCHAR(50)
AS
BEGIN
SET NOCOUNT ON
SELECT
*
FROM
[dbo].[ProviderView]
WHERE
[GatewaySubscriptionId] = @GatewaySubscriptionId
END

View File

@@ -0,0 +1,13 @@
CREATE PROCEDURE [dbo].[User_ReadByGatewayCustomerId]
@GatewayCustomerId VARCHAR(50)
AS
BEGIN
SET NOCOUNT ON
SELECT
*
FROM
[dbo].[UserView]
WHERE
[GatewayCustomerId] = @GatewayCustomerId
END

View File

@@ -0,0 +1,13 @@
CREATE PROCEDURE [dbo].[User_ReadByGatewaySubscriptionId]
@GatewaySubscriptionId VARCHAR(50)
AS
BEGIN
SET NOCOUNT ON
SELECT
*
FROM
[dbo].[UserView]
WHERE
[GatewaySubscriptionId] = @GatewaySubscriptionId
END

View File

@@ -76,3 +76,13 @@ GO
CREATE UNIQUE NONCLUSTERED INDEX [IX_Organization_Identifier]
ON [dbo].[Organization]([Identifier] ASC)
WHERE [Identifier] IS NOT NULL;
GO
CREATE NONCLUSTERED INDEX [IX_Organization_GatewayCustomerId]
ON [dbo].[Organization]([GatewayCustomerId])
WHERE [GatewayCustomerId] IS NOT NULL;
GO
CREATE NONCLUSTERED INDEX [IX_Organization_GatewaySubscriptionId]
ON [dbo].[Organization]([GatewaySubscriptionId])
WHERE [GatewaySubscriptionId] IS NOT NULL;

View File

@@ -21,3 +21,13 @@
[DiscountId] VARCHAR (50) NULL,
CONSTRAINT [PK_Provider] PRIMARY KEY CLUSTERED ([Id] ASC)
);
GO
CREATE NONCLUSTERED INDEX [IX_Provider_GatewayCustomerId]
ON [dbo].[Provider]([GatewayCustomerId])
WHERE [GatewayCustomerId] IS NOT NULL;
GO
CREATE NONCLUSTERED INDEX [IX_Provider_GatewaySubscriptionId]
ON [dbo].[Provider]([GatewaySubscriptionId])
WHERE [GatewaySubscriptionId] IS NOT NULL;

View File

@@ -62,3 +62,12 @@ GO
CREATE NONCLUSTERED INDEX [IX_User_Id_EmailDomain]
ON [dbo].[User]([Id] ASC, [Email] ASC);
GO
CREATE NONCLUSTERED INDEX [IX_User_GatewayCustomerId]
ON [dbo].[User]([GatewayCustomerId])
WHERE [GatewayCustomerId] IS NOT NULL;
GO
CREATE NONCLUSTERED INDEX [IX_User_GatewaySubscriptionId]
ON [dbo].[User]([GatewaySubscriptionId])
WHERE [GatewaySubscriptionId] IS NOT NULL;