1
0
mirror of https://github.com/bitwarden/server synced 2025-12-15 07:43:54 +00:00

[PM-25463] Work towards complete usage of Payments domain (#6363)

* Use payment domain

* Run dotnet format and remove unused code

* Fix swagger

* Stephon's feedback

* Run dotnet format
This commit is contained in:
Alex Morask
2025-10-01 10:26:39 -05:00
committed by GitHub
parent 7cefca330b
commit 61265c7533
46 changed files with 2988 additions and 1350 deletions

View File

@@ -12,7 +12,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Organizations;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services; using Bit.Core.AdminConsole.Services;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Providers.Services; using Bit.Core.Billing.Providers.Services;
using Bit.Core.Context; using Bit.Core.Context;
@@ -90,7 +90,7 @@ public class ProviderService : IProviderService
_providerClientOrganizationSignUpCommand = providerClientOrganizationSignUpCommand; _providerClientOrganizationSignUpCommand = providerClientOrganizationSignUpCommand;
} }
public async Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo, TokenizedPaymentSource tokenizedPaymentSource = null) public async Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TokenizedPaymentMethod paymentMethod, BillingAddress billingAddress)
{ {
var owner = await _userService.GetUserByIdAsync(ownerUserId); var owner = await _userService.GetUserByIdAsync(ownerUserId);
if (owner == null) if (owner == null)
@@ -115,21 +115,7 @@ public class ProviderService : IProviderService
throw new BadRequestException("Invalid owner."); throw new BadRequestException("Invalid owner.");
} }
if (taxInfo == null || string.IsNullOrEmpty(taxInfo.BillingAddressCountry) || string.IsNullOrEmpty(taxInfo.BillingAddressPostalCode)) var customer = await _providerBillingService.SetupCustomer(provider, paymentMethod, billingAddress);
{
throw new BadRequestException("Both address and postal code are required to set up your provider.");
}
if (tokenizedPaymentSource is not
{
Type: PaymentMethodType.BankAccount or PaymentMethodType.Card or PaymentMethodType.PayPal,
Token: not null and not ""
})
{
throw new BadRequestException("A payment method is required to set up your provider.");
}
var customer = await _providerBillingService.SetupCustomer(provider, taxInfo, tokenizedPaymentSource);
provider.GatewayCustomerId = customer.Id; provider.GatewayCustomerId = customer.Id;
var subscription = await _providerBillingService.SetupSubscription(provider); var subscription = await _providerBillingService.SetupSubscription(provider);
provider.GatewaySubscriptionId = subscription.Id; provider.GatewaySubscriptionId = subscription.Id;

View File

