1
0
mirror of https://github.com/bitwarden/server synced 2026-01-07 19:13:50 +00:00

Merge branch 'main' into auth/pm-22975/client-version-validator

This commit is contained in:
Patrick-Pimentel-Bitwarden
2025-12-17 14:34:11 -05:00
committed by GitHub
83 changed files with 13694 additions and 281 deletions

View File

@@ -796,6 +796,44 @@ public class ProviderBillingService(
}
}
public async Task UpdateProviderNameAndEmail(Provider provider)
{
if (string.IsNullOrWhiteSpace(provider.GatewayCustomerId))
{
logger.LogWarning(
"Provider ({ProviderId}) has no Stripe customer to update",
provider.Id);
return;
}
var newDisplayName = provider.DisplayName();
// Provider.DisplayName() can return null - handle gracefully
if (string.IsNullOrWhiteSpace(newDisplayName))
{
logger.LogWarning(
"Provider ({ProviderId}) has no name to update in Stripe",
provider.Id);
return;
}
await stripeAdapter.UpdateCustomerAsync(provider.GatewayCustomerId,
new CustomerUpdateOptions
{
Email = provider.BillingEmail,
Description = newDisplayName,
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
CustomFields = [
new CustomerInvoiceSettingsCustomFieldOptions
{
Name = provider.SubscriberType(),
Value = newDisplayName
}]
},
});
}
private Func<int, Task> CurrySeatScalingUpdate(
Provider provider,
ProviderPlan providerPlan,

View File

@@ -2150,4 +2150,151 @@ public class ProviderBillingServiceTests
}
#endregion
#region UpdateProviderNameAndEmail
[Theory, BitAutoData]
public async Task UpdateProviderNameAndEmail_NullGatewayCustomerId_LogsWarningAndReturns(
Provider provider,
SutProvider<ProviderBillingService> sutProvider)
{
// Arrange
provider.GatewayCustomerId = null;
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
// Act
await sutProvider.Sut.UpdateProviderNameAndEmail(provider);
// Assert
await stripeAdapter.DidNotReceive().UpdateCustomerAsync(
Arg.Any<string>(),
Arg.Any<CustomerUpdateOptions>());
}
[Theory, BitAutoData]
public async Task UpdateProviderNameAndEmail_EmptyGatewayCustomerId_LogsWarningAndReturns(
Provider provider,
SutProvider<ProviderBillingService> sutProvider)
{
// Arrange
provider.GatewayCustomerId = "";
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
// Act
await sutProvider.Sut.UpdateProviderNameAndEmail(provider);
// Assert
await stripeAdapter.DidNotReceive().UpdateCustomerAsync(
Arg.Any<string>(),
Arg.Any<CustomerUpdateOptions>());
}
[Theory, BitAutoData]
public async Task UpdateProviderNameAndEmail_NullProviderName_LogsWarningAndReturns(
Provider provider,
SutProvider<ProviderBillingService> sutProvider)
{
// Arrange
provider.Name = null;
provider.GatewayCustomerId = "cus_test123";
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
// Act
await sutProvider.Sut.UpdateProviderNameAndEmail(provider);
// Assert
await stripeAdapter.DidNotReceive().UpdateCustomerAsync(
Arg.Any<string>(),
Arg.Any<CustomerUpdateOptions>());
}
[Theory, BitAutoData]
public async Task UpdateProviderNameAndEmail_EmptyProviderName_LogsWarningAndReturns(
Provider provider,
SutProvider<ProviderBillingService> sutProvider)
{
// Arrange
provider.Name = "";
provider.GatewayCustomerId = "cus_test123";
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
// Act
await sutProvider.Sut.UpdateProviderNameAndEmail(provider);
// Assert
await stripeAdapter.DidNotReceive().UpdateCustomerAsync(
Arg.Any<string>(),
Arg.Any<CustomerUpdateOptions>());
}
[Theory, BitAutoData]
public async Task UpdateProviderNameAndEmail_ValidProvider_CallsStripeWithCorrectParameters(
Provider provider,
SutProvider<ProviderBillingService> sutProvider)
{
// Arrange
provider.Name = "Test Provider";
provider.BillingEmail = "billing@test.com";
provider.GatewayCustomerId = "cus_test123";
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
// Act
await sutProvider.Sut.UpdateProviderNameAndEmail(provider);
// Assert
await stripeAdapter.Received(1).UpdateCustomerAsync(
provider.GatewayCustomerId,
Arg.Is<CustomerUpdateOptions>(options =>
options.Email == provider.BillingEmail &&
options.Description == provider.Name &&
options.InvoiceSettings.CustomFields.Count == 1 &&
options.InvoiceSettings.CustomFields[0].Name == "Provider" &&
options.InvoiceSettings.CustomFields[0].Value == provider.Name));
}
[Theory, BitAutoData]
public async Task UpdateProviderNameAndEmail_LongProviderName_UsesFullName(
Provider provider,
SutProvider<ProviderBillingService> sutProvider)
{
// Arrange
var longName = new string('A', 50); // 50 characters
provider.Name = longName;
provider.BillingEmail = "billing@test.com";
provider.GatewayCustomerId = "cus_test123";
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
// Act
await sutProvider.Sut.UpdateProviderNameAndEmail(provider);
// Assert
await stripeAdapter.Received(1).UpdateCustomerAsync(
provider.GatewayCustomerId,
Arg.Is<CustomerUpdateOptions>(options =>
options.InvoiceSettings.CustomFields[0].Value == longName));
}
[Theory, BitAutoData]
public async Task UpdateProviderNameAndEmail_NullBillingEmail_UpdatesWithNull(
Provider provider,
SutProvider<ProviderBillingService> sutProvider)
{
// Arrange
provider.Name = "Test Provider";
provider.BillingEmail = null;
provider.GatewayCustomerId = "cus_test123";
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
// Act
await sutProvider.Sut.UpdateProviderNameAndEmail(provider);
// Assert
await stripeAdapter.Received(1).UpdateCustomerAsync(
provider.GatewayCustomerId,
Arg.Is<CustomerUpdateOptions>(options =>
options.Email == null &&
options.Description == provider.Name));
}
#endregion
}

View File

@@ -99,7 +99,7 @@ services:
- idp
rabbitmq:
image: rabbitmq:4.1.3-management
image: rabbitmq:4.2.0-management
ports:
- "5672:5672"
- "15672:15672"

View File

@@ -14,6 +14,7 @@ using Bit.Core.AdminConsole.Providers.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Organizations.Services;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Providers.Services;
using Bit.Core.Billing.Services;
@@ -57,6 +58,7 @@ public class OrganizationsController : Controller
private readonly IOrganizationInitiateDeleteCommand _organizationInitiateDeleteCommand;
private readonly IPricingClient _pricingClient;
private readonly IResendOrganizationInviteCommand _resendOrganizationInviteCommand;
private readonly IOrganizationBillingService _organizationBillingService;
public OrganizationsController(
IOrganizationRepository organizationRepository,
@@ -81,7 +83,8 @@ public class OrganizationsController : Controller
IProviderBillingService providerBillingService,
IOrganizationInitiateDeleteCommand organizationInitiateDeleteCommand,
IPricingClient pricingClient,
IResendOrganizationInviteCommand resendOrganizationInviteCommand)
IResendOrganizationInviteCommand resendOrganizationInviteCommand,
IOrganizationBillingService organizationBillingService)
{
_organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository;
@@ -106,6 +109,7 @@ public class OrganizationsController : Controller
_organizationInitiateDeleteCommand = organizationInitiateDeleteCommand;
_pricingClient = pricingClient;
_resendOrganizationInviteCommand = resendOrganizationInviteCommand;
_organizationBillingService = organizationBillingService;
}
[RequirePermission(Permission.Org_List_View)]
@@ -242,6 +246,8 @@ public class OrganizationsController : Controller
var existingOrganizationData = new Organization
{
Id = organization.Id,
Name = organization.Name,
BillingEmail = organization.BillingEmail,
Status = organization.Status,
PlanType = organization.PlanType,
Seats = organization.Seats
@@ -287,6 +293,22 @@ public class OrganizationsController : Controller
await _applicationCacheService.UpsertOrganizationAbilityAsync(organization);
// Sync name/email changes to Stripe
if (existingOrganizationData.Name != organization.Name || existingOrganizationData.BillingEmail != organization.BillingEmail)
{
try
{
await _organizationBillingService.UpdateOrganizationNameAndEmail(organization);
}
catch (Exception ex)
{
_logger.LogError(ex,
"Failed to update Stripe customer for organization {OrganizationId}. Database was updated successfully.",
organization.Id);
TempData["Warning"] = "Organization updated successfully, but Stripe customer name/email synchronization failed.";
}
}
return RedirectToAction("Edit", new { id });
}

View File

@@ -56,6 +56,7 @@ public class ProvidersController : Controller
private readonly IStripeAdapter _stripeAdapter;
private readonly IAccessControlService _accessControlService;
private readonly ISubscriberService _subscriberService;
private readonly ILogger<ProvidersController> _logger;
public ProvidersController(IOrganizationRepository organizationRepository,
IResellerClientOrganizationSignUpCommand resellerClientOrganizationSignUpCommand,
@@ -72,7 +73,8 @@ public class ProvidersController : Controller
IPricingClient pricingClient,
IStripeAdapter stripeAdapter,
IAccessControlService accessControlService,
ISubscriberService subscriberService)
ISubscriberService subscriberService,
ILogger<ProvidersController> logger)
{
_organizationRepository = organizationRepository;
_resellerClientOrganizationSignUpCommand = resellerClientOrganizationSignUpCommand;
@@ -92,6 +94,7 @@ public class ProvidersController : Controller
_braintreeMerchantUrl = webHostEnvironment.GetBraintreeMerchantUrl();
_braintreeMerchantId = globalSettings.Braintree.MerchantId;
_subscriberService = subscriberService;
_logger = logger;
}
[RequirePermission(Permission.Provider_List_View)]
@@ -296,6 +299,9 @@ public class ProvidersController : Controller
var originalProviderStatus = provider.Enabled;
// Capture original billing email before modifications for Stripe sync
var originalBillingEmail = provider.BillingEmail;
model.ToProvider(provider);
// validate the stripe ids to prevent saving a bad one
@@ -321,6 +327,22 @@ public class ProvidersController : Controller
await _providerService.UpdateAsync(provider);
await _applicationCacheService.UpsertProviderAbilityAsync(provider);
// Sync billing email changes to Stripe
if (!string.IsNullOrEmpty(provider.GatewayCustomerId) && originalBillingEmail != provider.BillingEmail)
{
try
{
await _providerBillingService.UpdateProviderNameAndEmail(provider);
}
catch (Exception ex)
{
_logger.LogError(ex,
"Failed to update Stripe customer for provider {ProviderId}. Database was updated successfully.",
provider.Id);
TempData["Warning"] = "Provider updated successfully, but Stripe customer email synchronization failed.";
}
}
if (!provider.IsBillable())
{
return RedirectToAction("Edit", new { id });

View File

@@ -1,7 +1,7 @@
###############################################
# Node.js build stage #
###############################################
FROM node:20-alpine3.21 AS node-build
FROM --platform=$BUILDPLATFORM node:20-alpine3.21 AS node-build
WORKDIR /app
COPY src/Admin/package*.json ./

View File

@@ -5,6 +5,7 @@ using Bit.Api.AdminConsole.Models.Request.Providers;
using Bit.Api.AdminConsole.Models.Response.Providers;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Billing.Providers.Services;
using Bit.Core.Context;
using Bit.Core.Exceptions;
using Bit.Core.Services;
@@ -23,15 +24,20 @@ public class ProvidersController : Controller
private readonly IProviderService _providerService;
private readonly ICurrentContext _currentContext;
private readonly GlobalSettings _globalSettings;
private readonly IProviderBillingService _providerBillingService;
private readonly ILogger<ProvidersController> _logger;
public ProvidersController(IUserService userService, IProviderRepository providerRepository,
IProviderService providerService, ICurrentContext currentContext, GlobalSettings globalSettings)
IProviderService providerService, ICurrentContext currentContext, GlobalSettings globalSettings,
IProviderBillingService providerBillingService, ILogger<ProvidersController> logger)
{
_userService = userService;
_providerRepository = providerRepository;
_providerService = providerService;
_currentContext = currentContext;
_globalSettings = globalSettings;
_providerBillingService = providerBillingService;
_logger = logger;
}
[HttpGet("{id:guid}")]
@@ -65,7 +71,27 @@ public class ProvidersController : Controller
throw new NotFoundException();
}
// Capture original values before modifications for Stripe sync
var originalName = provider.Name;
var originalBillingEmail = provider.BillingEmail;
await _providerService.UpdateAsync(model.ToProvider(provider, _globalSettings));
// Sync name/email changes to Stripe
if (originalName != provider.Name || originalBillingEmail != provider.BillingEmail)
{
try
{
await _providerBillingService.UpdateProviderNameAndEmail(provider);
}
catch (Exception ex)
{
_logger.LogError(ex,
"Failed to update Stripe customer for provider {ProviderId}. Database was updated successfully.",
provider.Id);
}
}
return new ProviderResponseModel(provider);
}

View File

@@ -2,6 +2,7 @@
#nullable disable
using System.ComponentModel.DataAnnotations;
using System.Text.Json.Serialization;
using Bit.Api.Models.Public.Response;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Models.Data;
@@ -13,6 +14,12 @@ namespace Bit.Api.AdminConsole.Public.Models.Response;
/// </summary>
public class GroupResponseModel : GroupBaseModel, IResponseModel
{
[JsonConstructor]
public GroupResponseModel()
{
}
public GroupResponseModel(Group group, IEnumerable<CollectionAccessSelection> collections)
{
if (group == null)

View File

@@ -2,6 +2,7 @@
#nullable disable
using System.ComponentModel.DataAnnotations;
using System.Text.Json.Serialization;
using Bit.Api.AdminConsole.Public.Models.Response;
using Bit.Core.Entities;
using Bit.Core.Models.Data;
@@ -13,6 +14,12 @@ namespace Bit.Api.Models.Public.Response;
/// </summary>
public class CollectionResponseModel : CollectionBaseModel, IResponseModel
{
[JsonConstructor]
public CollectionResponseModel()
{
}
public CollectionResponseModel(Collection collection, IEnumerable<CollectionAccessSelection> groups)
{
if (collection == null)

View File

@@ -65,10 +65,11 @@ public class CollectionsController : Controller
[ProducesResponseType(typeof(ListResponseModel<CollectionResponseModel>), (int)HttpStatusCode.OK)]
public async Task<IActionResult> List()
{
var collections = await _collectionRepository.GetManySharedCollectionsByOrganizationIdAsync(
_currentContext.OrganizationId.Value);
// TODO: Get all CollectionGroup associations for the organization and marry them up here for the response.
var collectionResponses = collections.Select(c => new CollectionResponseModel(c, null));
var collections = await _collectionRepository.GetManyByOrganizationIdWithAccessAsync(_currentContext.OrganizationId.Value);
var collectionResponses = collections.Select(c =>
new CollectionResponseModel(c.Item1, c.Item2.Groups));
var response = new ListResponseModel<CollectionResponseModel>(collectionResponses);
return new JsonResult(response);
}

View File

@@ -39,15 +39,11 @@ public class ReconcileAdditionalStorageJob(
logger.LogInformation("Starting ReconcileAdditionalStorageJob (live mode: {LiveMode})", liveMode);
var priceIds = new[] { _storageGbMonthlyPriceId, _storageGbAnnuallyPriceId, _personalStorageGbAnnuallyPriceId };
var stripeStatusesToProcess = new[] { StripeConstants.SubscriptionStatus.Active, StripeConstants.SubscriptionStatus.Trialing, StripeConstants.SubscriptionStatus.PastDue };
foreach (var priceId in priceIds)
{
var options = new SubscriptionListOptions
{
Limit = 100,
Status = StripeConstants.SubscriptionStatus.Active,
Price = priceId
};
var options = new SubscriptionListOptions { Limit = 100, Price = priceId };
await foreach (var subscription in stripeFacade.ListSubscriptionsAutoPagingAsync(options))
{
@@ -64,7 +60,7 @@ public class ReconcileAdditionalStorageJob(
failures.Count > 0
? $", Failures: {Environment.NewLine}{string.Join(Environment.NewLine, failures)}"
: string.Empty
);
);
return;
}
@@ -73,6 +69,12 @@ public class ReconcileAdditionalStorageJob(
continue;
}
if (!stripeStatusesToProcess.Contains(subscription.Status))
{
logger.LogInformation("Skipping subscription with unsupported status: {SubscriptionId} - {Status}", subscription.Id, subscription.Status);
continue;
}
logger.LogInformation("Processing subscription: {SubscriptionId}", subscription.Id);
subscriptionsFound++;
@@ -133,7 +135,7 @@ public class ReconcileAdditionalStorageJob(
failures.Count > 0
? $", Failures: {Environment.NewLine}{string.Join(Environment.NewLine, failures)}"
: string.Empty
);
);
}
private SubscriptionUpdateOptions? BuildSubscriptionUpdateOptions(
@@ -145,15 +147,7 @@ public class ReconcileAdditionalStorageJob(
return null;
}
var updateOptions = new SubscriptionUpdateOptions
{
ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations,
Metadata = new Dictionary<string, string>
{
[StripeConstants.MetadataKeys.StorageReconciled2025] = DateTime.UtcNow.ToString("o")
},
Items = []
};
var updateOptions = new SubscriptionUpdateOptions { ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations, Metadata = new Dictionary<string, string> { [StripeConstants.MetadataKeys.StorageReconciled2025] = DateTime.UtcNow.ToString("o") }, Items = [] };
var hasUpdates = false;
@@ -172,11 +166,7 @@ public class ReconcileAdditionalStorageJob(
newQuantity,
item.Price.Id);
updateOptions.Items.Add(new SubscriptionItemOptions
{
Id = item.Id,
Quantity = newQuantity
});
updateOptions.Items.Add(new SubscriptionItemOptions { Id = item.Id, Quantity = newQuantity });
}
else
{
@@ -185,11 +175,7 @@ public class ReconcileAdditionalStorageJob(
currentQuantity,
item.Price.Id);
updateOptions.Items.Add(new SubscriptionItemOptions
{
Id = item.Id,
Deleted = true
});
updateOptions.Items.Add(new SubscriptionItemOptions { Id = item.Id, Deleted = true });
}
}

View File

@@ -36,7 +36,7 @@ public interface IStripeEventUtilityService
/// <param name="userId"></param>
/// /// <param name="providerId"></param>
/// <returns></returns>
Transaction FromChargeToTransaction(Charge charge, Guid? organizationId, Guid? userId, Guid? providerId);
Task<Transaction> FromChargeToTransactionAsync(Charge charge, Guid? organizationId, Guid? userId, Guid? providerId);
/// <summary>
/// Attempts to pay the specified invoice. If a customer is eligible, the invoice is paid using Braintree or Stripe.

View File

@@ -20,6 +20,12 @@ public interface IStripeFacade
RequestOptions requestOptions = null,
CancellationToken cancellationToken = default);
IAsyncEnumerable<CustomerCashBalanceTransaction> GetCustomerCashBalanceTransactions(
string customerId,
CustomerCashBalanceTransactionListOptions customerCashBalanceTransactionListOptions = null,
RequestOptions requestOptions = null,
CancellationToken cancellationToken = default);
Task<Customer> UpdateCustomer(
string customerId,
CustomerUpdateOptions customerUpdateOptions = null,

View File

@@ -38,7 +38,7 @@ public class ChargeRefundedHandler : IChargeRefundedHandler
{
// Attempt to create a transaction for the charge if it doesn't exist
var (organizationId, userId, providerId) = await _stripeEventUtilityService.GetEntityIdsFromChargeAsync(charge);
var tx = _stripeEventUtilityService.FromChargeToTransaction(charge, organizationId, userId, providerId);
var tx = await _stripeEventUtilityService.FromChargeToTransactionAsync(charge, organizationId, userId, providerId);
try
{
parentTransaction = await _transactionRepository.CreateAsync(tx);

View File

@@ -46,7 +46,7 @@ public class ChargeSucceededHandler : IChargeSucceededHandler
return;
}
var transaction = _stripeEventUtilityService.FromChargeToTransaction(charge, organizationId, userId, providerId);
var transaction = await _stripeEventUtilityService.FromChargeToTransactionAsync(charge, organizationId, userId, providerId);
if (!transaction.PaymentMethodType.HasValue)
{
_logger.LogWarning("Charge success from unsupported source/method. {ChargeId}", charge.Id);

View File

@@ -124,7 +124,7 @@ public class StripeEventUtilityService : IStripeEventUtilityService
/// <param name="userId"></param>
/// /// <param name="providerId"></param>
/// <returns></returns>
public Transaction FromChargeToTransaction(Charge charge, Guid? organizationId, Guid? userId, Guid? providerId)
public async Task<Transaction> FromChargeToTransactionAsync(Charge charge, Guid? organizationId, Guid? userId, Guid? providerId)
{
var transaction = new Transaction
{
@@ -209,6 +209,24 @@ public class StripeEventUtilityService : IStripeEventUtilityService
transaction.PaymentMethodType = PaymentMethodType.BankAccount;
transaction.Details = $"ACH => {achCreditTransfer.BankName}, {achCreditTransfer.AccountNumber}";
}
else if (charge.PaymentMethodDetails.CustomerBalance != null)
{
var bankTransferType = await GetFundingBankTransferTypeAsync(charge);
if (!string.IsNullOrEmpty(bankTransferType))
{
transaction.PaymentMethodType = PaymentMethodType.BankAccount;
transaction.Details = bankTransferType switch
{
"eu_bank_transfer" => "EU Bank Transfer",
"gb_bank_transfer" => "GB Bank Transfer",
"jp_bank_transfer" => "JP Bank Transfer",
"mx_bank_transfer" => "MX Bank Transfer",
"us_bank_transfer" => "US Bank Transfer",
_ => "Bank Transfer"
};
}
}
break;
}
@@ -289,20 +307,13 @@ public class StripeEventUtilityService : IStripeEventUtilityService
}
var btInvoiceAmount = Math.Round(invoice.AmountDue / 100M, 2);
var existingTransactions = organizationId.HasValue
? await _transactionRepository.GetManyByOrganizationIdAsync(organizationId.Value)
: userId.HasValue
? await _transactionRepository.GetManyByUserIdAsync(userId.Value)
: await _transactionRepository.GetManyByProviderIdAsync(providerId.Value);
var duplicateTimeSpan = TimeSpan.FromHours(24);
var now = DateTime.UtcNow;
var duplicateTransaction = existingTransactions?
.FirstOrDefault(t => (now - t.CreationDate) < duplicateTimeSpan);
if (duplicateTransaction != null)
// Check if this invoice already has a Braintree transaction ID to prevent duplicate charges
if (invoice.Metadata?.ContainsKey("btTransactionId") ?? false)
{
_logger.LogWarning("There is already a recent PayPal transaction ({0}). " +
"Do not charge again to prevent possible duplicate.", duplicateTransaction.GatewayId);
_logger.LogWarning("Invoice {InvoiceId} already has a Braintree transaction ({TransactionId}). " +
"Do not charge again to prevent duplicate.",
invoice.Id,
invoice.Metadata["btTransactionId"]);
return false;
}
@@ -413,4 +424,55 @@ public class StripeEventUtilityService : IStripeEventUtilityService
throw;
}
}
/// <summary>
/// Retrieves the bank transfer type that funded a charge paid via customer balance.
/// </summary>
/// <param name="charge">The charge to analyze.</param>
/// <returns>
/// The bank transfer type (e.g., "us_bank_transfer", "eu_bank_transfer") if the charge was funded
/// by a bank transfer via customer balance, otherwise null.
/// </returns>
private async Task<string> GetFundingBankTransferTypeAsync(Charge charge)
{
if (charge is not
{
CustomerId: not null,
PaymentIntentId: not null,
PaymentMethodDetails: { Type: "customer_balance" }
})
{
return null;
}
var cashBalanceTransactions = _stripeFacade.GetCustomerCashBalanceTransactions(charge.CustomerId);
string bankTransferType = null;
var matchingPaymentIntentFound = false;
await foreach (var cashBalanceTransaction in cashBalanceTransactions)
{
switch (cashBalanceTransaction)
{
case { Type: "funded", Funded: not null }:
{
bankTransferType = cashBalanceTransaction.Funded.BankTransfer.Type;
break;
}
case { Type: "applied_to_payment", AppliedToPayment: not null }
when cashBalanceTransaction.AppliedToPayment.PaymentIntentId == charge.PaymentIntentId:
{
matchingPaymentIntentFound = true;
break;
}
}
if (matchingPaymentIntentFound && !string.IsNullOrEmpty(bankTransferType))
{
return bankTransferType;
}
}
return null;
}
}

