1
0
mirror of https://github.com/bitwarden/server synced 2026-01-04 17:43:53 +00:00

[AC-1910] Allocate seats to a provider organization (#3936)

* Add endpoint to update a provider organization's seats for consolidated billing.

* Fixed failing tests
This commit is contained in:
Alex Morask
2024-03-29 11:18:10 -04:00
committed by GitHub
parent c53e5eeab3
commit e2cb406a95
28 changed files with 1108 additions and 68 deletions

View File

@@ -0,0 +1,12 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
namespace Bit.Core.Billing.Commands;
public interface IAssignSeatsToClientOrganizationCommand
{
Task AssignSeatsToClientOrganization(
Provider provider,
Organization organization,
int seats);
}

View File

@@ -0,0 +1,174 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.Billing.Entities;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Queries;
using Bit.Core.Billing.Repositories;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Utilities;
using Microsoft.Extensions.Logging;
using static Bit.Core.Billing.Utilities;
namespace Bit.Core.Billing.Commands.Implementations;
public class AssignSeatsToClientOrganizationCommand(
ILogger<AssignSeatsToClientOrganizationCommand> logger,
IOrganizationRepository organizationRepository,
IPaymentService paymentService,
IProviderBillingQueries providerBillingQueries,
IProviderPlanRepository providerPlanRepository) : IAssignSeatsToClientOrganizationCommand
{
public async Task AssignSeatsToClientOrganization(
Provider provider,
Organization organization,
int seats)
{
ArgumentNullException.ThrowIfNull(provider);
ArgumentNullException.ThrowIfNull(organization);
if (provider.Type == ProviderType.Reseller)
{
logger.LogError("Reseller-type provider ({ID}) cannot assign seats to client organizations", provider.Id);
throw ContactSupport("Consolidated billing does not support reseller-type providers");
}
if (seats < 0)
{
throw new BillingException(
"You cannot assign negative seats to a client.",
"MSP cannot assign negative seats to a client organization");
}
if (seats == organization.Seats)
{
logger.LogWarning("Client organization ({ID}) already has {Seats} seats assigned", organization.Id, organization.Seats);
return;
}
var providerPlan = await GetProviderPlanAsync(provider, organization);
var providerSeatMinimum = providerPlan.SeatMinimum.GetValueOrDefault(0);
// How many seats the provider has assigned to all their client organizations that have the specified plan type.
var providerCurrentlyAssignedSeatTotal = await providerBillingQueries.GetAssignedSeatTotalForPlanOrThrow(provider.Id, providerPlan.PlanType);
// How many seats are being added to or subtracted from this client organization.
var seatDifference = seats - (organization.Seats ?? 0);
// How many seats the provider will have assigned to all of their client organizations after the update.
var providerNewlyAssignedSeatTotal = providerCurrentlyAssignedSeatTotal + seatDifference;
var update = CurryUpdateFunction(
provider,
providerPlan,
organization,
seats,
providerNewlyAssignedSeatTotal);
/*
* Below the limit => Below the limit:
* No subscription update required. We can safely update the organization's seats.
*/
if (providerCurrentlyAssignedSeatTotal <= providerSeatMinimum &&
providerNewlyAssignedSeatTotal <= providerSeatMinimum)
{
organization.Seats = seats;
await organizationRepository.ReplaceAsync(organization);
providerPlan.AllocatedSeats = providerNewlyAssignedSeatTotal;
await providerPlanRepository.ReplaceAsync(providerPlan);
}
/*
* Below the limit => Above the limit:
* We have to scale the subscription up from the seat minimum to the newly assigned seat total.
*/
else if (providerCurrentlyAssignedSeatTotal <= providerSeatMinimum &&
providerNewlyAssignedSeatTotal > providerSeatMinimum)
{
await update(
providerSeatMinimum,
providerNewlyAssignedSeatTotal);
}
/*
* Above the limit => Above the limit:
* We have to scale the subscription from the currently assigned seat total to the newly assigned seat total.
*/
else if (providerCurrentlyAssignedSeatTotal > providerSeatMinimum &&
providerNewlyAssignedSeatTotal > providerSeatMinimum)
{
await update(
providerCurrentlyAssignedSeatTotal,
providerNewlyAssignedSeatTotal);
}
/*
* Above the limit => Below the limit:
* We have to scale the subscription down from the currently assigned seat total to the seat minimum.
*/
else if (providerCurrentlyAssignedSeatTotal > providerSeatMinimum &&
providerNewlyAssignedSeatTotal <= providerSeatMinimum)
{
await update(
providerCurrentlyAssignedSeatTotal,
providerSeatMinimum);
}
}
// ReSharper disable once SuggestBaseTypeForParameter
private async Task<ProviderPlan> GetProviderPlanAsync(Provider provider, Organization organization)
{
if (!organization.PlanType.SupportsConsolidatedBilling())
{
logger.LogError("Cannot assign seats to a client organization ({ID}) with a plan type that does not support consolidated billing: {PlanType}", organization.Id, organization.PlanType);
throw ContactSupport();
}
var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id);
var providerPlan = providerPlans.FirstOrDefault(providerPlan => providerPlan.PlanType == organization.PlanType);
if (providerPlan != null && providerPlan.IsConfigured())
{
return providerPlan;
}
logger.LogError("Cannot assign seats to client organization ({ClientOrganizationID}) when provider's ({ProviderID}) matching plan is not configured", organization.Id, provider.Id);
throw ContactSupport();
}
private Func<int, int, Task> CurryUpdateFunction(
Provider provider,
ProviderPlan providerPlan,
Organization organization,
int organizationNewlyAssignedSeats,
int providerNewlyAssignedSeats) => async (providerCurrentlySubscribedSeats, providerNewlySubscribedSeats) =>
{
var plan = StaticStore.GetPlan(providerPlan.PlanType);
await paymentService.AdjustSeats(
provider,
plan,
providerCurrentlySubscribedSeats,
providerNewlySubscribedSeats);
organization.Seats = organizationNewlyAssignedSeats;
await organizationRepository.ReplaceAsync(organization);
var providerNewlyPurchasedSeats = providerNewlySubscribedSeats > providerPlan.SeatMinimum
? providerNewlySubscribedSeats - providerPlan.SeatMinimum
: 0;
providerPlan.PurchasedSeats = providerNewlyPurchasedSeats;
providerPlan.AllocatedSeats = providerNewlyAssignedSeats;
await providerPlanRepository.ReplaceAsync(providerPlan);
};
}