@@ -14,6 +14,7 @@ using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Providers.Entities; using Bit.Core.Billing.Providers.Entities;
using Bit.Core.Billing.Providers.Models; using Bit.Core.Billing.Providers.Models;
@@ -21,10 +22,8 @@ using Bit.Core.Billing.Providers.Repositories;
using Bit.Core.Billing.Providers.Services; using Bit.Core.Billing.Providers.Services;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Billing.Tax.Models; using Bit.Core.Billing.Tax.Models;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Models.Business;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Settings; using Bit.Core.Settings;
@@ -38,6 +37,9 @@ using Subscription = Stripe.Subscription;
namespace Bit.Commercial.Core.Billing.Providers.Services; namespace Bit.Commercial.Core.Billing.Providers.Services;
using static Constants;
using static StripeConstants;
public class ProviderBillingService( public class ProviderBillingService(
IBraintreeGateway braintreeGateway, IBraintreeGateway braintreeGateway,
IEventService eventService, IEventService eventService,
@@ -51,8 +53,7 @@ public class ProviderBillingService(
IProviderUserRepository providerUserRepository, IProviderUserRepository providerUserRepository,
ISetupIntentCache setupIntentCache, ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter, IStripeAdapter stripeAdapter,
ISubscriberService subscriberService, ISubscriberService subscriberService)
ITaxService taxService)
: IProviderBillingService : IProviderBillingService
{ {
public async Task AddExistingOrganization( public async Task AddExistingOrganization(
@@ -61,10 +62,7 @@ public class ProviderBillingService(
string key) string key)
{ {
await stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId, await stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId,
new SubscriptionUpdateOptions new SubscriptionUpdateOptions { CancelAtPeriodEnd = false });
{
CancelAtPeriodEnd = false
});
var subscription = var subscription =
await stripeAdapter.SubscriptionCancelAsync(organization.GatewaySubscriptionId, await stripeAdapter.SubscriptionCancelAsync(organization.GatewaySubscriptionId,
@@ -83,7 +81,7 @@ public class ProviderBillingService(
var wasTrialing = subscription.TrialEnd.HasValue && subscription.TrialEnd.Value > now; var wasTrialing = subscription.TrialEnd.HasValue && subscription.TrialEnd.Value > now;
if (!wasTrialing && subscription.LatestInvoice.Status == StripeConstants.InvoiceStatus.Draft) if (!wasTrialing && subscription.LatestInvoice.Status == InvoiceStatus.Draft)
{ {
await stripeAdapter.InvoiceFinalizeInvoiceAsync(subscription.LatestInvoiceId, await stripeAdapter.InvoiceFinalizeInvoiceAsync(subscription.LatestInvoiceId,
new InvoiceFinalizeOptions { AutoAdvance = true }); new InvoiceFinalizeOptions { AutoAdvance = true });
@@ -184,16 +182,8 @@ public class ProviderBillingService(
{ {
Items = Items =
[ [
new SubscriptionItemOptions new SubscriptionItemOptions { Price = newPriceId, Quantity = oldSubscriptionItem!.Quantity },
{ new SubscriptionItemOptions { Id = oldSubscriptionItem.Id, Deleted = true }
Price = newPriceId,
Quantity = oldSubscriptionItem!.Quantity
},
new SubscriptionItemOptions
{
Id = oldSubscriptionItem.Id,
Deleted = true
}
] ]
}; };
@@ -202,7 +192,8 @@ public class ProviderBillingService(
// Refactor later to ?ChangeClientPlanCommand? (ProviderPlanId, ProviderId, OrganizationId) // Refactor later to ?ChangeClientPlanCommand? (ProviderPlanId, ProviderId, OrganizationId)
// 1. Retrieve PlanType and PlanName for ProviderPlan // 1. Retrieve PlanType and PlanName for ProviderPlan
// 2. Assign PlanType & PlanName to Organization // 2. Assign PlanType & PlanName to Organization
var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerPlan.ProviderId); var providerOrganizations =
await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerPlan.ProviderId);
var newPlan = await pricingClient.GetPlanOrThrow(newPlanType); var newPlan = await pricingClient.GetPlanOrThrow(newPlanType);
@@ -213,6 +204,7 @@ public class ProviderBillingService(
{ {
throw new ConflictException($"Organization '{providerOrganization.Id}' not found."); throw new ConflictException($"Organization '{providerOrganization.Id}' not found.");
} }
organization.PlanType = newPlanType; organization.PlanType = newPlanType;
organization.Plan = newPlan.Name; organization.Plan = newPlan.Name;
await organizationRepository.ReplaceAsync(organization); await organizationRepository.ReplaceAsync(organization);
@@ -228,15 +220,15 @@ public class ProviderBillingService(
if (!string.IsNullOrEmpty(organization.GatewayCustomerId)) if (!string.IsNullOrEmpty(organization.GatewayCustomerId))
{ {
logger.LogWarning("Client organization ({ID}) already has a populated {FieldName}", organization.Id, nameof(organization.GatewayCustomerId)); logger.LogWarning("Client organization ({ID}) already has a populated {FieldName}", organization.Id,
nameof(organization.GatewayCustomerId));
return; return;
} }
var providerCustomer = await subscriberService.GetCustomerOrThrow(provider, new CustomerGetOptions var providerCustomer =
{ await subscriberService.GetCustomerOrThrow(provider,
Expand = ["tax", "tax_ids"] new CustomerGetOptions { Expand = ["tax", "tax_ids"] });
});
var providerTaxId = providerCustomer.TaxIds.FirstOrDefault(); var providerTaxId = providerCustomer.TaxIds.FirstOrDefault();
@@ -269,23 +261,18 @@ public class ProviderBillingService(
} }
] ]
}, },
Metadata = new Dictionary<string, string> Metadata = new Dictionary<string, string> { { "region", globalSettings.BaseServiceUri.CloudRegion } },
{ TaxIdData = providerTaxId == null
{ "region", globalSettings.BaseServiceUri.CloudRegion } ? null
}, :
TaxIdData = providerTaxId == null ? null :
[ [
new CustomerTaxIdDataOptions new CustomerTaxIdDataOptions { Type = providerTaxId.Type, Value = providerTaxId.Value }
{
Type = providerTaxId.Type,
Value = providerTaxId.Value
}
] ]
}; };
if (providerCustomer.Address is not { Country: Constants.CountryAbbreviations.UnitedStates }) if (providerCustomer.Address is not { Country: CountryAbbreviations.UnitedStates })
{ {
customerCreateOptions.TaxExempt = StripeConstants.TaxExempt.Reverse; customerCreateOptions.TaxExempt = TaxExempt.Reverse;
} }
var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions); var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions);
@@ -347,9 +334,9 @@ public class ProviderBillingService(
.Where(pair => pair.subscription is .Where(pair => pair.subscription is
{ {
Status: Status:
StripeConstants.SubscriptionStatus.Active or SubscriptionStatus.Active or
StripeConstants.SubscriptionStatus.Trialing or SubscriptionStatus.Trialing or
StripeConstants.SubscriptionStatus.PastDue SubscriptionStatus.PastDue
}).ToList(); }).ToList();
if (active.Count == 0) if (active.Count == 0)
@@ -474,37 +461,27 @@ public class ProviderBillingService(
// Below the limit to above the limit // Below the limit to above the limit
(currentlyAssignedSeatTotal <= seatMinimum && newlyAssignedSeatTotal > seatMinimum) || (currentlyAssignedSeatTotal <= seatMinimum && newlyAssignedSeatTotal > seatMinimum) ||
// Above the limit to further above the limit // Above the limit to further above the limit
(currentlyAssignedSeatTotal > seatMinimum && newlyAssignedSeatTotal > seatMinimum && newlyAssignedSeatTotal > currentlyAssignedSeatTotal); (currentlyAssignedSeatTotal > seatMinimum && newlyAssignedSeatTotal > seatMinimum &&
newlyAssignedSeatTotal > currentlyAssignedSeatTotal);
} }
public async Task<Customer> SetupCustomer( public async Task<Customer> SetupCustomer(
Provider provider, Provider provider,
TaxInfo taxInfo, TokenizedPaymentMethod paymentMethod,
TokenizedPaymentSource tokenizedPaymentSource) BillingAddress billingAddress)
{ {
ArgumentNullException.ThrowIfNull(tokenizedPaymentSource);
if (taxInfo is not
{
BillingAddressCountry: not null and not "",
BillingAddressPostalCode: not null and not ""
})
{
logger.LogError("Cannot create customer for provider ({ProviderID}) without both a country and postal code", provider.Id);
throw new BillingException();
}
var options = new CustomerCreateOptions var options = new CustomerCreateOptions
{ {
Address = new AddressOptions Address = new AddressOptions
{ {
Country = taxInfo.BillingAddressCountry, Country = billingAddress.Country,
PostalCode = taxInfo.BillingAddressPostalCode, PostalCode = billingAddress.PostalCode,
Line1 = taxInfo.BillingAddressLine1, Line1 = billingAddress.Line1,
Line2 = taxInfo.BillingAddressLine2, Line2 = billingAddress.Line2,
City = taxInfo.BillingAddressCity, City = billingAddress.City,
State = taxInfo.BillingAddressState State = billingAddress.State
}, },
Coupon = !string.IsNullOrEmpty(provider.DiscountId) ? provider.DiscountId : null,
Description = provider.DisplayBusinessName(), Description = provider.DisplayBusinessName(),
Email = provider.BillingEmail, Email = provider.BillingEmail,
InvoiceSettings = new CustomerInvoiceSettingsOptions InvoiceSettings = new CustomerInvoiceSettingsOptions
@@ -520,93 +497,61 @@ public class ProviderBillingService(
} }
] ]
}, },
Metadata = new Dictionary<string, string> Metadata = new Dictionary<string, string> { { "region", globalSettings.BaseServiceUri.CloudRegion } },
{ TaxExempt = billingAddress.Country != CountryAbbreviations.UnitedStates ? TaxExempt.Reverse : TaxExempt.None
{ "region", globalSettings.BaseServiceUri.CloudRegion }
}
}; };
if (taxInfo.BillingAddressCountry is not Constants.CountryAbbreviations.UnitedStates) if (billingAddress.TaxId != null)
{ {
options.TaxExempt = StripeConstants.TaxExempt.Reverse;
}
if (!string.IsNullOrEmpty(taxInfo.TaxIdNumber))
{
var taxIdType = taxService.GetStripeTaxCode(
taxInfo.BillingAddressCountry,
taxInfo.TaxIdNumber);
if (taxIdType == null)
{
logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.",
taxInfo.BillingAddressCountry,
taxInfo.TaxIdNumber);
throw new BadRequestException("billingTaxIdTypeInferenceError");
}
options.TaxIdData = options.TaxIdData =
[ [
new CustomerTaxIdDataOptions { Type = taxIdType, Value = taxInfo.TaxIdNumber } new CustomerTaxIdDataOptions { Type = billingAddress.TaxId.Code, Value = billingAddress.TaxId.Value }
]; ];
if (taxIdType == StripeConstants.TaxIdType.SpanishNIF) if (billingAddress.TaxId.Code == TaxIdType.SpanishNIF)
{ {
options.TaxIdData.Add(new CustomerTaxIdDataOptions options.TaxIdData.Add(new CustomerTaxIdDataOptions
{ {
Type = StripeConstants.TaxIdType.EUVAT, Type = TaxIdType.EUVAT,
Value = $"ES{taxInfo.TaxIdNumber}" Value = $"ES{billingAddress.TaxId.Value}"
}); });
} }
} }
if (!string.IsNullOrEmpty(provider.DiscountId))
{
options.Coupon = provider.DiscountId;
}
var braintreeCustomerId = ""; var braintreeCustomerId = "";
if (tokenizedPaymentSource is not
{
Type: PaymentMethodType.BankAccount or PaymentMethodType.Card or PaymentMethodType.PayPal,
Token: not null and not ""
})
{
logger.LogError("Cannot create customer for provider ({ProviderID}) with invalid payment method", provider.Id);
throw new BillingException();
}
var (type, token) = tokenizedPaymentSource;
// ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault // ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault
switch (type) switch (paymentMethod.Type)
{ {
case PaymentMethodType.BankAccount: case TokenizablePaymentMethodType.BankAccount:
{ {
var setupIntent = var setupIntent =
(await stripeAdapter.SetupIntentList(new SetupIntentListOptions { PaymentMethod = token })) (await stripeAdapter.SetupIntentList(new SetupIntentListOptions
{
PaymentMethod = paymentMethod.Token
}))
.FirstOrDefault(); .FirstOrDefault();
if (setupIntent == null) if (setupIntent == null)
{ {
logger.LogError("Cannot create customer for provider ({ProviderID}) without a setup intent for their bank account", provider.Id); logger.LogError(
"Cannot create customer for provider ({ProviderID}) without a setup intent for their bank account",
provider.Id);
throw new BillingException(); throw new BillingException();
} }
await setupIntentCache.Set(provider.Id, setupIntent.Id); await setupIntentCache.Set(provider.Id, setupIntent.Id);
break; break;
} }
case PaymentMethodType.Card: case TokenizablePaymentMethodType.Card:
{ {
options.PaymentMethod = token; options.PaymentMethod = paymentMethod.Token;
options.InvoiceSettings.DefaultPaymentMethod = token; options.InvoiceSettings.DefaultPaymentMethod = paymentMethod.Token;
break; break;
} }
case PaymentMethodType.PayPal: case TokenizablePaymentMethodType.PayPal:
{ {
braintreeCustomerId = await subscriberService.CreateBraintreeCustomer(provider, token); braintreeCustomerId = await subscriberService.CreateBraintreeCustomer(provider, paymentMethod.Token);
options.Metadata[BraintreeCustomerIdKey] = braintreeCustomerId; options.Metadata[BraintreeCustomerIdKey] = braintreeCustomerId;
break; break;
} }
@@ -616,8 +561,7 @@ public class ProviderBillingService(
{ {
return await stripeAdapter.CustomerCreateAsync(options); return await stripeAdapter.CustomerCreateAsync(options);
} }
catch (StripeException stripeException) when (stripeException.StripeError?.Code == catch (StripeException stripeException) when (stripeException.StripeError?.Code == ErrorCodes.TaxIdInvalid)
StripeConstants.ErrorCodes.TaxIdInvalid)
{ {
await Revert(); await Revert();
throw new BadRequestException( throw new BadRequestException(
@@ -632,9 +576,9 @@ public class ProviderBillingService(
async Task Revert() async Task Revert()
{ {
// ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault // ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault
switch (tokenizedPaymentSource.Type) switch (paymentMethod.Type)
{ {
case PaymentMethodType.BankAccount: case TokenizablePaymentMethodType.BankAccount:
{ {
var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(provider.Id); var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(provider.Id);
await stripeAdapter.SetupIntentCancel(setupIntentId, await stripeAdapter.SetupIntentCancel(setupIntentId,
@@ -642,7 +586,7 @@ public class ProviderBillingService(
await setupIntentCache.RemoveSetupIntentForSubscriber(provider.Id); await setupIntentCache.RemoveSetupIntentForSubscriber(provider.Id);
break; break;
} }
case PaymentMethodType.PayPal when !string.IsNullOrEmpty(braintreeCustomerId): case TokenizablePaymentMethodType.PayPal when !string.IsNullOrEmpty(braintreeCustomerId):
{ {
await braintreeGateway.Customer.DeleteAsync(braintreeCustomerId); await braintreeGateway.Customer.DeleteAsync(braintreeCustomerId);
break; break;
@@ -661,9 +605,10 @@ public class ProviderBillingService(
var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id);
if (providerPlans == null || providerPlans.Count == 0) if (providerPlans.Count == 0)
{ {
logger.LogError("Cannot start subscription for provider ({ProviderID}) that has no configured plans", provider.Id); logger.LogError("Cannot start subscription for provider ({ProviderID}) that has no configured plans",
provider.Id);
throw new BillingException(); throw new BillingException();
} }
@@ -676,7 +621,9 @@ public class ProviderBillingService(
if (!providerPlan.IsConfigured()) if (!providerPlan.IsConfigured())
{ {
logger.LogError("Cannot start subscription for provider ({ProviderID}) that has no configured {ProviderName} plan", provider.Id, plan.Name); logger.LogError(
"Cannot start subscription for provider ({ProviderID}) that has no configured {ProviderName} plan",
provider.Id, plan.Name);
throw new BillingException(); throw new BillingException();
} }
@@ -692,16 +639,14 @@ public class ProviderBillingService(
var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(provider.Id); var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(provider.Id);
var setupIntent = !string.IsNullOrEmpty(setupIntentId) var setupIntent = !string.IsNullOrEmpty(setupIntentId)
? await stripeAdapter.SetupIntentGet(setupIntentId, new SetupIntentGetOptions ? await stripeAdapter.SetupIntentGet(setupIntentId,
{ new SetupIntentGetOptions { Expand = ["payment_method"] })
Expand = ["payment_method"]
})
: null; : null;
var usePaymentMethod = var usePaymentMethod =
!string.IsNullOrEmpty(customer.InvoiceSettings?.DefaultPaymentMethodId) || !string.IsNullOrEmpty(customer.InvoiceSettings?.DefaultPaymentMethodId) ||
(customer.Metadata?.ContainsKey(BraintreeCustomerIdKey) == true) || customer.Metadata?.ContainsKey(BraintreeCustomerIdKey) == true ||
(setupIntent?.IsUnverifiedBankAccount() == true); setupIntent?.IsUnverifiedBankAccount() == true;
int? trialPeriodDays = provider.Type switch int? trialPeriodDays = provider.Type switch
{ {
@@ -712,30 +657,28 @@ public class ProviderBillingService(
var subscriptionCreateOptions = new SubscriptionCreateOptions var subscriptionCreateOptions = new SubscriptionCreateOptions
{ {
CollectionMethod = usePaymentMethod ? CollectionMethod =
StripeConstants.CollectionMethod.ChargeAutomatically : StripeConstants.CollectionMethod.SendInvoice, usePaymentMethod
? CollectionMethod.ChargeAutomatically
: CollectionMethod.SendInvoice,
Customer = customer.Id, Customer = customer.Id,
DaysUntilDue = usePaymentMethod ? null : 30, DaysUntilDue = usePaymentMethod ? null : 30,
Items = subscriptionItemOptionsList, Items = subscriptionItemOptionsList,
Metadata = new Dictionary<string, string> Metadata = new Dictionary<string, string> { { "providerId", provider.Id.ToString() } },
{
{ "providerId", provider.Id.ToString() }
},
OffSession = true, OffSession = true,
ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations, ProrationBehavior = ProrationBehavior.CreateProrations,
TrialPeriodDays = trialPeriodDays TrialPeriodDays = trialPeriodDays,
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
}; };
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
try try
{ {
var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
if (subscription is if (subscription is
{ {
Status: StripeConstants.SubscriptionStatus.Active or StripeConstants.SubscriptionStatus.Trialing Status: SubscriptionStatus.Active or SubscriptionStatus.Trialing
}) })
{ {
return subscription; return subscription;
@@ -749,9 +692,11 @@ public class ProviderBillingService(
throw new BillingException(); throw new BillingException();
} }
catch (StripeException stripeException) when (stripeException.StripeError?.Code == StripeConstants.ErrorCodes.CustomerTaxLocationInvalid) catch (StripeException stripeException) when (stripeException.StripeError?.Code ==
ErrorCodes.CustomerTaxLocationInvalid)
{ {
throw new BadRequestException("Your location wasn't recognized. Please ensure your country and postal code are valid."); throw new BadRequestException(
"Your location wasn't recognized. Please ensure your country and postal code are valid.");
} }
} }
@@ -765,7 +710,7 @@ public class ProviderBillingService(
subscriberService.UpdateTaxInformation(provider, taxInformation)); subscriberService.UpdateTaxInformation(provider, taxInformation));
await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId,
new SubscriptionUpdateOptions { CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically }); new SubscriptionUpdateOptions { CollectionMethod = CollectionMethod.ChargeAutomatically });
} }
public async Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command) public async Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command)
@@ -865,13 +810,9 @@ public class ProviderBillingService(
await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions
{ {
Items = [ Items =
new SubscriptionItemOptions [
{ new SubscriptionItemOptions { Id = item.Id, Price = priceId, Quantity = newlySubscribedSeats }
Id = item.Id,
Price = priceId,
Quantity = newlySubscribedSeats
}
] ]
}); });
@@ -894,7 +835,8 @@ public class ProviderBillingService(
var plan = await pricingClient.GetPlanOrThrow(planType); var plan = await pricingClient.GetPlanOrThrow(planType);
return providerOrganizations return providerOrganizations
.Where(providerOrganization => providerOrganization.Plan == plan.Name && providerOrganization.Status == OrganizationStatusType.Managed) .Where(providerOrganization => providerOrganization.Plan == plan.Name &&
providerOrganization.Status == OrganizationStatusType.Managed)
.Sum(providerOrganization => providerOrganization.Seats ?? 0); .Sum(providerOrganization => providerOrganization.Seats ?? 0);
} }

View File

@@ -9,7 +9,7 @@ using Bit.Core.AdminConsole.Models.Data.Provider;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Providers.Services; using Bit.Core.Billing.Providers.Services;
using Bit.Core.Context; using Bit.Core.Context;
@@ -41,7 +41,7 @@ public class ProviderServiceTests
public async Task CompleteSetupAsync_UserIdIsInvalid_Throws(SutProvider<ProviderService> sutProvider) public async Task CompleteSetupAsync_UserIdIsInvalid_Throws(SutProvider<ProviderService> sutProvider)
{ {
var exception = await Assert.ThrowsAsync<BadRequestException>( var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.CompleteSetupAsync(default, default, default, default, null)); () => sutProvider.Sut.CompleteSetupAsync(default, default, default, default, null, null));
Assert.Contains("Invalid owner.", exception.Message); Assert.Contains("Invalid owner.", exception.Message);
} }
@@ -53,83 +53,12 @@ public class ProviderServiceTests
userService.GetUserByIdAsync(user.Id).Returns(user); userService.GetUserByIdAsync(user.Id).Returns(user);
var exception = await Assert.ThrowsAsync<BadRequestException>( var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.CompleteSetupAsync(provider, user.Id, default, default, null)); () => sutProvider.Sut.CompleteSetupAsync(provider, user.Id, default, default, null, null));
Assert.Contains("Invalid token.", exception.Message); Assert.Contains("Invalid token.", exception.Message);
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task CompleteSetupAsync_InvalidTaxInfo_ThrowsBadRequestException( public async Task CompleteSetupAsync_Success(User user, Provider provider, string key, TokenizedPaymentMethod tokenizedPaymentMethod, BillingAddress billingAddress,
User user,
Provider provider,
string key,
TaxInfo taxInfo,
TokenizedPaymentSource tokenizedPaymentSource,
[ProviderUser] ProviderUser providerUser,
SutProvider<ProviderService> sutProvider)
{
providerUser.ProviderId = provider.Id;
providerUser.UserId = user.Id;
var userService = sutProvider.GetDependency<IUserService>();
userService.GetUserByIdAsync(user.Id).Returns(user);
var providerUserRepository = sutProvider.GetDependency<IProviderUserRepository>();
providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id).Returns(providerUser);
var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName");
var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector");
sutProvider.GetDependency<IDataProtectionProvider>().CreateProtector("ProviderServiceDataProtector")
.Returns(protector);
sutProvider.Create();
var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}");
taxInfo.BillingAddressCountry = null;
var exception = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, taxInfo, tokenizedPaymentSource));
Assert.Equal("Both address and postal code are required to set up your provider.", exception.Message);
}
[Theory, BitAutoData]
public async Task CompleteSetupAsync_InvalidTokenizedPaymentSource_ThrowsBadRequestException(
User user,
Provider provider,
string key,
TaxInfo taxInfo,
TokenizedPaymentSource tokenizedPaymentSource,
[ProviderUser] ProviderUser providerUser,
SutProvider<ProviderService> sutProvider)
{
providerUser.ProviderId = provider.Id;
providerUser.UserId = user.Id;
var userService = sutProvider.GetDependency<IUserService>();
userService.GetUserByIdAsync(user.Id).Returns(user);
var providerUserRepository = sutProvider.GetDependency<IProviderUserRepository>();
providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id).Returns(providerUser);
var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName");
var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector");
sutProvider.GetDependency<IDataProtectionProvider>().CreateProtector("ProviderServiceDataProtector")
.Returns(protector);
sutProvider.Create();
var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}");
tokenizedPaymentSource = tokenizedPaymentSource with { Type = PaymentMethodType.BitPay };
var exception = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, taxInfo, tokenizedPaymentSource));
Assert.Equal("A payment method is required to set up your provider.", exception.Message);
}
[Theory, BitAutoData]
public async Task CompleteSetupAsync_Success(User user, Provider provider, string key, TaxInfo taxInfo, TokenizedPaymentSource tokenizedPaymentSource,
[ProviderUser] ProviderUser providerUser, [ProviderUser] ProviderUser providerUser,
SutProvider<ProviderService> sutProvider) SutProvider<ProviderService> sutProvider)
{ {
@@ -149,7 +78,7 @@ public class ProviderServiceTests
var providerBillingService = sutProvider.GetDependency<IProviderBillingService>(); var providerBillingService = sutProvider.GetDependency<IProviderBillingService>();
var customer = new Customer { Id = "customer_id" }; var customer = new Customer { Id = "customer_id" };
providerBillingService.SetupCustomer(provider, taxInfo, tokenizedPaymentSource).Returns(customer); providerBillingService.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress).Returns(customer);
var subscription = new Subscription { Id = "subscription_id" }; var subscription = new Subscription { Id = "subscription_id" };
providerBillingService.SetupSubscription(provider).Returns(subscription); providerBillingService.SetupSubscription(provider).Returns(subscription);
@@ -158,7 +87,7 @@ public class ProviderServiceTests
var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}");
await sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, taxInfo, tokenizedPaymentSource); await sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, tokenizedPaymentMethod, billingAddress);
await sutProvider.GetDependency<IProviderRepository>().Received().UpsertAsync(Arg.Is<Provider>( await sutProvider.GetDependency<IProviderRepository>().Received().UpsertAsync(Arg.Is<Provider>(
p => p =>

View File

@@ -1,5 +1,4 @@
using System.Globalization; using System.Globalization;
using System.Net;
using Bit.Commercial.Core.Billing.Providers.Models; using Bit.Commercial.Core.Billing.Providers.Models;
using Bit.Commercial.Core.Billing.Providers.Services; using Bit.Commercial.Core.Billing.Providers.Services;
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
@@ -10,18 +9,16 @@ using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Caches; using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Providers.Entities; using Bit.Core.Billing.Providers.Entities;
using Bit.Core.Billing.Providers.Models; using Bit.Core.Billing.Providers.Models;
using Bit.Core.Billing.Providers.Repositories; using Bit.Core.Billing.Providers.Repositories;
using Bit.Core.Billing.Providers.Services; using Bit.Core.Billing.Providers.Services;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Models.Business;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Settings; using Bit.Core.Settings;
@@ -895,118 +892,53 @@ public class ProviderBillingServiceTests
#region SetupCustomer #region SetupCustomer
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task SetupCustomer_MissingCountry_ContactSupport( public async Task SetupCustomer_NullPaymentMethod_ThrowsNullReferenceException(
SutProvider<ProviderBillingService> sutProvider, SutProvider<ProviderBillingService> sutProvider,
Provider provider, Provider provider,
TaxInfo taxInfo, BillingAddress billingAddress)
TokenizedPaymentSource tokenizedPaymentSource)
{ {
taxInfo.BillingAddressCountry = null; await Assert.ThrowsAsync<NullReferenceException>(() =>
sutProvider.Sut.SetupCustomer(provider, null, billingAddress));
await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource));
await sutProvider.GetDependency<IStripeAdapter>()
.DidNotReceiveWithAnyArgs()
.CustomerGetAsync(Arg.Any<string>(), Arg.Any<CustomerGetOptions>());
}
[Theory, BitAutoData]
public async Task SetupCustomer_MissingPostalCode_ContactSupport(
SutProvider<ProviderBillingService> sutProvider,
Provider provider,
TaxInfo taxInfo,
TokenizedPaymentSource tokenizedPaymentSource)
{
taxInfo.BillingAddressCountry = null;
await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource));
await sutProvider.GetDependency<IStripeAdapter>()
.DidNotReceiveWithAnyArgs()
.CustomerGetAsync(Arg.Any<string>(), Arg.Any<CustomerGetOptions>());
}
[Theory, BitAutoData]
public async Task SetupCustomer_NullPaymentSource_ThrowsArgumentNullException(
SutProvider<ProviderBillingService> sutProvider,
Provider provider,
TaxInfo taxInfo)
{
await Assert.ThrowsAsync<ArgumentNullException>(() =>
sutProvider.Sut.SetupCustomer(provider, taxInfo, null));
}
[Theory, BitAutoData]
public async Task SetupCustomer_InvalidRequiredPaymentMethod_ThrowsBillingException(
SutProvider<ProviderBillingService> sutProvider,
Provider provider,
TaxInfo taxInfo,
TokenizedPaymentSource tokenizedPaymentSource)
{
provider.Name = "MSP";
sutProvider.GetDependency<ITaxService>()
.GetStripeTaxCode(Arg.Is<string>(
p => p == taxInfo.BillingAddressCountry),
Arg.Is<string>(p => p == taxInfo.TaxIdNumber))
.Returns(taxInfo.TaxIdType);
taxInfo.BillingAddressCountry = "AD";
tokenizedPaymentSource = tokenizedPaymentSource with { Type = PaymentMethodType.BitPay };
await ThrowsBillingExceptionAsync(() =>
sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource));
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task SetupCustomer_WithBankAccount_Error_Reverts( public async Task SetupCustomer_WithBankAccount_Error_Reverts(
SutProvider<ProviderBillingService> sutProvider, SutProvider<ProviderBillingService> sutProvider,
Provider provider, Provider provider,
TaxInfo taxInfo) BillingAddress billingAddress)
{ {
provider.Name = "MSP"; provider.Name = "MSP";
billingAddress.Country = "AD";
sutProvider.GetDependency<ITaxService>() billingAddress.TaxId = new TaxID("es_nif", "12345678Z");
.GetStripeTaxCode(Arg.Is<string>(
p => p == taxInfo.BillingAddressCountry),
Arg.Is<string>(p => p == taxInfo.TaxIdNumber))
.Returns(taxInfo.TaxIdType);
taxInfo.BillingAddressCountry = "AD";
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>(); var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.BankAccount, Token = "token" };
var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.BankAccount, "token");
stripeAdapter.SetupIntentList(Arg.Is<SetupIntentListOptions>(options => stripeAdapter.SetupIntentList(Arg.Is<SetupIntentListOptions>(options =>
options.PaymentMethod == tokenizedPaymentSource.Token)).Returns([ options.PaymentMethod == tokenizedPaymentMethod.Token)).Returns([
new SetupIntent { Id = "setup_intent_id" } new SetupIntent { Id = "setup_intent_id" }
]); ]);
stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o => stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o =>
o.Address.Country == taxInfo.BillingAddressCountry && o.Address.Country == billingAddress.Country &&
o.Address.PostalCode == taxInfo.BillingAddressPostalCode && o.Address.PostalCode == billingAddress.PostalCode &&
o.Address.Line1 == taxInfo.BillingAddressLine1 && o.Address.Line1 == billingAddress.Line1 &&
o.Address.Line2 == taxInfo.BillingAddressLine2 && o.Address.Line2 == billingAddress.Line2 &&
o.Address.City == taxInfo.BillingAddressCity && o.Address.City == billingAddress.City &&
o.Address.State == taxInfo.BillingAddressState && o.Address.State == billingAddress.State &&
o.Description == WebUtility.HtmlDecode(provider.BusinessName) && o.Description == provider.DisplayBusinessName() &&
o.Email == provider.BillingEmail && o.Email == provider.BillingEmail &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() &&
o.Metadata["region"] == "" && o.Metadata["region"] == "" &&
o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code &&
o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value))
.Throws<StripeException>(); .Throws<StripeException>();
sutProvider.GetDependency<ISetupIntentCache>().GetSetupIntentIdForSubscriber(provider.Id).Returns("setup_intent_id"); sutProvider.GetDependency<ISetupIntentCache>().GetSetupIntentIdForSubscriber(provider.Id).Returns("setup_intent_id");
await Assert.ThrowsAsync<StripeException>(() => await Assert.ThrowsAsync<StripeException>(() =>
sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource)); sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress));
await sutProvider.GetDependency<ISetupIntentCache>().Received(1).Set(provider.Id, "setup_intent_id"); await sutProvider.GetDependency<ISetupIntentCache>().Received(1).Set(provider.Id, "setup_intent_id");
@@ -1020,45 +952,37 @@ public class ProviderBillingServiceTests
public async Task SetupCustomer_WithPayPal_Error_Reverts( public async Task SetupCustomer_WithPayPal_Error_Reverts(
SutProvider<ProviderBillingService> sutProvider, SutProvider<ProviderBillingService> sutProvider,
Provider provider, Provider provider,
TaxInfo taxInfo) BillingAddress billingAddress)
{ {
provider.Name = "MSP"; provider.Name = "MSP";
billingAddress.Country = "AD";
sutProvider.GetDependency<ITaxService>() billingAddress.TaxId = new TaxID("es_nif", "12345678Z");
.GetStripeTaxCode(Arg.Is<string>(
p => p == taxInfo.BillingAddressCountry),
Arg.Is<string>(p => p == taxInfo.TaxIdNumber))
.Returns(taxInfo.TaxIdType);
taxInfo.BillingAddressCountry = "AD";
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>(); var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.PayPal, Token = "token" };
var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.PayPal, "token"); sutProvider.GetDependency<ISubscriberService>().CreateBraintreeCustomer(provider, tokenizedPaymentMethod.Token)
sutProvider.GetDependency<ISubscriberService>().CreateBraintreeCustomer(provider, tokenizedPaymentSource.Token)
.Returns("braintree_customer_id"); .Returns("braintree_customer_id");
stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o => stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o =>
o.Address.Country == taxInfo.BillingAddressCountry && o.Address.Country == billingAddress.Country &&
o.Address.PostalCode == taxInfo.BillingAddressPostalCode && o.Address.PostalCode == billingAddress.PostalCode &&
o.Address.Line1 == taxInfo.BillingAddressLine1 && o.Address.Line1 == billingAddress.Line1 &&
o.Address.Line2 == taxInfo.BillingAddressLine2 && o.Address.Line2 == billingAddress.Line2 &&
o.Address.City == taxInfo.BillingAddressCity && o.Address.City == billingAddress.City &&
o.Address.State == taxInfo.BillingAddressState && o.Address.State == billingAddress.State &&
o.Description == WebUtility.HtmlDecode(provider.BusinessName) && o.Description == provider.DisplayBusinessName() &&
o.Email == provider.BillingEmail && o.Email == provider.BillingEmail &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() &&
o.Metadata["region"] == "" && o.Metadata["region"] == "" &&
o.Metadata["btCustomerId"] == "braintree_customer_id" && o.Metadata["btCustomerId"] == "braintree_customer_id" &&
o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code &&
o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value))
.Throws<StripeException>(); .Throws<StripeException>();
await Assert.ThrowsAsync<StripeException>(() => await Assert.ThrowsAsync<StripeException>(() =>
sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource)); sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress));
await sutProvider.GetDependency<IBraintreeGateway>().Customer.Received(1).DeleteAsync("braintree_customer_id"); await sutProvider.GetDependency<IBraintreeGateway>().Customer.Received(1).DeleteAsync("braintree_customer_id");
} }
@@ -1067,17 +991,11 @@ public class ProviderBillingServiceTests
public async Task SetupCustomer_WithBankAccount_Success( public async Task SetupCustomer_WithBankAccount_Success(
SutProvider<ProviderBillingService> sutProvider, SutProvider<ProviderBillingService> sutProvider,
Provider provider, Provider provider,
TaxInfo taxInfo) BillingAddress billingAddress)
{ {
provider.Name = "MSP"; provider.Name = "MSP";
billingAddress.Country = "AD";
sutProvider.GetDependency<ITaxService>() billingAddress.TaxId = new TaxID("es_nif", "12345678Z");
.GetStripeTaxCode(Arg.Is<string>(
p => p == taxInfo.BillingAddressCountry),
Arg.Is<string>(p => p == taxInfo.TaxIdNumber))
.Returns(taxInfo.TaxIdType);
taxInfo.BillingAddressCountry = "AD";
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>(); var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
@@ -1087,31 +1005,30 @@ public class ProviderBillingServiceTests
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
}; };
var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.BankAccount, "token"); var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.BankAccount, Token = "token" };
stripeAdapter.SetupIntentList(Arg.Is<SetupIntentListOptions>(options => stripeAdapter.SetupIntentList(Arg.Is<SetupIntentListOptions>(options =>
options.PaymentMethod == tokenizedPaymentSource.Token)).Returns([ options.PaymentMethod == tokenizedPaymentMethod.Token)).Returns([
new SetupIntent { Id = "setup_intent_id" } new SetupIntent { Id = "setup_intent_id" }
]); ]);
stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o => stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o =>
o.Address.Country == taxInfo.BillingAddressCountry && o.Address.Country == billingAddress.Country &&
o.Address.PostalCode == taxInfo.BillingAddressPostalCode && o.Address.PostalCode == billingAddress.PostalCode &&
o.Address.Line1 == taxInfo.BillingAddressLine1 && o.Address.Line1 == billingAddress.Line1 &&
o.Address.Line2 == taxInfo.BillingAddressLine2 && o.Address.Line2 == billingAddress.Line2 &&
o.Address.City == taxInfo.BillingAddressCity && o.Address.City == billingAddress.City &&
o.Address.State == taxInfo.BillingAddressState && o.Address.State == billingAddress.State &&
o.Description == WebUtility.HtmlDecode(provider.BusinessName) && o.Description == provider.DisplayBusinessName() &&
o.Email == provider.BillingEmail && o.Email == provider.BillingEmail &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() &&
o.Metadata["region"] == "" && o.Metadata["region"] == "" &&
o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code &&
o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value))
.Returns(expected); .Returns(expected);
var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); var actual = await sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress);
Assert.Equivalent(expected, actual); Assert.Equivalent(expected, actual);
@@ -1122,17 +1039,11 @@ public class ProviderBillingServiceTests
public async Task SetupCustomer_WithPayPal_Success( public async Task SetupCustomer_WithPayPal_Success(
SutProvider<ProviderBillingService> sutProvider, SutProvider<ProviderBillingService> sutProvider,
Provider provider, Provider provider,
TaxInfo taxInfo) BillingAddress billingAddress)
{ {
provider.Name = "MSP"; provider.Name = "MSP";
billingAddress.Country = "AD";
sutProvider.GetDependency<ITaxService>() billingAddress.TaxId = new TaxID("es_nif", "12345678Z");
.GetStripeTaxCode(Arg.Is<string>(
p => p == taxInfo.BillingAddressCountry),
Arg.Is<string>(p => p == taxInfo.TaxIdNumber))
.Returns(taxInfo.TaxIdType);
taxInfo.BillingAddressCountry = "AD";
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>(); var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
@@ -1142,30 +1053,29 @@ public class ProviderBillingServiceTests
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
}; };
var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.PayPal, "token"); var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.PayPal, Token = "token" };
sutProvider.GetDependency<ISubscriberService>().CreateBraintreeCustomer(provider, tokenizedPaymentMethod.Token)
sutProvider.GetDependency<ISubscriberService>().CreateBraintreeCustomer(provider, tokenizedPaymentSource.Token)
.Returns("braintree_customer_id"); .Returns("braintree_customer_id");
stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o => stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o =>
o.Address.Country == taxInfo.BillingAddressCountry && o.Address.Country == billingAddress.Country &&
o.Address.PostalCode == taxInfo.BillingAddressPostalCode && o.Address.PostalCode == billingAddress.PostalCode &&
o.Address.Line1 == taxInfo.BillingAddressLine1 && o.Address.Line1 == billingAddress.Line1 &&
o.Address.Line2 == taxInfo.BillingAddressLine2 && o.Address.Line2 == billingAddress.Line2 &&
o.Address.City == taxInfo.BillingAddressCity && o.Address.City == billingAddress.City &&
o.Address.State == taxInfo.BillingAddressState && o.Address.State == billingAddress.State &&
o.Description == WebUtility.HtmlDecode(provider.BusinessName) && o.Description == provider.DisplayBusinessName() &&
o.Email == provider.BillingEmail && o.Email == provider.BillingEmail &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() &&
o.Metadata["region"] == "" && o.Metadata["region"] == "" &&
o.Metadata["btCustomerId"] == "braintree_customer_id" && o.Metadata["btCustomerId"] == "braintree_customer_id" &&
o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code &&
o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value))
.Returns(expected); .Returns(expected);
var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); var actual = await sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress);
Assert.Equivalent(expected, actual); Assert.Equivalent(expected, actual);
} }
@@ -1174,17 +1084,11 @@ public class ProviderBillingServiceTests
public async Task SetupCustomer_WithCard_Success( public async Task SetupCustomer_WithCard_Success(
SutProvider<ProviderBillingService> sutProvider, SutProvider<ProviderBillingService> sutProvider,
Provider provider, Provider provider,
TaxInfo taxInfo) BillingAddress billingAddress)
{ {
provider.Name = "MSP"; provider.Name = "MSP";
billingAddress.Country = "AD";
sutProvider.GetDependency<ITaxService>() billingAddress.TaxId = new TaxID("es_nif", "12345678Z");
.GetStripeTaxCode(Arg.Is<string>(
p => p == taxInfo.BillingAddressCountry),
Arg.Is<string>(p => p == taxInfo.TaxIdNumber))
.Returns(taxInfo.TaxIdType);
taxInfo.BillingAddressCountry = "AD";
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>(); var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
@@ -1194,28 +1098,26 @@ public class ProviderBillingServiceTests
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
}; };
var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "token"); var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = "token" };
stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o => stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o =>
o.Address.Country == taxInfo.BillingAddressCountry && o.Address.Country == billingAddress.Country &&
o.Address.PostalCode == taxInfo.BillingAddressPostalCode && o.Address.PostalCode == billingAddress.PostalCode &&
o.Address.Line1 == taxInfo.BillingAddressLine1 && o.Address.Line1 == billingAddress.Line1 &&
o.Address.Line2 == taxInfo.BillingAddressLine2 && o.Address.Line2 == billingAddress.Line2 &&
o.Address.City == taxInfo.BillingAddressCity && o.Address.City == billingAddress.City &&
o.Address.State == taxInfo.BillingAddressState && o.Address.State == billingAddress.State &&
o.Description == WebUtility.HtmlDecode(provider.BusinessName) && o.Description == provider.DisplayBusinessName() &&
o.Email == provider.BillingEmail && o.Email == provider.BillingEmail &&
o.PaymentMethod == tokenizedPaymentSource.Token && o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentMethod.Token &&
o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentSource.Token && o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" &&
o.Metadata["region"] == "" && o.Metadata["region"] == "" &&
o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code &&
o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value))
.Returns(expected); .Returns(expected);
var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); var actual = await sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress);
Assert.Equivalent(expected, actual); Assert.Equivalent(expected, actual);
} }
@@ -1224,17 +1126,11 @@ public class ProviderBillingServiceTests
public async Task SetupCustomer_WithCard_ReverseCharge_Success( public async Task SetupCustomer_WithCard_ReverseCharge_Success(
SutProvider<ProviderBillingService> sutProvider, SutProvider<ProviderBillingService> sutProvider,
Provider provider, Provider provider,
TaxInfo taxInfo) BillingAddress billingAddress)
{ {
provider.Name = "MSP"; provider.Name = "MSP";
billingAddress.Country = "FR"; // Non-US country to trigger reverse charge
sutProvider.GetDependency<ITaxService>() billingAddress.TaxId = new TaxID("fr_siren", "123456789");
.GetStripeTaxCode(Arg.Is<string>(
p => p == taxInfo.BillingAddressCountry),
Arg.Is<string>(p => p == taxInfo.TaxIdNumber))
.Returns(taxInfo.TaxIdType);
taxInfo.BillingAddressCountry = "AD";
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>(); var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
@@ -1244,55 +1140,51 @@ public class ProviderBillingServiceTests
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
}; };
var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "token"); var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = "token" };
stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o => stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o =>
o.Address.Country == taxInfo.BillingAddressCountry && o.Address.Country == billingAddress.Country &&
o.Address.PostalCode == taxInfo.BillingAddressPostalCode && o.Address.PostalCode == billingAddress.PostalCode &&
o.Address.Line1 == taxInfo.BillingAddressLine1 && o.Address.Line1 == billingAddress.Line1 &&
o.Address.Line2 == taxInfo.BillingAddressLine2 && o.Address.Line2 == billingAddress.Line2 &&
o.Address.City == taxInfo.BillingAddressCity && o.Address.City == billingAddress.City &&
o.Address.State == taxInfo.BillingAddressState && o.Address.State == billingAddress.State &&
o.Description == WebUtility.HtmlDecode(provider.BusinessName) && o.Description == provider.DisplayBusinessName() &&
o.Email == provider.BillingEmail && o.Email == provider.BillingEmail &&
o.PaymentMethod == tokenizedPaymentSource.Token && o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentMethod.Token &&
o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentSource.Token && o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" &&
o.Metadata["region"] == "" && o.Metadata["region"] == "" &&
o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code &&
o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber && o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value &&
o.TaxExempt == StripeConstants.TaxExempt.Reverse)) o.TaxExempt == StripeConstants.TaxExempt.Reverse))
.Returns(expected); .Returns(expected);
var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); var actual = await sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress);
Assert.Equivalent(expected, actual); Assert.Equivalent(expected, actual);
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task SetupCustomer_Throws_BadRequestException_WhenTaxIdIsInvalid( public async Task SetupCustomer_WithInvalidTaxId_ThrowsBadRequestException(
SutProvider<ProviderBillingService> sutProvider, SutProvider<ProviderBillingService> sutProvider,
Provider provider, Provider provider,
TaxInfo taxInfo, BillingAddress billingAddress)
TokenizedPaymentSource tokenizedPaymentSource)
{ {
provider.Name = "MSP"; provider.Name = "MSP";
billingAddress.Country = "AD";
billingAddress.TaxId = new TaxID("es_nif", "invalid_tax_id");
taxInfo.BillingAddressCountry = "AD"; var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = "token" };
sutProvider.GetDependency<ITaxService>() stripeAdapter.CustomerCreateAsync(Arg.Any<CustomerCreateOptions>())
.GetStripeTaxCode(Arg.Is<string>( .Throws(new StripeException("Invalid tax ID") { StripeError = new StripeError { Code = "tax_id_invalid" } });
p => p == taxInfo.BillingAddressCountry),
Arg.Is<string>(p => p == taxInfo.TaxIdNumber))
.Returns((string)null);
var actual = await Assert.ThrowsAsync<BadRequestException>(async () => var actual = await Assert.ThrowsAsync<BadRequestException>(async () =>
await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource)); await sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress));
Assert.IsType<BadRequestException>(actual); Assert.Equal("Your tax ID wasn't recognized for your selected country. Please ensure your country and tax ID are valid.", actual.Message);
Assert.Equal("billingTaxIdTypeInferenceError", actual.Message);
} }
#endregion #endregion