View File

@@ -11,6 +11,7 @@ public class StripeFacade : IStripeFacade
{
private readonly ChargeService _chargeService = new();
private readonly CustomerService _customerService = new();
private readonly CustomerCashBalanceTransactionService _customerCashBalanceTransactionService = new();
private readonly EventService _eventService = new();
private readonly InvoiceService _invoiceService = new();
private readonly PaymentMethodService _paymentMethodService = new();
@@ -41,6 +42,13 @@ public class StripeFacade : IStripeFacade
CancellationToken cancellationToken = default) =>
await _customerService.GetAsync(customerId, customerGetOptions, requestOptions, cancellationToken);
public IAsyncEnumerable<CustomerCashBalanceTransaction> GetCustomerCashBalanceTransactions(
string customerId,
CustomerCashBalanceTransactionListOptions customerCashBalanceTransactionListOptions = null,
RequestOptions requestOptions = null,
CancellationToken cancellationToken = default)
=> _customerCashBalanceTransactionService.ListAutoPagingAsync(customerId, customerCashBalanceTransactionListOptions, requestOptions, cancellationToken);
public async Task<Customer> UpdateCustomer(
string customerId,
CustomerUpdateOptions customerUpdateOptions = null,

View File

@@ -528,17 +528,21 @@ public static class EventIntegrationsServiceCollectionExtensions
/// <returns>True if all required RabbitMQ settings are present; otherwise, false.</returns>
/// <remarks>
/// Requires all the following settings to be configured:
/// - EventLogging.RabbitMq.HostName
/// - EventLogging.RabbitMq.Username
/// - EventLogging.RabbitMq.Password
/// - EventLogging.RabbitMq.EventExchangeName
/// <list type="bullet">
/// <item><description>EventLogging.RabbitMq.HostName</description></item>
/// <item><description>EventLogging.RabbitMq.Username</description></item>
/// <item><description>EventLogging.RabbitMq.Password</description></item>
/// <item><description>EventLogging.RabbitMq.EventExchangeName</description></item>
/// <item><description>EventLogging.RabbitMq.IntegrationExchangeName</description></item>
/// </list>
/// </remarks>
internal static bool IsRabbitMqEnabled(GlobalSettings settings)
{
return CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.HostName) &&
CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.Username) &&
CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.Password) &&
CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.EventExchangeName);
CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.EventExchangeName) &&
CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.IntegrationExchangeName);
}
/// <summary>
@@ -547,13 +551,17 @@ public static class EventIntegrationsServiceCollectionExtensions
/// <param name="settings">The global settings containing Azure Service Bus configuration.</param>
/// <returns>True if all required Azure Service Bus settings are present; otherwise, false.</returns>
/// <remarks>
/// Requires both of the following settings to be configured:
/// - EventLogging.AzureServiceBus.ConnectionString
/// - EventLogging.AzureServiceBus.EventTopicName
/// Requires all of the following settings to be configured:
/// <list type="bullet">
/// <item><description>EventLogging.AzureServiceBus.ConnectionString</description></item>
/// <item><description>EventLogging.AzureServiceBus.EventTopicName</description></item>
/// <item><description>EventLogging.AzureServiceBus.IntegrationTopicName</description></item>
/// </list>
/// </remarks>
internal static bool IsAzureServiceBusEnabled(GlobalSettings settings)
{
return CoreHelpers.SettingHasValue(settings.EventLogging.AzureServiceBus.ConnectionString) &&
CoreHelpers.SettingHasValue(settings.EventLogging.AzureServiceBus.EventTopicName);
CoreHelpers.SettingHasValue(settings.EventLogging.AzureServiceBus.EventTopicName) &&
CoreHelpers.SettingHasValue(settings.EventLogging.AzureServiceBus.IntegrationTopicName);
}
}

View File

@@ -0,0 +1,37 @@
namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations;
/// <summary>
/// Categories of event integration failures used for classification and retry logic.
/// </summary>
public enum IntegrationFailureCategory
{
/// <summary>
/// Service is temporarily unavailable (503, upstream outage, maintenance).
/// </summary>
ServiceUnavailable,
/// <summary>
/// Authentication failed (401, 403, invalid_auth, token issues).
/// </summary>
AuthenticationFailed,
/// <summary>
/// Configuration error (invalid config, channel_not_found, etc.).
/// </summary>
ConfigurationError,
/// <summary>
/// Rate limited (429, rate_limited).
/// </summary>
RateLimited,
/// <summary>
/// Transient error (timeouts, 500, network errors).
/// </summary>
TransientError,
/// <summary>
/// Permanent failure unrelated to authentication/config (e.g., unrecoverable payload/format issue).
/// </summary>
PermanentFailure
}

View File

@@ -1,16 +1,84 @@
namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations;
/// <summary>
/// Represents the result of an integration handler operation, including success status,
/// failure categorization, and retry metadata. Use the <see cref="Succeed"/> factory method
/// for successful operations or <see cref="Fail"/> for failures with automatic retry-ability
/// determination based on the failure category.
/// </summary>
public class IntegrationHandlerResult
{
public IntegrationHandlerResult(bool success, IIntegrationMessage message)
/// <summary>
/// True if the integration send succeeded, false otherwise.
/// </summary>
public bool Success { get; }
/// <summary>
/// The integration message that was processed.
/// </summary>
public IIntegrationMessage Message { get; }
/// <summary>
/// Optional UTC date/time indicating when a failed operation should be retried.
/// Will be used by the retry queue to delay re-sending the message.
/// Usually set based on the Retry-After header from rate-limited responses.
/// </summary>
public DateTime? DelayUntilDate { get; private init; }
/// <summary>
/// Category of the failure. Null for successful results.
/// </summary>
public IntegrationFailureCategory? Category { get; private init; }
/// <summary>
/// Detailed failure reason or error message. Empty for successful results.
/// </summary>
public string? FailureReason { get; private init; }
/// <summary>
/// Indicates whether the operation is retryable.
/// Computed from the failure category.
/// </summary>
public bool Retryable => Category switch
{
IntegrationFailureCategory.RateLimited => true,
IntegrationFailureCategory.TransientError => true,
IntegrationFailureCategory.ServiceUnavailable => true,
IntegrationFailureCategory.AuthenticationFailed => false,
IntegrationFailureCategory.ConfigurationError => false,
IntegrationFailureCategory.PermanentFailure => false,
null => false,
_ => false
};
/// <summary>
/// Creates a successful result.
/// </summary>
public static IntegrationHandlerResult Succeed(IIntegrationMessage message)
{
return new IntegrationHandlerResult(success: true, message: message);
}
/// <summary>
/// Creates a failed result with a failure category and reason.
/// </summary>
public static IntegrationHandlerResult Fail(
IIntegrationMessage message,
IntegrationFailureCategory category,
string failureReason,
DateTime? delayUntil = null)
{
return new IntegrationHandlerResult(success: false, message: message)
{
Category = category,
FailureReason = failureReason,
DelayUntilDate = delayUntil
};
}
private IntegrationHandlerResult(bool success, IIntegrationMessage message)
{
Success = success;
Message = message;
}
public bool Success { get; set; } = false;
public bool Retryable { get; set; } = false;
public IIntegrationMessage Message { get; set; }
public DateTime? DelayUntilDate { get; set; }
public string FailureReason { get; set; } = string.Empty;
}

View File

@@ -20,6 +20,12 @@ public class OrganizationUserUserDetails : IExternal, ITwoFactorProvidersUser, I
public string Email { get; set; }
public string AvatarColor { get; set; }
public string TwoFactorProviders { get; set; }
/// <summary>
/// Indicates whether the user has a personal premium subscription.
/// Does not include premium access from organizations -
/// do not use this to check whether the user can access premium features.
/// Null when the organization user is in Invited status (UserId is null).
/// </summary>
public bool? Premium { get; set; }
public OrganizationUserStatusType Status { get; set; }
public OrganizationUserType Type { get; set; }

View File

@@ -270,7 +270,9 @@ public class AcceptOrgUserCommand : IAcceptOrgUserCommand
ICollection<OrganizationUser> allOrgUsers, User user)
{
var error = (await _automaticUserConfirmationPolicyEnforcementValidator.IsCompliantAsync(
new AutomaticUserConfirmationPolicyEnforcementRequest(orgUser.OrganizationId, allOrgUsers, user)))
new AutomaticUserConfirmationPolicyEnforcementRequest(orgUser.OrganizationId,
allOrgUsers.Append(orgUser),
user)))
.Match(
error => error.Message,
_ => string.Empty

View File

@@ -67,7 +67,7 @@ public class OrganizationUpdateCommand(
var shouldUpdateBilling = originalName != organization.Name ||
originalBillingEmail != organization.BillingEmail;
if (!shouldUpdateBilling || string.IsNullOrWhiteSpace(organization.GatewayCustomerId))
if (!shouldUpdateBilling)
{
return;
}

View File

@@ -19,7 +19,8 @@ public class AutomaticUserConfirmationPolicyEnforcementValidator(
var currentOrganizationUser = request.AllOrganizationUsers
.FirstOrDefault(x => x.OrganizationId == request.OrganizationId
&& x.UserId == request.User.Id);
// invited users do not have a userId but will have email
&& (x.UserId == request.User.Id || x.Email == request.User.Email));
if (currentOrganizationUser is null)
{

View File

@@ -29,46 +29,87 @@ public abstract class IntegrationHandlerBase<T> : IIntegrationHandler<T>
IntegrationMessage<T> message,
TimeProvider timeProvider)
{
var result = new IntegrationHandlerResult(success: response.IsSuccessStatusCode, message);
if (response.IsSuccessStatusCode) return result;
switch (response.StatusCode)
if (response.IsSuccessStatusCode)
{
case HttpStatusCode.TooManyRequests:
case HttpStatusCode.RequestTimeout:
case HttpStatusCode.InternalServerError:
case HttpStatusCode.BadGateway:
case HttpStatusCode.ServiceUnavailable:
case HttpStatusCode.GatewayTimeout:
result.Retryable = true;
result.FailureReason = response.ReasonPhrase ?? $"Failure with status code: {(int)response.StatusCode}";
if (response.Headers.TryGetValues("Retry-After", out var values))
{
var value = values.FirstOrDefault();
if (int.TryParse(value, out var seconds))
{
// Retry-after was specified in seconds. Adjust DelayUntilDate by the requested number of seconds.
result.DelayUntilDate = timeProvider.GetUtcNow().AddSeconds(seconds).UtcDateTime;
}
else if (DateTimeOffset.TryParseExact(value,
"r", // "r" is the round-trip format: RFC1123
CultureInfo.InvariantCulture,
DateTimeStyles.AssumeUniversal | DateTimeStyles.AdjustToUniversal,
out var retryDate))
{
// Retry-after was specified as a date. Adjust DelayUntilDate to the specified date.
result.DelayUntilDate = retryDate.UtcDateTime;
}
}
break;
default:
result.Retryable = false;
result.FailureReason = response.ReasonPhrase ?? $"Failure with status code {(int)response.StatusCode}";
break;
return IntegrationHandlerResult.Succeed(message);
}
return result;
var category = ClassifyHttpStatusCode(response.StatusCode);
var failureReason = response.ReasonPhrase ?? $"Failure with status code {(int)response.StatusCode}";
if (category is not (IntegrationFailureCategory.RateLimited
or IntegrationFailureCategory.TransientError
or IntegrationFailureCategory.ServiceUnavailable) ||
!response.Headers.TryGetValues("Retry-After", out var values)
)
{
return IntegrationHandlerResult.Fail(message: message, category: category, failureReason: failureReason);
}
// Handle Retry-After header for rate-limited and retryable errors
DateTime? delayUntil = null;
var value = values.FirstOrDefault();
if (int.TryParse(value, out var seconds))
{
// Retry-after was specified in seconds
delayUntil = timeProvider.GetUtcNow().AddSeconds(seconds).UtcDateTime;
}
else if (DateTimeOffset.TryParseExact(value,
"r", // "r" is the round-trip format: RFC1123
CultureInfo.InvariantCulture,
DateTimeStyles.AssumeUniversal | DateTimeStyles.AdjustToUniversal,
out var retryDate))
{
// Retry-after was specified as a date
delayUntil = retryDate.UtcDateTime;
}
return IntegrationHandlerResult.Fail(
message,
category,
failureReason,
delayUntil
);
}
/// <summary>
/// Classifies an <see cref="HttpStatusCode"/> as an <see cref="IntegrationFailureCategory"/> to drive
/// retry behavior and operator-facing failure reporting.
/// </summary>
/// <param name="statusCode">The HTTP status code.</param>
/// <returns>The corresponding <see cref="IntegrationFailureCategory"/>.</returns>
protected static IntegrationFailureCategory ClassifyHttpStatusCode(HttpStatusCode statusCode)
{
var explicitCategory = statusCode switch
{
HttpStatusCode.Unauthorized => IntegrationFailureCategory.AuthenticationFailed,
HttpStatusCode.Forbidden => IntegrationFailureCategory.AuthenticationFailed,
HttpStatusCode.NotFound => IntegrationFailureCategory.ConfigurationError,
HttpStatusCode.Gone => IntegrationFailureCategory.ConfigurationError,
HttpStatusCode.MovedPermanently => IntegrationFailureCategory.ConfigurationError,
HttpStatusCode.TemporaryRedirect => IntegrationFailureCategory.ConfigurationError,
HttpStatusCode.PermanentRedirect => IntegrationFailureCategory.ConfigurationError,
HttpStatusCode.TooManyRequests => IntegrationFailureCategory.RateLimited,
HttpStatusCode.RequestTimeout => IntegrationFailureCategory.TransientError,
HttpStatusCode.InternalServerError => IntegrationFailureCategory.TransientError,
HttpStatusCode.BadGateway => IntegrationFailureCategory.TransientError,
HttpStatusCode.GatewayTimeout => IntegrationFailureCategory.TransientError,
HttpStatusCode.ServiceUnavailable => IntegrationFailureCategory.ServiceUnavailable,
HttpStatusCode.NotImplemented => IntegrationFailureCategory.PermanentFailure,
_ => (IntegrationFailureCategory?)null
};
if (explicitCategory is not null)
{
return explicitCategory.Value;
}
return (int)statusCode switch
{
>= 300 and <= 399 => IntegrationFailureCategory.ConfigurationError,
>= 400 and <= 499 => IntegrationFailureCategory.ConfigurationError,
>= 500 and <= 599 => IntegrationFailureCategory.ServiceUnavailable,
_ => IntegrationFailureCategory.ServiceUnavailable
};
}
}

View File

@@ -85,6 +85,17 @@ public class AzureServiceBusIntegrationListenerService<TConfiguration> : Backgro
{
// Non-recoverable failure or exceeded the max number of retries
// Return false to indicate this message should be dead-lettered
_logger.LogWarning(
"Integration failure - non-recoverable error or max retries exceeded. " +
"MessageId: {MessageId}, IntegrationType: {IntegrationType}, OrganizationId: {OrgId}, " +
"FailureCategory: {Category}, Reason: {Reason}, RetryCount: {RetryCount}, MaxRetries: {MaxRetries}",
message.MessageId,
message.IntegrationType,
message.OrganizationId,
result.Category,
result.FailureReason,
message.RetryCount,
_maxRetries);
return false;
}
}

View File

@@ -106,14 +106,32 @@ public class RabbitMqIntegrationListenerService<TConfiguration> : BackgroundServ
{
// Exceeded the max number of retries; fail and send to dead letter queue
await _rabbitMqService.PublishToDeadLetterAsync(channel, message, cancellationToken);
_logger.LogWarning("Max retry attempts reached. Sent to DLQ.");
_logger.LogWarning(
"Integration failure - max retries exceeded. " +
"MessageId: {MessageId}, IntegrationType: {IntegrationType}, OrganizationId: {OrgId}, " +
"FailureCategory: {Category}, Reason: {Reason}, RetryCount: {RetryCount}, MaxRetries: {MaxRetries}",
message.MessageId,
message.IntegrationType,
message.OrganizationId,
result.Category,
result.FailureReason,
message.RetryCount,
_maxRetries);
}
}
else
{
// Fatal error (i.e. not retryable) occurred. Send message to dead letter queue without any retries
await _rabbitMqService.PublishToDeadLetterAsync(channel, message, cancellationToken);
_logger.LogWarning("Non-retryable failure. Sent to DLQ.");
_logger.LogWarning(
"Integration failure - non-retryable. " +
"MessageId: {MessageId}, IntegrationType: {IntegrationType}, OrganizationId: {OrgId}, " +
"FailureCategory: {Category}, Reason: {Reason}",
message.MessageId,
message.IntegrationType,
message.OrganizationId,
result.Category,
result.FailureReason);
}
// Message has been sent to retry or dead letter queues.