View File

@@ -11,6 +11,7 @@ public class ProviderPlan : ITableObject<Guid>
public PlanType PlanType { get; set; }
public int? SeatMinimum { get; set; }
public int? PurchasedSeats { get; set; }
public int? AllocatedSeats { get; set; }
public void SetNewId()
{
@@ -20,5 +21,5 @@ public class ProviderPlan : ITableObject<Guid>
}
}
public bool Configured => SeatMinimum.HasValue && PurchasedSeats.HasValue;
public bool IsConfigured() => SeatMinimum.HasValue && PurchasedSeats.HasValue && AllocatedSeats.HasValue;
}

View File

@@ -0,0 +1,9 @@
using Bit.Core.Enums;
namespace Bit.Core.Billing.Extensions;
public static class BillingExtensions
{
public static bool SupportsConsolidatedBilling(this PlanType planType)
=> planType is PlanType.TeamsMonthly or PlanType.EnterpriseMonthly;
}

View File

@@ -9,15 +9,15 @@ using Microsoft.Extensions.DependencyInjection;
public static class ServiceCollectionExtensions
{
public static void AddBillingCommands(this IServiceCollection services)
public static void AddBillingOperations(this IServiceCollection services)
{
services.AddSingleton<ICancelSubscriptionCommand, CancelSubscriptionCommand>();
services.AddSingleton<IRemovePaymentMethodCommand, RemovePaymentMethodCommand>();
}
// Queries
services.AddTransient<IProviderBillingQueries, ProviderBillingQueries>();
services.AddTransient<ISubscriberQueries, SubscriberQueries>();
public static void AddBillingQueries(this IServiceCollection services)
{
services.AddSingleton<IProviderBillingQueries, ProviderBillingQueries>();
services.AddSingleton<ISubscriberQueries, SubscriberQueries>();
// Commands
services.AddTransient<IAssignSeatsToClientOrganizationCommand, AssignSeatsToClientOrganizationCommand>();
services.AddTransient<ICancelSubscriptionCommand, CancelSubscriptionCommand>();
services.AddTransient<IRemovePaymentMethodCommand, RemovePaymentMethodCommand>();
}
}

View File