View File

@@ -7,7 +7,6 @@ using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services; using Bit.Core.AdminConsole.Services;
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Models.Business;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Settings; using Bit.Core.Settings;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
@@ -93,22 +92,12 @@ public class ProvidersController : Controller
var userId = _userService.GetProperUserId(User).Value; var userId = _userService.GetProperUserId(User).Value;
var taxInfo = new TaxInfo var paymentMethod = model.PaymentMethod.ToDomain();
{ var billingAddress = model.BillingAddress.ToDomain();
BillingAddressCountry = model.TaxInfo.Country,
BillingAddressPostalCode = model.TaxInfo.PostalCode,
TaxIdNumber = model.TaxInfo.TaxId,
BillingAddressLine1 = model.TaxInfo.Line1,
BillingAddressLine2 = model.TaxInfo.Line2,
BillingAddressCity = model.TaxInfo.City,
BillingAddressState = model.TaxInfo.State
};
var tokenizedPaymentSource = model.PaymentSource?.ToDomain();
var response = var response =
await _providerService.CompleteSetupAsync(model.ToProvider(provider), userId, model.Token, model.Key, await _providerService.CompleteSetupAsync(model.ToProvider(provider), userId, model.Token, model.Key,
taxInfo, tokenizedPaymentSource); paymentMethod, billingAddress);
return new ProviderResponseModel(response); return new ProviderResponseModel(response);
} }

View File