View File

@@ -6,15 +6,6 @@ public class SlackIntegrationHandler(
ISlackService slackService)
: IntegrationHandlerBase<SlackIntegrationConfigurationDetails>
{
private static readonly HashSet<string> _retryableErrors = new(StringComparer.Ordinal)
{
"internal_error",
"message_limit_exceeded",
"rate_limited",
"ratelimited",
"service_unavailable"
};
public override async Task<IntegrationHandlerResult> HandleAsync(IntegrationMessage<SlackIntegrationConfigurationDetails> message)
{
var slackResponse = await slackService.SendSlackMessageByChannelIdAsync(
@@ -25,24 +16,61 @@ public class SlackIntegrationHandler(
if (slackResponse is null)
{
return new IntegrationHandlerResult(success: false, message: message)
{
FailureReason = "Slack response was null"
};
return IntegrationHandlerResult.Fail(
message,
IntegrationFailureCategory.TransientError,
"Slack response was null"
);
}
if (slackResponse.Ok)
{
return new IntegrationHandlerResult(success: true, message: message);
return IntegrationHandlerResult.Succeed(message);
}
var result = new IntegrationHandlerResult(success: false, message: message) { FailureReason = slackResponse.Error };
var category = ClassifySlackError(slackResponse.Error);
return IntegrationHandlerResult.Fail(
message,
category,
slackResponse.Error
);
}
if (_retryableErrors.Contains(slackResponse.Error))
/// <summary>
/// Classifies a Slack API error code string as an <see cref="IntegrationFailureCategory"/> to drive
/// retry behavior and operator-facing failure reporting.
/// </summary>
/// <remarks>
/// <para>
/// Slack responses commonly return an <c>error</c> string when <c>ok</c> is false. This method maps
/// known Slack error codes to failure categories.
/// </para>
/// <para>
/// Any unrecognized error codes default to <see cref="IntegrationFailureCategory.TransientError"/> to avoid
/// incorrectly marking new/unknown Slack failures as non-retryable.
/// </para>
/// </remarks>
/// <param name="error">The Slack error code string (e.g. <c>invalid_auth</c>, <c>rate_limited</c>).</param>
/// <returns>The corresponding <see cref="IntegrationFailureCategory"/>.</returns>
private static IntegrationFailureCategory ClassifySlackError(string error)
{
return error switch
{
result.Retryable = true;
}
return result;
"invalid_auth" => IntegrationFailureCategory.AuthenticationFailed,
"access_denied" => IntegrationFailureCategory.AuthenticationFailed,
"token_expired" => IntegrationFailureCategory.AuthenticationFailed,
"token_revoked" => IntegrationFailureCategory.AuthenticationFailed,
"account_inactive" => IntegrationFailureCategory.AuthenticationFailed,
"not_authed" => IntegrationFailureCategory.AuthenticationFailed,
"channel_not_found" => IntegrationFailureCategory.ConfigurationError,
"is_archived" => IntegrationFailureCategory.ConfigurationError,
"rate_limited" => IntegrationFailureCategory.RateLimited,
"ratelimited" => IntegrationFailureCategory.RateLimited,
"message_limit_exceeded" => IntegrationFailureCategory.RateLimited,
"internal_error" => IntegrationFailureCategory.TransientError,
"service_unavailable" => IntegrationFailureCategory.ServiceUnavailable,
"fatal_error" => IntegrationFailureCategory.ServiceUnavailable,
_ => IntegrationFailureCategory.TransientError
};
}
}

View File

@@ -1,4 +1,5 @@
using Bit.Core.AdminConsole.Models.Data.EventIntegrations;
using System.Text.Json;
using Bit.Core.AdminConsole.Models.Data.EventIntegrations;
using Microsoft.Rest;
namespace Bit.Core.Services;
@@ -18,24 +19,48 @@ public class TeamsIntegrationHandler(
channelId: message.Configuration.ChannelId
);
return new IntegrationHandlerResult(success: true, message: message);
return IntegrationHandlerResult.Succeed(message);
}
catch (HttpOperationException ex)
{
var result = new IntegrationHandlerResult(success: false, message: message);
var statusCode = (int)ex.Response.StatusCode;
result.Retryable = statusCode is 429 or >= 500 and < 600;
result.FailureReason = ex.Message;
return result;
var category = ClassifyHttpStatusCode(ex.Response.StatusCode);
return IntegrationHandlerResult.Fail(
message,
category,
ex.Message
);
}
catch (ArgumentException ex)
{
return IntegrationHandlerResult.Fail(
message,
IntegrationFailureCategory.ConfigurationError,
ex.Message
);
}
catch (UriFormatException ex)
{
return IntegrationHandlerResult.Fail(
message,
IntegrationFailureCategory.ConfigurationError,
ex.Message
);
}
catch (JsonException ex)
{
return IntegrationHandlerResult.Fail(
message,
IntegrationFailureCategory.PermanentFailure,
ex.Message
);
}
catch (Exception ex)
{
var result = new IntegrationHandlerResult(success: false, message: message);
result.Retryable = false;
result.FailureReason = ex.Message;
return result;
return IntegrationHandlerResult.Fail(
message,
IntegrationFailureCategory.TransientError,
ex.Message
);
}
}
}

View File

@@ -4,16 +4,37 @@
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models;
using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces;
using Bit.Core.Billing.Premium.Queries;
using Bit.Core.Entities;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Core.Services;
namespace Bit.Core.Auth.UserFeatures.TwoFactorAuth;
public class TwoFactorIsEnabledQuery(IUserRepository userRepository) : ITwoFactorIsEnabledQuery
public class TwoFactorIsEnabledQuery : ITwoFactorIsEnabledQuery
{
private readonly IUserRepository _userRepository = userRepository;
private readonly IUserRepository _userRepository;
private readonly IHasPremiumAccessQuery _hasPremiumAccessQuery;
private readonly IFeatureService _featureService;
public TwoFactorIsEnabledQuery(
IUserRepository userRepository,
IHasPremiumAccessQuery hasPremiumAccessQuery,
IFeatureService featureService)
{
_userRepository = userRepository;
_hasPremiumAccessQuery = hasPremiumAccessQuery;
_featureService = featureService;
}
public async Task<IEnumerable<(Guid userId, bool twoFactorIsEnabled)>> TwoFactorIsEnabledAsync(IEnumerable<Guid> userIds)
{
if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery))
{
return await TwoFactorIsEnabledVNextAsync(userIds);
}
var result = new List<(Guid userId, bool hasTwoFactor)>();
if (userIds == null || !userIds.Any())
{
@@ -36,6 +57,11 @@ public class TwoFactorIsEnabledQuery(IUserRepository userRepository) : ITwoFacto
public async Task<IEnumerable<(T user, bool twoFactorIsEnabled)>> TwoFactorIsEnabledAsync<T>(IEnumerable<T> users) where T : ITwoFactorProvidersUser
{
if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery))
{
return await TwoFactorIsEnabledVNextAsync(users);
}
var userIds = users
.Select(u => u.GetUserId())
.Where(u => u.HasValue)
@@ -71,13 +97,134 @@ public class TwoFactorIsEnabledQuery(IUserRepository userRepository) : ITwoFacto
return false;
}
if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery))
{
var userEntity = user as User ?? await _userRepository.GetByIdAsync(userId.Value);
if (userEntity == null)
{
throw new NotFoundException();
}
return await TwoFactorIsEnabledVNextAsync(userEntity);
}
return await TwoFactorEnabledAsync(
user.GetTwoFactorProviders(),
async () =>
{
var calcUser = await _userRepository.GetCalculatedPremiumAsync(userId.Value);
return calcUser?.HasPremiumAccess ?? false;
});
user.GetTwoFactorProviders(),
async () =>
{
var calcUser = await _userRepository.GetCalculatedPremiumAsync(userId.Value);
return calcUser?.HasPremiumAccess ?? false;
});
}
private async Task<IEnumerable<(Guid userId, bool twoFactorIsEnabled)>> TwoFactorIsEnabledVNextAsync(IEnumerable<Guid> userIds)
{
var result = new List<(Guid userId, bool hasTwoFactor)>();
if (userIds == null || !userIds.Any())
{
return result;
}
var users = await _userRepository.GetManyAsync([.. userIds]);
// Get enabled providers for each user
var usersTwoFactorProvidersMap = users.ToDictionary(u => u.Id, GetEnabledTwoFactorProviders);
// Bulk fetch premium status only for users who need it (those with only premium providers)
var userIdsNeedingPremium = usersTwoFactorProvidersMap
.Where(kvp => kvp.Value.Any() && kvp.Value.All(TwoFactorProvider.RequiresPremium))
.Select(kvp => kvp.Key)
.ToList();
var premiumStatusMap = userIdsNeedingPremium.Count > 0
? await _hasPremiumAccessQuery.HasPremiumAccessAsync(userIdsNeedingPremium)
: new Dictionary<Guid, bool>();
foreach (var user in users)
{
var userTwoFactorProviders = usersTwoFactorProvidersMap[user.Id];
if (!userTwoFactorProviders.Any())
{
result.Add((user.Id, false));
continue;
}
// User has providers. If they're in the premium check map, verify premium status
var twoFactorIsEnabled = !premiumStatusMap.TryGetValue(user.Id, out var hasPremium) || hasPremium;
result.Add((user.Id, twoFactorIsEnabled));
}
return result;
}
private async Task<IEnumerable<(T user, bool twoFactorIsEnabled)>> TwoFactorIsEnabledVNextAsync<T>(IEnumerable<T> users)
where T : ITwoFactorProvidersUser
{
var userIds = users
.Select(u => u.GetUserId())
.Where(u => u.HasValue)
.Select(u => u.Value)
.ToList();
var twoFactorResults = await TwoFactorIsEnabledVNextAsync(userIds);
var result = new List<(T user, bool twoFactorIsEnabled)>();
foreach (var user in users)
{
var userId = user.GetUserId();
if (userId.HasValue)
{
var hasTwoFactor = twoFactorResults.FirstOrDefault(res => res.userId == userId.Value).twoFactorIsEnabled;
result.Add((user, hasTwoFactor));
}
else
{
result.Add((user, false));
}
}
return result;
}
private async Task<bool> TwoFactorIsEnabledVNextAsync(User user)
{
var enabledProviders = GetEnabledTwoFactorProviders(user);
if (!enabledProviders.Any())
{
return false;
}
// If all providers require premium, check if user has premium access
if (enabledProviders.All(TwoFactorProvider.RequiresPremium))
{
return await _hasPremiumAccessQuery.HasPremiumAccessAsync(user.Id);
}
// User has at least one non-premium provider
return true;
}
/// <summary>
/// Gets all enabled two-factor provider types for a user.
/// </summary>
/// <param name="user">user with two factor providers</param>
/// <returns>list of enabled provider types</returns>
private static IList<TwoFactorProviderType> GetEnabledTwoFactorProviders(User user)
{
var providers = user.GetTwoFactorProviders();
if (providers == null || providers.Count == 0)
{
return Array.Empty<TwoFactorProviderType>();
}
// TODO: PM-21210: In practice we don't save disabled providers to the database, worth looking into.
return (from provider in providers
where provider.Value?.Enabled ?? false
select provider.Key).ToList();
}
/// <summary>

View File

@@ -6,6 +6,7 @@ using Bit.Core.Billing.Organizations.Queries;
using Bit.Core.Billing.Organizations.Services;
using Bit.Core.Billing.Payment;
using Bit.Core.Billing.Premium.Commands;
using Bit.Core.Billing.Premium.Queries;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Implementations;
@@ -31,6 +32,7 @@ public static class ServiceCollectionExtensions
services.AddPaymentOperations();
services.AddOrganizationLicenseCommandsQueries();
services.AddPremiumCommands();
services.AddPremiumQueries();
services.AddTransient<IGetOrganizationMetadataQuery, GetOrganizationMetadataQuery>();
services.AddTransient<IGetOrganizationWarningsQuery, GetOrganizationWarningsQuery>();
services.AddTransient<IRestartSubscriptionCommand, RestartSubscriptionCommand>();
@@ -50,4 +52,9 @@ public static class ServiceCollectionExtensions
services.AddScoped<ICreatePremiumSelfHostedSubscriptionCommand, CreatePremiumSelfHostedSubscriptionCommand>();
services.AddTransient<IPreviewPremiumTaxCommand, PreviewPremiumTaxCommand>();
}
private static void AddPremiumQueries(this IServiceCollection services)
{
services.AddScoped<IHasPremiumAccessQuery, HasPremiumAccessQuery>();
}
}

View File

@@ -61,10 +61,6 @@ public interface IOrganizationBillingService
/// Updates the organization name and email on the Stripe customer entry.
/// This only updates Stripe, not the Bitwarden database.
/// </summary>
/// <remarks>
/// The caller should ensure that the organization has a GatewayCustomerId before calling this method.
/// </remarks>
/// <param name="organization">The organization to update in Stripe.</param>
/// <exception cref="BillingException">Thrown when the organization does not have a GatewayCustomerId.</exception>
Task UpdateOrganizationNameAndEmail(Organization organization);
}

View File

@@ -177,13 +177,25 @@ public class OrganizationBillingService(
public async Task UpdateOrganizationNameAndEmail(Organization organization)
{
if (organization.GatewayCustomerId is null)
if (string.IsNullOrWhiteSpace(organization.GatewayCustomerId))
{
throw new BillingException("Cannot update an organization in Stripe without a GatewayCustomerId.");
logger.LogWarning(
"Organization ({OrganizationId}) has no Stripe customer to update",
organization.Id);
return;
}
var newDisplayName = organization.DisplayName();
// Organization.DisplayName() can return null - handle gracefully
if (string.IsNullOrWhiteSpace(newDisplayName))
{
logger.LogWarning(
"Organization ({OrganizationId}) has no name to update in Stripe",
organization.Id);
return;
}
await stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId,
new CustomerUpdateOptions
{
@@ -196,9 +208,7 @@ public class OrganizationBillingService(
new CustomerInvoiceSettingsCustomFieldOptions
{
Name = organization.SubscriberType(),
Value = newDisplayName.Length <= 30
? newDisplayName
: newDisplayName[..30]
Value = newDisplayName
}]
},
});

View File

@@ -0,0 +1,29 @@
namespace Bit.Core.Billing.Premium.Models;
/// <summary>
/// Represents user premium access status from personal subscriptions and organization memberships.
/// </summary>
public class UserPremiumAccess
{
/// <summary>
/// The unique identifier for the user.
/// </summary>
public Guid Id { get; set; }
/// <summary>
/// Indicates whether the user has a personal premium subscription.
/// This does NOT include premium access from organizations.
/// </summary>
public bool PersonalPremium { get; set; }
/// <summary>
/// Indicates whether the user has premium access through any organization membership.
/// This is true if the user is a member of at least one enabled organization that grants premium access to users.
/// </summary>
public bool OrganizationPremium { get; set; }
/// <summary>
/// Indicates whether the user has premium access from any source (personal subscription or organization).
/// </summary>
public bool HasPremiumAccess => PersonalPremium || OrganizationPremium;
}

View File

@@ -0,0 +1,49 @@
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
namespace Bit.Core.Billing.Premium.Queries;
public class HasPremiumAccessQuery : IHasPremiumAccessQuery
{
private readonly IUserRepository _userRepository;
public HasPremiumAccessQuery(IUserRepository userRepository)
{
_userRepository = userRepository;
}
public async Task<bool> HasPremiumAccessAsync(Guid userId)
{
var user = await _userRepository.GetPremiumAccessAsync(userId);
if (user == null)
{
throw new NotFoundException();
}
return user.HasPremiumAccess;
}
public async Task<Dictionary<Guid, bool>> HasPremiumAccessAsync(IEnumerable<Guid> userIds)
{
var distinctUserIds = userIds.Distinct().ToList();
var usersWithPremium = await _userRepository.GetPremiumAccessByIdsAsync(distinctUserIds);
if (usersWithPremium.Count() != distinctUserIds.Count)
{
throw new NotFoundException();
}
return usersWithPremium.ToDictionary(u => u.Id, u => u.HasPremiumAccess);
}
public async Task<bool> HasPremiumFromOrganizationAsync(Guid userId)
{
var user = await _userRepository.GetPremiumAccessAsync(userId);
if (user == null)
{
throw new NotFoundException();
}
return user.OrganizationPremium;
}
}

View File

@@ -0,0 +1,30 @@
namespace Bit.Core.Billing.Premium.Queries;
/// <summary>
/// Centralized query for checking if users have premium access through personal subscriptions or organizations.
/// Note: Different from User.Premium which only checks personal subscriptions.
/// </summary>
public interface IHasPremiumAccessQuery
{
/// <summary>
/// Checks if a user has premium access (personal or organization).
/// </summary>
/// <param name="userId">The user ID to check</param>
/// <returns>True if user can access premium features</returns>
Task<bool> HasPremiumAccessAsync(Guid userId);
/// <summary>
/// Checks premium access for multiple users.
/// </summary>
/// <param name="userIds">The user IDs to check</param>
/// <returns>Dictionary mapping user IDs to their premium access status</returns>
Task<Dictionary<Guid, bool>> HasPremiumAccessAsync(IEnumerable<Guid> userIds);
/// <summary>
/// Checks if a user belongs to any organization that grants premium (enabled org with UsersGetPremium).
/// Returns true regardless of personal subscription. Useful for UI decisions like showing subscription options.
/// </summary>
/// <param name="userId">The user ID to check</param>
/// <returns>True if user is in any organization that grants premium</returns>
Task<bool> HasPremiumFromOrganizationAsync(Guid userId);
}

View File

@@ -113,4 +113,11 @@ public interface IProviderBillingService
TaxInformation taxInformation);
Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command);
/// <summary>
/// Updates the provider name and email on the Stripe customer entry.
/// This only updates Stripe, not the Bitwarden database.
/// </summary>
/// <param name="provider">The provider to update in Stripe.</param>
Task UpdateProviderNameAndEmail(Provider provider);
}

View File

@@ -143,6 +143,7 @@ public static class FeatureFlagKeys
public const string BlockClaimedDomainAccountCreation = "pm-28297-block-uninvited-claimed-domain-registration";
public const string IncreaseBulkReinviteLimitForCloud = "pm-28251-increase-bulk-reinvite-limit-for-cloud";
public const string BulkRevokeUsersV2 = "pm-28456-bulk-revoke-users-v2";
public const string PremiumAccessQuery = "pm-21411-premium-access-query";
/* Architecture */
public const string DesktopMigrationMilestone1 = "desktop-ui-migration-milestone-1";

View File

@@ -70,6 +70,11 @@ public class User : ITableObject<Guid>, IStorableSubscriber, IRevisable, ITwoFac
/// The security state is a signed object attesting to the version of the user's account.
/// </summary>
public string? SecurityState { get; set; }
/// <summary>
/// Indicates whether the user has a personal premium subscription.
/// Does not include premium access from organizations -
/// do not use this to check whether the user can access premium features.
/// </summary>
public bool Premium { get; set; }
public DateTime? PremiumExpirationDate { get; set; }
public DateTime? RenewalReminderDate { get; set; }

View File

@@ -1,4 +1,5 @@
using Bit.Core.Entities;
using Bit.Core.Billing.Premium.Models;
using Bit.Core.Entities;
using Bit.Core.KeyManagement.Models.Data;
using Bit.Core.KeyManagement.UserKey;
using Bit.Core.Models.Data;
@@ -24,6 +25,7 @@ public interface IUserRepository : IRepository<User, Guid>
/// Retrieves the data for the requested user IDs and includes an additional property indicating
/// whether the user has premium access directly or through an organization.
/// </summary>
[Obsolete("Use GetPremiumAccessByIdsAsync instead. This method will be removed in a future version.")]
Task<IEnumerable<UserWithCalculatedPremium>> GetManyWithCalculatedPremiumAsync(IEnumerable<Guid> ids);
/// <summary>
/// Retrieves the data for the requested user ID and includes additional property indicating
@@ -34,8 +36,23 @@ public interface IUserRepository : IRepository<User, Guid>
/// </summary>
/// <param name="userId">The user ID to retrieve data for.</param>
/// <returns>User data with calculated premium access; null if nothing is found</returns>
[Obsolete("Use GetPremiumAccessAsync instead. This method will be removed in a future version.")]
Task<UserWithCalculatedPremium?> GetCalculatedPremiumAsync(Guid userId);
/// <summary>
/// Retrieves premium access status for multiple users.
/// For internal use - consumers should use IHasPremiumAccessQuery instead.
/// </summary>
/// <param name="ids">The user IDs to check</param>
/// <returns>Collection of UserPremiumAccess objects containing premium status information</returns>
Task<IEnumerable<UserPremiumAccess>> GetPremiumAccessByIdsAsync(IEnumerable<Guid> ids);
/// <summary>
/// Retrieves premium access status for a single user.
/// For internal use - consumers should use IHasPremiumAccessQuery instead.
/// </summary>
/// <param name="userId">The user ID to check</param>
/// <returns>UserPremiumAccess object containing premium status information, or null if user not found</returns>
Task<UserPremiumAccess?> GetPremiumAccessAsync(Guid userId);
/// <summary>
/// Sets a new user key and updates all encrypted data.
/// <para>Warning: Any user key encrypted data not included will be lost.</para>
/// </summary>