@@ -8,15 +8,17 @@ public record ConfiguredProviderPlan(
Guid ProviderId,
PlanType PlanType,
int SeatMinimum,
int PurchasedSeats)
int PurchasedSeats,
int AssignedSeats)
{
public static ConfiguredProviderPlan From(ProviderPlan providerPlan) =>
providerPlan.Configured
providerPlan.IsConfigured()
? new ConfiguredProviderPlan(
providerPlan.Id,
providerPlan.ProviderId,
providerPlan.PlanType,
providerPlan.SeatMinimum.GetValueOrDefault(0),
providerPlan.PurchasedSeats.GetValueOrDefault(0))
providerPlan.PurchasedSeats.GetValueOrDefault(0),
providerPlan.AllocatedSeats.GetValueOrDefault(0))
: null;
}

View File

@@ -1,9 +1,22 @@
using Bit.Core.Billing.Models;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.Billing.Models;
using Bit.Core.Enums;
namespace Bit.Core.Billing.Queries;
public interface IProviderBillingQueries
{
/// <summary>
/// Retrieves the number of seats an MSP has assigned to its client organizations with a specified <paramref name="planType"/>.
/// </summary>
/// <param name="providerId">The ID of the MSP to retrieve the assigned seat total for.</param>
/// <param name="planType">The type of plan to retrieve the assigned seat total for.</param>
/// <returns>An <see cref="int"/> representing the number of seats the provider has assigned to its client organizations with the specified <paramref name="planType"/>.</returns>
/// <exception cref="BillingException">Thrown when the provider represented by the <paramref name="providerId"/> is <see langword="null"/>.</exception>
/// <exception cref="BillingException">Thrown when the provider represented by the <paramref name="providerId"/> has <see cref="Provider.Type"/> <see cref="ProviderType.Reseller"/>.</exception>
Task<int> GetAssignedSeatTotalForPlanOrThrow(Guid providerId, PlanType planType);
/// <summary>
/// Retrieves a provider's billing subscription data.
/// </summary>

View File

@@ -1,17 +1,53 @@
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Repositories;
using Bit.Core.Enums;
using Bit.Core.Utilities;
using Microsoft.Extensions.Logging;
using Stripe;
using static Bit.Core.Billing.Utilities;
namespace Bit.Core.Billing.Queries.Implementations;
public class ProviderBillingQueries(
ILogger<ProviderBillingQueries> logger,
IProviderOrganizationRepository providerOrganizationRepository,
IProviderPlanRepository providerPlanRepository,
IProviderRepository providerRepository,
ISubscriberQueries subscriberQueries) : IProviderBillingQueries
{
public async Task<int> GetAssignedSeatTotalForPlanOrThrow(
Guid providerId,
PlanType planType)
{
var provider = await providerRepository.GetByIdAsync(providerId);
if (provider == null)
{
logger.LogError(
"Could not find provider ({ID}) when retrieving assigned seat total",
providerId);
throw ContactSupport();
}
if (provider.Type == ProviderType.Reseller)
{
logger.LogError("Assigned seats cannot be retrieved for reseller-type provider ({ID})", providerId);
throw ContactSupport("Consolidated billing does not support reseller-type providers");
}
var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId);
var plan = StaticStore.GetPlan(planType);
return providerOrganizations
.Where(providerOrganization => providerOrganization.Plan == plan.Name)
.Sum(providerOrganization => providerOrganization.Seats ?? 0);
}
public async Task<ProviderSubscriptionData> GetSubscriptionData(Guid providerId)
{
var provider = await providerRepository.GetByIdAsync(providerId);
@@ -25,6 +61,13 @@ public class ProviderBillingQueries(
return null;
}
if (provider.Type == ProviderType.Reseller)
{
logger.LogError("Subscription data cannot be retrieved for reseller-type provider ({ID})", providerId);
throw ContactSupport("Consolidated billing does not support reseller-type providers");
}
var subscription = await subscriberQueries.GetSubscription(provider, new SubscriptionGetOptions
{
Expand = ["customer"]
@@ -38,7 +81,7 @@ public class ProviderBillingQueries(
var providerPlans = await providerPlanRepository.GetByProviderId(providerId);
var configuredProviderPlans = providerPlans
.Where(providerPlan => providerPlan.Configured)
.Where(providerPlan => providerPlan.IsConfigured())
.Select(ConfiguredProviderPlan.From)
.ToList();