@@ -3,8 +3,7 @@
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using Bit.Api.Billing.Models.Requests; using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Api.Models.Request;
using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Utilities; using Bit.Core.Utilities;
@@ -28,8 +27,9 @@ public class ProviderSetupRequestModel
[Required] [Required]
public string Key { get; set; } public string Key { get; set; }
[Required] [Required]
public ExpandedTaxInfoUpdateRequestModel TaxInfo { get; set; } public MinimalTokenizedPaymentMethodRequest PaymentMethod { get; set; }
public TokenizedPaymentSourceRequestBody PaymentSource { get; set; } [Required]
public BillingAddressRequest BillingAddress { get; set; }
public virtual Provider ToProvider(Provider provider) public virtual Provider ToProvider(Provider provider)
{ {

View File

@@ -1,16 +1,8 @@
#nullable enable using Bit.Api.Billing.Models.Requests;
using System.Diagnostics;
using Bit.Api.AdminConsole.Models.Request.Organizations;
using Bit.Api.Billing.Models.Requests;
using Bit.Api.Billing.Models.Responses; using Bit.Api.Billing.Models.Responses;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Organizations.Models;
using Bit.Core.Billing.Organizations.Services; using Bit.Core.Billing.Organizations.Services;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Providers.Services; using Bit.Core.Billing.Providers.Services;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
@@ -28,10 +20,8 @@ public class OrganizationBillingController(
IOrganizationBillingService organizationBillingService, IOrganizationBillingService organizationBillingService,
IOrganizationRepository organizationRepository, IOrganizationRepository organizationRepository,
IPaymentService paymentService, IPaymentService paymentService,
IPricingClient pricingClient,
ISubscriberService subscriberService, ISubscriberService subscriberService,
IPaymentHistoryService paymentHistoryService, IPaymentHistoryService paymentHistoryService) : BaseBillingController
IUserService userService) : BaseBillingController
{ {
[HttpGet("metadata")] [HttpGet("metadata")]
public async Task<IResult> GetMetadataAsync([FromRoute] Guid organizationId) public async Task<IResult> GetMetadataAsync([FromRoute] Guid organizationId)
@@ -264,71 +254,6 @@ public class OrganizationBillingController(
return TypedResults.Ok(); return TypedResults.Ok();
} }
[HttpPost("restart-subscription")]
public async Task<IResult> RestartSubscriptionAsync([FromRoute] Guid organizationId,
[FromBody] OrganizationCreateRequestModel model)
{
var user = await userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
if (!await currentContext.EditPaymentMethods(organizationId))
{
return Error.Unauthorized();
}
var organization = await organizationRepository.GetByIdAsync(organizationId);
if (organization == null)
{
return Error.NotFound();
}
var existingPlan = organization.PlanType;
var organizationSignup = model.ToOrganizationSignup(user);
var sale = OrganizationSale.From(organization, organizationSignup);
var plan = await pricingClient.GetPlanOrThrow(model.PlanType);
sale.Organization.PlanType = plan.Type;
sale.Organization.Plan = plan.Name;
sale.SubscriptionSetup.SkipTrial = true;
if (existingPlan == PlanType.Free && organization.GatewaySubscriptionId is not null)
{
sale.Organization.UseTotp = plan.HasTotp;
sale.Organization.UseGroups = plan.HasGroups;
sale.Organization.UseDirectory = plan.HasDirectory;
sale.Organization.SelfHost = plan.HasSelfHost;
sale.Organization.UsersGetPremium = plan.UsersGetPremium;
sale.Organization.UseEvents = plan.HasEvents;
sale.Organization.Use2fa = plan.Has2fa;
sale.Organization.UseApi = plan.HasApi;
sale.Organization.UsePolicies = plan.HasPolicies;
sale.Organization.UseSso = plan.HasSso;
sale.Organization.UseResetPassword = plan.HasResetPassword;
sale.Organization.UseKeyConnector = plan.HasKeyConnector ? organization.UseKeyConnector : false;
sale.Organization.UseScim = plan.HasScim;
sale.Organization.UseCustomPermissions = plan.HasCustomPermissions;
sale.Organization.UseOrganizationDomains = plan.HasOrganizationDomains;
sale.Organization.MaxCollections = plan.PasswordManager.MaxCollections;
}
if (organizationSignup.PaymentMethodType == null || string.IsNullOrEmpty(organizationSignup.PaymentToken))
{
return Error.BadRequest("A payment method is required to restart the subscription.");
}
var org = await organizationRepository.GetByIdAsync(organizationId);
Debug.Assert(org is not null, "This organization has already been found via this same ID, this should be fine.");
var paymentSource = new TokenizedPaymentSource(organizationSignup.PaymentMethodType.Value, organizationSignup.PaymentToken);
var taxInformation = TaxInformation.From(organizationSignup.TaxInfo);
await organizationBillingService.Finalize(sale);
var updatedOrg = await organizationRepository.GetByIdAsync(organizationId);
if (updatedOrg != null)
{
await organizationBillingService.UpdatePaymentMethod(updatedOrg, paymentSource, taxInformation);
}
return TypedResults.Ok();
}
[HttpPost("setup-business-unit")] [HttpPost("setup-business-unit")]
[SelfHosted(NotSelfHostedOnly = true)] [SelfHosted(NotSelfHostedOnly = true)]
public async Task<IResult> SetupBusinessUnitAsync( public async Task<IResult> SetupBusinessUnitAsync(

View File

@@ -1,33 +1,73 @@
using Bit.Api.Billing.Models.Requests; using Bit.Api.Billing.Attributes;
using Bit.Core.Billing.Tax.Commands; using Bit.Api.Billing.Models.Requests.Tax;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Organizations.Commands;
using Bit.Core.Billing.Premium.Commands;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.Mvc.ModelBinding;
namespace Bit.Api.Billing.Controllers; namespace Bit.Api.Billing.Controllers;
[Authorize("Application")] [Authorize("Application")]
[Route("tax")] [Route("billing/tax")]
public class TaxController( public class TaxController(
IPreviewTaxAmountCommand previewTaxAmountCommand) : BaseBillingController IPreviewOrganizationTaxCommand previewOrganizationTaxCommand,
IPreviewPremiumTaxCommand previewPremiumTaxCommand) : BaseBillingController
{ {
[HttpPost("preview-amount/organization-trial")] [HttpPost("organizations/subscriptions/purchase")]
public async Task<IResult> PreviewTaxAmountForOrganizationTrialAsync( public async Task<IResult> PreviewOrganizationSubscriptionPurchaseTaxAsync(
[FromBody] PreviewTaxAmountForOrganizationTrialRequestBody requestBody) [FromBody] PreviewOrganizationSubscriptionPurchaseTaxRequest request)
{ {
var parameters = new OrganizationTrialParameters var (purchase, billingAddress) = request.ToDomain();
var result = await previewOrganizationTaxCommand.Run(purchase, billingAddress);
return Handle(result.Map(pair => new
{ {
PlanType = requestBody.PlanType, pair.Tax,
ProductType = requestBody.ProductType, pair.Total
TaxInformation = new OrganizationTrialParameters.TaxInformationDTO }));
{
Country = requestBody.TaxInformation.Country,
PostalCode = requestBody.TaxInformation.PostalCode,
TaxId = requestBody.TaxInformation.TaxId
} }
};
var result = await previewTaxAmountCommand.Run(parameters); [HttpPost("organizations/{organizationId:guid}/subscription/plan-change")]
[InjectOrganization]
public async Task<IResult> PreviewOrganizationSubscriptionPlanChangeTaxAsync(
[BindNever] Organization organization,
[FromBody] PreviewOrganizationSubscriptionPlanChangeTaxRequest request)
{
var (planChange, billingAddress) = request.ToDomain();
var result = await previewOrganizationTaxCommand.Run(organization, planChange, billingAddress);
return Handle(result.Map(pair => new
{
pair.Tax,
pair.Total
}));
}
return Handle(result); [HttpPut("organizations/{organizationId:guid}/subscription/update")]
[InjectOrganization]
public async Task<IResult> PreviewOrganizationSubscriptionUpdateTaxAsync(
[BindNever] Organization organization,
[FromBody] PreviewOrganizationSubscriptionUpdateTaxRequest request)
{
var update = request.ToDomain();
var result = await previewOrganizationTaxCommand.Run(organization, update);
return Handle(result.Map(pair => new
{
pair.Tax,
pair.Total
}));
}
[HttpPost("premium/subscriptions/purchase")]
public async Task<IResult> PreviewPremiumSubscriptionPurchaseTaxAsync(
[FromBody] PreviewPremiumSubscriptionPurchaseTaxRequest request)
{
var (purchase, billingAddress) = request.ToDomain();
var result = await previewPremiumTaxCommand.Run(purchase, billingAddress);
return Handle(result.Map(pair => new
{
pair.Tax,
pair.Total
}));
} }
} }

View File

@@ -1,5 +1,4 @@
#nullable enable using Bit.Api.Billing.Attributes;
using Bit.Api.Billing.Attributes;
using Bit.Api.Billing.Models.Requests.Payment; using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Api.Billing.Models.Requests.Premium; using Bit.Api.Billing.Models.Requests.Premium;
using Bit.Core; using Bit.Core;

View File

@@ -2,11 +2,14 @@
using Bit.Api.AdminConsole.Authorization.Requirements; using Bit.Api.AdminConsole.Authorization.Requirements;
using Bit.Api.Billing.Attributes; using Bit.Api.Billing.Attributes;
using Bit.Api.Billing.Models.Requests.Payment; using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Api.Billing.Models.Requests.Subscriptions;
using Bit.Api.Billing.Models.Requirements; using Bit.Api.Billing.Models.Requirements;
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Organizations.Queries; using Bit.Core.Billing.Organizations.Queries;
using Bit.Core.Billing.Payment.Commands; using Bit.Core.Billing.Payment.Commands;
using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Payment.Queries;
using Bit.Core.Billing.Subscriptions.Commands;
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
@@ -24,6 +27,7 @@ public class OrganizationBillingVNextController(
IGetCreditQuery getCreditQuery, IGetCreditQuery getCreditQuery,
IGetOrganizationWarningsQuery getOrganizationWarningsQuery, IGetOrganizationWarningsQuery getOrganizationWarningsQuery,
IGetPaymentMethodQuery getPaymentMethodQuery, IGetPaymentMethodQuery getPaymentMethodQuery,
IRestartSubscriptionCommand restartSubscriptionCommand,
IUpdateBillingAddressCommand updateBillingAddressCommand, IUpdateBillingAddressCommand updateBillingAddressCommand,
IUpdatePaymentMethodCommand updatePaymentMethodCommand) : BaseBillingController IUpdatePaymentMethodCommand updatePaymentMethodCommand) : BaseBillingController
{ {
@@ -95,6 +99,20 @@ public class OrganizationBillingVNextController(
return Handle(result); return Handle(result);
} }
[Authorize<ManageOrganizationBillingRequirement>]
[HttpPost("subscription/restart")]
[InjectOrganization]
public async Task<IResult> RestartSubscriptionAsync(
[BindNever] Organization organization,
[FromBody] RestartSubscriptionRequest request)
{
var (paymentMethod, billingAddress) = request.ToDomain();
var result = await updatePaymentMethodCommand.Run(organization, paymentMethod, null)
.AndThenAsync(_ => updateBillingAddressCommand.Run(organization, billingAddress))
.AndThenAsync(_ => restartSubscriptionCommand.Run(organization));
return Handle(result);
}
[Authorize<MemberOrProviderRequirement>] [Authorize<MemberOrProviderRequirement>]
[HttpGet("warnings")] [HttpGet("warnings")]
[InjectOrganization] [InjectOrganization]

View File

@@ -0,0 +1,31 @@
using System.ComponentModel.DataAnnotations;
using System.Text.Json.Serialization;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Organizations.Models;
namespace Bit.Api.Billing.Models.Requests.Organizations;
public record OrganizationSubscriptionPlanChangeRequest : IValidatableObject
{
[Required]
[JsonConverter(typeof(JsonStringEnumConverter))]
public ProductTierType Tier { get; set; }
[Required]
[JsonConverter(typeof(JsonStringEnumConverter))]
public PlanCadenceType Cadence { get; set; }
public OrganizationSubscriptionPlanChange ToDomain() => new()
{
Tier = Tier,
Cadence = Cadence
};
public IEnumerable<ValidationResult> Validate(ValidationContext validationContext)
{
if (Tier == ProductTierType.Families && Cadence == PlanCadenceType.Monthly)
{
yield return new ValidationResult("Monthly billing cadence is not available for the Families plan.");
}
}
}

View File

@@ -0,0 +1,84 @@
using System.ComponentModel.DataAnnotations;
using System.Text.Json.Serialization;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Organizations.Models;
namespace Bit.Api.Billing.Models.Requests.Organizations;
public record OrganizationSubscriptionPurchaseRequest : IValidatableObject
{
[Required]
[JsonConverter(typeof(JsonStringEnumConverter))]
public ProductTierType Tier { get; set; }
[Required]
[JsonConverter(typeof(JsonStringEnumConverter))]
public PlanCadenceType Cadence { get; set; }
[Required]
public required PasswordManagerPurchaseSelections PasswordManager { get; set; }
public SecretsManagerPurchaseSelections? SecretsManager { get; set; }
public OrganizationSubscriptionPurchase ToDomain() => new()
{
Tier = Tier,
Cadence = Cadence,
PasswordManager = new OrganizationSubscriptionPurchase.PasswordManagerSelections
{
Seats = PasswordManager.Seats,
AdditionalStorage = PasswordManager.AdditionalStorage,
Sponsored = PasswordManager.Sponsored
},
SecretsManager = SecretsManager != null ? new OrganizationSubscriptionPurchase.SecretsManagerSelections
{
Seats = SecretsManager.Seats,
AdditionalServiceAccounts = SecretsManager.AdditionalServiceAccounts,
Standalone = SecretsManager.Standalone
} : null
};
public IEnumerable<ValidationResult> Validate(ValidationContext validationContext)
{
if (Tier != ProductTierType.Families)
{
yield break;
}
if (Cadence == PlanCadenceType.Monthly)
{
yield return new ValidationResult("Monthly cadence is not available on the Families plan.");
}
if (SecretsManager != null)
{
yield return new ValidationResult("Secrets Manager is not available on the Families plan.");
}
}
public record PasswordManagerPurchaseSelections
{
[Required]
[Range(1, 100000, ErrorMessage = "Password Manager seats must be between 1 and 100,000")]
public int Seats { get; set; }
[Required]
[Range(0, 99, ErrorMessage = "Additional storage must be between 0 and 99 GB")]
public int AdditionalStorage { get; set; }
public bool Sponsored { get; set; } = false;
}
public record SecretsManagerPurchaseSelections
{
[Required]
[Range(1, 100000, ErrorMessage = "Secrets Manager seats must be between 1 and 100,000")]
public int Seats { get; set; }
[Required]
[Range(0, 100000, ErrorMessage = "Additional service accounts must be between 0 and 100,000")]
public int AdditionalServiceAccounts { get; set; }
public bool Standalone { get; set; } = false;
}
}

View File

@@ -0,0 +1,48 @@
using System.ComponentModel.DataAnnotations;
using Bit.Core.Billing.Organizations.Models;
namespace Bit.Api.Billing.Models.Requests.Organizations;
public record OrganizationSubscriptionUpdateRequest
{
public PasswordManagerUpdateSelections? PasswordManager { get; set; }
public SecretsManagerUpdateSelections? SecretsManager { get; set; }
public OrganizationSubscriptionUpdate ToDomain() => new()
{
PasswordManager =
PasswordManager != null
? new OrganizationSubscriptionUpdate.PasswordManagerSelections
{
Seats = PasswordManager.Seats,
AdditionalStorage = PasswordManager.AdditionalStorage
}
: null,
SecretsManager =
SecretsManager != null
? new OrganizationSubscriptionUpdate.SecretsManagerSelections
{
Seats = SecretsManager.Seats,
AdditionalServiceAccounts = SecretsManager.AdditionalServiceAccounts
}
: null
};
public record PasswordManagerUpdateSelections
{
[Range(1, 100000, ErrorMessage = "Password Manager seats must be between 1 and 100,000")]
public int? Seats { get; set; }
[Range(0, 99, ErrorMessage = "Additional storage must be between 0 and 99 GB")]
public int? AdditionalStorage { get; set; }
}
public record SecretsManagerUpdateSelections
{
[Range(0, 100000, ErrorMessage = "Secrets Manager seats must be between 0 and 100,000")]
public int? Seats { get; set; }
[Range(0, 100000, ErrorMessage = "Additional service accounts must be between 0 and 100,000")]
public int? AdditionalServiceAccounts { get; set; }
}
}

View File

@@ -1,5 +1,4 @@
#nullable enable using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Payment.Models;
namespace Bit.Api.Billing.Models.Requests.Payment; namespace Bit.Api.Billing.Models.Requests.Payment;

View File

@@ -1,5 +1,4 @@
#nullable enable using System.ComponentModel.DataAnnotations;
using System.ComponentModel.DataAnnotations;
namespace Bit.Api.Billing.Models.Requests.Payment; namespace Bit.Api.Billing.Models.Requests.Payment;

View File

@@ -1,5 +1,4 @@
#nullable enable using System.ComponentModel.DataAnnotations;
using System.ComponentModel.DataAnnotations;
using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Payment.Models;
namespace Bit.Api.Billing.Models.Requests.Payment; namespace Bit.Api.Billing.Models.Requests.Payment;

View File

@@ -1,5 +1,4 @@
#nullable enable using System.ComponentModel.DataAnnotations;
using System.ComponentModel.DataAnnotations;
using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Payment.Models;
namespace Bit.Api.Billing.Models.Requests.Payment; namespace Bit.Api.Billing.Models.Requests.Payment;

View File

@@ -1,5 +1,4 @@
#nullable enable using System.ComponentModel.DataAnnotations;
using System.ComponentModel.DataAnnotations;
using Bit.Api.Billing.Attributes; using Bit.Api.Billing.Attributes;
using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Payment.Models;
@@ -14,12 +13,9 @@ public class MinimalTokenizedPaymentMethodRequest
[Required] [Required]
public required string Token { get; set; } public required string Token { get; set; }
public TokenizedPaymentMethod ToDomain() public TokenizedPaymentMethod ToDomain() => new()
{
return new TokenizedPaymentMethod
{ {
Type = TokenizablePaymentMethodTypeExtensions.From(Type), Type = TokenizablePaymentMethodTypeExtensions.From(Type),
Token = Token Token = Token
}; };
}
} }

View File

@@ -1,31 +1,15 @@
#nullable enable using Bit.Core.Billing.Payment.Models;
using System.ComponentModel.DataAnnotations;
using Bit.Api.Billing.Attributes;
using Bit.Core.Billing.Payment.Models;
namespace Bit.Api.Billing.Models.Requests.Payment; namespace Bit.Api.Billing.Models.Requests.Payment;
public class TokenizedPaymentMethodRequest public class TokenizedPaymentMethodRequest : MinimalTokenizedPaymentMethodRequest
{ {
[Required]
[PaymentMethodTypeValidation]
public required string Type { get; set; }
[Required]
public required string Token { get; set; }
public MinimalBillingAddressRequest? BillingAddress { get; set; } public MinimalBillingAddressRequest? BillingAddress { get; set; }
public (TokenizedPaymentMethod, BillingAddress?) ToDomain() public new (TokenizedPaymentMethod, BillingAddress?) ToDomain()
{ {
var paymentMethod = new TokenizedPaymentMethod var paymentMethod = base.ToDomain();
{
Type = TokenizablePaymentMethodTypeExtensions.From(Type),
Token = Token
};
var billingAddress = BillingAddress?.ToDomain(); var billingAddress = BillingAddress?.ToDomain();
return (paymentMethod, billingAddress); return (paymentMethod, billingAddress);
} }
} }

View File

@@ -1,5 +1,4 @@
#nullable enable using System.ComponentModel.DataAnnotations;
using System.ComponentModel.DataAnnotations;
using Bit.Api.Billing.Models.Requests.Payment; using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Payment.Models;

View File

@@ -1,27 +0,0 @@
#nullable enable
using System.ComponentModel.DataAnnotations;
using Bit.Core.Billing.Enums;
namespace Bit.Api.Billing.Models.Requests;
public class PreviewTaxAmountForOrganizationTrialRequestBody
{
[Required]
public PlanType PlanType { get; set; }
[Required]
public ProductType ProductType { get; set; }
[Required] public TaxInformationDTO TaxInformation { get; set; } = null!;
public class TaxInformationDTO
{
[Required]
public string Country { get; set; } = null!;
[Required]
public string PostalCode { get; set; } = null!;
public string? TaxId { get; set; }
}
}

View File

@@ -0,0 +1,16 @@
using System.ComponentModel.DataAnnotations;
using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Core.Billing.Payment.Models;
namespace Bit.Api.Billing.Models.Requests.Subscriptions;
public class RestartSubscriptionRequest
{
[Required]
public required MinimalTokenizedPaymentMethodRequest PaymentMethod { get; set; }
[Required]
public required CheckoutBillingAddressRequest BillingAddress { get; set; }
public (TokenizedPaymentMethod, BillingAddress) ToDomain()
=> (PaymentMethod.ToDomain(), BillingAddress.ToDomain());
}

View File

@@ -0,0 +1,19 @@
using System.ComponentModel.DataAnnotations;
using Bit.Api.Billing.Models.Requests.Organizations;
using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Core.Billing.Organizations.Models;
using Bit.Core.Billing.Payment.Models;
namespace Bit.Api.Billing.Models.Requests.Tax;
public record PreviewOrganizationSubscriptionPlanChangeTaxRequest
{
[Required]
public required OrganizationSubscriptionPlanChangeRequest Plan { get; set; }
[Required]
public required CheckoutBillingAddressRequest BillingAddress { get; set; }
public (OrganizationSubscriptionPlanChange, BillingAddress) ToDomain() =>
(Plan.ToDomain(), BillingAddress.ToDomain());
}

View File

@@ -0,0 +1,19 @@
using System.ComponentModel.DataAnnotations;
using Bit.Api.Billing.Models.Requests.Organizations;
using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Core.Billing.Organizations.Models;
using Bit.Core.Billing.Payment.Models;
namespace Bit.Api.Billing.Models.Requests.Tax;
public record PreviewOrganizationSubscriptionPurchaseTaxRequest
{
[Required]
public required OrganizationSubscriptionPurchaseRequest Purchase { get; set; }
[Required]
public required CheckoutBillingAddressRequest BillingAddress { get; set; }
public (OrganizationSubscriptionPurchase, BillingAddress) ToDomain() =>
(Purchase.ToDomain(), BillingAddress.ToDomain());
}

View File

@@ -0,0 +1,11 @@
using Bit.Api.Billing.Models.Requests.Organizations;
using Bit.Core.Billing.Organizations.Models;
namespace Bit.Api.Billing.Models.Requests.Tax;
public class PreviewOrganizationSubscriptionUpdateTaxRequest
{
public required OrganizationSubscriptionUpdateRequest Update { get; set; }
public OrganizationSubscriptionUpdate ToDomain() => Update.ToDomain();
}

View File

@@ -0,0 +1,17 @@
using System.ComponentModel.DataAnnotations;
using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Core.Billing.Payment.Models;
namespace Bit.Api.Billing.Models.Requests.Tax;
public record PreviewPremiumSubscriptionPurchaseTaxRequest
{
[Required]
[Range(0, 99, ErrorMessage = "Additional storage must be between 0 and 99 GB.")]
public short AdditionalStorage { get; set; }
[Required]
public required MinimalBillingAddressRequest BillingAddress { get; set; }
public (short, BillingAddress) ToDomain() => (AdditionalStorage, BillingAddress.ToDomain());
}

View File

@@ -3,7 +3,7 @@
using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Models.Business.Provider; using Bit.Core.AdminConsole.Models.Business.Provider;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Payment.Models;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Models.Business; using Bit.Core.Models.Business;
@@ -11,8 +11,7 @@ namespace Bit.Core.AdminConsole.Services;
public interface IProviderService public interface IProviderService
{ {
Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo, Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TokenizedPaymentMethod paymentMethod, BillingAddress billingAddress);
TokenizedPaymentSource tokenizedPaymentSource = null);
Task UpdateAsync(Provider provider, bool updateBilling = false); Task UpdateAsync(Provider provider, bool updateBilling = false);
Task<List<ProviderUser>> InviteUserAsync(ProviderUserInvite<string> invite); Task<List<ProviderUser>> InviteUserAsync(ProviderUserInvite<string> invite);

View File

@@ -3,7 +3,7 @@
using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Models.Business.Provider; using Bit.Core.AdminConsole.Models.Business.Provider;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Payment.Models;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Models.Business; using Bit.Core.Models.Business;
@@ -11,7 +11,7 @@ namespace Bit.Core.AdminConsole.Services.NoopImplementations;
public class NoopProviderService : IProviderService public class NoopProviderService : IProviderService
{ {
public Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo, TokenizedPaymentSource tokenizedPaymentSource = null) => throw new NotImplementedException(); public Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TokenizedPaymentMethod paymentMethod, BillingAddress billingAddress) => throw new NotImplementedException();
public Task UpdateAsync(Provider provider, bool updateBilling = false) => throw new NotImplementedException(); public Task UpdateAsync(Provider provider, bool updateBilling = false) => throw new NotImplementedException();

View File

@@ -1,5 +1,4 @@
#nullable enable using OneOf;
using OneOf;
namespace Bit.Core.Billing.Commands; namespace Bit.Core.Billing.Commands;
@@ -20,18 +19,38 @@ public record Unhandled(Exception? Exception = null, string Response = "Somethin
/// </remarks> /// </remarks>
/// </summary> /// </summary>
/// <typeparam name="T">The successful result type of the operation.</typeparam> /// <typeparam name="T">The successful result type of the operation.</typeparam>
public class BillingCommandResult<T> : OneOfBase<T, BadRequest, Conflict, Unhandled> public class BillingCommandResult<T>(OneOf<T, BadRequest, Conflict, Unhandled> input)
: OneOfBase<T, BadRequest, Conflict, Unhandled>(input)
{ {
private BillingCommandResult(OneOf<T, BadRequest, Conflict, Unhandled> input) : base(input) { }
public static implicit operator BillingCommandResult<T>(T output) => new(output); public static implicit operator BillingCommandResult<T>(T output) => new(output);
public static implicit operator BillingCommandResult<T>(BadRequest badRequest) => new(badRequest); public static implicit operator BillingCommandResult<T>(BadRequest badRequest) => new(badRequest);
public static implicit operator BillingCommandResult<T>(Conflict conflict) => new(conflict); public static implicit operator BillingCommandResult<T>(Conflict conflict) => new(conflict);
public static implicit operator BillingCommandResult<T>(Unhandled unhandled) => new(unhandled); public static implicit operator BillingCommandResult<T>(Unhandled unhandled) => new(unhandled);
public BillingCommandResult<TResult> Map<TResult>(Func<T, TResult> f)
=> Match(
value => new BillingCommandResult<TResult>(f(value)),
badRequest => new BillingCommandResult<TResult>(badRequest),
conflict => new BillingCommandResult<TResult>(conflict),
unhandled => new BillingCommandResult<TResult>(unhandled));
public Task TapAsync(Func<T, Task> f) => Match( public Task TapAsync(Func<T, Task> f) => Match(
f, f,
_ => Task.CompletedTask, _ => Task.CompletedTask,
_ => Task.CompletedTask, _ => Task.CompletedTask,
_ => Task.CompletedTask); _ => Task.CompletedTask);
} }
public static class BillingCommandResultExtensions
{
public static async Task<BillingCommandResult<TResult>> AndThenAsync<T, TResult>(
this Task<BillingCommandResult<T>> task, Func<T, Task<BillingCommandResult<TResult>>> binder)
{
var result = await task;
return await result.Match(
binder,
badRequest => Task.FromResult(new BillingCommandResult<TResult>(badRequest)),
conflict => Task.FromResult(new BillingCommandResult<TResult>(conflict)),
unhandled => Task.FromResult(new BillingCommandResult<TResult>(unhandled)));
}
}

View File

@@ -0,0 +1,7 @@
namespace Bit.Core.Billing.Enums;
public enum PlanCadenceType
{
Annually,
Monthly
}

View File

@@ -9,7 +9,7 @@ using Bit.Core.Billing.Premium.Commands;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Implementations; using Bit.Core.Billing.Services.Implementations;
using Bit.Core.Billing.Tax.Commands; using Bit.Core.Billing.Subscriptions.Commands;
using Bit.Core.Billing.Tax.Services; using Bit.Core.Billing.Tax.Services;
using Bit.Core.Billing.Tax.Services.Implementations; using Bit.Core.Billing.Tax.Services.Implementations;
@@ -28,11 +28,12 @@ public static class ServiceCollectionExtensions
services.AddTransient<ISubscriberService, SubscriberService>(); services.AddTransient<ISubscriberService, SubscriberService>();
services.AddLicenseServices(); services.AddLicenseServices();
services.AddPricingClient(); services.AddPricingClient();
services.AddTransient<IPreviewTaxAmountCommand, PreviewTaxAmountCommand>();
services.AddPaymentOperations(); services.AddPaymentOperations();
services.AddOrganizationLicenseCommandsQueries(); services.AddOrganizationLicenseCommandsQueries();
services.AddPremiumCommands(); services.AddPremiumCommands();
services.AddTransient<IGetOrganizationWarningsQuery, GetOrganizationWarningsQuery>(); services.AddTransient<IGetOrganizationWarningsQuery, GetOrganizationWarningsQuery>();
services.AddTransient<IRestartSubscriptionCommand, RestartSubscriptionCommand>();
services.AddTransient<IPreviewOrganizationTaxCommand, PreviewOrganizationTaxCommand>();
} }
private static void AddOrganizationLicenseCommandsQueries(this IServiceCollection services) private static void AddOrganizationLicenseCommandsQueries(this IServiceCollection services)
@@ -46,5 +47,6 @@ public static class ServiceCollectionExtensions
{ {
services.AddScoped<ICreatePremiumCloudHostedSubscriptionCommand, CreatePremiumCloudHostedSubscriptionCommand>(); services.AddScoped<ICreatePremiumCloudHostedSubscriptionCommand, CreatePremiumCloudHostedSubscriptionCommand>();
services.AddScoped<ICreatePremiumSelfHostedSubscriptionCommand, CreatePremiumSelfHostedSubscriptionCommand>(); services.AddScoped<ICreatePremiumSelfHostedSubscriptionCommand, CreatePremiumSelfHostedSubscriptionCommand>();
services.AddTransient<IPreviewPremiumTaxCommand, PreviewPremiumTaxCommand>();
} }
} }