View File

@@ -60,7 +60,7 @@ public interface IUserService
/// <summary>
/// Checks if the user has access to premium features, either through a personal subscription or through an organization.
///
/// This is the preferred way to definitively know if a user has access to premium features.
/// This is the preferred way to definitively know if a user has access to premium features when you already have the User object.
/// </summary>
/// <param name="user">user being acted on</param>
/// <returns>true if they can access premium; false otherwise.</returns>
@@ -74,6 +74,7 @@ public interface IUserService
/// </summary>
/// <param name="user">user being acted on</param>
/// <returns>true if they can access premium because of organization membership; false otherwise.</returns>
[Obsolete("Use IHasPremiumAccessQuery.HasPremiumFromOrganizationAsync instead. This method will be removed in a future version.")]
Task<bool> HasPremiumFromOrganization(User user);
Task<string> GenerateSignInTokenAsync(User user, string purpose);

View File

@@ -17,6 +17,7 @@ using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Business;
using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Premium.Queries;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Tax.Models;
@@ -73,6 +74,7 @@ public class UserService : UserManager<User>, IUserService
private readonly IDistributedCache _distributedCache;
private readonly IPolicyRequirementQuery _policyRequirementQuery;
private readonly IPricingClient _pricingClient;
private readonly IHasPremiumAccessQuery _hasPremiumAccessQuery;
public UserService(
IUserRepository userRepository,
@@ -108,7 +110,8 @@ public class UserService : UserManager<User>, IUserService
ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery,
IDistributedCache distributedCache,
IPolicyRequirementQuery policyRequirementQuery,
IPricingClient pricingClient)
IPricingClient pricingClient,
IHasPremiumAccessQuery hasPremiumAccessQuery)
: base(
store,
optionsAccessor,
@@ -149,6 +152,7 @@ public class UserService : UserManager<User>, IUserService
_distributedCache = distributedCache;
_policyRequirementQuery = policyRequirementQuery;
_pricingClient = pricingClient;
_hasPremiumAccessQuery = hasPremiumAccessQuery;
}
public Guid? GetProperUserId(ClaimsPrincipal principal)
@@ -1112,6 +1116,11 @@ public class UserService : UserManager<User>, IUserService
return false;
}
if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery))
{
return user.Premium || await _hasPremiumAccessQuery.HasPremiumFromOrganizationAsync(userId.Value);
}
return user.Premium || await HasPremiumFromOrganization(user);
}
@@ -1123,6 +1132,11 @@ public class UserService : UserManager<User>, IUserService
return false;
}
if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery))
{
return await _hasPremiumAccessQuery.HasPremiumFromOrganizationAsync(userId.Value);
}
// orgUsers in the Invited status are not associated with a userId yet, so this will get
// orgUsers in Accepted and Confirmed states only
var orgUsers = await _organizationUserRepository.GetManyByUserAsync(userId.Value);

View File

@@ -703,7 +703,6 @@ public abstract class BaseRequestValidator<T> where T : class
customResponse.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(user));
customResponse.Add("ForcePasswordReset", user.ForcePasswordReset);
customResponse.Add("ResetMasterPassword", string.IsNullOrWhiteSpace(user.MasterPassword));
customResponse.Add("Kdf", (byte)user.Kdf);
customResponse.Add("KdfIterations", user.KdfIterations);
customResponse.Add("KdfMemory", user.KdfMemory);

View File

@@ -4,7 +4,6 @@ using Bit.Core;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Auth.IdentityServer;
using Bit.Core.Auth.Models.Api.Response;
using Bit.Core.Auth.Repositories;
using Bit.Core.Context;
using Bit.Core.Entities;
@@ -154,23 +153,7 @@ public class CustomTokenRequestValidator : BaseRequestValidator<CustomTokenReque
{
// KeyConnectorUrl is configured in the CLI client, we just need to tell the client to use it
context.Result.CustomResponse["ApiUseKeyConnector"] = true;
context.Result.CustomResponse["ResetMasterPassword"] = false;
}
return Task.CompletedTask;
}
// Key connector data should have already been set in the decryption options
// for backwards compatibility we set them this way too. We can eventually get rid of this once we clean up
// ResetMasterPassword
if (!context.Result.CustomResponse.TryGetValue("UserDecryptionOptions", out var userDecryptionOptionsObj) ||
userDecryptionOptionsObj is not UserDecryptionOptions userDecryptionOptions)
{
return Task.CompletedTask;
}
if (userDecryptionOptions is { KeyConnectorOption: { } })
{
context.Result.CustomResponse["ResetMasterPassword"] = false;
}
return Task.CompletedTask;

View File

@@ -226,7 +226,6 @@ public class CollectionRepository : Repository<Collection, Guid>, ICollectionRep
{
obj.SetNewId();
var objWithGroupsAndUsers = JsonSerializer.Deserialize<CollectionWithGroupsAndUsers>(JsonSerializer.Serialize(obj))!;
objWithGroupsAndUsers.Groups = groups != null ? groups.ToArrayTVP() : Enumerable.Empty<CollectionAccessSelection>().ToArrayTVP();
@@ -243,18 +242,52 @@ public class CollectionRepository : Repository<Collection, Guid>, ICollectionRep
public async Task ReplaceAsync(Collection obj, IEnumerable<CollectionAccessSelection>? groups, IEnumerable<CollectionAccessSelection>? users)
{
var objWithGroupsAndUsers = JsonSerializer.Deserialize<CollectionWithGroupsAndUsers>(JsonSerializer.Serialize(obj))!;
objWithGroupsAndUsers.Groups = groups != null ? groups.ToArrayTVP() : Enumerable.Empty<CollectionAccessSelection>().ToArrayTVP();
objWithGroupsAndUsers.Users = users != null ? users.ToArrayTVP() : Enumerable.Empty<CollectionAccessSelection>().ToArrayTVP();
using (var connection = new SqlConnection(ConnectionString))
await using var connection = new SqlConnection(ConnectionString);
await connection.OpenAsync();
await using var transaction = await connection.BeginTransactionAsync();
try
{
var results = await connection.ExecuteAsync(
$"[{Schema}].[Collection_UpdateWithGroupsAndUsers]",
objWithGroupsAndUsers,
commandType: CommandType.StoredProcedure);
if (groups == null && users == null)
{
await connection.ExecuteAsync(
$"[{Schema}].[Collection_Update]",
obj,
commandType: CommandType.StoredProcedure,
transaction: transaction);
}
else if (groups != null && users == null)
{
await connection.ExecuteAsync(
$"[{Schema}].[Collection_UpdateWithGroups]",
new CollectionWithGroups(obj, groups),
commandType: CommandType.StoredProcedure,
transaction: transaction);
}
else if (groups == null && users != null)
{
await connection.ExecuteAsync(
$"[{Schema}].[Collection_UpdateWithUsers]",
new CollectionWithUsers(obj, users),
commandType: CommandType.StoredProcedure,
transaction: transaction);
}
else if (groups != null && users != null)
{
await connection.ExecuteAsync(
$"[{Schema}].[Collection_UpdateWithGroupsAndUsers]",
new CollectionWithGroupsAndUsers(obj, groups, users),
commandType: CommandType.StoredProcedure,
transaction: transaction);
}
await transaction.CommitAsync();
}
catch
{
await transaction.RollbackAsync();
throw;
}
}
public async Task DeleteManyAsync(IEnumerable<Guid> collectionIds)
@@ -424,9 +457,70 @@ public class CollectionRepository : Repository<Collection, Guid>, ICollectionRep
public class CollectionWithGroupsAndUsers : Collection
{
public CollectionWithGroupsAndUsers() { }
public CollectionWithGroupsAndUsers(Collection collection,
IEnumerable<CollectionAccessSelection> groups,
IEnumerable<CollectionAccessSelection> users)
{
Id = collection.Id;
Name = collection.Name;
OrganizationId = collection.OrganizationId;
CreationDate = collection.CreationDate;
RevisionDate = collection.RevisionDate;
Type = collection.Type;
ExternalId = collection.ExternalId;
DefaultUserCollectionEmail = collection.DefaultUserCollectionEmail;
Groups = groups.ToArrayTVP();
Users = users.ToArrayTVP();
}
[DisallowNull]
public DataTable? Groups { get; set; }
[DisallowNull]
public DataTable? Users { get; set; }
}
public class CollectionWithGroups : Collection
{
public CollectionWithGroups() { }
public CollectionWithGroups(Collection collection, IEnumerable<CollectionAccessSelection> groups)
{
Id = collection.Id;
Name = collection.Name;
OrganizationId = collection.OrganizationId;
CreationDate = collection.CreationDate;
RevisionDate = collection.RevisionDate;
Type = collection.Type;
ExternalId = collection.ExternalId;
DefaultUserCollectionEmail = collection.DefaultUserCollectionEmail;
Groups = groups.ToArrayTVP();
}
[DisallowNull]
public DataTable? Groups { get; set; }
}
public class CollectionWithUsers : Collection
{
public CollectionWithUsers() { }
public CollectionWithUsers(Collection collection, IEnumerable<CollectionAccessSelection> users)
{
Id = collection.Id;
Name = collection.Name;
OrganizationId = collection.OrganizationId;
CreationDate = collection.CreationDate;
RevisionDate = collection.RevisionDate;
Type = collection.Type;
ExternalId = collection.ExternalId;
DefaultUserCollectionEmail = collection.DefaultUserCollectionEmail;
Users = users.ToArrayTVP();
}
[DisallowNull]
public DataTable? Users { get; set; }
}
}

View File

@@ -1,6 +1,7 @@
using System.Data;
using System.Text.Json;
using Bit.Core;
using Bit.Core.Billing.Premium.Models;
using Bit.Core.Entities;
using Bit.Core.KeyManagement.Models.Data;
using Bit.Core.KeyManagement.UserKey;
@@ -381,6 +382,25 @@ public class UserRepository : Repository<User, Guid>, IUserRepository
return result.SingleOrDefault();
}
public async Task<IEnumerable<UserPremiumAccess>> GetPremiumAccessByIdsAsync(IEnumerable<Guid> ids)
{
using (var connection = new SqlConnection(ReadOnlyConnectionString))
{
var results = await connection.QueryAsync<UserPremiumAccess>(
$"[{Schema}].[{Table}_ReadPremiumAccessByIds]",
new { Ids = ids.ToGuidIdArrayTVP() },
commandType: CommandType.StoredProcedure);
return results.ToList();
}
}
public async Task<UserPremiumAccess?> GetPremiumAccessAsync(Guid userId)
{
var result = await GetPremiumAccessByIdsAsync([userId]);
return result.SingleOrDefault();
}
private async Task ProtectDataAndSaveAsync(User user, Func<Task> saveTask)
{
if (user == null)

View File

@@ -18,7 +18,7 @@ public class OrganizationEntityTypeConfiguration : IEntityTypeConfiguration<Orga
NpgsqlIndexBuilderExtensions.IncludeProperties(
builder.HasIndex(o => new { o.Id, o.Enabled }),
o => o.UseTotp);
o => new { o.UseTotp, o.UsersGetPremium });
builder.ToTable(nameof(Organization));
}

View File