View File

@@ -0,0 +1,383 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Organizations.Models;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Pricing;
using Bit.Core.Enums;
using Bit.Core.Services;
using Bit.Core.Utilities;
using Microsoft.Extensions.Logging;
using OneOf;
using Stripe;
namespace Bit.Core.Billing.Organizations.Commands;
using static Core.Constants;
using static StripeConstants;
public interface IPreviewOrganizationTaxCommand
{
Task<BillingCommandResult<(decimal Tax, decimal Total)>> Run(
OrganizationSubscriptionPurchase purchase,
BillingAddress billingAddress);
Task<BillingCommandResult<(decimal Tax, decimal Total)>> Run(
Organization organization,
OrganizationSubscriptionPlanChange planChange,
BillingAddress billingAddress);
Task<BillingCommandResult<(decimal Tax, decimal Total)>> Run(
Organization organization,
OrganizationSubscriptionUpdate update);
}
public class PreviewOrganizationTaxCommand(
ILogger<PreviewOrganizationTaxCommand> logger,
IPricingClient pricingClient,
IStripeAdapter stripeAdapter)
: BaseBillingCommand<PreviewOrganizationTaxCommand>(logger), IPreviewOrganizationTaxCommand
{
public Task<BillingCommandResult<(decimal Tax, decimal Total)>> Run(
OrganizationSubscriptionPurchase purchase,
BillingAddress billingAddress)
=> HandleAsync<(decimal, decimal)>(async () =>
{
var plan = await pricingClient.GetPlanOrThrow(purchase.PlanType);
var options = GetBaseOptions(billingAddress, purchase.Tier != ProductTierType.Families);
var items = new List<InvoiceSubscriptionDetailsItemOptions>();
switch (purchase)
{
case { PasswordManager.Sponsored: true }:
var sponsoredPlan = StaticStore.GetSponsoredPlan(PlanSponsorshipType.FamiliesForEnterprise);
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = sponsoredPlan.StripePlanId,
Quantity = 1
});
break;
case { SecretsManager.Standalone: true }:
items.AddRange([
new InvoiceSubscriptionDetailsItemOptions
{
Price = plan.PasswordManager.StripeSeatPlanId,
Quantity = purchase.PasswordManager.Seats
},
new InvoiceSubscriptionDetailsItemOptions
{
Price = plan.SecretsManager.StripeSeatPlanId,
Quantity = purchase.SecretsManager.Seats
}
]);
options.Coupon = CouponIDs.SecretsManagerStandalone;
break;
default:
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = plan.HasNonSeatBasedPasswordManagerPlan()
? plan.PasswordManager.StripePlanId
: plan.PasswordManager.StripeSeatPlanId,
Quantity = purchase.PasswordManager.Seats
});
if (purchase.PasswordManager.AdditionalStorage > 0)
{
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = plan.PasswordManager.StripeStoragePlanId,
Quantity = purchase.PasswordManager.AdditionalStorage
});
}
if (purchase.SecretsManager is { Seats: > 0 })
{
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = plan.SecretsManager.StripeSeatPlanId,
Quantity = purchase.SecretsManager.Seats
});
if (purchase.SecretsManager.AdditionalServiceAccounts > 0)
{
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = plan.SecretsManager.StripeServiceAccountPlanId,
Quantity = purchase.SecretsManager.AdditionalServiceAccounts
});
}
}
break;
}
options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items };
var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options);
return GetAmounts(invoice);
});
public Task<BillingCommandResult<(decimal Tax, decimal Total)>> Run(
Organization organization,
OrganizationSubscriptionPlanChange planChange,
BillingAddress billingAddress)
=> HandleAsync<(decimal, decimal)>(async () =>
{
if (organization.PlanType.GetProductTier() == ProductTierType.Free)
{
var options = GetBaseOptions(billingAddress, planChange.Tier != ProductTierType.Families);
var newPlan = await pricingClient.GetPlanOrThrow(planChange.PlanType);
var items = new List<InvoiceSubscriptionDetailsItemOptions>
{
new ()
{
Price = newPlan.HasNonSeatBasedPasswordManagerPlan()
? newPlan.PasswordManager.StripePlanId
: newPlan.PasswordManager.StripeSeatPlanId,
Quantity = 2
}
};
if (organization.UseSecretsManager && planChange.Tier != ProductTierType.Families)
{
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = newPlan.SecretsManager.StripeSeatPlanId,
Quantity = 2
});
}
options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items };
var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options);
return GetAmounts(invoice);
}
else
{
if (organization is not
{
GatewayCustomerId: not null,
GatewaySubscriptionId: not null
})
{
return new BadRequest("Organization does not have a subscription.");
}
var options = GetBaseOptions(billingAddress, planChange.Tier != ProductTierType.Families);
var subscription = await stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId,
new SubscriptionGetOptions { Expand = ["customer"] });
if (subscription.Customer.Discount != null)
{
options.Coupon = subscription.Customer.Discount.Coupon.Id;
}
var currentPlan = await pricingClient.GetPlanOrThrow(organization.PlanType);
var newPlan = await pricingClient.GetPlanOrThrow(planChange.PlanType);
var subscriptionItemsByPriceId =
subscription.Items.ToDictionary(subscriptionItem => subscriptionItem.Price.Id);
var items = new List<InvoiceSubscriptionDetailsItemOptions>();
var passwordManagerSeats = subscriptionItemsByPriceId[
currentPlan.HasNonSeatBasedPasswordManagerPlan()
? currentPlan.PasswordManager.StripePlanId
: currentPlan.PasswordManager.StripeSeatPlanId];
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = newPlan.HasNonSeatBasedPasswordManagerPlan()
? newPlan.PasswordManager.StripePlanId
: newPlan.PasswordManager.StripeSeatPlanId,
Quantity = passwordManagerSeats.Quantity
});
var hasStorage =
subscriptionItemsByPriceId.TryGetValue(newPlan.PasswordManager.StripeStoragePlanId,
out var storage);
if (hasStorage && storage != null)
{
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = newPlan.PasswordManager.StripeStoragePlanId,
Quantity = storage.Quantity
});
}
var hasSecretsManagerSeats = subscriptionItemsByPriceId.TryGetValue(
newPlan.SecretsManager.StripeSeatPlanId,
out var secretsManagerSeats);
if (hasSecretsManagerSeats && secretsManagerSeats != null)
{
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = newPlan.SecretsManager.StripeSeatPlanId,
Quantity = secretsManagerSeats.Quantity
});
var hasServiceAccounts =
subscriptionItemsByPriceId.TryGetValue(newPlan.SecretsManager.StripeServiceAccountPlanId,
out var serviceAccounts);
if (hasServiceAccounts && serviceAccounts != null)
{
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = newPlan.SecretsManager.StripeServiceAccountPlanId,
Quantity = serviceAccounts.Quantity
});
}
}
options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items };
var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options);
return GetAmounts(invoice);
}
});
public Task<BillingCommandResult<(decimal Tax, decimal Total)>> Run(
Organization organization,
OrganizationSubscriptionUpdate update)
=> HandleAsync<(decimal, decimal)>(async () =>
{
if (organization is not
{
GatewayCustomerId: not null,
GatewaySubscriptionId: not null
})
{
return new BadRequest("Organization does not have a subscription.");
}
var subscription = await stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId,
new SubscriptionGetOptions { Expand = ["customer.tax_ids"] });
var options = GetBaseOptions(subscription.Customer,
organization.GetProductUsageType() == ProductUsageType.Business);
if (subscription.Customer.Discount != null)
{
options.Coupon = subscription.Customer.Discount.Coupon.Id;
}
var currentPlan = await pricingClient.GetPlanOrThrow(organization.PlanType);
var items = new List<InvoiceSubscriptionDetailsItemOptions>();
if (update.PasswordManager?.Seats != null)
{
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = currentPlan.HasNonSeatBasedPasswordManagerPlan()
? currentPlan.PasswordManager.StripePlanId
: currentPlan.PasswordManager.StripeSeatPlanId,
Quantity = update.PasswordManager.Seats
});
}
if (update.PasswordManager?.AdditionalStorage is > 0)
{
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = currentPlan.PasswordManager.StripeStoragePlanId,
Quantity = update.PasswordManager.AdditionalStorage
});
}
if (update.SecretsManager?.Seats is > 0)
{
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = currentPlan.SecretsManager.StripeSeatPlanId,
Quantity = update.SecretsManager.Seats
});
if (update.SecretsManager.AdditionalServiceAccounts is > 0)
{
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = currentPlan.SecretsManager.StripeServiceAccountPlanId,
Quantity = update.SecretsManager.AdditionalServiceAccounts
});
}
}
options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items };
var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options);
return GetAmounts(invoice);
});
private static (decimal, decimal) GetAmounts(Invoice invoice) => (
Convert.ToDecimal(invoice.Tax) / 100,
Convert.ToDecimal(invoice.Total) / 100);
private static InvoiceCreatePreviewOptions GetBaseOptions(
OneOf<Customer, BillingAddress> addressChoice,
bool businessUse)
{
var country = addressChoice.Match(
customer => customer.Address.Country,
billingAddress => billingAddress.Country
);
var postalCode = addressChoice.Match(
customer => customer.Address.PostalCode,
billingAddress => billingAddress.PostalCode);
var options = new InvoiceCreatePreviewOptions
{
AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true },
Currency = "usd",
CustomerDetails = new InvoiceCustomerDetailsOptions
{
Address = new AddressOptions { Country = country, PostalCode = postalCode },
TaxExempt = businessUse && country != CountryAbbreviations.UnitedStates
? TaxExempt.Reverse
: TaxExempt.None
}
};
var taxId = addressChoice.Match(
customer =>
{
var taxId = customer.TaxIds?.FirstOrDefault();
return taxId != null ? new TaxID(taxId.Type, taxId.Value) : null;
},
billingAddress => billingAddress.TaxId);
if (taxId == null)
{
return options;
}
options.CustomerDetails.TaxIds =
[
new InvoiceCustomerDetailsTaxIdOptions { Type = taxId.Code, Value = taxId.Value }
];
if (taxId.Code == TaxIdType.SpanishNIF)
{
options.CustomerDetails.TaxIds.Add(new InvoiceCustomerDetailsTaxIdOptions
{
Type = TaxIdType.EUVAT,
Value = $"ES{taxId.Value}"
});
}
return options;
}
}

View File

@@ -0,0 +1,23 @@
using Bit.Core.Billing.Enums;
namespace Bit.Core.Billing.Organizations.Models;
public record OrganizationSubscriptionPlanChange
{
public ProductTierType Tier { get; init; }
public PlanCadenceType Cadence { get; init; }
public PlanType PlanType =>
// ReSharper disable once SwitchExpressionHandlesSomeKnownEnumValuesWithExceptionInDefault
Tier switch
{
ProductTierType.Families => PlanType.FamiliesAnnually,
ProductTierType.Teams => Cadence == PlanCadenceType.Monthly
? PlanType.TeamsMonthly
: PlanType.TeamsAnnually,
ProductTierType.Enterprise => Cadence == PlanCadenceType.Monthly
? PlanType.EnterpriseMonthly
: PlanType.EnterpriseAnnually,
_ => throw new InvalidOperationException("Cannot change an Organization subscription to a tier that isn't Families, Teams or Enterprise.")
};
}

View File

@@ -0,0 +1,39 @@
using Bit.Core.Billing.Enums;
namespace Bit.Core.Billing.Organizations.Models;
public record OrganizationSubscriptionPurchase
{
public ProductTierType Tier { get; init; }
public PlanCadenceType Cadence { get; init; }
public required PasswordManagerSelections PasswordManager { get; init; }
public SecretsManagerSelections? SecretsManager { get; init; }
public PlanType PlanType =>
// ReSharper disable once SwitchExpressionHandlesSomeKnownEnumValuesWithExceptionInDefault
Tier switch
{
ProductTierType.Families => PlanType.FamiliesAnnually,
ProductTierType.Teams => Cadence == PlanCadenceType.Monthly
? PlanType.TeamsMonthly
: PlanType.TeamsAnnually,
ProductTierType.Enterprise => Cadence == PlanCadenceType.Monthly
? PlanType.EnterpriseMonthly
: PlanType.EnterpriseAnnually,
_ => throw new InvalidOperationException("Cannot purchase an Organization subscription that isn't Families, Teams or Enterprise.")
};
public record PasswordManagerSelections
{
public int Seats { get; init; }
public int AdditionalStorage { get; init; }
public bool Sponsored { get; init; }
}
public record SecretsManagerSelections
{
public int Seats { get; init; }
public int AdditionalServiceAccounts { get; init; }
public bool Standalone { get; init; }
}
}

View File

@@ -0,0 +1,19 @@
namespace Bit.Core.Billing.Organizations.Models;
public record OrganizationSubscriptionUpdate
{
public PasswordManagerSelections? PasswordManager { get; init; }
public SecretsManagerSelections? SecretsManager { get; init; }
public record PasswordManagerSelections
{
public int? Seats { get; init; }
public int? AdditionalStorage { get; init; }
}
public record SecretsManagerSelections
{
public int? Seats { get; init; }
public int? AdditionalServiceAccounts { get; init; }
}
}

View File

@@ -0,0 +1,65 @@
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Services;
using Microsoft.Extensions.Logging;
using Stripe;
namespace Bit.Core.Billing.Premium.Commands;
using static StripeConstants;
public interface IPreviewPremiumTaxCommand
{
Task<BillingCommandResult<(decimal Tax, decimal Total)>> Run(
int additionalStorage,
BillingAddress billingAddress);
}
public class PreviewPremiumTaxCommand(
ILogger<PreviewPremiumTaxCommand> logger,
IStripeAdapter stripeAdapter) : BaseBillingCommand<PreviewPremiumTaxCommand>(logger), IPreviewPremiumTaxCommand
{
public Task<BillingCommandResult<(decimal Tax, decimal Total)>> Run(
int additionalStorage,
BillingAddress billingAddress)
=> HandleAsync<(decimal, decimal)>(async () =>
{
var options = new InvoiceCreatePreviewOptions
{
AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true },
CustomerDetails = new InvoiceCustomerDetailsOptions
{
Address = new AddressOptions
{
Country = billingAddress.Country,
PostalCode = billingAddress.PostalCode
}
},
Currency = "usd",
SubscriptionDetails = new InvoiceSubscriptionDetailsOptions
{
Items =
[
new InvoiceSubscriptionDetailsItemOptions { Price = Prices.PremiumAnnually, Quantity = 1 }
]
}
};
if (additionalStorage > 0)
{
options.SubscriptionDetails.Items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = Prices.StoragePlanPersonal,
Quantity = additionalStorage
});
}
var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options);
return GetAmounts(invoice);
});
private static (decimal, decimal) GetAmounts(Invoice invoice) => (
Convert.ToDecimal(invoice.Tax) / 100,
Convert.ToDecimal(invoice.Total) / 100);
}

View File

@@ -258,7 +258,7 @@ public class ProviderMigrator(
// Create dummy payment source for legacy migration - this migrator is deprecated and will be removed // Create dummy payment source for legacy migration - this migrator is deprecated and will be removed
var dummyPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "migration_dummy_token"); var dummyPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "migration_dummy_token");
var customer = await providerBillingService.SetupCustomer(provider, taxInfo, dummyPaymentSource); var customer = await providerBillingService.SetupCustomer(provider, null, null);
await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions
{ {

View File

@@ -5,10 +5,10 @@ using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Providers.Entities; using Bit.Core.Billing.Providers.Entities;
using Bit.Core.Billing.Providers.Models; using Bit.Core.Billing.Providers.Models;
using Bit.Core.Billing.Tax.Models; using Bit.Core.Billing.Tax.Models;
using Bit.Core.Models.Business;
using Stripe; using Stripe;
namespace Bit.Core.Billing.Providers.Services; namespace Bit.Core.Billing.Providers.Services;
@@ -79,16 +79,16 @@ public interface IProviderBillingService
int seatAdjustment); int seatAdjustment);
/// <summary> /// <summary>
/// For use during the provider setup process, this method creates a Stripe <see cref="Stripe.Customer"/> for the specified <paramref name="provider"/> utilizing the provided <paramref name="taxInfo"/>. /// For use during the provider setup process, this method creates a Stripe <see cref="Stripe.Customer"/> for the specified <paramref name="provider"/> utilizing the provided <paramref name="paymentMethod"/> and <paramref name="billingAddress"/>.
/// </summary> /// </summary>
/// <param name="provider">The <see cref="Provider"/> to create a Stripe customer for.</param> /// <param name="provider">The <see cref="Provider"/> to create a Stripe customer for.</param>
/// <param name="taxInfo">The <see cref="TaxInfo"/> to use for calculating the customer's automatic tax.</param> /// <param name="paymentMethod">The <see cref="TokenizedPaymentMethod"/> (e.g., Credit Card, Bank Account, or PayPal) to attach to the customer.</param>
/// <param name="tokenizedPaymentSource">The <see cref="TokenizedPaymentSource"/> (ex. Credit Card) to attach to the customer.</param> /// <param name="billingAddress">The <see cref="BillingAddress"/> containing the customer's billing information including address and tax ID details.</param>
/// <returns>The newly created <see cref="Stripe.Customer"/> for the <paramref name="provider"/>.</returns> /// <returns>The newly created <see cref="Stripe.Customer"/> for the <paramref name="provider"/>.</returns>
Task<Customer> SetupCustomer( Task<Customer> SetupCustomer(
Provider provider, Provider provider,
TaxInfo taxInfo, TokenizedPaymentMethod paymentMethod,
TokenizedPaymentSource tokenizedPaymentSource); BillingAddress billingAddress);
/// <summary> /// <summary>
/// For use during the provider setup process, this method starts a Stripe <see cref="Stripe.Subscription"/> for the given <paramref name="provider"/>. /// For use during the provider setup process, this method starts a Stripe <see cref="Stripe.Subscription"/> for the given <paramref name="provider"/>.

View File

@@ -0,0 +1,92 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
using Bit.Core.Repositories;
using Bit.Core.Services;
using OneOf.Types;
using Stripe;
namespace Bit.Core.Billing.Subscriptions.Commands;
using static StripeConstants;
public interface IRestartSubscriptionCommand
{
Task<BillingCommandResult<None>> Run(
ISubscriber subscriber);
}
public class RestartSubscriptionCommand(
IOrganizationRepository organizationRepository,
IProviderRepository providerRepository,
IStripeAdapter stripeAdapter,
ISubscriberService subscriberService,
IUserRepository userRepository) : IRestartSubscriptionCommand
{
public async Task<BillingCommandResult<None>> Run(
ISubscriber subscriber)
{
var existingSubscription = await subscriberService.GetSubscription(subscriber);
if (existingSubscription is not { Status: SubscriptionStatus.Canceled })
{
return new BadRequest("Cannot restart a subscription that is not canceled.");
}
var options = new SubscriptionCreateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true },
CollectionMethod = CollectionMethod.ChargeAutomatically,
Customer = existingSubscription.CustomerId,
Items = existingSubscription.Items.Select(subscriptionItem => new SubscriptionItemOptions
{
Price = subscriptionItem.Price.Id,
Quantity = subscriptionItem.Quantity
}).ToList(),
Metadata = existingSubscription.Metadata,
OffSession = true,
TrialPeriodDays = 0
};
var subscription = await stripeAdapter.SubscriptionCreateAsync(options);
await EnableAsync(subscriber, subscription);
return new None();
}
private async Task EnableAsync(ISubscriber subscriber, Subscription subscription)
{
switch (subscriber)
{
case Organization organization:
{
organization.GatewaySubscriptionId = subscription.Id;
organization.Enabled = true;
organization.ExpirationDate = subscription.CurrentPeriodEnd;
organization.RevisionDate = DateTime.UtcNow;
await organizationRepository.ReplaceAsync(organization);
break;
}
case Provider provider:
{
provider.GatewaySubscriptionId = subscription.Id;
provider.Enabled = true;
provider.RevisionDate = DateTime.UtcNow;
await providerRepository.ReplaceAsync(provider);
break;
}
case User user:
{
user.GatewaySubscriptionId = subscription.Id;
user.Premium = true;
user.PremiumExpirationDate = subscription.CurrentPeriodEnd;
user.RevisionDate = DateTime.UtcNow;
await userRepository.ReplaceAsync(user);
break;
}
}
}
}

View File