@@ -1,4 +1,5 @@
using AutoMapper;
using Bit.Core.Billing.Premium.Models;
using Bit.Core.KeyManagement.Models.Data;
using Bit.Core.KeyManagement.UserKey;
using Bit.Core.Models.Data;
@@ -350,6 +351,36 @@ public class UserRepository : Repository<Core.Entities.User, User, Guid>, IUserR
return result.FirstOrDefault();
}
public async Task<IEnumerable<UserPremiumAccess>> GetPremiumAccessByIdsAsync(IEnumerable<Guid> ids)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var users = await dbContext.Users
.Where(x => ids.Contains(x.Id))
.Include(u => u.OrganizationUsers)
.ThenInclude(ou => ou.Organization)
.ToListAsync();
return users.Select(user => new UserPremiumAccess
{
Id = user.Id,
PersonalPremium = user.Premium,
OrganizationPremium = user.OrganizationUsers
.Any(ou => ou.Organization != null &&
ou.Organization.Enabled == true &&
ou.Organization.UsersGetPremium == true)
}).ToList();
}
}
public async Task<UserPremiumAccess?> GetPremiumAccessAsync(Guid userId)
{
var result = await GetPremiumAccessByIdsAsync([userId]);
return result.FirstOrDefault();
}
public override async Task DeleteAsync(Core.Entities.User user)
{
using (var scope = ServiceScopeFactory.CreateScope())

View File

@@ -0,0 +1,74 @@
CREATE PROCEDURE [dbo].[Collection_UpdateWithGroups]
@Id UNIQUEIDENTIFIER,
@OrganizationId UNIQUEIDENTIFIER,
@Name VARCHAR(MAX),
@ExternalId NVARCHAR(300),
@CreationDate DATETIME2(7),
@RevisionDate DATETIME2(7),
@Groups AS [dbo].[CollectionAccessSelectionType] READONLY,
@DefaultUserCollectionEmail NVARCHAR(256) = NULL,
@Type TINYINT = 0
AS
BEGIN
SET NOCOUNT ON
EXEC [dbo].[Collection_Update] @Id, @OrganizationId, @Name, @ExternalId, @CreationDate, @RevisionDate, @DefaultUserCollectionEmail, @Type
-- Groups
-- Delete groups that are no longer in source
DELETE
cg
FROM
[dbo].[CollectionGroup] cg
LEFT JOIN
@Groups g ON cg.GroupId = g.Id
WHERE
cg.CollectionId = @Id
AND g.Id IS NULL;
-- Update existing groups
UPDATE
cg
SET
cg.ReadOnly = g.ReadOnly,
cg.HidePasswords = g.HidePasswords,
cg.Manage = g.Manage
FROM
[dbo].[CollectionGroup] cg
INNER JOIN
@Groups g ON cg.GroupId = g.Id
WHERE
cg.CollectionId = @Id
AND (
cg.ReadOnly != g.ReadOnly
OR cg.HidePasswords != g.HidePasswords
OR cg.Manage != g.Manage
);
-- Insert new groups
INSERT INTO [dbo].[CollectionGroup]
(
[CollectionId],
[GroupId],
[ReadOnly],
[HidePasswords],
[Manage]
)
SELECT
@Id,
g.Id,
g.ReadOnly,
g.HidePasswords,
g.Manage
FROM
@Groups g
INNER JOIN
[dbo].[Group] grp ON grp.Id = g.Id
LEFT JOIN
[dbo].[CollectionGroup] cg ON cg.CollectionId = @Id AND cg.GroupId = g.Id
WHERE
grp.OrganizationId = @OrganizationId
AND cg.CollectionId IS NULL;
EXEC [dbo].[User_BumpAccountRevisionDateByCollectionId] @Id, @OrganizationId
END

View File

@@ -0,0 +1,74 @@
CREATE PROCEDURE [dbo].[Collection_UpdateWithUsers]
@Id UNIQUEIDENTIFIER,
@OrganizationId UNIQUEIDENTIFIER,
@Name VARCHAR(MAX),
@ExternalId NVARCHAR(300),
@CreationDate DATETIME2(7),
@RevisionDate DATETIME2(7),
@Users AS [dbo].[CollectionAccessSelectionType] READONLY,
@DefaultUserCollectionEmail NVARCHAR(256) = NULL,
@Type TINYINT = 0
AS
BEGIN
SET NOCOUNT ON
EXEC [dbo].[Collection_Update] @Id, @OrganizationId, @Name, @ExternalId, @CreationDate, @RevisionDate, @DefaultUserCollectionEmail, @Type
-- Users
-- Delete users that are no longer in source
DELETE
cu
FROM
[dbo].[CollectionUser] cu
LEFT JOIN
@Users u ON cu.OrganizationUserId = u.Id
WHERE
cu.CollectionId = @Id
AND u.Id IS NULL;
-- Update existing users
UPDATE
cu
SET
cu.ReadOnly = u.ReadOnly,
cu.HidePasswords = u.HidePasswords,
cu.Manage = u.Manage
FROM
[dbo].[CollectionUser] cu
INNER JOIN
@Users u ON cu.OrganizationUserId = u.Id
WHERE
cu.CollectionId = @Id
AND (
cu.ReadOnly != u.ReadOnly
OR cu.HidePasswords != u.HidePasswords
OR cu.Manage != u.Manage
);
-- Insert new users
INSERT INTO [dbo].[CollectionUser]
(
[CollectionId],
[OrganizationUserId],
[ReadOnly],
[HidePasswords],
[Manage]
)
SELECT
@Id,
u.Id,
u.ReadOnly,
u.HidePasswords,
u.Manage
FROM
@Users u
INNER JOIN
[dbo].[OrganizationUser] ou ON ou.Id = u.Id
LEFT JOIN
[dbo].[CollectionUser] cu ON cu.CollectionId = @Id AND cu.OrganizationUserId = u.Id
WHERE
ou.OrganizationId = @OrganizationId
AND cu.CollectionId IS NULL;
EXEC [dbo].[User_BumpAccountRevisionDateByCollectionId] @Id, @OrganizationId
END

View File

@@ -0,0 +1,15 @@
CREATE PROCEDURE [dbo].[User_ReadPremiumAccessByIds]
@Ids [dbo].[GuidIdArray] READONLY
AS
BEGIN
SET NOCOUNT ON
SELECT
UPA.[Id],
UPA.[PersonalPremium],
UPA.[OrganizationPremium]
FROM
[dbo].[UserPremiumAccessView] UPA
WHERE
UPA.[Id] IN (SELECT [Id] FROM @Ids)
END

View File

@@ -69,7 +69,7 @@ CREATE TABLE [dbo].[Organization] (
GO
CREATE NONCLUSTERED INDEX [IX_Organization_Enabled]
ON [dbo].[Organization]([Id] ASC, [Enabled] ASC)
INCLUDE ([UseTotp]);
INCLUDE ([UseTotp], [UsersGetPremium]);
GO
CREATE UNIQUE NONCLUSTERED INDEX [IX_Organization_Identifier]

View File

@@ -0,0 +1,21 @@
CREATE VIEW [dbo].[UserPremiumAccessView]
AS
SELECT
U.[Id],
U.[Premium] AS [PersonalPremium],
CAST(
MAX(CASE
WHEN O.[Id] IS NOT NULL THEN 1
ELSE 0
END) AS BIT
) AS [OrganizationPremium]
FROM
[dbo].[User] U
LEFT JOIN
[dbo].[OrganizationUser] OU ON OU.[UserId] = U.[Id]
LEFT JOIN
[dbo].[Organization] O ON O.[Id] = OU.[OrganizationId]
AND O.[UsersGetPremium] = 1
AND O.[Enabled] = 1
GROUP BY
U.[Id], U.[Premium];

View File

@@ -0,0 +1,117 @@
using Bit.Api.AdminConsole.Public.Models.Request;
using Bit.Api.IntegrationTest.Factories;
using Bit.Api.IntegrationTest.Helpers;
using Bit.Api.Models.Public.Request;
using Bit.Api.Models.Public.Response;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Enums;
using Bit.Core.Enums;
using Bit.Core.Models.Data;
using Bit.Core.Platform.Push;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Xunit;
namespace Bit.Api.IntegrationTest.Controllers.Public;
public class CollectionsControllerTests : IClassFixture<ApiApplicationFactory>, IAsyncLifetime
{
private readonly HttpClient _client;
private readonly ApiApplicationFactory _factory;
private readonly LoginHelper _loginHelper;
private string _ownerEmail = null!;
private Organization _organization = null!;
public CollectionsControllerTests(ApiApplicationFactory factory)
{
_factory = factory;
_factory.SubstituteService<IPushNotificationService>(_ => { });
_factory.SubstituteService<IFeatureService>(_ => { });
_client = factory.CreateClient();
_loginHelper = new LoginHelper(_factory, _client);
}
public async Task InitializeAsync()
{
_ownerEmail = $"integration-test{Guid.NewGuid()}@bitwarden.com";
await _factory.LoginWithNewAccount(_ownerEmail);
(_organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory,
plan: PlanType.EnterpriseAnnually,
ownerEmail: _ownerEmail,
passwordManagerSeats: 10,
paymentMethod: PaymentMethodType.Card);
await _loginHelper.LoginWithOrganizationApiKeyAsync(_organization.Id);
}
public Task DisposeAsync()
{
_client.Dispose();
return Task.CompletedTask;
}
[Fact]
public async Task CreateCollectionWithMultipleUsersAndVariedPermissions_Success()
{
// Arrange
_organization.AllowAdminAccessToAllCollectionItems = true;
await _factory.GetService<IOrganizationRepository>().UpsertAsync(_organization);
var groupRepository = _factory.GetService<IGroupRepository>();
var group = await groupRepository.CreateAsync(new Group
{
OrganizationId = _organization.Id,
Name = "CollectionControllerTests.CreateCollectionWithMultipleUsersAndVariedPermissions_Success",
ExternalId = $"CollectionControllerTests.CreateCollectionWithMultipleUsersAndVariedPermissions_Success{Guid.NewGuid()}",
});
var (_, user) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(
_factory,
_organization.Id,
OrganizationUserType.User);
var collection = await OrganizationTestHelpers.CreateCollectionAsync(
_factory,
_organization.Id,
"Shared Collection with a group",
externalId: "shared-collection-with-group",
groups:
[
new CollectionAccessSelection { Id = group.Id, ReadOnly = false, HidePasswords = false, Manage = true }
],
users:
[
new CollectionAccessSelection { Id = user.Id, ReadOnly = false, HidePasswords = false, Manage = true }
]);
var getCollectionsResponse = await _client.GetFromJsonAsync<ListResponseModel<CollectionResponseModel>>("public/collections");
var getCollectionResponse = await _client.GetFromJsonAsync<CollectionResponseModel>($"public/collections/{collection.Id}");
var firstCollection = getCollectionsResponse.Data.First(x => x.ExternalId == "shared-collection-with-group");
var update = new CollectionUpdateRequestModel
{
ExternalId = firstCollection.ExternalId,
Groups = firstCollection.Groups?.Select(x => new AssociationWithPermissionsRequestModel
{
Id = x.Id,
ReadOnly = x.ReadOnly,
HidePasswords = x.HidePasswords,
Manage = x.Manage
}),
};
await _client.PutAsJsonAsync($"public/collections/{firstCollection.Id}", update);
var result = await _factory.GetService<ICollectionRepository>()
.GetByIdWithAccessAsync(firstCollection.Id);
Assert.NotNull(result);
Assert.NotEmpty(result.Item2.Groups);
Assert.NotEmpty(result.Item2.Users);
}
}

View File

@@ -159,14 +159,16 @@ public static class OrganizationTestHelpers
Guid organizationId,
string name,
IEnumerable<CollectionAccessSelection>? users = null,
IEnumerable<CollectionAccessSelection>? groups = null)
IEnumerable<CollectionAccessSelection>? groups = null,
string? externalId = null)
{
var collectionRepository = factory.GetService<ICollectionRepository>();
var collection = new Collection
{
OrganizationId = organizationId,
Name = name,
Type = CollectionType.SharedCollection
Type = CollectionType.SharedCollection,
ExternalId = externalId
};
await collectionRepository.CreateAsync(collection, groups, users);

View File

@@ -62,7 +62,7 @@ public class ReconcileAdditionalStorageJobTests
// Assert
_stripeFacade.Received(3).ListSubscriptionsAutoPagingAsync(
Arg.Is<SubscriptionListOptions>(o => o.Status == "active"));
Arg.Is<SubscriptionListOptions>(o => o.Limit == 100));
}
#endregion
@@ -553,6 +553,152 @@ public class ReconcileAdditionalStorageJobTests
#endregion
#region Subscription Status Filtering Tests
[Fact]
public async Task Execute_ActiveStatusSubscription_ProcessesSubscription()
{
// Arrange
var context = CreateJobExecutionContext();
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.Active);
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(subscription));
_stripeFacade.UpdateSubscription(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>())
.Returns(subscription);
// Act
await _sut.Execute(context);
// Assert
await _stripeFacade.Received(1).UpdateSubscription("sub_123", Arg.Any<SubscriptionUpdateOptions>());
}
[Fact]
public async Task Execute_TrialingStatusSubscription_ProcessesSubscription()
{
// Arrange
var context = CreateJobExecutionContext();
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.Trialing);
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(subscription));
_stripeFacade.UpdateSubscription(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>())
.Returns(subscription);
// Act
await _sut.Execute(context);
// Assert
await _stripeFacade.Received(1).UpdateSubscription("sub_123", Arg.Any<SubscriptionUpdateOptions>());
}
[Fact]
public async Task Execute_PastDueStatusSubscription_ProcessesSubscription()
{
// Arrange
var context = CreateJobExecutionContext();
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.PastDue);
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(subscription));
_stripeFacade.UpdateSubscription(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>())
.Returns(subscription);
// Act
await _sut.Execute(context);
// Assert
await _stripeFacade.Received(1).UpdateSubscription("sub_123", Arg.Any<SubscriptionUpdateOptions>());
}
[Fact]
public async Task Execute_CanceledStatusSubscription_SkipsSubscription()
{
// Arrange
var context = CreateJobExecutionContext();
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.Canceled);
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(subscription));
// Act
await _sut.Execute(context);
// Assert
await _stripeFacade.DidNotReceiveWithAnyArgs().UpdateSubscription(null!);
}
[Fact]
public async Task Execute_IncompleteStatusSubscription_SkipsSubscription()
{
// Arrange
var context = CreateJobExecutionContext();
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.Incomplete);
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(subscription));
// Act
await _sut.Execute(context);
// Assert
await _stripeFacade.DidNotReceiveWithAnyArgs().UpdateSubscription(null!);
}
[Fact]
public async Task Execute_MixedSubscriptionStatuses_OnlyProcessesValidStatuses()
{
// Arrange
var context = CreateJobExecutionContext();
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
var activeSubscription = CreateSubscription("sub_active", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.Active);
var trialingSubscription = CreateSubscription("sub_trialing", "storage-gb-monthly", quantity: 8, status: StripeConstants.SubscriptionStatus.Trialing);
var pastDueSubscription = CreateSubscription("sub_pastdue", "storage-gb-monthly", quantity: 6, status: StripeConstants.SubscriptionStatus.PastDue);
var canceledSubscription = CreateSubscription("sub_canceled", "storage-gb-monthly", quantity: 5, status: StripeConstants.SubscriptionStatus.Canceled);
var incompleteSubscription = CreateSubscription("sub_incomplete", "storage-gb-monthly", quantity: 4, status: StripeConstants.SubscriptionStatus.Incomplete);
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(activeSubscription, trialingSubscription, pastDueSubscription, canceledSubscription, incompleteSubscription));
_stripeFacade.UpdateSubscription(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>())
.Returns(callInfo => callInfo.Arg<string>() switch
{
"sub_active" => activeSubscription,
"sub_trialing" => trialingSubscription,
"sub_pastdue" => pastDueSubscription,
_ => null
});
// Act
await _sut.Execute(context);
// Assert
await _stripeFacade.Received(1).UpdateSubscription("sub_active", Arg.Any<SubscriptionUpdateOptions>());
await _stripeFacade.Received(1).UpdateSubscription("sub_trialing", Arg.Any<SubscriptionUpdateOptions>());
await _stripeFacade.Received(1).UpdateSubscription("sub_pastdue", Arg.Any<SubscriptionUpdateOptions>());
await _stripeFacade.DidNotReceive().UpdateSubscription("sub_canceled", Arg.Any<SubscriptionUpdateOptions>());
await _stripeFacade.DidNotReceive().UpdateSubscription("sub_incomplete", Arg.Any<SubscriptionUpdateOptions>());
}
#endregion
#region Cancellation Tests
[Fact]
@@ -598,7 +744,8 @@ public class ReconcileAdditionalStorageJobTests
string id,
string priceId,
long? quantity = null,
Dictionary<string, string>? metadata = null)
Dictionary<string, string>? metadata = null,
string status = StripeConstants.SubscriptionStatus.Active)
{
var price = new Price { Id = priceId };
var item = new SubscriptionItem
@@ -611,6 +758,7 @@ public class ReconcileAdditionalStorageJobTests
return new Subscription
{
Id = id,
Status = status,
Metadata = metadata,
Items = new StripeList<SubscriptionItem>
{

View File

@@ -200,7 +200,8 @@ public class EventIntegrationServiceCollectionExtensionsTests
["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost",
["GlobalSettings:EventLogging:RabbitMq:Username"] = "user",
["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass",
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange"
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange",
["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration"
});
Assert.True(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings));
@@ -214,7 +215,8 @@ public class EventIntegrationServiceCollectionExtensionsTests
["GlobalSettings:EventLogging:RabbitMq:HostName"] = null,
["GlobalSettings:EventLogging:RabbitMq:Username"] = "user",
["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass",
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange"
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange",
["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration"
});
Assert.False(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings));
@@ -228,7 +230,8 @@ public class EventIntegrationServiceCollectionExtensionsTests
["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost",
["GlobalSettings:EventLogging:RabbitMq:Username"] = null,
["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass",
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange"
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange",
["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration"
});
Assert.False(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings));
@@ -242,21 +245,38 @@ public class EventIntegrationServiceCollectionExtensionsTests
["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost",
["GlobalSettings:EventLogging:RabbitMq:Username"] = "user",
["GlobalSettings:EventLogging:RabbitMq:Password"] = null,
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange"
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange",
["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration"
});
Assert.False(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings));
}
[Fact]
public void IsRabbitMqEnabled_MissingExchangeName_ReturnsFalse()
public void IsRabbitMqEnabled_MissingEventExchangeName_ReturnsFalse()
{
var globalSettings = CreateGlobalSettings(new Dictionary<string, string?>
{
["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost",
["GlobalSettings:EventLogging:RabbitMq:Username"] = "user",
["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass",
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = null
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = null,
["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration"
});
Assert.False(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings));
}
[Fact]
public void IsRabbitMqEnabled_MissingIntegrationExchangeName_ReturnsFalse()
{
var globalSettings = CreateGlobalSettings(new Dictionary<string, string?>
{
["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost",
["GlobalSettings:EventLogging:RabbitMq:Username"] = "user",
["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass",
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange",
["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = null
});
Assert.False(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings));
@@ -268,7 +288,8 @@ public class EventIntegrationServiceCollectionExtensionsTests
var globalSettings = CreateGlobalSettings(new Dictionary<string, string?>
{
["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test",
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events"
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events",
["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration"
});
Assert.True(EventIntegrationsServiceCollectionExtensions.IsAzureServiceBusEnabled(globalSettings));
@@ -280,19 +301,34 @@ public class EventIntegrationServiceCollectionExtensionsTests
var globalSettings = CreateGlobalSettings(new Dictionary<string, string?>
{
["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = null,
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events"
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events",
["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration"
});
Assert.False(EventIntegrationsServiceCollectionExtensions.IsAzureServiceBusEnabled(globalSettings));
}
[Fact]
public void IsAzureServiceBusEnabled_MissingTopicName_ReturnsFalse()
public void IsAzureServiceBusEnabled_MissingEventTopicName_ReturnsFalse()
{
var globalSettings = CreateGlobalSettings(new Dictionary<string, string?>
{
["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test",
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = null
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = null,
["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration"
});
Assert.False(EventIntegrationsServiceCollectionExtensions.IsAzureServiceBusEnabled(globalSettings));
}
[Fact]
public void IsAzureServiceBusEnabled_MissingIntegrationTopicName_ReturnsFalse()
{
var globalSettings = CreateGlobalSettings(new Dictionary<string, string?>
{
["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test",
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events",
["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = null
});
Assert.False(EventIntegrationsServiceCollectionExtensions.IsAzureServiceBusEnabled(globalSettings));
@@ -601,7 +637,8 @@ public class EventIntegrationServiceCollectionExtensionsTests
["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost",
["GlobalSettings:EventLogging:RabbitMq:Username"] = "user",
["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass",
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange"
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange",
["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration"
});
// Add prerequisites
@@ -624,7 +661,8 @@ public class EventIntegrationServiceCollectionExtensionsTests
var globalSettings = CreateGlobalSettings(new Dictionary<string, string?>
{
["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test",
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events"
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events",
["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration"
});
// Add prerequisites
@@ -650,8 +688,10 @@ public class EventIntegrationServiceCollectionExtensionsTests
["GlobalSettings:EventLogging:RabbitMq:Username"] = "user",
["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass",
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange",
["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration",
["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test",
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events"
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events",
["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration"
});
// Add prerequisites
@@ -694,7 +734,8 @@ public class EventIntegrationServiceCollectionExtensionsTests
var globalSettings = CreateGlobalSettings(new Dictionary<string, string?>
{
["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test",
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events"
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events",
["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration"
});
services.AddEventWriteServices(globalSettings);
@@ -712,7 +753,8 @@ public class EventIntegrationServiceCollectionExtensionsTests
["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost",
["GlobalSettings:EventLogging:RabbitMq:Username"] = "user",
["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass",
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange"
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange",
["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration"
});
services.AddEventWriteServices(globalSettings);
@@ -769,10 +811,12 @@ public class EventIntegrationServiceCollectionExtensionsTests
{
["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test",
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events",
["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration",
["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost",
["GlobalSettings:EventLogging:RabbitMq:Username"] = "user",
["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass",
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange"
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange",
["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration"
});
services.AddEventWriteServices(globalSettings);
@@ -789,7 +833,8 @@ public class EventIntegrationServiceCollectionExtensionsTests
var globalSettings = CreateGlobalSettings(new Dictionary<string, string?>
{
["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test",
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events"
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events",
["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration"
});
// Add prerequisites
@@ -826,7 +871,8 @@ public class EventIntegrationServiceCollectionExtensionsTests
["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost",
["GlobalSettings:EventLogging:RabbitMq:Username"] = "user",
["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass",
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange"
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange",
["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration"
});
// Add prerequisites

View File

@@ -0,0 +1,128 @@
using Bit.Core.AdminConsole.Models.Data.EventIntegrations;
using Bit.Test.Common.AutoFixture.Attributes;
using Xunit;
namespace Bit.Core.Test.AdminConsole.Models.Data.EventIntegrations;
public class IntegrationHandlerResultTests
{
[Theory, BitAutoData]
public void Succeed_SetsSuccessTrue_CategoryNull(IntegrationMessage message)
{
var result = IntegrationHandlerResult.Succeed(message);
Assert.True(result.Success);
Assert.Null(result.Category);
Assert.Equal(message, result.Message);
Assert.Null(result.FailureReason);
}
[Theory, BitAutoData]
public void Fail_WithCategory_SetsSuccessFalse_CategorySet(IntegrationMessage message)
{
var category = IntegrationFailureCategory.AuthenticationFailed;
var failureReason = "Invalid credentials";
var result = IntegrationHandlerResult.Fail(message, category, failureReason);
Assert.False(result.Success);
Assert.Equal(category, result.Category);
Assert.Equal(failureReason, result.FailureReason);
Assert.Equal(message, result.Message);
}
[Theory, BitAutoData]
public void Fail_WithDelayUntil_SetsDelayUntilDate(IntegrationMessage message)
{
var delayUntil = DateTime.UtcNow.AddMinutes(5);
var result = IntegrationHandlerResult.Fail(
message,
IntegrationFailureCategory.RateLimited,
"Rate limited",
delayUntil
);
Assert.Equal(delayUntil, result.DelayUntilDate);
}
[Theory, BitAutoData]
public void Retryable_RateLimited_ReturnsTrue(IntegrationMessage message)
{
var result = IntegrationHandlerResult.Fail(
message,
IntegrationFailureCategory.RateLimited,
"Rate limited"
);
Assert.True(result.Retryable);
}
[Theory, BitAutoData]
public void Retryable_TransientError_ReturnsTrue(IntegrationMessage message)
{
var result = IntegrationHandlerResult.Fail(
message,
IntegrationFailureCategory.TransientError,
"Temporary network issue"
);
Assert.True(result.Retryable);
}
[Theory, BitAutoData]
public void Retryable_AuthenticationFailed_ReturnsFalse(IntegrationMessage message)
{
var result = IntegrationHandlerResult.Fail(
message,
IntegrationFailureCategory.AuthenticationFailed,
"Invalid token"
);
Assert.False(result.Retryable);
}
[Theory, BitAutoData]
public void Retryable_ConfigurationError_ReturnsFalse(IntegrationMessage message)
{
var result = IntegrationHandlerResult.Fail(
message,
IntegrationFailureCategory.ConfigurationError,
"Channel not found"
);
Assert.False(result.Retryable);
}
[Theory, BitAutoData]
public void Retryable_ServiceUnavailable_ReturnsTrue(IntegrationMessage message)
{
var result = IntegrationHandlerResult.Fail(
message,
IntegrationFailureCategory.ServiceUnavailable,
"Service is down"
);
Assert.True(result.Retryable);
}
[Theory, BitAutoData]
public void Retryable_PermanentFailure_ReturnsFalse(IntegrationMessage message)
{
var result = IntegrationHandlerResult.Fail(
message,
IntegrationFailureCategory.PermanentFailure,
"Permanent failure"
);
Assert.False(result.Retryable);
}
[Theory, BitAutoData]
public void Retryable_SuccessCase_ReturnsFalse(IntegrationMessage message)
{
var result = IntegrationHandlerResult.Succeed(message);
Assert.False(result.Retryable);
}
}

View File

@@ -30,7 +30,7 @@ public class OrganizationUpdateCommandTests
var organizationBillingService = sutProvider.GetDependency<IOrganizationBillingService>();
organization.Id = organizationId;
organization.GatewayCustomerId = null; // No Stripe customer, so no billing update
organization.GatewayCustomerId = null; // No Stripe customer, but billing update is still called
organizationRepository
.GetByIdAsync(organizationId)
@@ -61,8 +61,8 @@ public class OrganizationUpdateCommandTests
result,
EventType.Organization_Updated);
await organizationBillingService
.DidNotReceiveWithAnyArgs()
.UpdateOrganizationNameAndEmail(Arg.Any<Organization>());
.Received(1)
.UpdateOrganizationNameAndEmail(result);
}
[Theory, BitAutoData]
@@ -93,7 +93,7 @@ public class OrganizationUpdateCommandTests
[Theory]
[BitAutoData("")]
[BitAutoData((string)null)]
public async Task UpdateAsync_WhenGatewayCustomerIdIsNullOrEmpty_SkipsBillingUpdate(
public async Task UpdateAsync_WhenGatewayCustomerIdIsNullOrEmpty_CallsBillingUpdateButHandledGracefully(
string gatewayCustomerId,
Guid organizationId,
Organization organization,
@@ -133,8 +133,8 @@ public class OrganizationUpdateCommandTests
result,
EventType.Organization_Updated);
await organizationBillingService
.DidNotReceiveWithAnyArgs()
.UpdateOrganizationNameAndEmail(Arg.Any<Organization>());
.Received(1)
.UpdateOrganizationNameAndEmail(result);
}
[Theory, BitAutoData]

View File

@@ -78,8 +78,10 @@ public class AzureServiceBusIntegrationListenerServiceTests
var sutProvider = GetSutProvider();
message.RetryCount = 0;
var result = new IntegrationHandlerResult(false, message);
result.Retryable = false;
var result = IntegrationHandlerResult.Fail(
message: message,
category: IntegrationFailureCategory.AuthenticationFailed, // NOT retryable
failureReason: "403");
_handler.HandleAsync(Arg.Any<string>()).Returns(result);
var expected = IntegrationMessage<WebhookIntegrationConfiguration>.FromJson(message.ToJson());
@@ -89,6 +91,12 @@ public class AzureServiceBusIntegrationListenerServiceTests
await _handler.Received(1).HandleAsync(Arg.Is(expected.ToJson()));
await _serviceBusService.DidNotReceiveWithAnyArgs().PublishToRetryAsync(Arg.Any<IIntegrationMessage>());
_logger.Received().Log(
LogLevel.Warning,
Arg.Any<EventId>(),
Arg.Is<object>(o => (o.ToString() ?? "").Contains("Integration failure - non-recoverable error or max retries exceeded.")),
Arg.Any<Exception?>(),
Arg.Any<Func<object, Exception?, string>>());
}
[Theory, BitAutoData]
@@ -96,9 +104,10 @@ public class AzureServiceBusIntegrationListenerServiceTests
{
var sutProvider = GetSutProvider();
message.RetryCount = _config.MaxRetries;
var result = new IntegrationHandlerResult(false, message);
result.Retryable = true;
var result = IntegrationHandlerResult.Fail(
message: message,
category: IntegrationFailureCategory.TransientError, // Retryable
failureReason: "403");
_handler.HandleAsync(Arg.Any<string>()).Returns(result);
var expected = IntegrationMessage<WebhookIntegrationConfiguration>.FromJson(message.ToJson());
@@ -108,6 +117,12 @@ public class AzureServiceBusIntegrationListenerServiceTests
await _handler.Received(1).HandleAsync(Arg.Is(expected.ToJson()));
await _serviceBusService.DidNotReceiveWithAnyArgs().PublishToRetryAsync(Arg.Any<IIntegrationMessage>());
_logger.Received().Log(
LogLevel.Warning,
Arg.Any<EventId>(),
Arg.Is<object>(o => (o.ToString() ?? "").Contains("Integration failure - non-recoverable error or max retries exceeded.")),
Arg.Any<Exception?>(),
Arg.Any<Func<object, Exception?, string>>());
}
[Theory, BitAutoData]
@@ -116,8 +131,10 @@ public class AzureServiceBusIntegrationListenerServiceTests
var sutProvider = GetSutProvider();
message.RetryCount = 0;
var result = new IntegrationHandlerResult(false, message);
result.Retryable = true;
var result = IntegrationHandlerResult.Fail(
message: message,
category: IntegrationFailureCategory.TransientError, // Retryable
failureReason: "403");
_handler.HandleAsync(Arg.Any<string>()).Returns(result);
var expected = IntegrationMessage<WebhookIntegrationConfiguration>.FromJson(message.ToJson());
@@ -133,7 +150,7 @@ public class AzureServiceBusIntegrationListenerServiceTests
public async Task HandleMessageAsync_SuccessfulResult_Succeeds(IntegrationMessage<WebhookIntegrationConfiguration> message)
{
var sutProvider = GetSutProvider();
var result = new IntegrationHandlerResult(true, message);
var result = IntegrationHandlerResult.Succeed(message);
_handler.HandleAsync(Arg.Any<string>()).Returns(result);
var expected = IntegrationMessage<WebhookIntegrationConfiguration>.FromJson(message.ToJson());
@@ -156,7 +173,7 @@ public class AzureServiceBusIntegrationListenerServiceTests
_logger.Received(1).Log(
LogLevel.Error,
Arg.Any<EventId>(),
Arg.Any<object>(),
Arg.Is<object>(o => (o.ToString() ?? "").Contains("Unhandled error processing ASB message")),
Arg.Any<Exception>(),
Arg.Any<Func<object, Exception?, string>>());

View File

@@ -51,7 +51,7 @@ public class DatadogIntegrationHandlerTests
Assert.True(result.Success);
Assert.Equal(result.Message, message);
Assert.Empty(result.FailureReason);
Assert.Null(result.FailureReason);
sutProvider.GetDependency<IHttpClientFactory>().Received(1).CreateClient(
Arg.Is(AssertHelper.AssertPropertyEqual(DatadogIntegrationHandler.HttpClientName))

View File

@@ -1,4 +1,5 @@
using Bit.Core.AdminConsole.Models.Data.EventIntegrations;
using System.Net;
using Bit.Core.AdminConsole.Models.Data.EventIntegrations;
using Bit.Core.Enums;
using Bit.Core.Services;
using Xunit;
@@ -7,7 +8,6 @@ namespace Bit.Core.Test.Services;
public class IntegrationHandlerTests
{
[Fact]
public async Task HandleAsync_ConvertsJsonToTypedIntegrationMessage()
{
@@ -33,13 +33,113 @@ public class IntegrationHandlerTests
Assert.Equal(expected.IntegrationType, typedResult.IntegrationType);
}
[Theory]
[InlineData(HttpStatusCode.Unauthorized)]
[InlineData(HttpStatusCode.Forbidden)]
public void ClassifyHttpStatusCode_AuthenticationFailed(HttpStatusCode code)
{
Assert.Equal(
IntegrationFailureCategory.AuthenticationFailed,
TestIntegrationHandler.Classify(code));
}
[Theory]
[InlineData(HttpStatusCode.NotFound)]
[InlineData(HttpStatusCode.Gone)]
[InlineData(HttpStatusCode.MovedPermanently)]
[InlineData(HttpStatusCode.TemporaryRedirect)]
[InlineData(HttpStatusCode.PermanentRedirect)]
public void ClassifyHttpStatusCode_ConfigurationError(HttpStatusCode code)
{
Assert.Equal(
IntegrationFailureCategory.ConfigurationError,
TestIntegrationHandler.Classify(code));
}
[Fact]
public void ClassifyHttpStatusCode_TooManyRequests_IsRateLimited()
{
Assert.Equal(
IntegrationFailureCategory.RateLimited,
TestIntegrationHandler.Classify(HttpStatusCode.TooManyRequests));
}
[Fact]
public void ClassifyHttpStatusCode_RequestTimeout_IsTransient()
{
Assert.Equal(
IntegrationFailureCategory.TransientError,
TestIntegrationHandler.Classify(HttpStatusCode.RequestTimeout));
}
[Theory]
[InlineData(HttpStatusCode.InternalServerError)]
[InlineData(HttpStatusCode.BadGateway)]
[InlineData(HttpStatusCode.GatewayTimeout)]
public void ClassifyHttpStatusCode_Common5xx_AreTransient(HttpStatusCode code)
{
Assert.Equal(
IntegrationFailureCategory.TransientError,
TestIntegrationHandler.Classify(code));
}
[Fact]
public void ClassifyHttpStatusCode_ServiceUnavailable_IsServiceUnavailable()
{
Assert.Equal(
IntegrationFailureCategory.ServiceUnavailable,
TestIntegrationHandler.Classify(HttpStatusCode.ServiceUnavailable));
}
[Fact]
public void ClassifyHttpStatusCode_NotImplemented_IsPermanentFailure()
{
Assert.Equal(
IntegrationFailureCategory.PermanentFailure,
TestIntegrationHandler.Classify(HttpStatusCode.NotImplemented));
}
[Fact]
public void FClassifyHttpStatusCode_Unhandled3xx_IsConfigurationError()
{
Assert.Equal(
IntegrationFailureCategory.ConfigurationError,
TestIntegrationHandler.Classify(HttpStatusCode.Found));
}
[Fact]
public void ClassifyHttpStatusCode_Unhandled4xx_IsConfigurationError()
{
Assert.Equal(
IntegrationFailureCategory.ConfigurationError,
TestIntegrationHandler.Classify(HttpStatusCode.BadRequest));
}
[Fact]
public void ClassifyHttpStatusCode_Unhandled5xx_IsServiceUnavailable()
{
Assert.Equal(
IntegrationFailureCategory.ServiceUnavailable,
TestIntegrationHandler.Classify(HttpStatusCode.HttpVersionNotSupported));
}
[Fact]
public void ClassifyHttpStatusCode_UnknownCode_DefaultsToServiceUnavailable()
{
// cast an out-of-range value to ensure default path is stable
Assert.Equal(
IntegrationFailureCategory.ServiceUnavailable,
TestIntegrationHandler.Classify((HttpStatusCode)799));
}
private class TestIntegrationHandler : IntegrationHandlerBase<WebhookIntegrationConfigurationDetails>
{
public override Task<IntegrationHandlerResult> HandleAsync(
IntegrationMessage<WebhookIntegrationConfigurationDetails> message)
{
var result = new IntegrationHandlerResult(success: true, message: message);
return Task.FromResult(result);
return Task.FromResult(IntegrationHandlerResult.Succeed(message: message));
}
public static IntegrationFailureCategory Classify(HttpStatusCode code) => ClassifyHttpStatusCode(code);
}
}

View File

@@ -86,8 +86,10 @@ public class RabbitMqIntegrationListenerServiceTests
new BasicProperties(),
body: Encoding.UTF8.GetBytes(message.ToJson())
);
var result = new IntegrationHandlerResult(false, message);
result.Retryable = false;
var result = IntegrationHandlerResult.Fail(
message: message,
category: IntegrationFailureCategory.AuthenticationFailed, // NOT retryable
failureReason: "403");
_handler.HandleAsync(Arg.Any<string>()).Returns(result);
var expected = IntegrationMessage<WebhookIntegrationConfiguration>.FromJson(message.ToJson());
@@ -105,7 +107,7 @@ public class RabbitMqIntegrationListenerServiceTests
_logger.Received().Log(
LogLevel.Warning,
Arg.Any<EventId>(),
Arg.Is<object>(o => (o.ToString() ?? "").Contains("Non-retryable failure")),
Arg.Is<object>(o => (o.ToString() ?? "").Contains("Integration failure - non-retryable.")),
Arg.Any<Exception?>(),
Arg.Any<Func<object, Exception?, string>>());
@@ -133,8 +135,10 @@ public class RabbitMqIntegrationListenerServiceTests
new BasicProperties(),
body: Encoding.UTF8.GetBytes(message.ToJson())
);
var result = new IntegrationHandlerResult(false, message);
result.Retryable = true;
var result = IntegrationHandlerResult.Fail(
message: message,
category: IntegrationFailureCategory.TransientError, // Retryable
failureReason: "403");
_handler.HandleAsync(Arg.Any<string>()).Returns(result);
var expected = IntegrationMessage<WebhookIntegrationConfiguration>.FromJson(message.ToJson());
@@ -151,7 +155,7 @@ public class RabbitMqIntegrationListenerServiceTests
_logger.Received().Log(
LogLevel.Warning,
Arg.Any<EventId>(),
Arg.Is<object>(o => (o.ToString() ?? "").Contains("Max retry attempts reached")),
Arg.Is<object>(o => (o.ToString() ?? "").Contains("Integration failure - max retries exceeded.")),
Arg.Any<Exception?>(),
Arg.Any<Func<object, Exception?, string>>());
@@ -179,9 +183,10 @@ public class RabbitMqIntegrationListenerServiceTests
new BasicProperties(),
body: Encoding.UTF8.GetBytes(message.ToJson())
);
var result = new IntegrationHandlerResult(false, message);
result.Retryable = true;
result.DelayUntilDate = _now.AddMinutes(1);
var result = IntegrationHandlerResult.Fail(
message: message,
category: IntegrationFailureCategory.TransientError, // Retryable
failureReason: "403");
_handler.HandleAsync(Arg.Any<string>()).Returns(result);
var expected = IntegrationMessage<WebhookIntegrationConfiguration>.FromJson(message.ToJson());
@@ -220,7 +225,7 @@ public class RabbitMqIntegrationListenerServiceTests
new BasicProperties(),
body: Encoding.UTF8.GetBytes(message.ToJson())
);
var result = new IntegrationHandlerResult(true, message);
var result = IntegrationHandlerResult.Succeed(message);
_handler.HandleAsync(Arg.Any<string>()).Returns(result);
await sutProvider.Sut.ProcessReceivedMessageAsync(eventArgs, cancellationToken);

View File

@@ -110,7 +110,7 @@ public class SlackIntegrationHandlerTests
}
[Fact]
public async Task HandleAsync_NullResponse_ReturnsNonRetryableFailure()
public async Task HandleAsync_NullResponse_ReturnsRetryableFailure()
{
var sutProvider = GetSutProvider();
var message = new IntegrationMessage<SlackIntegrationConfigurationDetails>()
@@ -126,7 +126,7 @@ public class SlackIntegrationHandlerTests
var result = await sutProvider.Sut.HandleAsync(message);
Assert.False(result.Success);
Assert.False(result.Retryable);
Assert.True(result.Retryable); // Null response is classified as TransientError (retryable)
Assert.Equal("Slack response was null", result.FailureReason);
Assert.Equal(result.Message, message);

View File

@@ -1,4 +1,5 @@
using Bit.Core.AdminConsole.Models.Data.EventIntegrations;
using System.Text.Json;
using Bit.Core.AdminConsole.Models.Data.EventIntegrations;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
@@ -42,9 +43,77 @@ public class TeamsIntegrationHandlerTests
);
}
[Theory, BitAutoData]
public async Task HandleAsync_ArgumentException_ReturnsConfigurationError(IntegrationMessage<TeamsIntegrationConfigurationDetails> message)
{
var sutProvider = GetSutProvider();
message.Configuration = new TeamsIntegrationConfigurationDetails(_channelId, _serviceUrl);
sutProvider.GetDependency<ITeamsService>()
.SendMessageToChannelAsync(Arg.Any<Uri>(), Arg.Any<string>(), Arg.Any<string>())
.ThrowsAsync(new ArgumentException("argument error"));
var result = await sutProvider.Sut.HandleAsync(message);
Assert.False(result.Success);
Assert.Equal(IntegrationFailureCategory.ConfigurationError, result.Category);
Assert.False(result.Retryable);
Assert.Equal(result.Message, message);
await sutProvider.GetDependency<ITeamsService>().Received(1).SendMessageToChannelAsync(
Arg.Is(AssertHelper.AssertPropertyEqual(_serviceUrl)),
Arg.Is(AssertHelper.AssertPropertyEqual(_channelId)),
Arg.Is(AssertHelper.AssertPropertyEqual(message.RenderedTemplate))
);
}
[Theory, BitAutoData]
public async Task HandleAsync_HttpExceptionNonRetryable_ReturnsFalseAndNotRetryable(IntegrationMessage<TeamsIntegrationConfigurationDetails> message)
public async Task HandleAsync_JsonException_ReturnsPermanentFailure(IntegrationMessage<TeamsIntegrationConfigurationDetails> message)
{
var sutProvider = GetSutProvider();
message.Configuration = new TeamsIntegrationConfigurationDetails(_channelId, _serviceUrl);
sutProvider.GetDependency<ITeamsService>()
.SendMessageToChannelAsync(Arg.Any<Uri>(), Arg.Any<string>(), Arg.Any<string>())
.ThrowsAsync(new JsonException("JSON error"));
var result = await sutProvider.Sut.HandleAsync(message);
Assert.False(result.Success);
Assert.Equal(IntegrationFailureCategory.PermanentFailure, result.Category);
Assert.False(result.Retryable);
Assert.Equal(result.Message, message);
await sutProvider.GetDependency<ITeamsService>().Received(1).SendMessageToChannelAsync(
Arg.Is(AssertHelper.AssertPropertyEqual(_serviceUrl)),
Arg.Is(AssertHelper.AssertPropertyEqual(_channelId)),
Arg.Is(AssertHelper.AssertPropertyEqual(message.RenderedTemplate))
);
}
[Theory, BitAutoData]
public async Task HandleAsync_UriFormatException_ReturnsConfigurationError(IntegrationMessage<TeamsIntegrationConfigurationDetails> message)
{
var sutProvider = GetSutProvider();
message.Configuration = new TeamsIntegrationConfigurationDetails(_channelId, _serviceUrl);
sutProvider.GetDependency<ITeamsService>()
.SendMessageToChannelAsync(Arg.Any<Uri>(), Arg.Any<string>(), Arg.Any<string>())
.ThrowsAsync(new UriFormatException("Bad URI"));
var result = await sutProvider.Sut.HandleAsync(message);
Assert.False(result.Success);
Assert.Equal(IntegrationFailureCategory.ConfigurationError, result.Category);
Assert.False(result.Retryable);
Assert.Equal(result.Message, message);
await sutProvider.GetDependency<ITeamsService>().Received(1).SendMessageToChannelAsync(
Arg.Is(AssertHelper.AssertPropertyEqual(_serviceUrl)),
Arg.Is(AssertHelper.AssertPropertyEqual(_channelId)),
Arg.Is(AssertHelper.AssertPropertyEqual(message.RenderedTemplate))
);
}
[Theory, BitAutoData]
public async Task HandleAsync_HttpExceptionForbidden_ReturnsAuthenticationFailed(IntegrationMessage<TeamsIntegrationConfigurationDetails> message)
{
var sutProvider = GetSutProvider();
message.Configuration = new TeamsIntegrationConfigurationDetails(_channelId, _serviceUrl);
@@ -62,6 +131,7 @@ public class TeamsIntegrationHandlerTests
var result = await sutProvider.Sut.HandleAsync(message);
Assert.False(result.Success);
Assert.Equal(IntegrationFailureCategory.AuthenticationFailed, result.Category);
Assert.False(result.Retryable);
Assert.Equal(result.Message, message);
@@ -73,7 +143,7 @@ public class TeamsIntegrationHandlerTests
}
[Theory, BitAutoData]
public async Task HandleAsync_HttpExceptionRetryable_ReturnsFalseAndRetryable(IntegrationMessage<TeamsIntegrationConfigurationDetails> message)
public async Task HandleAsync_HttpExceptionTooManyRequests_ReturnsRateLimited(IntegrationMessage<TeamsIntegrationConfigurationDetails> message)
{
var sutProvider = GetSutProvider();
message.Configuration = new TeamsIntegrationConfigurationDetails(_channelId, _serviceUrl);
@@ -92,6 +162,7 @@ public class TeamsIntegrationHandlerTests
var result = await sutProvider.Sut.HandleAsync(message);
Assert.False(result.Success);
Assert.Equal(IntegrationFailureCategory.RateLimited, result.Category);
Assert.True(result.Retryable);
Assert.Equal(result.Message, message);
@@ -103,7 +174,7 @@ public class TeamsIntegrationHandlerTests
}
[Theory, BitAutoData]
public async Task HandleAsync_UnknownException_ReturnsFalseAndNotRetryable(IntegrationMessage<TeamsIntegrationConfigurationDetails> message)
public async Task HandleAsync_UnknownException_ReturnsTransientError(IntegrationMessage<TeamsIntegrationConfigurationDetails> message)
{
var sutProvider = GetSutProvider();
message.Configuration = new TeamsIntegrationConfigurationDetails(_channelId, _serviceUrl);
@@ -114,7 +185,8 @@ public class TeamsIntegrationHandlerTests
var result = await sutProvider.Sut.HandleAsync(message);
Assert.False(result.Success);
Assert.False(result.Retryable);
Assert.Equal(IntegrationFailureCategory.TransientError, result.Category);
Assert.True(result.Retryable);
Assert.Equal(result.Message, message);
await sutProvider.GetDependency<ITeamsService>().Received(1).SendMessageToChannelAsync(

View File

@@ -51,7 +51,7 @@ public class WebhookIntegrationHandlerTests
Assert.True(result.Success);
Assert.Equal(result.Message, message);
Assert.Empty(result.FailureReason);
Assert.Null(result.FailureReason);
sutProvider.GetDependency<IHttpClientFactory>().Received(1).CreateClient(
Arg.Is(AssertHelper.AssertPropertyEqual(WebhookIntegrationHandler.HttpClientName))
@@ -79,7 +79,7 @@ public class WebhookIntegrationHandlerTests
Assert.True(result.Success);
Assert.Equal(result.Message, message);
Assert.Empty(result.FailureReason);
Assert.Null(result.FailureReason);
sutProvider.GetDependency<IHttpClientFactory>().Received(1).CreateClient(
Arg.Is(AssertHelper.AssertPropertyEqual(WebhookIntegrationHandler.HttpClientName))

View File

@@ -1,10 +1,13 @@
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models;
using Bit.Core.Auth.UserFeatures.TwoFactorAuth;
using Bit.Core.Billing.Premium.Queries;
using Bit.Core.Entities;
using Bit.Core.Exceptions;
using Bit.Core.Models.Data;
using Bit.Core.Models.Data.Organizations.OrganizationUsers;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Utilities;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
@@ -404,6 +407,277 @@ public class TwoFactorIsEnabledQueryTests
.GetCalculatedPremiumAsync(default);
}
[Theory]
[BitAutoData((IEnumerable<Guid>)null)]
[BitAutoData([])]
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithNoUserIds_ReturnsEmpty(
IEnumerable<Guid> userIds,
SutProvider<TwoFactorIsEnabledQuery> sutProvider)
{
// Arrange
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
.Returns(true);
// Act
var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(userIds);
// Assert
Assert.Empty(result);
}
[Theory]
[BitAutoData(TwoFactorProviderType.Duo)]
[BitAutoData(TwoFactorProviderType.YubiKey)]
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithMixedScenarios_ReturnsCorrectResults(
TwoFactorProviderType premiumProviderType,
SutProvider<TwoFactorIsEnabledQuery> sutProvider,
User user1,
User user2,
User user3)
{
// Arrange
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
.Returns(true);
var users = new List<User> { user1, user2, user3 };
var userIds = users.Select(u => u.Id).ToList();
// User 1: Non-premium provider → 2FA enabled
user1.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
{
{ TwoFactorProviderType.Authenticator, new TwoFactorProvider { Enabled = true } }
});
// User 2: Premium provider + has premium → 2FA enabled
user2.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
{
{ premiumProviderType, new TwoFactorProvider { Enabled = true } }
});
// User 3: Premium provider + no premium → 2FA disabled
user3.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
{
{ premiumProviderType, new TwoFactorProvider { Enabled = true } }
});
var premiumStatus = new Dictionary<Guid, bool>
{
{ user2.Id, true },
{ user3.Id, false }
};
sutProvider.GetDependency<IUserRepository>()
.GetManyAsync(Arg.Is<IEnumerable<Guid>>(ids => ids.SequenceEqual(userIds)))
.Returns(users);
sutProvider.GetDependency<IHasPremiumAccessQuery>()
.HasPremiumAccessAsync(Arg.Is<IEnumerable<Guid>>(ids =>
ids.Count() == 2 && ids.Contains(user2.Id) && ids.Contains(user3.Id)))
.Returns(premiumStatus);
// Act
var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(userIds);
// Assert
Assert.Contains(result, res => res.userId == user1.Id && res.twoFactorIsEnabled == true); // Non-premium provider
Assert.Contains(result, res => res.userId == user2.Id && res.twoFactorIsEnabled == true); // Premium + has premium
Assert.Contains(result, res => res.userId == user3.Id && res.twoFactorIsEnabled == false); // Premium + no premium
}
[Theory]
[BitAutoData(TwoFactorProviderType.Duo)]
[BitAutoData(TwoFactorProviderType.YubiKey)]
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_OnlyChecksPremiumAccessForUsersWhoNeedIt(
TwoFactorProviderType premiumProviderType,
SutProvider<TwoFactorIsEnabledQuery> sutProvider,
User user1,
User user2,
User user3)
{
// Arrange
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
.Returns(true);
var users = new List<User> { user1, user2, user3 };
var userIds = users.Select(u => u.Id).ToList();
// User 1: Has non-premium provider - should NOT trigger premium check
user1.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
{
{ TwoFactorProviderType.Authenticator, new TwoFactorProvider { Enabled = true } }
});
// User 2 & 3: Have only premium providers - SHOULD trigger premium check
user2.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
{
{ premiumProviderType, new TwoFactorProvider { Enabled = true } }
});
user3.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
{
{ premiumProviderType, new TwoFactorProvider { Enabled = true } }
});
var premiumStatus = new Dictionary<Guid, bool>
{
{ user2.Id, true },
{ user3.Id, false }
};
sutProvider.GetDependency<IUserRepository>()
.GetManyAsync(Arg.Is<IEnumerable<Guid>>(ids => ids.SequenceEqual(userIds)))
.Returns(users);
sutProvider.GetDependency<IHasPremiumAccessQuery>()
.HasPremiumAccessAsync(Arg.Is<IEnumerable<Guid>>(ids =>
ids.Count() == 2 && ids.Contains(user2.Id) && ids.Contains(user3.Id)))
.Returns(premiumStatus);
// Act
var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(userIds);
// Assert - Verify optimization: premium checked ONLY for users 2 and 3 (not user 1)
await sutProvider.GetDependency<IHasPremiumAccessQuery>()
.Received(1)
.HasPremiumAccessAsync(Arg.Is<IEnumerable<Guid>>(ids =>
ids.Count() == 2 && ids.Contains(user2.Id) && ids.Contains(user3.Id)));
}
[Theory]
[BitAutoData]
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithNoUserIds_ReturnsAllTwoFactorDisabled(
SutProvider<TwoFactorIsEnabledQuery> sutProvider,
List<OrganizationUserUserDetails> users)
{
// Arrange
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
.Returns(true);
foreach (var user in users)
{
user.UserId = null;
}
// Act
var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(users);
// Assert
foreach (var user in users)
{
Assert.Contains(result, res => res.user.Equals(user) && res.twoFactorIsEnabled == false);
}
// No UserIds were supplied so no calls to the UserRepository should have been made
await sutProvider.GetDependency<IUserRepository>()
.DidNotReceiveWithAnyArgs()
.GetManyAsync(default);
}
[Theory]
[BitAutoData(TwoFactorProviderType.Authenticator, true)] // Non-premium provider
[BitAutoData(TwoFactorProviderType.Duo, true)] // Premium provider with premium access
[BitAutoData(TwoFactorProviderType.YubiKey, false)] // Premium provider without premium access
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_SingleUser_VariousScenarios(
TwoFactorProviderType providerType,
bool hasPremiumAccess,
SutProvider<TwoFactorIsEnabledQuery> sutProvider,
User user)
{
// Arrange
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
.Returns(true);
user.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
{
{ providerType, new TwoFactorProvider { Enabled = true } }
});
sutProvider.GetDependency<IHasPremiumAccessQuery>()
.HasPremiumAccessAsync(user.Id)
.Returns(hasPremiumAccess);
// Act
var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(user);
// Assert
var requiresPremium = TwoFactorProvider.RequiresPremium(providerType);
var expectedResult = !requiresPremium || hasPremiumAccess;
Assert.Equal(expectedResult, result);
}
[Theory]
[BitAutoData]
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithNoEnabledProviders_ReturnsFalse(
SutProvider<TwoFactorIsEnabledQuery> sutProvider,
User user)
{
// Arrange
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
.Returns(true);
user.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
{
{ TwoFactorProviderType.Email, new TwoFactorProvider { Enabled = false } }
});
// Act
var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(user);
// Assert
Assert.False(result);
}
[Theory]
[BitAutoData]
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithNullProviders_ReturnsFalse(
SutProvider<TwoFactorIsEnabledQuery> sutProvider,
User user)
{
// Arrange
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
.Returns(true);
user.TwoFactorProviders = null;
// Act
var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(user);
// Assert
Assert.False(result);
}
[Theory]
[BitAutoData]
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_UserNotFound_ThrowsNotFoundException(
SutProvider<TwoFactorIsEnabledQuery> sutProvider,
Guid userId)
{
// Arrange
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
.Returns(true);
var testUser = new TestTwoFactorProviderUser
{
Id = userId,
TwoFactorProviders = null
};
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(userId)
.Returns((User)null);
// Act & Assert
await Assert.ThrowsAsync<NotFoundException>(
async () => await sutProvider.Sut.TwoFactorIsEnabledAsync(testUser));
}
private class TestTwoFactorProviderUser : ITwoFactorProvidersUser
{
public Guid? Id { get; set; }
@@ -418,10 +692,5 @@ public class TwoFactorIsEnabledQueryTests
{
return Id;
}
public bool GetPremium()
{
return Premium;
}
}
}

View File

@@ -0,0 +1,234 @@
using Bit.Core.Billing.Premium.Models;
using Bit.Core.Billing.Premium.Queries;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Billing.Premium.Queries;
[SutProviderCustomize]
public class HasPremiumAccessQueryTests
{
[Theory, BitAutoData]
public async Task HasPremiumAccessAsync_WhenUserHasPersonalPremium_ReturnsTrue(
UserPremiumAccess user,
SutProvider<HasPremiumAccessQuery> sutProvider)
{
// Arrange
user.PersonalPremium = true;
user.OrganizationPremium = false;
sutProvider.GetDependency<IUserRepository>()
.GetPremiumAccessAsync(user.Id)
.Returns(user);
// Act
var result = await sutProvider.Sut.HasPremiumAccessAsync(user.Id);
// Assert
Assert.True(result);
}
[Theory, BitAutoData]
public async Task HasPremiumAccessAsync_WhenUserHasNoPersonalPremiumButHasOrgPremium_ReturnsTrue(
UserPremiumAccess user,
SutProvider<HasPremiumAccessQuery> sutProvider)
{
// Arrange
user.PersonalPremium = false;
user.OrganizationPremium = true; // Has org premium
sutProvider.GetDependency<IUserRepository>()
.GetPremiumAccessAsync(user.Id)
.Returns(user);
// Act
var result = await sutProvider.Sut.HasPremiumAccessAsync(user.Id);
// Assert
Assert.True(result);
}
[Theory, BitAutoData]
public async Task HasPremiumAccessAsync_WhenUserHasNoPersonalPremiumAndNoOrgPremium_ReturnsFalse(
UserPremiumAccess user,
SutProvider<HasPremiumAccessQuery> sutProvider)
{
// Arrange
user.PersonalPremium = false;
user.OrganizationPremium = false;
sutProvider.GetDependency<IUserRepository>()
.GetPremiumAccessAsync(user.Id)
.Returns(user);
// Act
var result = await sutProvider.Sut.HasPremiumAccessAsync(user.Id);
// Assert
Assert.False(result);
}
[Theory, BitAutoData]
public async Task HasPremiumAccessAsync_WhenUserNotFound_ThrowsNotFoundException(
Guid userId,
SutProvider<HasPremiumAccessQuery> sutProvider)
{
// Arrange
sutProvider.GetDependency<IUserRepository>()
.GetPremiumAccessAsync(userId)
.Returns((UserPremiumAccess?)null);
// Act & Assert
await Assert.ThrowsAsync<NotFoundException>(
() => sutProvider.Sut.HasPremiumAccessAsync(userId));
}
[Theory, BitAutoData]
public async Task HasPremiumFromOrganizationAsync_WhenUserHasNoOrganizations_ReturnsFalse(
UserPremiumAccess user,
SutProvider<HasPremiumAccessQuery> sutProvider)
{
// Arrange
user.PersonalPremium = false;
user.OrganizationPremium = false; // No premium from anywhere
sutProvider.GetDependency<IUserRepository>()
.GetPremiumAccessAsync(user.Id)
.Returns(user);
// Act
var result = await sutProvider.Sut.HasPremiumFromOrganizationAsync(user.Id);
// Assert
Assert.False(result);
}
[Theory, BitAutoData]
public async Task HasPremiumFromOrganizationAsync_WhenUserHasPremiumFromOrg_ReturnsTrue(
UserPremiumAccess user,
SutProvider<HasPremiumAccessQuery> sutProvider)
{
// Arrange
user.PersonalPremium = false; // No personal premium
user.OrganizationPremium = true; // But has premium from org
sutProvider.GetDependency<IUserRepository>()
.GetPremiumAccessAsync(user.Id)
.Returns(user);
// Act
var result = await sutProvider.Sut.HasPremiumFromOrganizationAsync(user.Id);
// Assert
Assert.True(result);
}
[Theory, BitAutoData]
public async Task HasPremiumFromOrganizationAsync_WhenUserHasOnlyPersonalPremium_ReturnsFalse(
UserPremiumAccess user,
SutProvider<HasPremiumAccessQuery> sutProvider)
{
// Arrange
user.PersonalPremium = true; // Has personal premium
user.OrganizationPremium = false; // Not in any org that grants premium
sutProvider.GetDependency<IUserRepository>()
.GetPremiumAccessAsync(user.Id)
.Returns(user);
// Act
var result = await sutProvider.Sut.HasPremiumFromOrganizationAsync(user.Id);
// Assert
Assert.False(result); // Should return false because user is not in an org that grants premium
}
[Theory, BitAutoData]
public async Task HasPremiumFromOrganizationAsync_WhenUserHasBothPersonalAndOrgPremium_ReturnsTrue(
UserPremiumAccess user,
SutProvider<HasPremiumAccessQuery> sutProvider)
{
// Arrange
user.PersonalPremium = true; // Has personal premium
user.OrganizationPremium = true; // Also in an org that grants premium
sutProvider.GetDependency<IUserRepository>()
.GetPremiumAccessAsync(user.Id)
.Returns(user);
// Act
var result = await sutProvider.Sut.HasPremiumFromOrganizationAsync(user.Id);
// Assert
Assert.True(result); // Should return true because user IS in an org that grants premium (regardless of personal premium)
}
[Theory, BitAutoData]
public async Task HasPremiumFromOrganizationAsync_WhenUserNotFound_ThrowsNotFoundException(
Guid userId,
SutProvider<HasPremiumAccessQuery> sutProvider)
{
// Arrange
sutProvider.GetDependency<IUserRepository>()
.GetPremiumAccessAsync(userId)
.Returns((UserPremiumAccess?)null);
// Act & Assert
await Assert.ThrowsAsync<NotFoundException>(
() => sutProvider.Sut.HasPremiumFromOrganizationAsync(userId));
}
[Theory, BitAutoData]
public async Task HasPremiumAccessAsync_Bulk_WhenEmptyList_ReturnsEmptyDictionary(
SutProvider<HasPremiumAccessQuery> sutProvider)
{
// Arrange
var userIds = new List<Guid>();
sutProvider.GetDependency<IUserRepository>()
.GetPremiumAccessByIdsAsync(userIds)
.Returns(new List<UserPremiumAccess>());
// Act
var result = await sutProvider.Sut.HasPremiumAccessAsync(userIds);
// Assert
Assert.Empty(result);
}
[Theory, BitAutoData]
public async Task HasPremiumAccessAsync_Bulk_ReturnsCorrectStatus(
UserPremiumAccess user1,
UserPremiumAccess user2,
UserPremiumAccess user3,
SutProvider<HasPremiumAccessQuery> sutProvider)
{
// Arrange
user1.PersonalPremium = true;
user1.OrganizationPremium = false;
user2.PersonalPremium = false;
user2.OrganizationPremium = false;
user3.PersonalPremium = false;
user3.OrganizationPremium = true;
var users = new List<UserPremiumAccess> { user1, user2, user3 };
var userIds = users.Select(u => u.Id).ToList();
sutProvider.GetDependency<IUserRepository>()
.GetPremiumAccessByIdsAsync(Arg.Is<IEnumerable<Guid>>(ids => ids.SequenceEqual(userIds)))
.Returns(users);
// Act
var result = await sutProvider.Sut.HasPremiumAccessAsync(userIds);
// Assert
Assert.Equal(3, result.Count);
Assert.True(result[user1.Id]); // Personal premium
Assert.False(result[user2.Id]); // No premium
Assert.True(result[user3.Id]); // Organization premium
}
}

View File

@@ -1,5 +1,4 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models.Sales;
@@ -391,12 +390,13 @@ public class OrganizationBillingServiceTests
}
[Theory, BitAutoData]
public async Task UpdateOrganizationNameAndEmail_WhenNameIsLong_TruncatesTo30Characters(
public async Task UpdateOrganizationNameAndEmail_WhenNameIsLong_UsesFullName(
Organization organization,
SutProvider<OrganizationBillingService> sutProvider)
{
// Arrange
organization.Name = "This is a very long organization name that exceeds thirty characters";
var longName = "This is a very long organization name that exceeds thirty characters";
organization.Name = longName;
CustomerUpdateOptions capturedOptions = null;
sutProvider.GetDependency<IStripeAdapter>()
@@ -420,14 +420,11 @@ public class OrganizationBillingServiceTests
Assert.NotNull(capturedOptions.InvoiceSettings.CustomFields);
var customField = capturedOptions.InvoiceSettings.CustomFields.First();
Assert.Equal(30, customField.Value.Length);
var expectedCustomFieldDisplayName = "This is a very long organizati";
Assert.Equal(expectedCustomFieldDisplayName, customField.Value);
Assert.Equal(longName, customField.Value);
}
[Theory, BitAutoData]
public async Task UpdateOrganizationNameAndEmail_WhenGatewayCustomerIdIsNull_ThrowsBillingException(
public async Task UpdateOrganizationNameAndEmail_WhenGatewayCustomerIdIsNull_LogsWarningAndReturns(
Organization organization,
SutProvider<OrganizationBillingService> sutProvider)
{
@@ -435,15 +432,93 @@ public class OrganizationBillingServiceTests
organization.GatewayCustomerId = null;
organization.Name = "Test Organization";
organization.BillingEmail = "billing@example.com";
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
// Act & Assert
var exception = await Assert.ThrowsAsync<BillingException>(
() => sutProvider.Sut.UpdateOrganizationNameAndEmail(organization));
// Act
await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization);
Assert.Contains("Cannot update an organization in Stripe without a GatewayCustomerId.", exception.Response);
// Assert
await stripeAdapter.DidNotReceive().UpdateCustomerAsync(
Arg.Any<string>(),
Arg.Any<CustomerUpdateOptions>());
}
await sutProvider.GetDependency<IStripeAdapter>()
.DidNotReceiveWithAnyArgs()
.UpdateCustomerAsync(Arg.Any<string>(), Arg.Any<CustomerUpdateOptions>());
[Theory, BitAutoData]
public async Task UpdateOrganizationNameAndEmail_WhenGatewayCustomerIdIsEmpty_LogsWarningAndReturns(
Organization organization,
SutProvider<OrganizationBillingService> sutProvider)
{
// Arrange
organization.GatewayCustomerId = "";
organization.Name = "Test Organization";
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
// Act
await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization);
// Assert
await stripeAdapter.DidNotReceive().UpdateCustomerAsync(
Arg.Any<string>(),
Arg.Any<CustomerUpdateOptions>());
}
[Theory, BitAutoData]
public async Task UpdateOrganizationNameAndEmail_WhenNameIsNull_LogsWarningAndReturns(
Organization organization,
SutProvider<OrganizationBillingService> sutProvider)
{
// Arrange
organization.Name = null;
organization.GatewayCustomerId = "cus_test123";
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
// Act
await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization);
// Assert
await stripeAdapter.DidNotReceive().UpdateCustomerAsync(
Arg.Any<string>(),
Arg.Any<CustomerUpdateOptions>());
}
[Theory, BitAutoData]
public async Task UpdateOrganizationNameAndEmail_WhenNameIsEmpty_LogsWarningAndReturns(
Organization organization,
SutProvider<OrganizationBillingService> sutProvider)
{
// Arrange
organization.Name = "";
organization.GatewayCustomerId = "cus_test123";
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
// Act
await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization);
// Assert
await stripeAdapter.DidNotReceive().UpdateCustomerAsync(
Arg.Any<string>(),
Arg.Any<CustomerUpdateOptions>());
}
[Theory, BitAutoData]
public async Task UpdateOrganizationNameAndEmail_WhenBillingEmailIsNull_UpdatesWithNull(
Organization organization,
SutProvider<OrganizationBillingService> sutProvider)
{
// Arrange
organization.Name = "Test Organization";
organization.BillingEmail = null;
organization.GatewayCustomerId = "cus_test123";
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
// Act
await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization);
// Assert
await stripeAdapter.Received(1).UpdateCustomerAsync(
organization.GatewayCustomerId,
Arg.Is<CustomerUpdateOptions>(options =>
options.Email == null &&
options.Description == organization.Name));
}
}

View File

@@ -81,7 +81,6 @@ public class IdentityServerTests : IClassFixture<IdentityApplicationFactory>
var root = body.RootElement;
AssertRefreshTokenExists(root);
AssertHelper.AssertJsonProperty(root, "ForcePasswordReset", JsonValueKind.False);
AssertHelper.AssertJsonProperty(root, "ResetMasterPassword", JsonValueKind.False);
var kdf = AssertHelper.AssertJsonProperty(root, "Kdf", JsonValueKind.Number).GetInt32();
Assert.Equal(0, kdf);
var kdfIterations = AssertHelper.AssertJsonProperty(root, "KdfIterations", JsonValueKind.Number).GetInt32();

View File

@@ -144,4 +144,69 @@ public class CollectionRepositoryReplaceTests
await userRepository.DeleteAsync(user);
await organizationRepository.DeleteAsync(organization);
}
[Theory, DatabaseData]
public async Task ReplaceAsync_WhenNotPassingGroupsOrUsers_DoesNotDeleteAccess(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IGroupRepository groupRepository,
ICollectionRepository collectionRepository)
{
// Arrange
var organization = await organizationRepository.CreateTestOrganizationAsync();
var user1 = await userRepository.CreateTestUserAsync();
var orgUser1 = await organizationUserRepository.CreateTestOrganizationUserAsync(organization, user1);
var user2 = await userRepository.CreateTestUserAsync();
var orgUser2 = await organizationUserRepository.CreateTestOrganizationUserAsync(organization, user2);
var group1 = await groupRepository.CreateTestGroupAsync(organization);
var group2 = await groupRepository.CreateTestGroupAsync(organization);
var collection = new Collection
{
Name = "Test Collection Name",
OrganizationId = organization.Id,
};
await collectionRepository.CreateAsync(collection,
[
new CollectionAccessSelection { Id = group1.Id, Manage = true, HidePasswords = true, ReadOnly = false, },
new CollectionAccessSelection { Id = group2.Id, Manage = false, HidePasswords = false, ReadOnly = true, },
],
[
new CollectionAccessSelection { Id = orgUser1.Id, Manage = true, HidePasswords = false, ReadOnly = true },
new CollectionAccessSelection { Id = orgUser2.Id, Manage = false, HidePasswords = true, ReadOnly = false },
]
);
// Act
collection.Name = "Updated Collection Name";
await collectionRepository.ReplaceAsync(collection, null, null);
// Assert
var (actualCollection, actualAccess) = await collectionRepository.GetByIdWithAccessAsync(collection.Id);
Assert.NotNull(actualCollection);
Assert.Equal("Updated Collection Name", actualCollection.Name);
var groups = actualAccess.Groups.ToArray();
Assert.Equal(2, groups.Length);
Assert.Single(groups, g => g.Id == group1.Id && g.Manage && g.HidePasswords && !g.ReadOnly);
Assert.Single(groups, g => g.Id == group2.Id && !g.Manage && !g.HidePasswords && g.ReadOnly);
var users = actualAccess.Users.ToArray();
Assert.Equal(2, users.Length);
Assert.Single(users, u => u.Id == orgUser1.Id && u.Manage && !u.HidePasswords && u.ReadOnly);
Assert.Single(users, u => u.Id == orgUser2.Id && !u.Manage && u.HidePasswords && !u.ReadOnly);
// Clean up data
await userRepository.DeleteAsync(user1);
await userRepository.DeleteAsync(user2);
await organizationRepository.DeleteAsync(organization);
}
}

View File

@@ -179,4 +179,325 @@ public class UserRepositoryTests
Assert.Equal(CollectionType.SharedCollection, updatedCollection2.Type);
Assert.Equal(user2.Email, updatedCollection2.DefaultUserCollectionEmail);
}
[Theory, DatabaseData]
public async Task GetPremiumAccessAsync_WithPersonalPremium_ReturnsCorrectAccess(
IUserRepository userRepository)
{
// Arrange
var user = await userRepository.CreateAsync(new User
{
Name = "Premium User",
Email = $"premium+{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = true
});
// Act
var result = await userRepository.GetPremiumAccessAsync(user.Id);
// Assert
Assert.NotNull(result);
Assert.True(result.PersonalPremium);
Assert.False(result.OrganizationPremium);
Assert.True(result.HasPremiumAccess);
}
[Theory, DatabaseData]
public async Task GetPremiumAccessAsync_WithOrganizationPremium_ReturnsCorrectAccess(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository)
{
// Arrange
var user = await userRepository.CreateAsync(new User
{
Name = "Org User",
Email = $"org+{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = false
});
var organization = await organizationRepository.CreateTestOrganizationAsync();
await organizationUserRepository.CreateTestOrganizationUserAsync(organization, user);
// Act
var result = await userRepository.GetPremiumAccessAsync(user.Id);
// Assert
Assert.NotNull(result);
Assert.False(result.PersonalPremium);
Assert.True(result.OrganizationPremium);
Assert.True(result.HasPremiumAccess);
}
[Theory, DatabaseData]
public async Task GetPremiumAccessAsync_WithDisabledOrganization_ReturnsNoOrganizationPremium(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository)
{
// Arrange
var user = await userRepository.CreateAsync(new User
{
Name = "User",
Email = $"user+{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = false
});
var organization = await organizationRepository.CreateTestOrganizationAsync();
organization.Enabled = false;
await organizationRepository.ReplaceAsync(organization);
await organizationUserRepository.CreateTestOrganizationUserAsync(organization, user);
// Act
var result = await userRepository.GetPremiumAccessAsync(user.Id);
// Assert
Assert.NotNull(result);
Assert.False(result.OrganizationPremium);
Assert.False(result.HasPremiumAccess);
}
[Theory, DatabaseData]
public async Task GetPremiumAccessAsync_WithOrganizationUsersGetPremiumFalse_ReturnsNoOrganizationPremium(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository)
{
// Arrange
var user = await userRepository.CreateAsync(new User
{
Name = "User",
Email = $"{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = false
});
var organization = await organizationRepository.CreateTestOrganizationAsync();
organization.UsersGetPremium = false;
await organizationRepository.ReplaceAsync(organization);
await organizationUserRepository.CreateTestOrganizationUserAsync(organization, user);
// Act
var result = await userRepository.GetPremiumAccessAsync(user.Id);
// Assert
Assert.NotNull(result);
Assert.False(result.OrganizationPremium);
Assert.False(result.HasPremiumAccess);
}
[Theory, DatabaseData]
public async Task GetPremiumAccessAsync_WithMultipleOrganizations_OneProvidesPremium_ReturnsOrganizationPremium(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository)
{
// Arrange
var user = await userRepository.CreateAsync(new User
{
Name = "User With Premium Org",
Email = $"{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = false
});
var orgWithPremium = await organizationRepository.CreateTestOrganizationAsync();
await organizationUserRepository.CreateTestOrganizationUserAsync(orgWithPremium, user);
var orgNoPremium = await organizationRepository.CreateTestOrganizationAsync();
orgNoPremium.UsersGetPremium = false;
await organizationRepository.ReplaceAsync(orgNoPremium);
await organizationUserRepository.CreateTestOrganizationUserAsync(orgNoPremium, user);
// Act
var result = await userRepository.GetPremiumAccessAsync(user.Id);
// Assert
Assert.NotNull(result);
Assert.False(result.PersonalPremium);
Assert.True(result.OrganizationPremium);
Assert.True(result.HasPremiumAccess);
}
[Theory, DatabaseData]
public async Task GetPremiumAccessAsync_WithMultipleOrganizations_NoneProvidePremium_ReturnsNoOrganizationPremium(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository)
{
// Arrange
var user = await userRepository.CreateAsync(new User
{
Name = "User With No Premium Orgs",
Email = $"{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = false
});
var disabledOrg = await organizationRepository.CreateTestOrganizationAsync();
disabledOrg.Enabled = false;
await organizationRepository.ReplaceAsync(disabledOrg);
await organizationUserRepository.CreateTestOrganizationUserAsync(disabledOrg, user);
var orgNoPremium = await organizationRepository.CreateTestOrganizationAsync();
orgNoPremium.UsersGetPremium = false;
await organizationRepository.ReplaceAsync(orgNoPremium);
await organizationUserRepository.CreateTestOrganizationUserAsync(orgNoPremium, user);
// Act
var result = await userRepository.GetPremiumAccessAsync(user.Id);
// Assert
Assert.NotNull(result);
Assert.False(result.PersonalPremium);
Assert.False(result.OrganizationPremium);
Assert.False(result.HasPremiumAccess);
}
[Theory, DatabaseData]
public async Task GetPremiumAccessAsync_NonExistentUser_ReturnsNull(
IUserRepository userRepository)
{
// Act
var result = await userRepository.GetPremiumAccessAsync(Guid.NewGuid());
// Assert
Assert.Null(result);
}
[Theory, DatabaseData]
public async Task GetPremiumAccessByIdsAsync_MultipleUsers_ReturnsCorrectAccessForEach(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository)
{
// Arrange
var personalPremiumUser = await userRepository.CreateAsync(new User
{
Name = "Personal Premium",
Email = $"{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = true
});
var orgPremiumUser = await userRepository.CreateAsync(new User
{
Name = "Org Premium",
Email = $"{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = false
});
var bothPremiumUser = await userRepository.CreateAsync(new User
{
Name = "Both Premium",
Email = $"{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = true
});
var noPremiumUser = await userRepository.CreateAsync(new User
{
Name = "No Premium",
Email = $"{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = false
});
var multiOrgUser = await userRepository.CreateAsync(new User
{
Name = "Multi Org User",
Email = $"{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = false
});
var personalPremiumWithDisabledOrg = await userRepository.CreateAsync(new User
{
Name = "Personal Premium With Disabled Org",
Email = $"{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = true
});
var organization = await organizationRepository.CreateTestOrganizationAsync();
await organizationUserRepository.CreateTestOrganizationUserAsync(organization, orgPremiumUser);
await organizationUserRepository.CreateTestOrganizationUserAsync(organization, bothPremiumUser);
await organizationUserRepository.CreateTestOrganizationUserAsync(organization, multiOrgUser);
var orgWithoutPremium = await organizationRepository.CreateTestOrganizationAsync();
orgWithoutPremium.UsersGetPremium = false;
await organizationRepository.ReplaceAsync(orgWithoutPremium);
await organizationUserRepository.CreateTestOrganizationUserAsync(orgWithoutPremium, multiOrgUser);
var disabledOrg = await organizationRepository.CreateTestOrganizationAsync();
disabledOrg.Enabled = false;
await organizationRepository.ReplaceAsync(disabledOrg);
await organizationUserRepository.CreateTestOrganizationUserAsync(disabledOrg, personalPremiumWithDisabledOrg);
// Act
var results = await userRepository.GetPremiumAccessByIdsAsync([
personalPremiumUser.Id,
orgPremiumUser.Id,
bothPremiumUser.Id,
noPremiumUser.Id,
multiOrgUser.Id,
personalPremiumWithDisabledOrg.Id
]);
var resultsList = results.ToList();
// Assert
Assert.Equal(6, resultsList.Count);
var personalResult = resultsList.First(r => r.Id == personalPremiumUser.Id);
Assert.True(personalResult.PersonalPremium);
Assert.False(personalResult.OrganizationPremium);
var orgResult = resultsList.First(r => r.Id == orgPremiumUser.Id);
Assert.False(orgResult.PersonalPremium);
Assert.True(orgResult.OrganizationPremium);
var bothResult = resultsList.First(r => r.Id == bothPremiumUser.Id);
Assert.True(bothResult.PersonalPremium);
Assert.True(bothResult.OrganizationPremium);
var noneResult = resultsList.First(r => r.Id == noPremiumUser.Id);
Assert.False(noneResult.PersonalPremium);
Assert.False(noneResult.OrganizationPremium);
var multiResult = resultsList.First(r => r.Id == multiOrgUser.Id);
Assert.False(multiResult.PersonalPremium);
Assert.True(multiResult.OrganizationPremium);
var personalWithDisabledOrgResult = resultsList.First(r => r.Id == personalPremiumWithDisabledOrg.Id);
Assert.True(personalWithDisabledOrgResult.PersonalPremium);
Assert.False(personalWithDisabledOrgResult.OrganizationPremium);
}
[Theory, DatabaseData]
public async Task GetPremiumAccessByIdsAsync_EmptyList_ReturnsEmptyResult(
IUserRepository userRepository)
{
// Act
var results = await userRepository.GetPremiumAccessByIdsAsync([]);
// Assert
Assert.Empty(results);
}
}

View File

@@ -0,0 +1,151 @@
CREATE OR ALTER PROCEDURE [dbo].[Collection_UpdateWithUsers]
@Id UNIQUEIDENTIFIER,
@OrganizationId UNIQUEIDENTIFIER,
@Name VARCHAR(MAX),
@ExternalId NVARCHAR(300),
@CreationDate DATETIME2(7),
@RevisionDate DATETIME2(7),
@Users AS [dbo].[CollectionAccessSelectionType] READONLY,
@DefaultUserCollectionEmail NVARCHAR(256) = NULL,
@Type TINYINT = 0
AS
BEGIN
SET NOCOUNT ON
EXEC [dbo].[Collection_Update] @Id, @OrganizationId, @Name, @ExternalId, @CreationDate, @RevisionDate, @DefaultUserCollectionEmail, @Type
-- Users
-- Delete users that are no longer in source
DELETE
cu
FROM
[dbo].[CollectionUser] cu
LEFT JOIN
@Users u ON cu.OrganizationUserId = u.Id
WHERE
cu.CollectionId = @Id
AND u.Id IS NULL;
-- Update existing users
UPDATE
cu
SET
cu.ReadOnly = u.ReadOnly,
cu.HidePasswords = u.HidePasswords,
cu.Manage = u.Manage
FROM
[dbo].[CollectionUser] cu
INNER JOIN
@Users u ON cu.OrganizationUserId = u.Id
WHERE
cu.CollectionId = @Id
AND (
cu.ReadOnly != u.ReadOnly
OR cu.HidePasswords != u.HidePasswords
OR cu.Manage != u.Manage
);
-- Insert new users
INSERT INTO [dbo].[CollectionUser]
(
[CollectionId],
[OrganizationUserId],
[ReadOnly],
[HidePasswords],
[Manage]
)
SELECT
@Id,
u.Id,
u.ReadOnly,
u.HidePasswords,
u.Manage
FROM
@Users u
INNER JOIN
[dbo].[OrganizationUser] ou ON ou.Id = u.Id
LEFT JOIN
[dbo].[CollectionUser] cu ON cu.CollectionId = @Id AND cu.OrganizationUserId = u.Id
WHERE
ou.OrganizationId = @OrganizationId
AND cu.CollectionId IS NULL;
EXEC [dbo].[User_BumpAccountRevisionDateByCollectionId] @Id, @OrganizationId
END
GO
CREATE OR ALTER PROCEDURE [dbo].[Collection_UpdateWithGroups]
@Id UNIQUEIDENTIFIER,
@OrganizationId UNIQUEIDENTIFIER,
@Name VARCHAR(MAX),
@ExternalId NVARCHAR(300),
@CreationDate DATETIME2(7),
@RevisionDate DATETIME2(7),
@Groups AS [dbo].[CollectionAccessSelectionType] READONLY,
@DefaultUserCollectionEmail NVARCHAR(256) = NULL,
@Type TINYINT = 0
AS
BEGIN
SET NOCOUNT ON
EXEC [dbo].[Collection_Update] @Id, @OrganizationId, @Name, @ExternalId, @CreationDate, @RevisionDate, @DefaultUserCollectionEmail, @Type
-- Groups
-- Delete groups that are no longer in source
DELETE
cg
FROM
[dbo].[CollectionGroup] cg
LEFT JOIN
@Groups g ON cg.GroupId = g.Id
WHERE
cg.CollectionId = @Id
AND g.Id IS NULL;
-- Update existing groups
UPDATE
cg
SET
cg.ReadOnly = g.ReadOnly,
cg.HidePasswords = g.HidePasswords,
cg.Manage = g.Manage
FROM
[dbo].[CollectionGroup] cg
INNER JOIN
@Groups g ON cg.GroupId = g.Id
WHERE
cg.CollectionId = @Id
AND (
cg.ReadOnly != g.ReadOnly
OR cg.HidePasswords != g.HidePasswords
OR cg.Manage != g.Manage
);
-- Insert new groups
INSERT INTO [dbo].[CollectionGroup]
(
[CollectionId],
[GroupId],
[ReadOnly],
[HidePasswords],
[Manage]
)
SELECT
@Id,
g.Id,
g.ReadOnly,
g.HidePasswords,
g.Manage
FROM
@Groups g
INNER JOIN
[dbo].[Group] grp ON grp.Id = g.Id
LEFT JOIN
[dbo].[CollectionGroup] cg ON cg.CollectionId = @Id AND cg.GroupId = g.Id
WHERE
grp.OrganizationId = @OrganizationId
AND cg.CollectionId IS NULL;
EXEC [dbo].[User_BumpAccountRevisionDateByCollectionId] @Id, @OrganizationId
END
GO

View File

@@ -0,0 +1,60 @@
-- Add UsersGetPremium to IX_Organization_Enabled index to support premium access queries
IF EXISTS (
SELECT * FROM sys.indexes
WHERE name = 'IX_Organization_Enabled'
AND object_id = OBJECT_ID('[dbo].[Organization]')
)
BEGIN
CREATE NONCLUSTERED INDEX [IX_Organization_Enabled]
ON [dbo].[Organization]([Id] ASC, [Enabled] ASC)
INCLUDE ([UseTotp], [UsersGetPremium])
WITH (DROP_EXISTING = ON);
END
ELSE
BEGIN
CREATE NONCLUSTERED INDEX [IX_Organization_Enabled]
ON [dbo].[Organization]([Id] ASC, [Enabled] ASC)
INCLUDE ([UseTotp], [UsersGetPremium]);
END
GO
CREATE OR ALTER VIEW [dbo].[UserPremiumAccessView]
AS
SELECT
U.[Id],
U.[Premium] AS [PersonalPremium],
CAST(
MAX(CASE
WHEN O.[Id] IS NOT NULL THEN 1
ELSE 0
END) AS BIT
) AS [OrganizationPremium]
FROM
[dbo].[User] U
LEFT JOIN
[dbo].[OrganizationUser] OU ON OU.[UserId] = U.[Id]
LEFT JOIN
[dbo].[Organization] O ON O.[Id] = OU.[OrganizationId]
AND O.[UsersGetPremium] = 1
AND O.[Enabled] = 1
GROUP BY
U.[Id], U.[Premium];
GO
CREATE OR ALTER PROCEDURE [dbo].[User_ReadPremiumAccessByIds]
@Ids [dbo].[GuidIdArray] READONLY
AS
BEGIN
SET NOCOUNT ON
SELECT
UPA.[Id],
UPA.[PersonalPremium],
UPA.[OrganizationPremium]
FROM
[dbo].[UserPremiumAccessView] UPA
WHERE
UPA.[Id] IN (SELECT [Id] FROM @Ids)
END
GO

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,21 @@
using Microsoft.EntityFrameworkCore.Migrations;
#nullable disable
namespace Bit.MySqlMigrations.Migrations;
/// <inheritdoc />
public partial class OrganizationUsersGetPremiumIndex : Migration
{
/// <inheritdoc />
protected override void Up(MigrationBuilder migrationBuilder)
{
}
/// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder)
{
}
}

View File

@@ -274,7 +274,7 @@ namespace Bit.MySqlMigrations.Migrations
b.HasKey("Id");
b.HasIndex("Id", "Enabled")
.HasAnnotation("Npgsql:IndexInclude", new[] { "UseTotp" });
.HasAnnotation("Npgsql:IndexInclude", new[] { "UseTotp", "UsersGetPremium" });
b.ToTable("Organization", (string)null);
});

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,37 @@
using Microsoft.EntityFrameworkCore.Migrations;
#nullable disable
namespace Bit.PostgresMigrations.Migrations;
/// <inheritdoc />
public partial class OrganizationUsersGetPremiumIndex : Migration
{
/// <inheritdoc />
protected override void Up(MigrationBuilder migrationBuilder)
{
migrationBuilder.DropIndex(
name: "IX_Organization_Id_Enabled",
table: "Organization");
migrationBuilder.CreateIndex(
name: "IX_Organization_Id_Enabled",
table: "Organization",
columns: new[] { "Id", "Enabled" })
.Annotation("Npgsql:IndexInclude", new[] { "UseTotp", "UsersGetPremium" });
}
/// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder)
{
migrationBuilder.DropIndex(
name: "IX_Organization_Id_Enabled",
table: "Organization");
migrationBuilder.CreateIndex(
name: "IX_Organization_Id_Enabled",
table: "Organization",
columns: new[] { "Id", "Enabled" })
.Annotation("Npgsql:IndexInclude", new[] { "UseTotp" });
}
}

View File

@@ -277,7 +277,7 @@ namespace Bit.PostgresMigrations.Migrations
b.HasIndex("Id", "Enabled");
NpgsqlIndexBuilderExtensions.IncludeProperties(b.HasIndex("Id", "Enabled"), new[] { "UseTotp" });
NpgsqlIndexBuilderExtensions.IncludeProperties(b.HasIndex("Id", "Enabled"), new[] { "UseTotp", "UsersGetPremium" });
b.ToTable("Organization", (string)null);
});

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,21 @@
using Microsoft.EntityFrameworkCore.Migrations;
#nullable disable
namespace Bit.SqliteMigrations.Migrations;
/// <inheritdoc />
public partial class OrganizationUsersGetPremiumIndex : Migration
{
/// <inheritdoc />
protected override void Up(MigrationBuilder migrationBuilder)
{
}
/// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder)
{
}
}

View File

@@ -269,7 +269,7 @@ namespace Bit.SqliteMigrations.Migrations
b.HasKey("Id");
b.HasIndex("Id", "Enabled")
.HasAnnotation("Npgsql:IndexInclude", new[] { "UseTotp" });
.HasAnnotation("Npgsql:IndexInclude", new[] { "UseTotp", "UsersGetPremium" });
b.ToTable("Organization", (string)null);
});