@@ -1,136 +0,0 @@
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Services;
using Microsoft.Extensions.Logging;
using Stripe;
namespace Bit.Core.Billing.Tax.Commands;
public interface IPreviewTaxAmountCommand
{
Task<BillingCommandResult<decimal>> Run(OrganizationTrialParameters parameters);
}
public class PreviewTaxAmountCommand(
ILogger<PreviewTaxAmountCommand> logger,
IPricingClient pricingClient,
IStripeAdapter stripeAdapter,
ITaxService taxService) : BaseBillingCommand<PreviewTaxAmountCommand>(logger), IPreviewTaxAmountCommand
{
protected override Conflict DefaultConflict
=> new("We had a problem calculating your tax obligation. Please contact support for assistance.");
public Task<BillingCommandResult<decimal>> Run(OrganizationTrialParameters parameters)
=> HandleAsync<decimal>(async () =>
{
var (planType, productType, taxInformation) = parameters;
var plan = await pricingClient.GetPlanOrThrow(planType);
var options = new InvoiceCreatePreviewOptions
{
Currency = "usd",
CustomerDetails = new InvoiceCustomerDetailsOptions
{
Address = new AddressOptions
{
Country = taxInformation.Country,
PostalCode = taxInformation.PostalCode
}
},
SubscriptionDetails = new InvoiceSubscriptionDetailsOptions
{
Items =
[
new InvoiceSubscriptionDetailsItemOptions
{
Price = plan.HasNonSeatBasedPasswordManagerPlan()
? plan.PasswordManager.StripePlanId
: plan.PasswordManager.StripeSeatPlanId,
Quantity = 1
}
]
}
};
if (productType == ProductType.SecretsManager)
{
options.SubscriptionDetails.Items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = plan.SecretsManager.StripeSeatPlanId,
Quantity = 1
});
options.Coupon = StripeConstants.CouponIDs.SecretsManagerStandalone;
}
if (!string.IsNullOrEmpty(taxInformation.TaxId))
{
var taxIdType = taxService.GetStripeTaxCode(
taxInformation.Country,
taxInformation.TaxId);
if (string.IsNullOrEmpty(taxIdType))
{
return new BadRequest(
"We couldn't find a corresponding tax ID type for the tax ID you provided. Please try again or contact support for assistance.");
}
options.CustomerDetails.TaxIds =
[
new InvoiceCustomerDetailsTaxIdOptions { Type = taxIdType, Value = taxInformation.TaxId }
];
if (taxIdType == StripeConstants.TaxIdType.SpanishNIF)
{
options.CustomerDetails.TaxIds.Add(new InvoiceCustomerDetailsTaxIdOptions
{
Type = StripeConstants.TaxIdType.EUVAT,
Value = $"ES{parameters.TaxInformation.TaxId}"
});
}
}
options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true };
if (parameters.PlanType.IsBusinessProductTierType() &&
parameters.TaxInformation.Country != Core.Constants.CountryAbbreviations.UnitedStates)
{
options.CustomerDetails.TaxExempt = StripeConstants.TaxExempt.Reverse;
}
var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options);
return Convert.ToDecimal(invoice.Tax) / 100;
});
}
#region Command Parameters
public record OrganizationTrialParameters
{
public required PlanType PlanType { get; set; }
public required ProductType ProductType { get; set; }
public required TaxInformationDTO TaxInformation { get; set; }
public void Deconstruct(
out PlanType planType,
out ProductType productType,
out TaxInformationDTO taxInformation)
{
planType = PlanType;
productType = ProductType;
taxInformation = TaxInformation;
}
public record TaxInformationDTO
{
public required string Country { get; set; }
public required string PostalCode { get; set; }
public string? TaxId { get; set; }
}
}
#endregion

View File

@@ -169,7 +169,6 @@ public static class FeatureFlagKeys
public const string PM17772_AdminInitiatedSponsorships = "pm-17772-admin-initiated-sponsorships"; public const string PM17772_AdminInitiatedSponsorships = "pm-17772-admin-initiated-sponsorships";
public const string UsePricingService = "use-pricing-service"; public const string UsePricingService = "use-pricing-service";
public const string PM19422_AllowAutomaticTaxUpdates = "pm-19422-allow-automatic-tax-updates"; public const string PM19422_AllowAutomaticTaxUpdates = "pm-19422-allow-automatic-tax-updates";
public const string PM21881_ManagePaymentDetailsOutsideCheckout = "pm-21881-manage-payment-details-outside-checkout";
public const string PM21821_ProviderPortalTakeover = "pm-21821-provider-portal-takeover"; public const string PM21821_ProviderPortalTakeover = "pm-21821-provider-portal-takeover";
public const string PM22415_TaxIDWarnings = "pm-22415-tax-id-warnings"; public const string PM22415_TaxIDWarnings = "pm-22415-tax-id-warnings";
public const string PM23385_UseNewPremiumFlow = "pm-23385-use-new-premium-flow"; public const string PM23385_UseNewPremiumFlow = "pm-23385-use-new-premium-flow";

View File

@@ -3,6 +3,7 @@
using Bit.Core.Models.BitStripe; using Bit.Core.Models.BitStripe;
using Stripe; using Stripe;
using Stripe.Tax;
namespace Bit.Core.Services; namespace Bit.Core.Services;
@@ -23,6 +24,7 @@ public class StripeAdapter : IStripeAdapter
private readonly Stripe.TestHelpers.TestClockService _testClockService; private readonly Stripe.TestHelpers.TestClockService _testClockService;
private readonly CustomerBalanceTransactionService _customerBalanceTransactionService; private readonly CustomerBalanceTransactionService _customerBalanceTransactionService;
private readonly Stripe.Tax.RegistrationService _taxRegistrationService; private readonly Stripe.Tax.RegistrationService _taxRegistrationService;
private readonly CalculationService _calculationService;
public StripeAdapter() public StripeAdapter()
{ {
@@ -41,6 +43,7 @@ public class StripeAdapter : IStripeAdapter
_testClockService = new Stripe.TestHelpers.TestClockService(); _testClockService = new Stripe.TestHelpers.TestClockService();
_customerBalanceTransactionService = new CustomerBalanceTransactionService(); _customerBalanceTransactionService = new CustomerBalanceTransactionService();
_taxRegistrationService = new Stripe.Tax.RegistrationService(); _taxRegistrationService = new Stripe.Tax.RegistrationService();
_calculationService = new CalculationService();
} }
public Task<Stripe.Customer> CustomerCreateAsync(Stripe.CustomerCreateOptions options) public Task<Stripe.Customer> CustomerCreateAsync(Stripe.CustomerCreateOptions options)

View File

@@ -0,0 +1,292 @@
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Premium.Commands;
using Bit.Core.Services;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Stripe;
using Xunit;
using static Bit.Core.Billing.Constants.StripeConstants;
namespace Bit.Core.Test.Billing.Premium.Commands;
public class PreviewPremiumTaxCommandTests
{
private readonly ILogger<PreviewPremiumTaxCommand> _logger = Substitute.For<ILogger<PreviewPremiumTaxCommand>>();
private readonly IStripeAdapter _stripeAdapter = Substitute.For<IStripeAdapter>();
private readonly PreviewPremiumTaxCommand _command;
public PreviewPremiumTaxCommandTests()
{
_command = new PreviewPremiumTaxCommand(_logger, _stripeAdapter);
}
[Fact]
public async Task Run_PremiumWithoutStorage_ReturnsCorrectTaxAmounts()
{
var billingAddress = new BillingAddress
{
Country = "US",
PostalCode = "12345"
};
var invoice = new Invoice
{
Tax = 300,
Total = 3300
};
_stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>()).Returns(invoice);
var result = await _command.Run(0, billingAddress);
Assert.True(result.IsT0);
var (tax, total) = result.AsT0;
Assert.Equal(3.00m, tax);
Assert.Equal(33.00m, total);
await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.AutomaticTax.Enabled == true &&
options.Currency == "usd" &&
options.CustomerDetails.Address.Country == "US" &&
options.CustomerDetails.Address.PostalCode == "12345" &&
options.SubscriptionDetails.Items.Count == 1 &&
options.SubscriptionDetails.Items[0].Price == Prices.PremiumAnnually &&
options.SubscriptionDetails.Items[0].Quantity == 1));
}
[Fact]
public async Task Run_PremiumWithAdditionalStorage_ReturnsCorrectTaxAmounts()
{
var billingAddress = new BillingAddress
{
Country = "CA",
PostalCode = "K1A 0A6"
};
var invoice = new Invoice
{
Tax = 500,
Total = 5500
};
_stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>()).Returns(invoice);
var result = await _command.Run(5, billingAddress);
Assert.True(result.IsT0);
var (tax, total) = result.AsT0;
Assert.Equal(5.00m, tax);
Assert.Equal(55.00m, total);
await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.AutomaticTax.Enabled == true &&
options.Currency == "usd" &&
options.CustomerDetails.Address.Country == "CA" &&
options.CustomerDetails.Address.PostalCode == "K1A 0A6" &&
options.SubscriptionDetails.Items.Count == 2 &&
options.SubscriptionDetails.Items.Any(item =>
item.Price == Prices.PremiumAnnually && item.Quantity == 1) &&
options.SubscriptionDetails.Items.Any(item =>
item.Price == Prices.StoragePlanPersonal && item.Quantity == 5)));
}
[Fact]
public async Task Run_PremiumWithZeroStorage_ExcludesStorageFromItems()
{
var billingAddress = new BillingAddress
{
Country = "GB",
PostalCode = "SW1A 1AA"
};
var invoice = new Invoice
{
Tax = 250,
Total = 2750
};
_stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>()).Returns(invoice);
var result = await _command.Run(0, billingAddress);
Assert.True(result.IsT0);
var (tax, total) = result.AsT0;
Assert.Equal(2.50m, tax);
Assert.Equal(27.50m, total);
await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.AutomaticTax.Enabled == true &&
options.Currency == "usd" &&
options.CustomerDetails.Address.Country == "GB" &&
options.CustomerDetails.Address.PostalCode == "SW1A 1AA" &&
options.SubscriptionDetails.Items.Count == 1 &&
options.SubscriptionDetails.Items[0].Price == Prices.PremiumAnnually &&
options.SubscriptionDetails.Items[0].Quantity == 1));
}
[Fact]
public async Task Run_PremiumWithLargeStorage_HandlesMultipleStorageUnits()
{
var billingAddress = new BillingAddress
{
Country = "DE",
PostalCode = "10115"
};
var invoice = new Invoice
{
Tax = 800,
Total = 8800
};
_stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>()).Returns(invoice);
var result = await _command.Run(20, billingAddress);
Assert.True(result.IsT0);
var (tax, total) = result.AsT0;
Assert.Equal(8.00m, tax);
Assert.Equal(88.00m, total);
await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.AutomaticTax.Enabled == true &&
options.Currency == "usd" &&
options.CustomerDetails.Address.Country == "DE" &&
options.CustomerDetails.Address.PostalCode == "10115" &&
options.SubscriptionDetails.Items.Count == 2 &&
options.SubscriptionDetails.Items.Any(item =>
item.Price == Prices.PremiumAnnually && item.Quantity == 1) &&
options.SubscriptionDetails.Items.Any(item =>
item.Price == Prices.StoragePlanPersonal && item.Quantity == 20)));
}
[Fact]
public async Task Run_PremiumInternationalAddress_UsesCorrectAddressInfo()
{
var billingAddress = new BillingAddress
{
Country = "AU",
PostalCode = "2000"
};
var invoice = new Invoice
{
Tax = 450,
Total = 4950
};
_stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>()).Returns(invoice);
var result = await _command.Run(10, billingAddress);
Assert.True(result.IsT0);
var (tax, total) = result.AsT0;
Assert.Equal(4.50m, tax);
Assert.Equal(49.50m, total);
await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.AutomaticTax.Enabled == true &&
options.Currency == "usd" &&
options.CustomerDetails.Address.Country == "AU" &&
options.CustomerDetails.Address.PostalCode == "2000" &&
options.SubscriptionDetails.Items.Count == 2 &&
options.SubscriptionDetails.Items.Any(item =>
item.Price == Prices.PremiumAnnually && item.Quantity == 1) &&
options.SubscriptionDetails.Items.Any(item =>
item.Price == Prices.StoragePlanPersonal && item.Quantity == 10)));
}
[Fact]
public async Task Run_PremiumNoTax_ReturnsZeroTax()
{
var billingAddress = new BillingAddress
{
Country = "US",
PostalCode = "97330" // Example of a tax-free jurisdiction
};
var invoice = new Invoice
{
Tax = 0,
Total = 3000
};
_stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>()).Returns(invoice);
var result = await _command.Run(0, billingAddress);
Assert.True(result.IsT0);
var (tax, total) = result.AsT0;
Assert.Equal(0.00m, tax);
Assert.Equal(30.00m, total);
await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.AutomaticTax.Enabled == true &&
options.Currency == "usd" &&
options.CustomerDetails.Address.Country == "US" &&
options.CustomerDetails.Address.PostalCode == "97330" &&
options.SubscriptionDetails.Items.Count == 1 &&
options.SubscriptionDetails.Items[0].Price == Prices.PremiumAnnually &&
options.SubscriptionDetails.Items[0].Quantity == 1));
}
[Fact]
public async Task Run_NegativeStorage_TreatedAsZero()
{
var billingAddress = new BillingAddress
{
Country = "FR",
PostalCode = "75001"
};
var invoice = new Invoice
{
Tax = 600,
Total = 6600
};
_stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>()).Returns(invoice);
var result = await _command.Run(-5, billingAddress);
Assert.True(result.IsT0);
var (tax, total) = result.AsT0;
Assert.Equal(6.00m, tax);
Assert.Equal(66.00m, total);
await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.AutomaticTax.Enabled == true &&
options.Currency == "usd" &&
options.CustomerDetails.Address.Country == "FR" &&
options.CustomerDetails.Address.PostalCode == "75001" &&
options.SubscriptionDetails.Items.Count == 1 &&
options.SubscriptionDetails.Items[0].Price == Prices.PremiumAnnually &&
options.SubscriptionDetails.Items[0].Quantity == 1));
}
[Fact]
public async Task Run_AmountConversion_CorrectlyConvertsStripeAmounts()
{
var billingAddress = new BillingAddress
{
Country = "US",
PostalCode = "12345"
};
// Stripe amounts are in cents
var invoice = new Invoice
{
Tax = 123, // $1.23
Total = 3123 // $31.23
};
_stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>()).Returns(invoice);
var result = await _command.Run(0, billingAddress);
Assert.True(result.IsT0);
var (tax, total) = result.AsT0;
Assert.Equal(1.23m, tax);
Assert.Equal(31.23m, total);
}
}

View File

@@ -0,0 +1,198 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Subscriptions.Commands;
using Bit.Core.Entities;
using Bit.Core.Repositories;
using Bit.Core.Services;
using NSubstitute;
using Stripe;
using Xunit;
namespace Bit.Core.Test.Billing.Subscriptions;
using static StripeConstants;
public class RestartSubscriptionCommandTests
{
private readonly IOrganizationRepository _organizationRepository = Substitute.For<IOrganizationRepository>();
private readonly IProviderRepository _providerRepository = Substitute.For<IProviderRepository>();
private readonly IStripeAdapter _stripeAdapter = Substitute.For<IStripeAdapter>();
private readonly ISubscriberService _subscriberService = Substitute.For<ISubscriberService>();
private readonly IUserRepository _userRepository = Substitute.For<IUserRepository>();
private readonly RestartSubscriptionCommand _command;
public RestartSubscriptionCommandTests()
{
_command = new RestartSubscriptionCommand(
_organizationRepository,
_providerRepository,
_stripeAdapter,
_subscriberService,
_userRepository);
}
[Fact]
public async Task Run_SubscriptionNotCanceled_ReturnsBadRequest()
{
var organization = new Organization { Id = Guid.NewGuid() };
var subscription = new Subscription { Status = SubscriptionStatus.Active };
_subscriberService.GetSubscription(organization).Returns(subscription);
var result = await _command.Run(organization);
Assert.True(result.IsT1);
var badRequest = result.AsT1;
Assert.Equal("Cannot restart a subscription that is not canceled.", badRequest.Response);
}
[Fact]
public async Task Run_NoExistingSubscription_ReturnsBadRequest()
{
var organization = new Organization { Id = Guid.NewGuid() };
_subscriberService.GetSubscription(organization).Returns((Subscription)null);
var result = await _command.Run(organization);
Assert.True(result.IsT1);
var badRequest = result.AsT1;
Assert.Equal("Cannot restart a subscription that is not canceled.", badRequest.Response);
}
[Fact]
public async Task Run_Organization_Success_ReturnsNone()
{
var organizationId = Guid.NewGuid();
var organization = new Organization { Id = organizationId };
var currentPeriodEnd = DateTime.UtcNow.AddMonths(1);
var existingSubscription = new Subscription
{
Status = SubscriptionStatus.Canceled,
CustomerId = "cus_123",
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { Price = new Price { Id = "price_1" }, Quantity = 1 },
new SubscriptionItem { Price = new Price { Id = "price_2" }, Quantity = 2 }
]
},
Metadata = new Dictionary<string, string> { ["key"] = "value" }
};
var newSubscription = new Subscription
{
Id = "sub_new",
CurrentPeriodEnd = currentPeriodEnd
};
_subscriberService.GetSubscription(organization).Returns(existingSubscription);
_stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(newSubscription);
var result = await _command.Run(organization);
Assert.True(result.IsT0);
await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is((SubscriptionCreateOptions options) =>
options.AutomaticTax.Enabled == true &&
options.CollectionMethod == CollectionMethod.ChargeAutomatically &&
options.Customer == "cus_123" &&
options.Items.Count == 2 &&
options.Items[0].Price == "price_1" &&
options.Items[0].Quantity == 1 &&
options.Items[1].Price == "price_2" &&
options.Items[1].Quantity == 2 &&
options.Metadata["key"] == "value" &&
options.OffSession == true &&
options.TrialPeriodDays == 0));
await _organizationRepository.Received(1).ReplaceAsync(Arg.Is<Organization>(org =>
org.Id == organizationId &&
org.GatewaySubscriptionId == "sub_new" &&
org.Enabled == true &&
org.ExpirationDate == currentPeriodEnd));
}
[Fact]
public async Task Run_Provider_Success_ReturnsNone()
{
var providerId = Guid.NewGuid();
var provider = new Provider { Id = providerId };
var existingSubscription = new Subscription
{
Status = SubscriptionStatus.Canceled,
CustomerId = "cus_123",
Items = new StripeList<SubscriptionItem>
{
Data = [new SubscriptionItem { Price = new Price { Id = "price_1" }, Quantity = 1 }]
},
Metadata = new Dictionary<string, string>()
};
var newSubscription = new Subscription
{
Id = "sub_new",
CurrentPeriodEnd = DateTime.UtcNow.AddMonths(1)
};
_subscriberService.GetSubscription(provider).Returns(existingSubscription);
_stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(newSubscription);
var result = await _command.Run(provider);
Assert.True(result.IsT0);
await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>());
await _providerRepository.Received(1).ReplaceAsync(Arg.Is<Provider>(prov =>
prov.Id == providerId &&
prov.GatewaySubscriptionId == "sub_new" &&
prov.Enabled == true));
}
[Fact]
public async Task Run_User_Success_ReturnsNone()
{
var userId = Guid.NewGuid();
var user = new User { Id = userId };
var currentPeriodEnd = DateTime.UtcNow.AddMonths(1);
var existingSubscription = new Subscription
{
Status = SubscriptionStatus.Canceled,
CustomerId = "cus_123",
Items = new StripeList<SubscriptionItem>
{
Data = [new SubscriptionItem { Price = new Price { Id = "price_1" }, Quantity = 1 }]
},
Metadata = new Dictionary<string, string>()
};
var newSubscription = new Subscription
{
Id = "sub_new",
CurrentPeriodEnd = currentPeriodEnd
};
_subscriberService.GetSubscription(user).Returns(existingSubscription);
_stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(newSubscription);
var result = await _command.Run(user);
Assert.True(result.IsT0);
await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>());
await _userRepository.Received(1).ReplaceAsync(Arg.Is<User>(u =>
u.Id == userId &&
u.GatewaySubscriptionId == "sub_new" &&
u.Premium == true &&
u.PremiumExpirationDate == currentPeriodEnd));
}
}

View File

@@ -1,541 +0,0 @@
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Tax.Commands;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Services;
using Bit.Core.Utilities;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Stripe;
using Xunit;
using static Bit.Core.Billing.Tax.Commands.OrganizationTrialParameters;
namespace Bit.Core.Test.Billing.Tax.Commands;
public class PreviewTaxAmountCommandTests
{
private readonly ILogger<PreviewTaxAmountCommand> _logger = Substitute.For<ILogger<PreviewTaxAmountCommand>>();
private readonly IPricingClient _pricingClient = Substitute.For<IPricingClient>();
private readonly IStripeAdapter _stripeAdapter = Substitute.For<IStripeAdapter>();
private readonly ITaxService _taxService = Substitute.For<ITaxService>();
private readonly PreviewTaxAmountCommand _command;
public PreviewTaxAmountCommandTests()
{
_command = new PreviewTaxAmountCommand(_logger, _pricingClient, _stripeAdapter, _taxService);
}
[Fact]
public async Task Run_WithSeatBasedPasswordManagerPlan_GetsTaxAmount()
{
// Arrange
var parameters = new OrganizationTrialParameters
{
PlanType = PlanType.EnterpriseAnnually,
ProductType = ProductType.PasswordManager,
TaxInformation = new TaxInformationDTO
{
Country = "US",
PostalCode = "12345"
}
};
var plan = StaticStore.GetPlan(parameters.PlanType);
_pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan);
var expectedInvoice = new Invoice { Tax = 1000 }; // $10.00 in cents
_stripeAdapter.InvoiceCreatePreviewAsync(Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.Currency == "usd" &&
options.CustomerDetails.Address.Country == "US" &&
options.CustomerDetails.Address.PostalCode == "12345" &&
options.SubscriptionDetails.Items.Count == 1 &&
options.SubscriptionDetails.Items[0].Price == plan.PasswordManager.StripeSeatPlanId &&
options.SubscriptionDetails.Items[0].Quantity == 1 &&
options.AutomaticTax.Enabled == true
))
.Returns(expectedInvoice);
// Act
var result = await _command.Run(parameters);
// Assert
Assert.True(result.IsT0);
var taxAmount = result.AsT0;
Assert.Equal(expectedInvoice.Tax, (long)taxAmount * 100);
}
[Fact]
public async Task Run_WithNonSeatBasedPasswordManagerPlan_GetsTaxAmount()
{
// Arrange
var parameters = new OrganizationTrialParameters
{
PlanType = PlanType.FamiliesAnnually,
ProductType = ProductType.PasswordManager,
TaxInformation = new TaxInformationDTO
{
Country = "US",
PostalCode = "12345"
}
};
var plan = StaticStore.GetPlan(parameters.PlanType);
_pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan);
var expectedInvoice = new Invoice { Tax = 1000 }; // $10.00 in cents
_stripeAdapter.InvoiceCreatePreviewAsync(Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.Currency == "usd" &&
options.CustomerDetails.Address.Country == "US" &&
options.CustomerDetails.Address.PostalCode == "12345" &&
options.SubscriptionDetails.Items.Count == 1 &&
options.SubscriptionDetails.Items[0].Price == plan.PasswordManager.StripePlanId &&
options.SubscriptionDetails.Items[0].Quantity == 1 &&
options.AutomaticTax.Enabled == true
))
.Returns(expectedInvoice);
// Act
var result = await _command.Run(parameters);
// Assert
Assert.True(result.IsT0);
var taxAmount = result.AsT0;
Assert.Equal(expectedInvoice.Tax, (long)taxAmount * 100);
}
[Fact]
public async Task Run_WithSecretsManagerPlan_GetsTaxAmount()
{
// Arrange
var parameters = new OrganizationTrialParameters
{
PlanType = PlanType.EnterpriseAnnually,
ProductType = ProductType.SecretsManager,
TaxInformation = new TaxInformationDTO
{
Country = "US",
PostalCode = "12345"
}
};
var plan = StaticStore.GetPlan(parameters.PlanType);
_pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan);
var expectedInvoice = new Invoice { Tax = 1000 }; // $10.00 in cents
_stripeAdapter.InvoiceCreatePreviewAsync(Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.Currency == "usd" &&
options.CustomerDetails.Address.Country == "US" &&
options.CustomerDetails.Address.PostalCode == "12345" &&
options.SubscriptionDetails.Items.Count == 2 &&
options.SubscriptionDetails.Items[0].Price == plan.PasswordManager.StripeSeatPlanId &&
options.SubscriptionDetails.Items[0].Quantity == 1 &&
options.SubscriptionDetails.Items[1].Price == plan.SecretsManager.StripeSeatPlanId &&
options.SubscriptionDetails.Items[1].Quantity == 1 &&
options.Coupon == StripeConstants.CouponIDs.SecretsManagerStandalone &&
options.AutomaticTax.Enabled == true
))
.Returns(expectedInvoice);
// Act
var result = await _command.Run(parameters);
// Assert
Assert.True(result.IsT0);
var taxAmount = result.AsT0;
Assert.Equal(expectedInvoice.Tax, (long)taxAmount * 100);
}
[Fact]
public async Task Run_NonUSWithoutTaxId_GetsTaxAmount()
{
// Arrange
var parameters = new OrganizationTrialParameters
{
PlanType = PlanType.EnterpriseAnnually,
ProductType = ProductType.PasswordManager,
TaxInformation = new TaxInformationDTO
{
Country = "CA",
PostalCode = "12345"
}
};
var plan = StaticStore.GetPlan(parameters.PlanType);
_pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan);
var expectedInvoice = new Invoice { Tax = 1000 }; // $10.00 in cents
_stripeAdapter.InvoiceCreatePreviewAsync(Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.Currency == "usd" &&
options.CustomerDetails.Address.Country == "CA" &&
options.CustomerDetails.Address.PostalCode == "12345" &&
options.SubscriptionDetails.Items.Count == 1 &&
options.SubscriptionDetails.Items[0].Price == plan.PasswordManager.StripeSeatPlanId &&
options.SubscriptionDetails.Items[0].Quantity == 1 &&
options.AutomaticTax.Enabled == true
))
.Returns(expectedInvoice);
// Act
var result = await _command.Run(parameters);
// Assert
Assert.True(result.IsT0);
var taxAmount = result.AsT0;
Assert.Equal(expectedInvoice.Tax, (long)taxAmount * 100);
}
[Fact]
public async Task Run_NonUSWithTaxId_GetsTaxAmount()
{
// Arrange
var parameters = new OrganizationTrialParameters
{
PlanType = PlanType.EnterpriseAnnually,
ProductType = ProductType.PasswordManager,
TaxInformation = new TaxInformationDTO
{
Country = "CA",
PostalCode = "12345",
TaxId = "123456789"
}
};
var plan = StaticStore.GetPlan(parameters.PlanType);
_pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan);
_taxService.GetStripeTaxCode(parameters.TaxInformation.Country, parameters.TaxInformation.TaxId)
.Returns("ca_st");
var expectedInvoice = new Invoice { Tax = 1000 }; // $10.00 in cents
_stripeAdapter.InvoiceCreatePreviewAsync(Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.Currency == "usd" &&
options.CustomerDetails.Address.Country == "CA" &&
options.CustomerDetails.Address.PostalCode == "12345" &&
options.CustomerDetails.TaxIds.Count == 1 &&
options.CustomerDetails.TaxIds[0].Type == "ca_st" &&
options.CustomerDetails.TaxIds[0].Value == "123456789" &&
options.SubscriptionDetails.Items.Count == 1 &&
options.SubscriptionDetails.Items[0].Price == plan.PasswordManager.StripeSeatPlanId &&
options.SubscriptionDetails.Items[0].Quantity == 1 &&
options.AutomaticTax.Enabled == true
))
.Returns(expectedInvoice);
// Act
var result = await _command.Run(parameters);
// Assert
Assert.True(result.IsT0);
var taxAmount = result.AsT0;
Assert.Equal(expectedInvoice.Tax, (long)taxAmount * 100);
}
[Fact]
public async Task Run_NonUSWithTaxId_UnknownTaxIdType_BadRequest()
{
// Arrange
var parameters = new OrganizationTrialParameters
{
PlanType = PlanType.EnterpriseAnnually,
ProductType = ProductType.PasswordManager,
TaxInformation = new TaxInformationDTO
{
Country = "CA",
PostalCode = "12345",
TaxId = "123456789"
}
};
var plan = StaticStore.GetPlan(parameters.PlanType);
_pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan);
_taxService.GetStripeTaxCode(parameters.TaxInformation.Country, parameters.TaxInformation.TaxId)
.Returns((string)null);
// Act
var result = await _command.Run(parameters);
// Assert
Assert.True(result.IsT1);
var badRequest = result.AsT1;
Assert.Equal("We couldn't find a corresponding tax ID type for the tax ID you provided. Please try again or contact support for assistance.", badRequest.Response);
}
[Fact]
public async Task Run_USBased_PersonalUse_SetsAutomaticTaxEnabled()
{
// Arrange
var parameters = new OrganizationTrialParameters
{
PlanType = PlanType.FamiliesAnnually,
ProductType = ProductType.PasswordManager,
TaxInformation = new TaxInformationDTO
{
Country = "US",
PostalCode = "12345"
}
};
var plan = StaticStore.GetPlan(parameters.PlanType);
_pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan);
var expectedInvoice = new Invoice { Tax = 1000 }; // $10.00 in cents
_stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(expectedInvoice);
// Act
var result = await _command.Run(parameters);
// Assert
await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.AutomaticTax.Enabled == true
));
Assert.True(result.IsT0);
}
[Fact]
public async Task Run_USBased_BusinessUse_SetsAutomaticTaxEnabled()
{
// Arrange
var parameters = new OrganizationTrialParameters
{
PlanType = PlanType.EnterpriseAnnually,
ProductType = ProductType.PasswordManager,
TaxInformation = new TaxInformationDTO
{
Country = "US",
PostalCode = "12345"
}
};
var plan = StaticStore.GetPlan(parameters.PlanType);
_pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan);
var expectedInvoice = new Invoice { Tax = 1000 }; // $10.00 in cents
_stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(expectedInvoice);
// Act
var result = await _command.Run(parameters);
// Assert
await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.AutomaticTax.Enabled == true
));
Assert.True(result.IsT0);
}
[Fact]
public async Task Run_NonUSBased_PersonalUse_SetsAutomaticTaxEnabled()
{
// Arrange
var parameters = new OrganizationTrialParameters
{
PlanType = PlanType.FamiliesAnnually,
ProductType = ProductType.PasswordManager,
TaxInformation = new TaxInformationDTO
{
Country = "CA",
PostalCode = "12345"
}
};
var plan = StaticStore.GetPlan(parameters.PlanType);
_pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan);
var expectedInvoice = new Invoice { Tax = 1000 }; // $10.00 in cents
_stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(expectedInvoice);
// Act
var result = await _command.Run(parameters);
// Assert
await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.AutomaticTax.Enabled == true
));
Assert.True(result.IsT0);
}
[Fact]
public async Task Run_NonUSBased_BusinessUse_SetsAutomaticTaxEnabled()
{
// Arrange
var parameters = new OrganizationTrialParameters
{
PlanType = PlanType.EnterpriseAnnually,
ProductType = ProductType.PasswordManager,
TaxInformation = new TaxInformationDTO
{
Country = "CA",
PostalCode = "12345"
}
};
var plan = StaticStore.GetPlan(parameters.PlanType);
_pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan);
var expectedInvoice = new Invoice { Tax = 1000 }; // $10.00 in cents
_stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(expectedInvoice);
// Act
var result = await _command.Run(parameters);
// Assert
await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.AutomaticTax.Enabled == true
));
Assert.True(result.IsT0);
}
[Fact]
public async Task Run_USBased_PersonalUse_DoesNotSetTaxExempt()
{
// Arrange
var parameters = new OrganizationTrialParameters
{
PlanType = PlanType.FamiliesAnnually,
ProductType = ProductType.PasswordManager,
TaxInformation = new TaxInformationDTO
{
Country = "US",
PostalCode = "12345"
}
};
var plan = StaticStore.GetPlan(parameters.PlanType);
_pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan);
var expectedInvoice = new Invoice { Tax = 1000 }; // $10.00 in cents
_stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(expectedInvoice);
// Act
var result = await _command.Run(parameters);
// Assert
await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.CustomerDetails.TaxExempt == null
));
Assert.True(result.IsT0);
}
[Fact]
public async Task Run_USBased_BusinessUse_DoesNotSetTaxExempt()
{
// Arrange
var parameters = new OrganizationTrialParameters
{
PlanType = PlanType.EnterpriseAnnually,
ProductType = ProductType.PasswordManager,
TaxInformation = new TaxInformationDTO
{
Country = "US",
PostalCode = "12345"
}
};
var plan = StaticStore.GetPlan(parameters.PlanType);
_pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan);
var expectedInvoice = new Invoice { Tax = 1000 }; // $10.00 in cents
_stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(expectedInvoice);
// Act
var result = await _command.Run(parameters);
// Assert
await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.CustomerDetails.TaxExempt == null
));
Assert.True(result.IsT0);
}
[Fact]
public async Task Run_NonUSBased_PersonalUse_DoesNotSetTaxExempt()
{
// Arrange
var parameters = new OrganizationTrialParameters
{
PlanType = PlanType.FamiliesAnnually,
ProductType = ProductType.PasswordManager,
TaxInformation = new TaxInformationDTO
{
Country = "CA",
PostalCode = "12345"
}
};
var plan = StaticStore.GetPlan(parameters.PlanType);
_pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan);
var expectedInvoice = new Invoice { Tax = 1000 }; // $10.00 in cents
_stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(expectedInvoice);
// Act
var result = await _command.Run(parameters);
// Assert
await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.CustomerDetails.TaxExempt == null
));
Assert.True(result.IsT0);
}
[Fact]
public async Task Run_NonUSBased_BusinessUse_SetsTaxExemptReverse()
{
// Arrange
var parameters = new OrganizationTrialParameters
{
PlanType = PlanType.EnterpriseAnnually,
ProductType = ProductType.PasswordManager,
TaxInformation = new TaxInformationDTO
{
Country = "CA",
PostalCode = "12345"
}
};
var plan = StaticStore.GetPlan(parameters.PlanType);
_pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan);
var expectedInvoice = new Invoice { Tax = 1000 }; // $10.00 in cents
_stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(expectedInvoice);
// Act
var result = await _command.Run(parameters);
// Assert
await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.CustomerDetails.TaxExempt == StripeConstants.TaxExempt.Reverse
));
Assert.True(result.IsT0);
}
}