mirror of
https://github.com/bitwarden/server
synced 2026-01-07 19:13:50 +00:00
Merge branch 'main' into auth/pm-22975/client-version-validator
This commit is contained in:
@@ -796,6 +796,44 @@ public class ProviderBillingService(
|
||||
}
|
||||
}
|
||||
|
||||
public async Task UpdateProviderNameAndEmail(Provider provider)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(provider.GatewayCustomerId))
|
||||
{
|
||||
logger.LogWarning(
|
||||
"Provider ({ProviderId}) has no Stripe customer to update",
|
||||
provider.Id);
|
||||
return;
|
||||
}
|
||||
|
||||
var newDisplayName = provider.DisplayName();
|
||||
|
||||
// Provider.DisplayName() can return null - handle gracefully
|
||||
if (string.IsNullOrWhiteSpace(newDisplayName))
|
||||
{
|
||||
logger.LogWarning(
|
||||
"Provider ({ProviderId}) has no name to update in Stripe",
|
||||
provider.Id);
|
||||
return;
|
||||
}
|
||||
|
||||
await stripeAdapter.UpdateCustomerAsync(provider.GatewayCustomerId,
|
||||
new CustomerUpdateOptions
|
||||
{
|
||||
Email = provider.BillingEmail,
|
||||
Description = newDisplayName,
|
||||
InvoiceSettings = new CustomerInvoiceSettingsOptions
|
||||
{
|
||||
CustomFields = [
|
||||
new CustomerInvoiceSettingsCustomFieldOptions
|
||||
{
|
||||
Name = provider.SubscriberType(),
|
||||
Value = newDisplayName
|
||||
}]
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
private Func<int, Task> CurrySeatScalingUpdate(
|
||||
Provider provider,
|
||||
ProviderPlan providerPlan,
|
||||
|
||||
@@ -2150,4 +2150,151 @@ public class ProviderBillingServiceTests
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region UpdateProviderNameAndEmail
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task UpdateProviderNameAndEmail_NullGatewayCustomerId_LogsWarningAndReturns(
|
||||
Provider provider,
|
||||
SutProvider<ProviderBillingService> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
provider.GatewayCustomerId = null;
|
||||
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
|
||||
|
||||
// Act
|
||||
await sutProvider.Sut.UpdateProviderNameAndEmail(provider);
|
||||
|
||||
// Assert
|
||||
await stripeAdapter.DidNotReceive().UpdateCustomerAsync(
|
||||
Arg.Any<string>(),
|
||||
Arg.Any<CustomerUpdateOptions>());
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task UpdateProviderNameAndEmail_EmptyGatewayCustomerId_LogsWarningAndReturns(
|
||||
Provider provider,
|
||||
SutProvider<ProviderBillingService> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
provider.GatewayCustomerId = "";
|
||||
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
|
||||
|
||||
// Act
|
||||
await sutProvider.Sut.UpdateProviderNameAndEmail(provider);
|
||||
|
||||
// Assert
|
||||
await stripeAdapter.DidNotReceive().UpdateCustomerAsync(
|
||||
Arg.Any<string>(),
|
||||
Arg.Any<CustomerUpdateOptions>());
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task UpdateProviderNameAndEmail_NullProviderName_LogsWarningAndReturns(
|
||||
Provider provider,
|
||||
SutProvider<ProviderBillingService> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
provider.Name = null;
|
||||
provider.GatewayCustomerId = "cus_test123";
|
||||
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
|
||||
|
||||
// Act
|
||||
await sutProvider.Sut.UpdateProviderNameAndEmail(provider);
|
||||
|
||||
// Assert
|
||||
await stripeAdapter.DidNotReceive().UpdateCustomerAsync(
|
||||
Arg.Any<string>(),
|
||||
Arg.Any<CustomerUpdateOptions>());
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task UpdateProviderNameAndEmail_EmptyProviderName_LogsWarningAndReturns(
|
||||
Provider provider,
|
||||
SutProvider<ProviderBillingService> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
provider.Name = "";
|
||||
provider.GatewayCustomerId = "cus_test123";
|
||||
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
|
||||
|
||||
// Act
|
||||
await sutProvider.Sut.UpdateProviderNameAndEmail(provider);
|
||||
|
||||
// Assert
|
||||
await stripeAdapter.DidNotReceive().UpdateCustomerAsync(
|
||||
Arg.Any<string>(),
|
||||
Arg.Any<CustomerUpdateOptions>());
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task UpdateProviderNameAndEmail_ValidProvider_CallsStripeWithCorrectParameters(
|
||||
Provider provider,
|
||||
SutProvider<ProviderBillingService> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
provider.Name = "Test Provider";
|
||||
provider.BillingEmail = "billing@test.com";
|
||||
provider.GatewayCustomerId = "cus_test123";
|
||||
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
|
||||
|
||||
// Act
|
||||
await sutProvider.Sut.UpdateProviderNameAndEmail(provider);
|
||||
|
||||
// Assert
|
||||
await stripeAdapter.Received(1).UpdateCustomerAsync(
|
||||
provider.GatewayCustomerId,
|
||||
Arg.Is<CustomerUpdateOptions>(options =>
|
||||
options.Email == provider.BillingEmail &&
|
||||
options.Description == provider.Name &&
|
||||
options.InvoiceSettings.CustomFields.Count == 1 &&
|
||||
options.InvoiceSettings.CustomFields[0].Name == "Provider" &&
|
||||
options.InvoiceSettings.CustomFields[0].Value == provider.Name));
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task UpdateProviderNameAndEmail_LongProviderName_UsesFullName(
|
||||
Provider provider,
|
||||
SutProvider<ProviderBillingService> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
var longName = new string('A', 50); // 50 characters
|
||||
provider.Name = longName;
|
||||
provider.BillingEmail = "billing@test.com";
|
||||
provider.GatewayCustomerId = "cus_test123";
|
||||
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
|
||||
|
||||
// Act
|
||||
await sutProvider.Sut.UpdateProviderNameAndEmail(provider);
|
||||
|
||||
// Assert
|
||||
await stripeAdapter.Received(1).UpdateCustomerAsync(
|
||||
provider.GatewayCustomerId,
|
||||
Arg.Is<CustomerUpdateOptions>(options =>
|
||||
options.InvoiceSettings.CustomFields[0].Value == longName));
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task UpdateProviderNameAndEmail_NullBillingEmail_UpdatesWithNull(
|
||||
Provider provider,
|
||||
SutProvider<ProviderBillingService> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
provider.Name = "Test Provider";
|
||||
provider.BillingEmail = null;
|
||||
provider.GatewayCustomerId = "cus_test123";
|
||||
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
|
||||
|
||||
// Act
|
||||
await sutProvider.Sut.UpdateProviderNameAndEmail(provider);
|
||||
|
||||
// Assert
|
||||
await stripeAdapter.Received(1).UpdateCustomerAsync(
|
||||
provider.GatewayCustomerId,
|
||||
Arg.Is<CustomerUpdateOptions>(options =>
|
||||
options.Email == null &&
|
||||
options.Description == provider.Name));
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
||||
|
||||
@@ -99,7 +99,7 @@ services:
|
||||
- idp
|
||||
|
||||
rabbitmq:
|
||||
image: rabbitmq:4.1.3-management
|
||||
image: rabbitmq:4.2.0-management
|
||||
ports:
|
||||
- "5672:5672"
|
||||
- "15672:15672"
|
||||
|
||||
@@ -14,6 +14,7 @@ using Bit.Core.AdminConsole.Providers.Interfaces;
|
||||
using Bit.Core.AdminConsole.Repositories;
|
||||
using Bit.Core.Billing.Enums;
|
||||
using Bit.Core.Billing.Extensions;
|
||||
using Bit.Core.Billing.Organizations.Services;
|
||||
using Bit.Core.Billing.Pricing;
|
||||
using Bit.Core.Billing.Providers.Services;
|
||||
using Bit.Core.Billing.Services;
|
||||
@@ -57,6 +58,7 @@ public class OrganizationsController : Controller
|
||||
private readonly IOrganizationInitiateDeleteCommand _organizationInitiateDeleteCommand;
|
||||
private readonly IPricingClient _pricingClient;
|
||||
private readonly IResendOrganizationInviteCommand _resendOrganizationInviteCommand;
|
||||
private readonly IOrganizationBillingService _organizationBillingService;
|
||||
|
||||
public OrganizationsController(
|
||||
IOrganizationRepository organizationRepository,
|
||||
@@ -81,7 +83,8 @@ public class OrganizationsController : Controller
|
||||
IProviderBillingService providerBillingService,
|
||||
IOrganizationInitiateDeleteCommand organizationInitiateDeleteCommand,
|
||||
IPricingClient pricingClient,
|
||||
IResendOrganizationInviteCommand resendOrganizationInviteCommand)
|
||||
IResendOrganizationInviteCommand resendOrganizationInviteCommand,
|
||||
IOrganizationBillingService organizationBillingService)
|
||||
{
|
||||
_organizationRepository = organizationRepository;
|
||||
_organizationUserRepository = organizationUserRepository;
|
||||
@@ -106,6 +109,7 @@ public class OrganizationsController : Controller
|
||||
_organizationInitiateDeleteCommand = organizationInitiateDeleteCommand;
|
||||
_pricingClient = pricingClient;
|
||||
_resendOrganizationInviteCommand = resendOrganizationInviteCommand;
|
||||
_organizationBillingService = organizationBillingService;
|
||||
}
|
||||
|
||||
[RequirePermission(Permission.Org_List_View)]
|
||||
@@ -242,6 +246,8 @@ public class OrganizationsController : Controller
|
||||
var existingOrganizationData = new Organization
|
||||
{
|
||||
Id = organization.Id,
|
||||
Name = organization.Name,
|
||||
BillingEmail = organization.BillingEmail,
|
||||
Status = organization.Status,
|
||||
PlanType = organization.PlanType,
|
||||
Seats = organization.Seats
|
||||
@@ -287,6 +293,22 @@ public class OrganizationsController : Controller
|
||||
|
||||
await _applicationCacheService.UpsertOrganizationAbilityAsync(organization);
|
||||
|
||||
// Sync name/email changes to Stripe
|
||||
if (existingOrganizationData.Name != organization.Name || existingOrganizationData.BillingEmail != organization.BillingEmail)
|
||||
{
|
||||
try
|
||||
{
|
||||
await _organizationBillingService.UpdateOrganizationNameAndEmail(organization);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex,
|
||||
"Failed to update Stripe customer for organization {OrganizationId}. Database was updated successfully.",
|
||||
organization.Id);
|
||||
TempData["Warning"] = "Organization updated successfully, but Stripe customer name/email synchronization failed.";
|
||||
}
|
||||
}
|
||||
|
||||
return RedirectToAction("Edit", new { id });
|
||||
}
|
||||
|
||||
|
||||
@@ -56,6 +56,7 @@ public class ProvidersController : Controller
|
||||
private readonly IStripeAdapter _stripeAdapter;
|
||||
private readonly IAccessControlService _accessControlService;
|
||||
private readonly ISubscriberService _subscriberService;
|
||||
private readonly ILogger<ProvidersController> _logger;
|
||||
|
||||
public ProvidersController(IOrganizationRepository organizationRepository,
|
||||
IResellerClientOrganizationSignUpCommand resellerClientOrganizationSignUpCommand,
|
||||
@@ -72,7 +73,8 @@ public class ProvidersController : Controller
|
||||
IPricingClient pricingClient,
|
||||
IStripeAdapter stripeAdapter,
|
||||
IAccessControlService accessControlService,
|
||||
ISubscriberService subscriberService)
|
||||
ISubscriberService subscriberService,
|
||||
ILogger<ProvidersController> logger)
|
||||
{
|
||||
_organizationRepository = organizationRepository;
|
||||
_resellerClientOrganizationSignUpCommand = resellerClientOrganizationSignUpCommand;
|
||||
@@ -92,6 +94,7 @@ public class ProvidersController : Controller
|
||||
_braintreeMerchantUrl = webHostEnvironment.GetBraintreeMerchantUrl();
|
||||
_braintreeMerchantId = globalSettings.Braintree.MerchantId;
|
||||
_subscriberService = subscriberService;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
[RequirePermission(Permission.Provider_List_View)]
|
||||
@@ -296,6 +299,9 @@ public class ProvidersController : Controller
|
||||
|
||||
var originalProviderStatus = provider.Enabled;
|
||||
|
||||
// Capture original billing email before modifications for Stripe sync
|
||||
var originalBillingEmail = provider.BillingEmail;
|
||||
|
||||
model.ToProvider(provider);
|
||||
|
||||
// validate the stripe ids to prevent saving a bad one
|
||||
@@ -321,6 +327,22 @@ public class ProvidersController : Controller
|
||||
await _providerService.UpdateAsync(provider);
|
||||
await _applicationCacheService.UpsertProviderAbilityAsync(provider);
|
||||
|
||||
// Sync billing email changes to Stripe
|
||||
if (!string.IsNullOrEmpty(provider.GatewayCustomerId) && originalBillingEmail != provider.BillingEmail)
|
||||
{
|
||||
try
|
||||
{
|
||||
await _providerBillingService.UpdateProviderNameAndEmail(provider);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex,
|
||||
"Failed to update Stripe customer for provider {ProviderId}. Database was updated successfully.",
|
||||
provider.Id);
|
||||
TempData["Warning"] = "Provider updated successfully, but Stripe customer email synchronization failed.";
|
||||
}
|
||||
}
|
||||
|
||||
if (!provider.IsBillable())
|
||||
{
|
||||
return RedirectToAction("Edit", new { id });
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
###############################################
|
||||
# Node.js build stage #
|
||||
###############################################
|
||||
FROM node:20-alpine3.21 AS node-build
|
||||
FROM --platform=$BUILDPLATFORM node:20-alpine3.21 AS node-build
|
||||
|
||||
WORKDIR /app
|
||||
COPY src/Admin/package*.json ./
|
||||
|
||||
@@ -5,6 +5,7 @@ using Bit.Api.AdminConsole.Models.Request.Providers;
|
||||
using Bit.Api.AdminConsole.Models.Response.Providers;
|
||||
using Bit.Core.AdminConsole.Repositories;
|
||||
using Bit.Core.AdminConsole.Services;
|
||||
using Bit.Core.Billing.Providers.Services;
|
||||
using Bit.Core.Context;
|
||||
using Bit.Core.Exceptions;
|
||||
using Bit.Core.Services;
|
||||
@@ -23,15 +24,20 @@ public class ProvidersController : Controller
|
||||
private readonly IProviderService _providerService;
|
||||
private readonly ICurrentContext _currentContext;
|
||||
private readonly GlobalSettings _globalSettings;
|
||||
private readonly IProviderBillingService _providerBillingService;
|
||||
private readonly ILogger<ProvidersController> _logger;
|
||||
|
||||
public ProvidersController(IUserService userService, IProviderRepository providerRepository,
|
||||
IProviderService providerService, ICurrentContext currentContext, GlobalSettings globalSettings)
|
||||
IProviderService providerService, ICurrentContext currentContext, GlobalSettings globalSettings,
|
||||
IProviderBillingService providerBillingService, ILogger<ProvidersController> logger)
|
||||
{
|
||||
_userService = userService;
|
||||
_providerRepository = providerRepository;
|
||||
_providerService = providerService;
|
||||
_currentContext = currentContext;
|
||||
_globalSettings = globalSettings;
|
||||
_providerBillingService = providerBillingService;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
[HttpGet("{id:guid}")]
|
||||
@@ -65,7 +71,27 @@ public class ProvidersController : Controller
|
||||
throw new NotFoundException();
|
||||
}
|
||||
|
||||
// Capture original values before modifications for Stripe sync
|
||||
var originalName = provider.Name;
|
||||
var originalBillingEmail = provider.BillingEmail;
|
||||
|
||||
await _providerService.UpdateAsync(model.ToProvider(provider, _globalSettings));
|
||||
|
||||
// Sync name/email changes to Stripe
|
||||
if (originalName != provider.Name || originalBillingEmail != provider.BillingEmail)
|
||||
{
|
||||
try
|
||||
{
|
||||
await _providerBillingService.UpdateProviderNameAndEmail(provider);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex,
|
||||
"Failed to update Stripe customer for provider {ProviderId}. Database was updated successfully.",
|
||||
provider.Id);
|
||||
}
|
||||
}
|
||||
|
||||
return new ProviderResponseModel(provider);
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#nullable disable
|
||||
|
||||
using System.ComponentModel.DataAnnotations;
|
||||
using System.Text.Json.Serialization;
|
||||
using Bit.Api.Models.Public.Response;
|
||||
using Bit.Core.AdminConsole.Entities;
|
||||
using Bit.Core.Models.Data;
|
||||
@@ -13,6 +14,12 @@ namespace Bit.Api.AdminConsole.Public.Models.Response;
|
||||
/// </summary>
|
||||
public class GroupResponseModel : GroupBaseModel, IResponseModel
|
||||
{
|
||||
[JsonConstructor]
|
||||
public GroupResponseModel()
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
public GroupResponseModel(Group group, IEnumerable<CollectionAccessSelection> collections)
|
||||
{
|
||||
if (group == null)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#nullable disable
|
||||
|
||||
using System.ComponentModel.DataAnnotations;
|
||||
using System.Text.Json.Serialization;
|
||||
using Bit.Api.AdminConsole.Public.Models.Response;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Models.Data;
|
||||
@@ -13,6 +14,12 @@ namespace Bit.Api.Models.Public.Response;
|
||||
/// </summary>
|
||||
public class CollectionResponseModel : CollectionBaseModel, IResponseModel
|
||||
{
|
||||
[JsonConstructor]
|
||||
public CollectionResponseModel()
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
public CollectionResponseModel(Collection collection, IEnumerable<CollectionAccessSelection> groups)
|
||||
{
|
||||
if (collection == null)
|
||||
|
||||
@@ -65,10 +65,11 @@ public class CollectionsController : Controller
|
||||
[ProducesResponseType(typeof(ListResponseModel<CollectionResponseModel>), (int)HttpStatusCode.OK)]
|
||||
public async Task<IActionResult> List()
|
||||
{
|
||||
var collections = await _collectionRepository.GetManySharedCollectionsByOrganizationIdAsync(
|
||||
_currentContext.OrganizationId.Value);
|
||||
// TODO: Get all CollectionGroup associations for the organization and marry them up here for the response.
|
||||
var collectionResponses = collections.Select(c => new CollectionResponseModel(c, null));
|
||||
var collections = await _collectionRepository.GetManyByOrganizationIdWithAccessAsync(_currentContext.OrganizationId.Value);
|
||||
|
||||
var collectionResponses = collections.Select(c =>
|
||||
new CollectionResponseModel(c.Item1, c.Item2.Groups));
|
||||
|
||||
var response = new ListResponseModel<CollectionResponseModel>(collectionResponses);
|
||||
return new JsonResult(response);
|
||||
}
|
||||
|
||||
@@ -39,15 +39,11 @@ public class ReconcileAdditionalStorageJob(
|
||||
logger.LogInformation("Starting ReconcileAdditionalStorageJob (live mode: {LiveMode})", liveMode);
|
||||
|
||||
var priceIds = new[] { _storageGbMonthlyPriceId, _storageGbAnnuallyPriceId, _personalStorageGbAnnuallyPriceId };
|
||||
var stripeStatusesToProcess = new[] { StripeConstants.SubscriptionStatus.Active, StripeConstants.SubscriptionStatus.Trialing, StripeConstants.SubscriptionStatus.PastDue };
|
||||
|
||||
foreach (var priceId in priceIds)
|
||||
{
|
||||
var options = new SubscriptionListOptions
|
||||
{
|
||||
Limit = 100,
|
||||
Status = StripeConstants.SubscriptionStatus.Active,
|
||||
Price = priceId
|
||||
};
|
||||
var options = new SubscriptionListOptions { Limit = 100, Price = priceId };
|
||||
|
||||
await foreach (var subscription in stripeFacade.ListSubscriptionsAutoPagingAsync(options))
|
||||
{
|
||||
@@ -64,7 +60,7 @@ public class ReconcileAdditionalStorageJob(
|
||||
failures.Count > 0
|
||||
? $", Failures: {Environment.NewLine}{string.Join(Environment.NewLine, failures)}"
|
||||
: string.Empty
|
||||
);
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -73,6 +69,12 @@ public class ReconcileAdditionalStorageJob(
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!stripeStatusesToProcess.Contains(subscription.Status))
|
||||
{
|
||||
logger.LogInformation("Skipping subscription with unsupported status: {SubscriptionId} - {Status}", subscription.Id, subscription.Status);
|
||||
continue;
|
||||
}
|
||||
|
||||
logger.LogInformation("Processing subscription: {SubscriptionId}", subscription.Id);
|
||||
subscriptionsFound++;
|
||||
|
||||
@@ -133,7 +135,7 @@ public class ReconcileAdditionalStorageJob(
|
||||
failures.Count > 0
|
||||
? $", Failures: {Environment.NewLine}{string.Join(Environment.NewLine, failures)}"
|
||||
: string.Empty
|
||||
);
|
||||
);
|
||||
}
|
||||
|
||||
private SubscriptionUpdateOptions? BuildSubscriptionUpdateOptions(
|
||||
@@ -145,15 +147,7 @@ public class ReconcileAdditionalStorageJob(
|
||||
return null;
|
||||
}
|
||||
|
||||
var updateOptions = new SubscriptionUpdateOptions
|
||||
{
|
||||
ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations,
|
||||
Metadata = new Dictionary<string, string>
|
||||
{
|
||||
[StripeConstants.MetadataKeys.StorageReconciled2025] = DateTime.UtcNow.ToString("o")
|
||||
},
|
||||
Items = []
|
||||
};
|
||||
var updateOptions = new SubscriptionUpdateOptions { ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations, Metadata = new Dictionary<string, string> { [StripeConstants.MetadataKeys.StorageReconciled2025] = DateTime.UtcNow.ToString("o") }, Items = [] };
|
||||
|
||||
var hasUpdates = false;
|
||||
|
||||
@@ -172,11 +166,7 @@ public class ReconcileAdditionalStorageJob(
|
||||
newQuantity,
|
||||
item.Price.Id);
|
||||
|
||||
updateOptions.Items.Add(new SubscriptionItemOptions
|
||||
{
|
||||
Id = item.Id,
|
||||
Quantity = newQuantity
|
||||
});
|
||||
updateOptions.Items.Add(new SubscriptionItemOptions { Id = item.Id, Quantity = newQuantity });
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -185,11 +175,7 @@ public class ReconcileAdditionalStorageJob(
|
||||
currentQuantity,
|
||||
item.Price.Id);
|
||||
|
||||
updateOptions.Items.Add(new SubscriptionItemOptions
|
||||
{
|
||||
Id = item.Id,
|
||||
Deleted = true
|
||||
});
|
||||
updateOptions.Items.Add(new SubscriptionItemOptions { Id = item.Id, Deleted = true });
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ public interface IStripeEventUtilityService
|
||||
/// <param name="userId"></param>
|
||||
/// /// <param name="providerId"></param>
|
||||
/// <returns></returns>
|
||||
Transaction FromChargeToTransaction(Charge charge, Guid? organizationId, Guid? userId, Guid? providerId);
|
||||
Task<Transaction> FromChargeToTransactionAsync(Charge charge, Guid? organizationId, Guid? userId, Guid? providerId);
|
||||
|
||||
/// <summary>
|
||||
/// Attempts to pay the specified invoice. If a customer is eligible, the invoice is paid using Braintree or Stripe.
|
||||
|
||||
@@ -20,6 +20,12 @@ public interface IStripeFacade
|
||||
RequestOptions requestOptions = null,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
IAsyncEnumerable<CustomerCashBalanceTransaction> GetCustomerCashBalanceTransactions(
|
||||
string customerId,
|
||||
CustomerCashBalanceTransactionListOptions customerCashBalanceTransactionListOptions = null,
|
||||
RequestOptions requestOptions = null,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
Task<Customer> UpdateCustomer(
|
||||
string customerId,
|
||||
CustomerUpdateOptions customerUpdateOptions = null,
|
||||
|
||||
@@ -38,7 +38,7 @@ public class ChargeRefundedHandler : IChargeRefundedHandler
|
||||
{
|
||||
// Attempt to create a transaction for the charge if it doesn't exist
|
||||
var (organizationId, userId, providerId) = await _stripeEventUtilityService.GetEntityIdsFromChargeAsync(charge);
|
||||
var tx = _stripeEventUtilityService.FromChargeToTransaction(charge, organizationId, userId, providerId);
|
||||
var tx = await _stripeEventUtilityService.FromChargeToTransactionAsync(charge, organizationId, userId, providerId);
|
||||
try
|
||||
{
|
||||
parentTransaction = await _transactionRepository.CreateAsync(tx);
|
||||
|
||||
@@ -46,7 +46,7 @@ public class ChargeSucceededHandler : IChargeSucceededHandler
|
||||
return;
|
||||
}
|
||||
|
||||
var transaction = _stripeEventUtilityService.FromChargeToTransaction(charge, organizationId, userId, providerId);
|
||||
var transaction = await _stripeEventUtilityService.FromChargeToTransactionAsync(charge, organizationId, userId, providerId);
|
||||
if (!transaction.PaymentMethodType.HasValue)
|
||||
{
|
||||
_logger.LogWarning("Charge success from unsupported source/method. {ChargeId}", charge.Id);
|
||||
|
||||
@@ -124,7 +124,7 @@ public class StripeEventUtilityService : IStripeEventUtilityService
|
||||
/// <param name="userId"></param>
|
||||
/// /// <param name="providerId"></param>
|
||||
/// <returns></returns>
|
||||
public Transaction FromChargeToTransaction(Charge charge, Guid? organizationId, Guid? userId, Guid? providerId)
|
||||
public async Task<Transaction> FromChargeToTransactionAsync(Charge charge, Guid? organizationId, Guid? userId, Guid? providerId)
|
||||
{
|
||||
var transaction = new Transaction
|
||||
{
|
||||
@@ -209,6 +209,24 @@ public class StripeEventUtilityService : IStripeEventUtilityService
|
||||
transaction.PaymentMethodType = PaymentMethodType.BankAccount;
|
||||
transaction.Details = $"ACH => {achCreditTransfer.BankName}, {achCreditTransfer.AccountNumber}";
|
||||
}
|
||||
else if (charge.PaymentMethodDetails.CustomerBalance != null)
|
||||
{
|
||||
var bankTransferType = await GetFundingBankTransferTypeAsync(charge);
|
||||
|
||||
if (!string.IsNullOrEmpty(bankTransferType))
|
||||
{
|
||||
transaction.PaymentMethodType = PaymentMethodType.BankAccount;
|
||||
transaction.Details = bankTransferType switch
|
||||
{
|
||||
"eu_bank_transfer" => "EU Bank Transfer",
|
||||
"gb_bank_transfer" => "GB Bank Transfer",
|
||||
"jp_bank_transfer" => "JP Bank Transfer",
|
||||
"mx_bank_transfer" => "MX Bank Transfer",
|
||||
"us_bank_transfer" => "US Bank Transfer",
|
||||
_ => "Bank Transfer"
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
@@ -289,20 +307,13 @@ public class StripeEventUtilityService : IStripeEventUtilityService
|
||||
}
|
||||
var btInvoiceAmount = Math.Round(invoice.AmountDue / 100M, 2);
|
||||
|
||||
var existingTransactions = organizationId.HasValue
|
||||
? await _transactionRepository.GetManyByOrganizationIdAsync(organizationId.Value)
|
||||
: userId.HasValue
|
||||
? await _transactionRepository.GetManyByUserIdAsync(userId.Value)
|
||||
: await _transactionRepository.GetManyByProviderIdAsync(providerId.Value);
|
||||
|
||||
var duplicateTimeSpan = TimeSpan.FromHours(24);
|
||||
var now = DateTime.UtcNow;
|
||||
var duplicateTransaction = existingTransactions?
|
||||
.FirstOrDefault(t => (now - t.CreationDate) < duplicateTimeSpan);
|
||||
if (duplicateTransaction != null)
|
||||
// Check if this invoice already has a Braintree transaction ID to prevent duplicate charges
|
||||
if (invoice.Metadata?.ContainsKey("btTransactionId") ?? false)
|
||||
{
|
||||
_logger.LogWarning("There is already a recent PayPal transaction ({0}). " +
|
||||
"Do not charge again to prevent possible duplicate.", duplicateTransaction.GatewayId);
|
||||
_logger.LogWarning("Invoice {InvoiceId} already has a Braintree transaction ({TransactionId}). " +
|
||||
"Do not charge again to prevent duplicate.",
|
||||
invoice.Id,
|
||||
invoice.Metadata["btTransactionId"]);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -413,4 +424,55 @@ public class StripeEventUtilityService : IStripeEventUtilityService
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Retrieves the bank transfer type that funded a charge paid via customer balance.
|
||||
/// </summary>
|
||||
/// <param name="charge">The charge to analyze.</param>
|
||||
/// <returns>
|
||||
/// The bank transfer type (e.g., "us_bank_transfer", "eu_bank_transfer") if the charge was funded
|
||||
/// by a bank transfer via customer balance, otherwise null.
|
||||
/// </returns>
|
||||
private async Task<string> GetFundingBankTransferTypeAsync(Charge charge)
|
||||
{
|
||||
if (charge is not
|
||||
{
|
||||
CustomerId: not null,
|
||||
PaymentIntentId: not null,
|
||||
PaymentMethodDetails: { Type: "customer_balance" }
|
||||
})
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
var cashBalanceTransactions = _stripeFacade.GetCustomerCashBalanceTransactions(charge.CustomerId);
|
||||
|
||||
string bankTransferType = null;
|
||||
var matchingPaymentIntentFound = false;
|
||||
|
||||
await foreach (var cashBalanceTransaction in cashBalanceTransactions)
|
||||
{
|
||||
switch (cashBalanceTransaction)
|
||||
{
|
||||
case { Type: "funded", Funded: not null }:
|
||||
{
|
||||
bankTransferType = cashBalanceTransaction.Funded.BankTransfer.Type;
|
||||
break;
|
||||
}
|
||||
case { Type: "applied_to_payment", AppliedToPayment: not null }
|
||||
when cashBalanceTransaction.AppliedToPayment.PaymentIntentId == charge.PaymentIntentId:
|
||||
{
|
||||
matchingPaymentIntentFound = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (matchingPaymentIntentFound && !string.IsNullOrEmpty(bankTransferType))
|
||||
{
|
||||
return bankTransferType;
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ public class StripeFacade : IStripeFacade
|
||||
{
|
||||
private readonly ChargeService _chargeService = new();
|
||||
private readonly CustomerService _customerService = new();
|
||||
private readonly CustomerCashBalanceTransactionService _customerCashBalanceTransactionService = new();
|
||||
private readonly EventService _eventService = new();
|
||||
private readonly InvoiceService _invoiceService = new();
|
||||
private readonly PaymentMethodService _paymentMethodService = new();
|
||||
@@ -41,6 +42,13 @@ public class StripeFacade : IStripeFacade
|
||||
CancellationToken cancellationToken = default) =>
|
||||
await _customerService.GetAsync(customerId, customerGetOptions, requestOptions, cancellationToken);
|
||||
|
||||
public IAsyncEnumerable<CustomerCashBalanceTransaction> GetCustomerCashBalanceTransactions(
|
||||
string customerId,
|
||||
CustomerCashBalanceTransactionListOptions customerCashBalanceTransactionListOptions = null,
|
||||
RequestOptions requestOptions = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
=> _customerCashBalanceTransactionService.ListAutoPagingAsync(customerId, customerCashBalanceTransactionListOptions, requestOptions, cancellationToken);
|
||||
|
||||
public async Task<Customer> UpdateCustomer(
|
||||
string customerId,
|
||||
CustomerUpdateOptions customerUpdateOptions = null,
|
||||
|
||||
@@ -528,17 +528,21 @@ public static class EventIntegrationsServiceCollectionExtensions
|
||||
/// <returns>True if all required RabbitMQ settings are present; otherwise, false.</returns>
|
||||
/// <remarks>
|
||||
/// Requires all the following settings to be configured:
|
||||
/// - EventLogging.RabbitMq.HostName
|
||||
/// - EventLogging.RabbitMq.Username
|
||||
/// - EventLogging.RabbitMq.Password
|
||||
/// - EventLogging.RabbitMq.EventExchangeName
|
||||
/// <list type="bullet">
|
||||
/// <item><description>EventLogging.RabbitMq.HostName</description></item>
|
||||
/// <item><description>EventLogging.RabbitMq.Username</description></item>
|
||||
/// <item><description>EventLogging.RabbitMq.Password</description></item>
|
||||
/// <item><description>EventLogging.RabbitMq.EventExchangeName</description></item>
|
||||
/// <item><description>EventLogging.RabbitMq.IntegrationExchangeName</description></item>
|
||||
/// </list>
|
||||
/// </remarks>
|
||||
internal static bool IsRabbitMqEnabled(GlobalSettings settings)
|
||||
{
|
||||
return CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.HostName) &&
|
||||
CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.Username) &&
|
||||
CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.Password) &&
|
||||
CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.EventExchangeName);
|
||||
CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.EventExchangeName) &&
|
||||
CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.IntegrationExchangeName);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
@@ -547,13 +551,17 @@ public static class EventIntegrationsServiceCollectionExtensions
|
||||
/// <param name="settings">The global settings containing Azure Service Bus configuration.</param>
|
||||
/// <returns>True if all required Azure Service Bus settings are present; otherwise, false.</returns>
|
||||
/// <remarks>
|
||||
/// Requires both of the following settings to be configured:
|
||||
/// - EventLogging.AzureServiceBus.ConnectionString
|
||||
/// - EventLogging.AzureServiceBus.EventTopicName
|
||||
/// Requires all of the following settings to be configured:
|
||||
/// <list type="bullet">
|
||||
/// <item><description>EventLogging.AzureServiceBus.ConnectionString</description></item>
|
||||
/// <item><description>EventLogging.AzureServiceBus.EventTopicName</description></item>
|
||||
/// <item><description>EventLogging.AzureServiceBus.IntegrationTopicName</description></item>
|
||||
/// </list>
|
||||
/// </remarks>
|
||||
internal static bool IsAzureServiceBusEnabled(GlobalSettings settings)
|
||||
{
|
||||
return CoreHelpers.SettingHasValue(settings.EventLogging.AzureServiceBus.ConnectionString) &&
|
||||
CoreHelpers.SettingHasValue(settings.EventLogging.AzureServiceBus.EventTopicName);
|
||||
CoreHelpers.SettingHasValue(settings.EventLogging.AzureServiceBus.EventTopicName) &&
|
||||
CoreHelpers.SettingHasValue(settings.EventLogging.AzureServiceBus.IntegrationTopicName);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations;
|
||||
|
||||
/// <summary>
|
||||
/// Categories of event integration failures used for classification and retry logic.
|
||||
/// </summary>
|
||||
public enum IntegrationFailureCategory
|
||||
{
|
||||
/// <summary>
|
||||
/// Service is temporarily unavailable (503, upstream outage, maintenance).
|
||||
/// </summary>
|
||||
ServiceUnavailable,
|
||||
|
||||
/// <summary>
|
||||
/// Authentication failed (401, 403, invalid_auth, token issues).
|
||||
/// </summary>
|
||||
AuthenticationFailed,
|
||||
|
||||
/// <summary>
|
||||
/// Configuration error (invalid config, channel_not_found, etc.).
|
||||
/// </summary>
|
||||
ConfigurationError,
|
||||
|
||||
/// <summary>
|
||||
/// Rate limited (429, rate_limited).
|
||||
/// </summary>
|
||||
RateLimited,
|
||||
|
||||
/// <summary>
|
||||
/// Transient error (timeouts, 500, network errors).
|
||||
/// </summary>
|
||||
TransientError,
|
||||
|
||||
/// <summary>
|
||||
/// Permanent failure unrelated to authentication/config (e.g., unrecoverable payload/format issue).
|
||||
/// </summary>
|
||||
PermanentFailure
|
||||
}
|
||||
@@ -1,16 +1,84 @@
|
||||
namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations;
|
||||
|
||||
/// <summary>
|
||||
/// Represents the result of an integration handler operation, including success status,
|
||||
/// failure categorization, and retry metadata. Use the <see cref="Succeed"/> factory method
|
||||
/// for successful operations or <see cref="Fail"/> for failures with automatic retry-ability
|
||||
/// determination based on the failure category.
|
||||
/// </summary>
|
||||
public class IntegrationHandlerResult
|
||||
{
|
||||
public IntegrationHandlerResult(bool success, IIntegrationMessage message)
|
||||
/// <summary>
|
||||
/// True if the integration send succeeded, false otherwise.
|
||||
/// </summary>
|
||||
public bool Success { get; }
|
||||
|
||||
/// <summary>
|
||||
/// The integration message that was processed.
|
||||
/// </summary>
|
||||
public IIntegrationMessage Message { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Optional UTC date/time indicating when a failed operation should be retried.
|
||||
/// Will be used by the retry queue to delay re-sending the message.
|
||||
/// Usually set based on the Retry-After header from rate-limited responses.
|
||||
/// </summary>
|
||||
public DateTime? DelayUntilDate { get; private init; }
|
||||
|
||||
/// <summary>
|
||||
/// Category of the failure. Null for successful results.
|
||||
/// </summary>
|
||||
public IntegrationFailureCategory? Category { get; private init; }
|
||||
|
||||
/// <summary>
|
||||
/// Detailed failure reason or error message. Empty for successful results.
|
||||
/// </summary>
|
||||
public string? FailureReason { get; private init; }
|
||||
|
||||
/// <summary>
|
||||
/// Indicates whether the operation is retryable.
|
||||
/// Computed from the failure category.
|
||||
/// </summary>
|
||||
public bool Retryable => Category switch
|
||||
{
|
||||
IntegrationFailureCategory.RateLimited => true,
|
||||
IntegrationFailureCategory.TransientError => true,
|
||||
IntegrationFailureCategory.ServiceUnavailable => true,
|
||||
IntegrationFailureCategory.AuthenticationFailed => false,
|
||||
IntegrationFailureCategory.ConfigurationError => false,
|
||||
IntegrationFailureCategory.PermanentFailure => false,
|
||||
null => false,
|
||||
_ => false
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// Creates a successful result.
|
||||
/// </summary>
|
||||
public static IntegrationHandlerResult Succeed(IIntegrationMessage message)
|
||||
{
|
||||
return new IntegrationHandlerResult(success: true, message: message);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Creates a failed result with a failure category and reason.
|
||||
/// </summary>
|
||||
public static IntegrationHandlerResult Fail(
|
||||
IIntegrationMessage message,
|
||||
IntegrationFailureCategory category,
|
||||
string failureReason,
|
||||
DateTime? delayUntil = null)
|
||||
{
|
||||
return new IntegrationHandlerResult(success: false, message: message)
|
||||
{
|
||||
Category = category,
|
||||
FailureReason = failureReason,
|
||||
DelayUntilDate = delayUntil
|
||||
};
|
||||
}
|
||||
|
||||
private IntegrationHandlerResult(bool success, IIntegrationMessage message)
|
||||
{
|
||||
Success = success;
|
||||
Message = message;
|
||||
}
|
||||
|
||||
public bool Success { get; set; } = false;
|
||||
public bool Retryable { get; set; } = false;
|
||||
public IIntegrationMessage Message { get; set; }
|
||||
public DateTime? DelayUntilDate { get; set; }
|
||||
public string FailureReason { get; set; } = string.Empty;
|
||||
}
|
||||
|
||||
@@ -20,6 +20,12 @@ public class OrganizationUserUserDetails : IExternal, ITwoFactorProvidersUser, I
|
||||
public string Email { get; set; }
|
||||
public string AvatarColor { get; set; }
|
||||
public string TwoFactorProviders { get; set; }
|
||||
/// <summary>
|
||||
/// Indicates whether the user has a personal premium subscription.
|
||||
/// Does not include premium access from organizations -
|
||||
/// do not use this to check whether the user can access premium features.
|
||||
/// Null when the organization user is in Invited status (UserId is null).
|
||||
/// </summary>
|
||||
public bool? Premium { get; set; }
|
||||
public OrganizationUserStatusType Status { get; set; }
|
||||
public OrganizationUserType Type { get; set; }
|
||||
|
||||
@@ -270,7 +270,9 @@ public class AcceptOrgUserCommand : IAcceptOrgUserCommand
|
||||
ICollection<OrganizationUser> allOrgUsers, User user)
|
||||
{
|
||||
var error = (await _automaticUserConfirmationPolicyEnforcementValidator.IsCompliantAsync(
|
||||
new AutomaticUserConfirmationPolicyEnforcementRequest(orgUser.OrganizationId, allOrgUsers, user)))
|
||||
new AutomaticUserConfirmationPolicyEnforcementRequest(orgUser.OrganizationId,
|
||||
allOrgUsers.Append(orgUser),
|
||||
user)))
|
||||
.Match(
|
||||
error => error.Message,
|
||||
_ => string.Empty
|
||||
|
||||
@@ -67,7 +67,7 @@ public class OrganizationUpdateCommand(
|
||||
var shouldUpdateBilling = originalName != organization.Name ||
|
||||
originalBillingEmail != organization.BillingEmail;
|
||||
|
||||
if (!shouldUpdateBilling || string.IsNullOrWhiteSpace(organization.GatewayCustomerId))
|
||||
if (!shouldUpdateBilling)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -19,7 +19,8 @@ public class AutomaticUserConfirmationPolicyEnforcementValidator(
|
||||
|
||||
var currentOrganizationUser = request.AllOrganizationUsers
|
||||
.FirstOrDefault(x => x.OrganizationId == request.OrganizationId
|
||||
&& x.UserId == request.User.Id);
|
||||
// invited users do not have a userId but will have email
|
||||
&& (x.UserId == request.User.Id || x.Email == request.User.Email));
|
||||
|
||||
if (currentOrganizationUser is null)
|
||||
{
|
||||
|
||||
@@ -29,46 +29,87 @@ public abstract class IntegrationHandlerBase<T> : IIntegrationHandler<T>
|
||||
IntegrationMessage<T> message,
|
||||
TimeProvider timeProvider)
|
||||
{
|
||||
var result = new IntegrationHandlerResult(success: response.IsSuccessStatusCode, message);
|
||||
|
||||
if (response.IsSuccessStatusCode) return result;
|
||||
|
||||
switch (response.StatusCode)
|
||||
if (response.IsSuccessStatusCode)
|
||||
{
|
||||
case HttpStatusCode.TooManyRequests:
|
||||
case HttpStatusCode.RequestTimeout:
|
||||
case HttpStatusCode.InternalServerError:
|
||||
case HttpStatusCode.BadGateway:
|
||||
case HttpStatusCode.ServiceUnavailable:
|
||||
case HttpStatusCode.GatewayTimeout:
|
||||
result.Retryable = true;
|
||||
result.FailureReason = response.ReasonPhrase ?? $"Failure with status code: {(int)response.StatusCode}";
|
||||
|
||||
if (response.Headers.TryGetValues("Retry-After", out var values))
|
||||
{
|
||||
var value = values.FirstOrDefault();
|
||||
if (int.TryParse(value, out var seconds))
|
||||
{
|
||||
// Retry-after was specified in seconds. Adjust DelayUntilDate by the requested number of seconds.
|
||||
result.DelayUntilDate = timeProvider.GetUtcNow().AddSeconds(seconds).UtcDateTime;
|
||||
}
|
||||
else if (DateTimeOffset.TryParseExact(value,
|
||||
"r", // "r" is the round-trip format: RFC1123
|
||||
CultureInfo.InvariantCulture,
|
||||
DateTimeStyles.AssumeUniversal | DateTimeStyles.AdjustToUniversal,
|
||||
out var retryDate))
|
||||
{
|
||||
// Retry-after was specified as a date. Adjust DelayUntilDate to the specified date.
|
||||
result.DelayUntilDate = retryDate.UtcDateTime;
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
result.Retryable = false;
|
||||
result.FailureReason = response.ReasonPhrase ?? $"Failure with status code {(int)response.StatusCode}";
|
||||
break;
|
||||
return IntegrationHandlerResult.Succeed(message);
|
||||
}
|
||||
|
||||
return result;
|
||||
var category = ClassifyHttpStatusCode(response.StatusCode);
|
||||
var failureReason = response.ReasonPhrase ?? $"Failure with status code {(int)response.StatusCode}";
|
||||
|
||||
if (category is not (IntegrationFailureCategory.RateLimited
|
||||
or IntegrationFailureCategory.TransientError
|
||||
or IntegrationFailureCategory.ServiceUnavailable) ||
|
||||
!response.Headers.TryGetValues("Retry-After", out var values)
|
||||
)
|
||||
{
|
||||
return IntegrationHandlerResult.Fail(message: message, category: category, failureReason: failureReason);
|
||||
}
|
||||
|
||||
// Handle Retry-After header for rate-limited and retryable errors
|
||||
DateTime? delayUntil = null;
|
||||
var value = values.FirstOrDefault();
|
||||
if (int.TryParse(value, out var seconds))
|
||||
{
|
||||
// Retry-after was specified in seconds
|
||||
delayUntil = timeProvider.GetUtcNow().AddSeconds(seconds).UtcDateTime;
|
||||
}
|
||||
else if (DateTimeOffset.TryParseExact(value,
|
||||
"r", // "r" is the round-trip format: RFC1123
|
||||
CultureInfo.InvariantCulture,
|
||||
DateTimeStyles.AssumeUniversal | DateTimeStyles.AdjustToUniversal,
|
||||
out var retryDate))
|
||||
{
|
||||
// Retry-after was specified as a date
|
||||
delayUntil = retryDate.UtcDateTime;
|
||||
}
|
||||
|
||||
return IntegrationHandlerResult.Fail(
|
||||
message,
|
||||
category,
|
||||
failureReason,
|
||||
delayUntil
|
||||
);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Classifies an <see cref="HttpStatusCode"/> as an <see cref="IntegrationFailureCategory"/> to drive
|
||||
/// retry behavior and operator-facing failure reporting.
|
||||
/// </summary>
|
||||
/// <param name="statusCode">The HTTP status code.</param>
|
||||
/// <returns>The corresponding <see cref="IntegrationFailureCategory"/>.</returns>
|
||||
protected static IntegrationFailureCategory ClassifyHttpStatusCode(HttpStatusCode statusCode)
|
||||
{
|
||||
var explicitCategory = statusCode switch
|
||||
{
|
||||
HttpStatusCode.Unauthorized => IntegrationFailureCategory.AuthenticationFailed,
|
||||
HttpStatusCode.Forbidden => IntegrationFailureCategory.AuthenticationFailed,
|
||||
HttpStatusCode.NotFound => IntegrationFailureCategory.ConfigurationError,
|
||||
HttpStatusCode.Gone => IntegrationFailureCategory.ConfigurationError,
|
||||
HttpStatusCode.MovedPermanently => IntegrationFailureCategory.ConfigurationError,
|
||||
HttpStatusCode.TemporaryRedirect => IntegrationFailureCategory.ConfigurationError,
|
||||
HttpStatusCode.PermanentRedirect => IntegrationFailureCategory.ConfigurationError,
|
||||
HttpStatusCode.TooManyRequests => IntegrationFailureCategory.RateLimited,
|
||||
HttpStatusCode.RequestTimeout => IntegrationFailureCategory.TransientError,
|
||||
HttpStatusCode.InternalServerError => IntegrationFailureCategory.TransientError,
|
||||
HttpStatusCode.BadGateway => IntegrationFailureCategory.TransientError,
|
||||
HttpStatusCode.GatewayTimeout => IntegrationFailureCategory.TransientError,
|
||||
HttpStatusCode.ServiceUnavailable => IntegrationFailureCategory.ServiceUnavailable,
|
||||
HttpStatusCode.NotImplemented => IntegrationFailureCategory.PermanentFailure,
|
||||
_ => (IntegrationFailureCategory?)null
|
||||
};
|
||||
|
||||
if (explicitCategory is not null)
|
||||
{
|
||||
return explicitCategory.Value;
|
||||
}
|
||||
|
||||
return (int)statusCode switch
|
||||
{
|
||||
>= 300 and <= 399 => IntegrationFailureCategory.ConfigurationError,
|
||||
>= 400 and <= 499 => IntegrationFailureCategory.ConfigurationError,
|
||||
>= 500 and <= 599 => IntegrationFailureCategory.ServiceUnavailable,
|
||||
_ => IntegrationFailureCategory.ServiceUnavailable
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,6 +85,17 @@ public class AzureServiceBusIntegrationListenerService<TConfiguration> : Backgro
|
||||
{
|
||||
// Non-recoverable failure or exceeded the max number of retries
|
||||
// Return false to indicate this message should be dead-lettered
|
||||
_logger.LogWarning(
|
||||
"Integration failure - non-recoverable error or max retries exceeded. " +
|
||||
"MessageId: {MessageId}, IntegrationType: {IntegrationType}, OrganizationId: {OrgId}, " +
|
||||
"FailureCategory: {Category}, Reason: {Reason}, RetryCount: {RetryCount}, MaxRetries: {MaxRetries}",
|
||||
message.MessageId,
|
||||
message.IntegrationType,
|
||||
message.OrganizationId,
|
||||
result.Category,
|
||||
result.FailureReason,
|
||||
message.RetryCount,
|
||||
_maxRetries);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,14 +106,32 @@ public class RabbitMqIntegrationListenerService<TConfiguration> : BackgroundServ
|
||||
{
|
||||
// Exceeded the max number of retries; fail and send to dead letter queue
|
||||
await _rabbitMqService.PublishToDeadLetterAsync(channel, message, cancellationToken);
|
||||
_logger.LogWarning("Max retry attempts reached. Sent to DLQ.");
|
||||
_logger.LogWarning(
|
||||
"Integration failure - max retries exceeded. " +
|
||||
"MessageId: {MessageId}, IntegrationType: {IntegrationType}, OrganizationId: {OrgId}, " +
|
||||
"FailureCategory: {Category}, Reason: {Reason}, RetryCount: {RetryCount}, MaxRetries: {MaxRetries}",
|
||||
message.MessageId,
|
||||
message.IntegrationType,
|
||||
message.OrganizationId,
|
||||
result.Category,
|
||||
result.FailureReason,
|
||||
message.RetryCount,
|
||||
_maxRetries);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Fatal error (i.e. not retryable) occurred. Send message to dead letter queue without any retries
|
||||
await _rabbitMqService.PublishToDeadLetterAsync(channel, message, cancellationToken);
|
||||
_logger.LogWarning("Non-retryable failure. Sent to DLQ.");
|
||||
_logger.LogWarning(
|
||||
"Integration failure - non-retryable. " +
|
||||
"MessageId: {MessageId}, IntegrationType: {IntegrationType}, OrganizationId: {OrgId}, " +
|
||||
"FailureCategory: {Category}, Reason: {Reason}",
|
||||
message.MessageId,
|
||||
message.IntegrationType,
|
||||
message.OrganizationId,
|
||||
result.Category,
|
||||
result.FailureReason);
|
||||
}
|
||||
|
||||
// Message has been sent to retry or dead letter queues.
|
||||
|
||||
@@ -6,15 +6,6 @@ public class SlackIntegrationHandler(
|
||||
ISlackService slackService)
|
||||
: IntegrationHandlerBase<SlackIntegrationConfigurationDetails>
|
||||
{
|
||||
private static readonly HashSet<string> _retryableErrors = new(StringComparer.Ordinal)
|
||||
{
|
||||
"internal_error",
|
||||
"message_limit_exceeded",
|
||||
"rate_limited",
|
||||
"ratelimited",
|
||||
"service_unavailable"
|
||||
};
|
||||
|
||||
public override async Task<IntegrationHandlerResult> HandleAsync(IntegrationMessage<SlackIntegrationConfigurationDetails> message)
|
||||
{
|
||||
var slackResponse = await slackService.SendSlackMessageByChannelIdAsync(
|
||||
@@ -25,24 +16,61 @@ public class SlackIntegrationHandler(
|
||||
|
||||
if (slackResponse is null)
|
||||
{
|
||||
return new IntegrationHandlerResult(success: false, message: message)
|
||||
{
|
||||
FailureReason = "Slack response was null"
|
||||
};
|
||||
return IntegrationHandlerResult.Fail(
|
||||
message,
|
||||
IntegrationFailureCategory.TransientError,
|
||||
"Slack response was null"
|
||||
);
|
||||
}
|
||||
|
||||
if (slackResponse.Ok)
|
||||
{
|
||||
return new IntegrationHandlerResult(success: true, message: message);
|
||||
return IntegrationHandlerResult.Succeed(message);
|
||||
}
|
||||
|
||||
var result = new IntegrationHandlerResult(success: false, message: message) { FailureReason = slackResponse.Error };
|
||||
var category = ClassifySlackError(slackResponse.Error);
|
||||
return IntegrationHandlerResult.Fail(
|
||||
message,
|
||||
category,
|
||||
slackResponse.Error
|
||||
);
|
||||
}
|
||||
|
||||
if (_retryableErrors.Contains(slackResponse.Error))
|
||||
/// <summary>
|
||||
/// Classifies a Slack API error code string as an <see cref="IntegrationFailureCategory"/> to drive
|
||||
/// retry behavior and operator-facing failure reporting.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// <para>
|
||||
/// Slack responses commonly return an <c>error</c> string when <c>ok</c> is false. This method maps
|
||||
/// known Slack error codes to failure categories.
|
||||
/// </para>
|
||||
/// <para>
|
||||
/// Any unrecognized error codes default to <see cref="IntegrationFailureCategory.TransientError"/> to avoid
|
||||
/// incorrectly marking new/unknown Slack failures as non-retryable.
|
||||
/// </para>
|
||||
/// </remarks>
|
||||
/// <param name="error">The Slack error code string (e.g. <c>invalid_auth</c>, <c>rate_limited</c>).</param>
|
||||
/// <returns>The corresponding <see cref="IntegrationFailureCategory"/>.</returns>
|
||||
private static IntegrationFailureCategory ClassifySlackError(string error)
|
||||
{
|
||||
return error switch
|
||||
{
|
||||
result.Retryable = true;
|
||||
}
|
||||
|
||||
return result;
|
||||
"invalid_auth" => IntegrationFailureCategory.AuthenticationFailed,
|
||||
"access_denied" => IntegrationFailureCategory.AuthenticationFailed,
|
||||
"token_expired" => IntegrationFailureCategory.AuthenticationFailed,
|
||||
"token_revoked" => IntegrationFailureCategory.AuthenticationFailed,
|
||||
"account_inactive" => IntegrationFailureCategory.AuthenticationFailed,
|
||||
"not_authed" => IntegrationFailureCategory.AuthenticationFailed,
|
||||
"channel_not_found" => IntegrationFailureCategory.ConfigurationError,
|
||||
"is_archived" => IntegrationFailureCategory.ConfigurationError,
|
||||
"rate_limited" => IntegrationFailureCategory.RateLimited,
|
||||
"ratelimited" => IntegrationFailureCategory.RateLimited,
|
||||
"message_limit_exceeded" => IntegrationFailureCategory.RateLimited,
|
||||
"internal_error" => IntegrationFailureCategory.TransientError,
|
||||
"service_unavailable" => IntegrationFailureCategory.ServiceUnavailable,
|
||||
"fatal_error" => IntegrationFailureCategory.ServiceUnavailable,
|
||||
_ => IntegrationFailureCategory.TransientError
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
using Bit.Core.AdminConsole.Models.Data.EventIntegrations;
|
||||
using System.Text.Json;
|
||||
using Bit.Core.AdminConsole.Models.Data.EventIntegrations;
|
||||
using Microsoft.Rest;
|
||||
|
||||
namespace Bit.Core.Services;
|
||||
@@ -18,24 +19,48 @@ public class TeamsIntegrationHandler(
|
||||
channelId: message.Configuration.ChannelId
|
||||
);
|
||||
|
||||
return new IntegrationHandlerResult(success: true, message: message);
|
||||
return IntegrationHandlerResult.Succeed(message);
|
||||
}
|
||||
catch (HttpOperationException ex)
|
||||
{
|
||||
var result = new IntegrationHandlerResult(success: false, message: message);
|
||||
var statusCode = (int)ex.Response.StatusCode;
|
||||
result.Retryable = statusCode is 429 or >= 500 and < 600;
|
||||
result.FailureReason = ex.Message;
|
||||
|
||||
return result;
|
||||
var category = ClassifyHttpStatusCode(ex.Response.StatusCode);
|
||||
return IntegrationHandlerResult.Fail(
|
||||
message,
|
||||
category,
|
||||
ex.Message
|
||||
);
|
||||
}
|
||||
catch (ArgumentException ex)
|
||||
{
|
||||
return IntegrationHandlerResult.Fail(
|
||||
message,
|
||||
IntegrationFailureCategory.ConfigurationError,
|
||||
ex.Message
|
||||
);
|
||||
}
|
||||
catch (UriFormatException ex)
|
||||
{
|
||||
return IntegrationHandlerResult.Fail(
|
||||
message,
|
||||
IntegrationFailureCategory.ConfigurationError,
|
||||
ex.Message
|
||||
);
|
||||
}
|
||||
catch (JsonException ex)
|
||||
{
|
||||
return IntegrationHandlerResult.Fail(
|
||||
message,
|
||||
IntegrationFailureCategory.PermanentFailure,
|
||||
ex.Message
|
||||
);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
var result = new IntegrationHandlerResult(success: false, message: message);
|
||||
result.Retryable = false;
|
||||
result.FailureReason = ex.Message;
|
||||
|
||||
return result;
|
||||
return IntegrationHandlerResult.Fail(
|
||||
message,
|
||||
IntegrationFailureCategory.TransientError,
|
||||
ex.Message
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,16 +4,37 @@
|
||||
using Bit.Core.Auth.Enums;
|
||||
using Bit.Core.Auth.Models;
|
||||
using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces;
|
||||
using Bit.Core.Billing.Premium.Queries;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Exceptions;
|
||||
using Bit.Core.Repositories;
|
||||
using Bit.Core.Services;
|
||||
|
||||
namespace Bit.Core.Auth.UserFeatures.TwoFactorAuth;
|
||||
|
||||
public class TwoFactorIsEnabledQuery(IUserRepository userRepository) : ITwoFactorIsEnabledQuery
|
||||
public class TwoFactorIsEnabledQuery : ITwoFactorIsEnabledQuery
|
||||
{
|
||||
private readonly IUserRepository _userRepository = userRepository;
|
||||
private readonly IUserRepository _userRepository;
|
||||
private readonly IHasPremiumAccessQuery _hasPremiumAccessQuery;
|
||||
private readonly IFeatureService _featureService;
|
||||
|
||||
public TwoFactorIsEnabledQuery(
|
||||
IUserRepository userRepository,
|
||||
IHasPremiumAccessQuery hasPremiumAccessQuery,
|
||||
IFeatureService featureService)
|
||||
{
|
||||
_userRepository = userRepository;
|
||||
_hasPremiumAccessQuery = hasPremiumAccessQuery;
|
||||
_featureService = featureService;
|
||||
}
|
||||
|
||||
public async Task<IEnumerable<(Guid userId, bool twoFactorIsEnabled)>> TwoFactorIsEnabledAsync(IEnumerable<Guid> userIds)
|
||||
{
|
||||
if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery))
|
||||
{
|
||||
return await TwoFactorIsEnabledVNextAsync(userIds);
|
||||
}
|
||||
|
||||
var result = new List<(Guid userId, bool hasTwoFactor)>();
|
||||
if (userIds == null || !userIds.Any())
|
||||
{
|
||||
@@ -36,6 +57,11 @@ public class TwoFactorIsEnabledQuery(IUserRepository userRepository) : ITwoFacto
|
||||
|
||||
public async Task<IEnumerable<(T user, bool twoFactorIsEnabled)>> TwoFactorIsEnabledAsync<T>(IEnumerable<T> users) where T : ITwoFactorProvidersUser
|
||||
{
|
||||
if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery))
|
||||
{
|
||||
return await TwoFactorIsEnabledVNextAsync(users);
|
||||
}
|
||||
|
||||
var userIds = users
|
||||
.Select(u => u.GetUserId())
|
||||
.Where(u => u.HasValue)
|
||||
@@ -71,13 +97,134 @@ public class TwoFactorIsEnabledQuery(IUserRepository userRepository) : ITwoFacto
|
||||
return false;
|
||||
}
|
||||
|
||||
if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery))
|
||||
{
|
||||
var userEntity = user as User ?? await _userRepository.GetByIdAsync(userId.Value);
|
||||
if (userEntity == null)
|
||||
{
|
||||
throw new NotFoundException();
|
||||
}
|
||||
|
||||
return await TwoFactorIsEnabledVNextAsync(userEntity);
|
||||
}
|
||||
|
||||
return await TwoFactorEnabledAsync(
|
||||
user.GetTwoFactorProviders(),
|
||||
async () =>
|
||||
{
|
||||
var calcUser = await _userRepository.GetCalculatedPremiumAsync(userId.Value);
|
||||
return calcUser?.HasPremiumAccess ?? false;
|
||||
});
|
||||
user.GetTwoFactorProviders(),
|
||||
async () =>
|
||||
{
|
||||
var calcUser = await _userRepository.GetCalculatedPremiumAsync(userId.Value);
|
||||
return calcUser?.HasPremiumAccess ?? false;
|
||||
});
|
||||
}
|
||||
|
||||
private async Task<IEnumerable<(Guid userId, bool twoFactorIsEnabled)>> TwoFactorIsEnabledVNextAsync(IEnumerable<Guid> userIds)
|
||||
{
|
||||
var result = new List<(Guid userId, bool hasTwoFactor)>();
|
||||
if (userIds == null || !userIds.Any())
|
||||
{
|
||||
return result;
|
||||
}
|
||||
|
||||
var users = await _userRepository.GetManyAsync([.. userIds]);
|
||||
|
||||
// Get enabled providers for each user
|
||||
var usersTwoFactorProvidersMap = users.ToDictionary(u => u.Id, GetEnabledTwoFactorProviders);
|
||||
|
||||
// Bulk fetch premium status only for users who need it (those with only premium providers)
|
||||
var userIdsNeedingPremium = usersTwoFactorProvidersMap
|
||||
.Where(kvp => kvp.Value.Any() && kvp.Value.All(TwoFactorProvider.RequiresPremium))
|
||||
.Select(kvp => kvp.Key)
|
||||
.ToList();
|
||||
|
||||
var premiumStatusMap = userIdsNeedingPremium.Count > 0
|
||||
? await _hasPremiumAccessQuery.HasPremiumAccessAsync(userIdsNeedingPremium)
|
||||
: new Dictionary<Guid, bool>();
|
||||
|
||||
foreach (var user in users)
|
||||
{
|
||||
var userTwoFactorProviders = usersTwoFactorProvidersMap[user.Id];
|
||||
|
||||
if (!userTwoFactorProviders.Any())
|
||||
{
|
||||
result.Add((user.Id, false));
|
||||
continue;
|
||||
}
|
||||
|
||||
// User has providers. If they're in the premium check map, verify premium status
|
||||
var twoFactorIsEnabled = !premiumStatusMap.TryGetValue(user.Id, out var hasPremium) || hasPremium;
|
||||
result.Add((user.Id, twoFactorIsEnabled));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private async Task<IEnumerable<(T user, bool twoFactorIsEnabled)>> TwoFactorIsEnabledVNextAsync<T>(IEnumerable<T> users)
|
||||
where T : ITwoFactorProvidersUser
|
||||
{
|
||||
var userIds = users
|
||||
.Select(u => u.GetUserId())
|
||||
.Where(u => u.HasValue)
|
||||
.Select(u => u.Value)
|
||||
.ToList();
|
||||
|
||||
var twoFactorResults = await TwoFactorIsEnabledVNextAsync(userIds);
|
||||
|
||||
var result = new List<(T user, bool twoFactorIsEnabled)>();
|
||||
|
||||
foreach (var user in users)
|
||||
{
|
||||
var userId = user.GetUserId();
|
||||
if (userId.HasValue)
|
||||
{
|
||||
var hasTwoFactor = twoFactorResults.FirstOrDefault(res => res.userId == userId.Value).twoFactorIsEnabled;
|
||||
result.Add((user, hasTwoFactor));
|
||||
}
|
||||
else
|
||||
{
|
||||
result.Add((user, false));
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private async Task<bool> TwoFactorIsEnabledVNextAsync(User user)
|
||||
{
|
||||
var enabledProviders = GetEnabledTwoFactorProviders(user);
|
||||
|
||||
if (!enabledProviders.Any())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// If all providers require premium, check if user has premium access
|
||||
if (enabledProviders.All(TwoFactorProvider.RequiresPremium))
|
||||
{
|
||||
return await _hasPremiumAccessQuery.HasPremiumAccessAsync(user.Id);
|
||||
}
|
||||
|
||||
// User has at least one non-premium provider
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets all enabled two-factor provider types for a user.
|
||||
/// </summary>
|
||||
/// <param name="user">user with two factor providers</param>
|
||||
/// <returns>list of enabled provider types</returns>
|
||||
private static IList<TwoFactorProviderType> GetEnabledTwoFactorProviders(User user)
|
||||
{
|
||||
var providers = user.GetTwoFactorProviders();
|
||||
|
||||
if (providers == null || providers.Count == 0)
|
||||
{
|
||||
return Array.Empty<TwoFactorProviderType>();
|
||||
}
|
||||
|
||||
// TODO: PM-21210: In practice we don't save disabled providers to the database, worth looking into.
|
||||
return (from provider in providers
|
||||
where provider.Value?.Enabled ?? false
|
||||
select provider.Key).ToList();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
||||
@@ -6,6 +6,7 @@ using Bit.Core.Billing.Organizations.Queries;
|
||||
using Bit.Core.Billing.Organizations.Services;
|
||||
using Bit.Core.Billing.Payment;
|
||||
using Bit.Core.Billing.Premium.Commands;
|
||||
using Bit.Core.Billing.Premium.Queries;
|
||||
using Bit.Core.Billing.Pricing;
|
||||
using Bit.Core.Billing.Services;
|
||||
using Bit.Core.Billing.Services.Implementations;
|
||||
@@ -31,6 +32,7 @@ public static class ServiceCollectionExtensions
|
||||
services.AddPaymentOperations();
|
||||
services.AddOrganizationLicenseCommandsQueries();
|
||||
services.AddPremiumCommands();
|
||||
services.AddPremiumQueries();
|
||||
services.AddTransient<IGetOrganizationMetadataQuery, GetOrganizationMetadataQuery>();
|
||||
services.AddTransient<IGetOrganizationWarningsQuery, GetOrganizationWarningsQuery>();
|
||||
services.AddTransient<IRestartSubscriptionCommand, RestartSubscriptionCommand>();
|
||||
@@ -50,4 +52,9 @@ public static class ServiceCollectionExtensions
|
||||
services.AddScoped<ICreatePremiumSelfHostedSubscriptionCommand, CreatePremiumSelfHostedSubscriptionCommand>();
|
||||
services.AddTransient<IPreviewPremiumTaxCommand, PreviewPremiumTaxCommand>();
|
||||
}
|
||||
|
||||
private static void AddPremiumQueries(this IServiceCollection services)
|
||||
{
|
||||
services.AddScoped<IHasPremiumAccessQuery, HasPremiumAccessQuery>();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,10 +61,6 @@ public interface IOrganizationBillingService
|
||||
/// Updates the organization name and email on the Stripe customer entry.
|
||||
/// This only updates Stripe, not the Bitwarden database.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// The caller should ensure that the organization has a GatewayCustomerId before calling this method.
|
||||
/// </remarks>
|
||||
/// <param name="organization">The organization to update in Stripe.</param>
|
||||
/// <exception cref="BillingException">Thrown when the organization does not have a GatewayCustomerId.</exception>
|
||||
Task UpdateOrganizationNameAndEmail(Organization organization);
|
||||
}
|
||||
|
||||
@@ -177,13 +177,25 @@ public class OrganizationBillingService(
|
||||
|
||||
public async Task UpdateOrganizationNameAndEmail(Organization organization)
|
||||
{
|
||||
if (organization.GatewayCustomerId is null)
|
||||
if (string.IsNullOrWhiteSpace(organization.GatewayCustomerId))
|
||||
{
|
||||
throw new BillingException("Cannot update an organization in Stripe without a GatewayCustomerId.");
|
||||
logger.LogWarning(
|
||||
"Organization ({OrganizationId}) has no Stripe customer to update",
|
||||
organization.Id);
|
||||
return;
|
||||
}
|
||||
|
||||
var newDisplayName = organization.DisplayName();
|
||||
|
||||
// Organization.DisplayName() can return null - handle gracefully
|
||||
if (string.IsNullOrWhiteSpace(newDisplayName))
|
||||
{
|
||||
logger.LogWarning(
|
||||
"Organization ({OrganizationId}) has no name to update in Stripe",
|
||||
organization.Id);
|
||||
return;
|
||||
}
|
||||
|
||||
await stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId,
|
||||
new CustomerUpdateOptions
|
||||
{
|
||||
@@ -196,9 +208,7 @@ public class OrganizationBillingService(
|
||||
new CustomerInvoiceSettingsCustomFieldOptions
|
||||
{
|
||||
Name = organization.SubscriberType(),
|
||||
Value = newDisplayName.Length <= 30
|
||||
? newDisplayName
|
||||
: newDisplayName[..30]
|
||||
Value = newDisplayName
|
||||
}]
|
||||
},
|
||||
});
|
||||
|
||||
29
src/Core/Billing/Premium/Models/UserPremiumAccess.cs
Normal file
29
src/Core/Billing/Premium/Models/UserPremiumAccess.cs
Normal file
@@ -0,0 +1,29 @@
|
||||
namespace Bit.Core.Billing.Premium.Models;
|
||||
|
||||
/// <summary>
|
||||
/// Represents user premium access status from personal subscriptions and organization memberships.
|
||||
/// </summary>
|
||||
public class UserPremiumAccess
|
||||
{
|
||||
/// <summary>
|
||||
/// The unique identifier for the user.
|
||||
/// </summary>
|
||||
public Guid Id { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Indicates whether the user has a personal premium subscription.
|
||||
/// This does NOT include premium access from organizations.
|
||||
/// </summary>
|
||||
public bool PersonalPremium { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Indicates whether the user has premium access through any organization membership.
|
||||
/// This is true if the user is a member of at least one enabled organization that grants premium access to users.
|
||||
/// </summary>
|
||||
public bool OrganizationPremium { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Indicates whether the user has premium access from any source (personal subscription or organization).
|
||||
/// </summary>
|
||||
public bool HasPremiumAccess => PersonalPremium || OrganizationPremium;
|
||||
}
|
||||
49
src/Core/Billing/Premium/Queries/HasPremiumAccessQuery.cs
Normal file
49
src/Core/Billing/Premium/Queries/HasPremiumAccessQuery.cs
Normal file
@@ -0,0 +1,49 @@
|
||||
using Bit.Core.Exceptions;
|
||||
using Bit.Core.Repositories;
|
||||
|
||||
namespace Bit.Core.Billing.Premium.Queries;
|
||||
|
||||
public class HasPremiumAccessQuery : IHasPremiumAccessQuery
|
||||
{
|
||||
private readonly IUserRepository _userRepository;
|
||||
|
||||
public HasPremiumAccessQuery(IUserRepository userRepository)
|
||||
{
|
||||
_userRepository = userRepository;
|
||||
}
|
||||
|
||||
public async Task<bool> HasPremiumAccessAsync(Guid userId)
|
||||
{
|
||||
var user = await _userRepository.GetPremiumAccessAsync(userId);
|
||||
if (user == null)
|
||||
{
|
||||
throw new NotFoundException();
|
||||
}
|
||||
|
||||
return user.HasPremiumAccess;
|
||||
}
|
||||
|
||||
public async Task<Dictionary<Guid, bool>> HasPremiumAccessAsync(IEnumerable<Guid> userIds)
|
||||
{
|
||||
var distinctUserIds = userIds.Distinct().ToList();
|
||||
var usersWithPremium = await _userRepository.GetPremiumAccessByIdsAsync(distinctUserIds);
|
||||
|
||||
if (usersWithPremium.Count() != distinctUserIds.Count)
|
||||
{
|
||||
throw new NotFoundException();
|
||||
}
|
||||
|
||||
return usersWithPremium.ToDictionary(u => u.Id, u => u.HasPremiumAccess);
|
||||
}
|
||||
|
||||
public async Task<bool> HasPremiumFromOrganizationAsync(Guid userId)
|
||||
{
|
||||
var user = await _userRepository.GetPremiumAccessAsync(userId);
|
||||
if (user == null)
|
||||
{
|
||||
throw new NotFoundException();
|
||||
}
|
||||
|
||||
return user.OrganizationPremium;
|
||||
}
|
||||
}
|
||||
30
src/Core/Billing/Premium/Queries/IHasPremiumAccessQuery.cs
Normal file
30
src/Core/Billing/Premium/Queries/IHasPremiumAccessQuery.cs
Normal file
@@ -0,0 +1,30 @@
|
||||
namespace Bit.Core.Billing.Premium.Queries;
|
||||
|
||||
/// <summary>
|
||||
/// Centralized query for checking if users have premium access through personal subscriptions or organizations.
|
||||
/// Note: Different from User.Premium which only checks personal subscriptions.
|
||||
/// </summary>
|
||||
public interface IHasPremiumAccessQuery
|
||||
{
|
||||
/// <summary>
|
||||
/// Checks if a user has premium access (personal or organization).
|
||||
/// </summary>
|
||||
/// <param name="userId">The user ID to check</param>
|
||||
/// <returns>True if user can access premium features</returns>
|
||||
Task<bool> HasPremiumAccessAsync(Guid userId);
|
||||
|
||||
/// <summary>
|
||||
/// Checks premium access for multiple users.
|
||||
/// </summary>
|
||||
/// <param name="userIds">The user IDs to check</param>
|
||||
/// <returns>Dictionary mapping user IDs to their premium access status</returns>
|
||||
Task<Dictionary<Guid, bool>> HasPremiumAccessAsync(IEnumerable<Guid> userIds);
|
||||
|
||||
/// <summary>
|
||||
/// Checks if a user belongs to any organization that grants premium (enabled org with UsersGetPremium).
|
||||
/// Returns true regardless of personal subscription. Useful for UI decisions like showing subscription options.
|
||||
/// </summary>
|
||||
/// <param name="userId">The user ID to check</param>
|
||||
/// <returns>True if user is in any organization that grants premium</returns>
|
||||
Task<bool> HasPremiumFromOrganizationAsync(Guid userId);
|
||||
}
|
||||
@@ -113,4 +113,11 @@ public interface IProviderBillingService
|
||||
TaxInformation taxInformation);
|
||||
|
||||
Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command);
|
||||
|
||||
/// <summary>
|
||||
/// Updates the provider name and email on the Stripe customer entry.
|
||||
/// This only updates Stripe, not the Bitwarden database.
|
||||
/// </summary>
|
||||
/// <param name="provider">The provider to update in Stripe.</param>
|
||||
Task UpdateProviderNameAndEmail(Provider provider);
|
||||
}
|
||||
|
||||
@@ -143,6 +143,7 @@ public static class FeatureFlagKeys
|
||||
public const string BlockClaimedDomainAccountCreation = "pm-28297-block-uninvited-claimed-domain-registration";
|
||||
public const string IncreaseBulkReinviteLimitForCloud = "pm-28251-increase-bulk-reinvite-limit-for-cloud";
|
||||
public const string BulkRevokeUsersV2 = "pm-28456-bulk-revoke-users-v2";
|
||||
public const string PremiumAccessQuery = "pm-21411-premium-access-query";
|
||||
|
||||
/* Architecture */
|
||||
public const string DesktopMigrationMilestone1 = "desktop-ui-migration-milestone-1";
|
||||
|
||||
@@ -70,6 +70,11 @@ public class User : ITableObject<Guid>, IStorableSubscriber, IRevisable, ITwoFac
|
||||
/// The security state is a signed object attesting to the version of the user's account.
|
||||
/// </summary>
|
||||
public string? SecurityState { get; set; }
|
||||
/// <summary>
|
||||
/// Indicates whether the user has a personal premium subscription.
|
||||
/// Does not include premium access from organizations -
|
||||
/// do not use this to check whether the user can access premium features.
|
||||
/// </summary>
|
||||
public bool Premium { get; set; }
|
||||
public DateTime? PremiumExpirationDate { get; set; }
|
||||
public DateTime? RenewalReminderDate { get; set; }
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Billing.Premium.Models;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.KeyManagement.Models.Data;
|
||||
using Bit.Core.KeyManagement.UserKey;
|
||||
using Bit.Core.Models.Data;
|
||||
@@ -24,6 +25,7 @@ public interface IUserRepository : IRepository<User, Guid>
|
||||
/// Retrieves the data for the requested user IDs and includes an additional property indicating
|
||||
/// whether the user has premium access directly or through an organization.
|
||||
/// </summary>
|
||||
[Obsolete("Use GetPremiumAccessByIdsAsync instead. This method will be removed in a future version.")]
|
||||
Task<IEnumerable<UserWithCalculatedPremium>> GetManyWithCalculatedPremiumAsync(IEnumerable<Guid> ids);
|
||||
/// <summary>
|
||||
/// Retrieves the data for the requested user ID and includes additional property indicating
|
||||
@@ -34,8 +36,23 @@ public interface IUserRepository : IRepository<User, Guid>
|
||||
/// </summary>
|
||||
/// <param name="userId">The user ID to retrieve data for.</param>
|
||||
/// <returns>User data with calculated premium access; null if nothing is found</returns>
|
||||
[Obsolete("Use GetPremiumAccessAsync instead. This method will be removed in a future version.")]
|
||||
Task<UserWithCalculatedPremium?> GetCalculatedPremiumAsync(Guid userId);
|
||||
/// <summary>
|
||||
/// Retrieves premium access status for multiple users.
|
||||
/// For internal use - consumers should use IHasPremiumAccessQuery instead.
|
||||
/// </summary>
|
||||
/// <param name="ids">The user IDs to check</param>
|
||||
/// <returns>Collection of UserPremiumAccess objects containing premium status information</returns>
|
||||
Task<IEnumerable<UserPremiumAccess>> GetPremiumAccessByIdsAsync(IEnumerable<Guid> ids);
|
||||
/// <summary>
|
||||
/// Retrieves premium access status for a single user.
|
||||
/// For internal use - consumers should use IHasPremiumAccessQuery instead.
|
||||
/// </summary>
|
||||
/// <param name="userId">The user ID to check</param>
|
||||
/// <returns>UserPremiumAccess object containing premium status information, or null if user not found</returns>
|
||||
Task<UserPremiumAccess?> GetPremiumAccessAsync(Guid userId);
|
||||
/// <summary>
|
||||
/// Sets a new user key and updates all encrypted data.
|
||||
/// <para>Warning: Any user key encrypted data not included will be lost.</para>
|
||||
/// </summary>
|
||||
|
||||
@@ -60,7 +60,7 @@ public interface IUserService
|
||||
/// <summary>
|
||||
/// Checks if the user has access to premium features, either through a personal subscription or through an organization.
|
||||
///
|
||||
/// This is the preferred way to definitively know if a user has access to premium features.
|
||||
/// This is the preferred way to definitively know if a user has access to premium features when you already have the User object.
|
||||
/// </summary>
|
||||
/// <param name="user">user being acted on</param>
|
||||
/// <returns>true if they can access premium; false otherwise.</returns>
|
||||
@@ -74,6 +74,7 @@ public interface IUserService
|
||||
/// </summary>
|
||||
/// <param name="user">user being acted on</param>
|
||||
/// <returns>true if they can access premium because of organization membership; false otherwise.</returns>
|
||||
[Obsolete("Use IHasPremiumAccessQuery.HasPremiumFromOrganizationAsync instead. This method will be removed in a future version.")]
|
||||
Task<bool> HasPremiumFromOrganization(User user);
|
||||
Task<string> GenerateSignInTokenAsync(User user, string purpose);
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces;
|
||||
using Bit.Core.Billing.Models;
|
||||
using Bit.Core.Billing.Models.Business;
|
||||
using Bit.Core.Billing.Models.Sales;
|
||||
using Bit.Core.Billing.Premium.Queries;
|
||||
using Bit.Core.Billing.Pricing;
|
||||
using Bit.Core.Billing.Services;
|
||||
using Bit.Core.Billing.Tax.Models;
|
||||
@@ -73,6 +74,7 @@ public class UserService : UserManager<User>, IUserService
|
||||
private readonly IDistributedCache _distributedCache;
|
||||
private readonly IPolicyRequirementQuery _policyRequirementQuery;
|
||||
private readonly IPricingClient _pricingClient;
|
||||
private readonly IHasPremiumAccessQuery _hasPremiumAccessQuery;
|
||||
|
||||
public UserService(
|
||||
IUserRepository userRepository,
|
||||
@@ -108,7 +110,8 @@ public class UserService : UserManager<User>, IUserService
|
||||
ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery,
|
||||
IDistributedCache distributedCache,
|
||||
IPolicyRequirementQuery policyRequirementQuery,
|
||||
IPricingClient pricingClient)
|
||||
IPricingClient pricingClient,
|
||||
IHasPremiumAccessQuery hasPremiumAccessQuery)
|
||||
: base(
|
||||
store,
|
||||
optionsAccessor,
|
||||
@@ -149,6 +152,7 @@ public class UserService : UserManager<User>, IUserService
|
||||
_distributedCache = distributedCache;
|
||||
_policyRequirementQuery = policyRequirementQuery;
|
||||
_pricingClient = pricingClient;
|
||||
_hasPremiumAccessQuery = hasPremiumAccessQuery;
|
||||
}
|
||||
|
||||
public Guid? GetProperUserId(ClaimsPrincipal principal)
|
||||
@@ -1112,6 +1116,11 @@ public class UserService : UserManager<User>, IUserService
|
||||
return false;
|
||||
}
|
||||
|
||||
if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery))
|
||||
{
|
||||
return user.Premium || await _hasPremiumAccessQuery.HasPremiumFromOrganizationAsync(userId.Value);
|
||||
}
|
||||
|
||||
return user.Premium || await HasPremiumFromOrganization(user);
|
||||
}
|
||||
|
||||
@@ -1123,6 +1132,11 @@ public class UserService : UserManager<User>, IUserService
|
||||
return false;
|
||||
}
|
||||
|
||||
if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery))
|
||||
{
|
||||
return await _hasPremiumAccessQuery.HasPremiumFromOrganizationAsync(userId.Value);
|
||||
}
|
||||
|
||||
// orgUsers in the Invited status are not associated with a userId yet, so this will get
|
||||
// orgUsers in Accepted and Confirmed states only
|
||||
var orgUsers = await _organizationUserRepository.GetManyByUserAsync(userId.Value);
|
||||
|
||||
@@ -703,7 +703,6 @@ public abstract class BaseRequestValidator<T> where T : class
|
||||
|
||||
customResponse.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(user));
|
||||
customResponse.Add("ForcePasswordReset", user.ForcePasswordReset);
|
||||
customResponse.Add("ResetMasterPassword", string.IsNullOrWhiteSpace(user.MasterPassword));
|
||||
customResponse.Add("Kdf", (byte)user.Kdf);
|
||||
customResponse.Add("KdfIterations", user.KdfIterations);
|
||||
customResponse.Add("KdfMemory", user.KdfMemory);
|
||||
|
||||
@@ -4,7 +4,6 @@ using Bit.Core;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
|
||||
using Bit.Core.AdminConsole.Services;
|
||||
using Bit.Core.Auth.IdentityServer;
|
||||
using Bit.Core.Auth.Models.Api.Response;
|
||||
using Bit.Core.Auth.Repositories;
|
||||
using Bit.Core.Context;
|
||||
using Bit.Core.Entities;
|
||||
@@ -154,23 +153,7 @@ public class CustomTokenRequestValidator : BaseRequestValidator<CustomTokenReque
|
||||
{
|
||||
// KeyConnectorUrl is configured in the CLI client, we just need to tell the client to use it
|
||||
context.Result.CustomResponse["ApiUseKeyConnector"] = true;
|
||||
context.Result.CustomResponse["ResetMasterPassword"] = false;
|
||||
}
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
// Key connector data should have already been set in the decryption options
|
||||
// for backwards compatibility we set them this way too. We can eventually get rid of this once we clean up
|
||||
// ResetMasterPassword
|
||||
if (!context.Result.CustomResponse.TryGetValue("UserDecryptionOptions", out var userDecryptionOptionsObj) ||
|
||||
userDecryptionOptionsObj is not UserDecryptionOptions userDecryptionOptions)
|
||||
{
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
if (userDecryptionOptions is { KeyConnectorOption: { } })
|
||||
{
|
||||
context.Result.CustomResponse["ResetMasterPassword"] = false;
|
||||
}
|
||||
|
||||
return Task.CompletedTask;
|
||||
|
||||
@@ -226,7 +226,6 @@ public class CollectionRepository : Repository<Collection, Guid>, ICollectionRep
|
||||
{
|
||||
obj.SetNewId();
|
||||
|
||||
|
||||
var objWithGroupsAndUsers = JsonSerializer.Deserialize<CollectionWithGroupsAndUsers>(JsonSerializer.Serialize(obj))!;
|
||||
|
||||
objWithGroupsAndUsers.Groups = groups != null ? groups.ToArrayTVP() : Enumerable.Empty<CollectionAccessSelection>().ToArrayTVP();
|
||||
@@ -243,18 +242,52 @@ public class CollectionRepository : Repository<Collection, Guid>, ICollectionRep
|
||||
|
||||
public async Task ReplaceAsync(Collection obj, IEnumerable<CollectionAccessSelection>? groups, IEnumerable<CollectionAccessSelection>? users)
|
||||
{
|
||||
var objWithGroupsAndUsers = JsonSerializer.Deserialize<CollectionWithGroupsAndUsers>(JsonSerializer.Serialize(obj))!;
|
||||
|
||||
objWithGroupsAndUsers.Groups = groups != null ? groups.ToArrayTVP() : Enumerable.Empty<CollectionAccessSelection>().ToArrayTVP();
|
||||
objWithGroupsAndUsers.Users = users != null ? users.ToArrayTVP() : Enumerable.Empty<CollectionAccessSelection>().ToArrayTVP();
|
||||
|
||||
using (var connection = new SqlConnection(ConnectionString))
|
||||
await using var connection = new SqlConnection(ConnectionString);
|
||||
await connection.OpenAsync();
|
||||
await using var transaction = await connection.BeginTransactionAsync();
|
||||
try
|
||||
{
|
||||
var results = await connection.ExecuteAsync(
|
||||
$"[{Schema}].[Collection_UpdateWithGroupsAndUsers]",
|
||||
objWithGroupsAndUsers,
|
||||
commandType: CommandType.StoredProcedure);
|
||||
if (groups == null && users == null)
|
||||
{
|
||||
await connection.ExecuteAsync(
|
||||
$"[{Schema}].[Collection_Update]",
|
||||
obj,
|
||||
commandType: CommandType.StoredProcedure,
|
||||
transaction: transaction);
|
||||
}
|
||||
else if (groups != null && users == null)
|
||||
{
|
||||
await connection.ExecuteAsync(
|
||||
$"[{Schema}].[Collection_UpdateWithGroups]",
|
||||
new CollectionWithGroups(obj, groups),
|
||||
commandType: CommandType.StoredProcedure,
|
||||
transaction: transaction);
|
||||
}
|
||||
else if (groups == null && users != null)
|
||||
{
|
||||
await connection.ExecuteAsync(
|
||||
$"[{Schema}].[Collection_UpdateWithUsers]",
|
||||
new CollectionWithUsers(obj, users),
|
||||
commandType: CommandType.StoredProcedure,
|
||||
transaction: transaction);
|
||||
}
|
||||
else if (groups != null && users != null)
|
||||
{
|
||||
await connection.ExecuteAsync(
|
||||
$"[{Schema}].[Collection_UpdateWithGroupsAndUsers]",
|
||||
new CollectionWithGroupsAndUsers(obj, groups, users),
|
||||
commandType: CommandType.StoredProcedure,
|
||||
transaction: transaction);
|
||||
}
|
||||
|
||||
await transaction.CommitAsync();
|
||||
}
|
||||
catch
|
||||
{
|
||||
await transaction.RollbackAsync();
|
||||
throw;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public async Task DeleteManyAsync(IEnumerable<Guid> collectionIds)
|
||||
@@ -424,9 +457,70 @@ public class CollectionRepository : Repository<Collection, Guid>, ICollectionRep
|
||||
|
||||
public class CollectionWithGroupsAndUsers : Collection
|
||||
{
|
||||
public CollectionWithGroupsAndUsers() { }
|
||||
|
||||
public CollectionWithGroupsAndUsers(Collection collection,
|
||||
IEnumerable<CollectionAccessSelection> groups,
|
||||
IEnumerable<CollectionAccessSelection> users)
|
||||
{
|
||||
Id = collection.Id;
|
||||
Name = collection.Name;
|
||||
OrganizationId = collection.OrganizationId;
|
||||
CreationDate = collection.CreationDate;
|
||||
RevisionDate = collection.RevisionDate;
|
||||
Type = collection.Type;
|
||||
ExternalId = collection.ExternalId;
|
||||
DefaultUserCollectionEmail = collection.DefaultUserCollectionEmail;
|
||||
Groups = groups.ToArrayTVP();
|
||||
Users = users.ToArrayTVP();
|
||||
}
|
||||
|
||||
[DisallowNull]
|
||||
public DataTable? Groups { get; set; }
|
||||
[DisallowNull]
|
||||
public DataTable? Users { get; set; }
|
||||
}
|
||||
|
||||
public class CollectionWithGroups : Collection
|
||||
{
|
||||
public CollectionWithGroups() { }
|
||||
|
||||
public CollectionWithGroups(Collection collection, IEnumerable<CollectionAccessSelection> groups)
|
||||
{
|
||||
Id = collection.Id;
|
||||
Name = collection.Name;
|
||||
OrganizationId = collection.OrganizationId;
|
||||
CreationDate = collection.CreationDate;
|
||||
RevisionDate = collection.RevisionDate;
|
||||
Type = collection.Type;
|
||||
ExternalId = collection.ExternalId;
|
||||
DefaultUserCollectionEmail = collection.DefaultUserCollectionEmail;
|
||||
Groups = groups.ToArrayTVP();
|
||||
}
|
||||
|
||||
[DisallowNull]
|
||||
public DataTable? Groups { get; set; }
|
||||
}
|
||||
|
||||
public class CollectionWithUsers : Collection
|
||||
{
|
||||
public CollectionWithUsers() { }
|
||||
|
||||
public CollectionWithUsers(Collection collection, IEnumerable<CollectionAccessSelection> users)
|
||||
{
|
||||
|
||||
Id = collection.Id;
|
||||
Name = collection.Name;
|
||||
OrganizationId = collection.OrganizationId;
|
||||
CreationDate = collection.CreationDate;
|
||||
RevisionDate = collection.RevisionDate;
|
||||
Type = collection.Type;
|
||||
ExternalId = collection.ExternalId;
|
||||
DefaultUserCollectionEmail = collection.DefaultUserCollectionEmail;
|
||||
Users = users.ToArrayTVP();
|
||||
}
|
||||
|
||||
[DisallowNull]
|
||||
public DataTable? Users { get; set; }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
using System.Data;
|
||||
using System.Text.Json;
|
||||
using Bit.Core;
|
||||
using Bit.Core.Billing.Premium.Models;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.KeyManagement.Models.Data;
|
||||
using Bit.Core.KeyManagement.UserKey;
|
||||
@@ -381,6 +382,25 @@ public class UserRepository : Repository<User, Guid>, IUserRepository
|
||||
return result.SingleOrDefault();
|
||||
}
|
||||
|
||||
public async Task<IEnumerable<UserPremiumAccess>> GetPremiumAccessByIdsAsync(IEnumerable<Guid> ids)
|
||||
{
|
||||
using (var connection = new SqlConnection(ReadOnlyConnectionString))
|
||||
{
|
||||
var results = await connection.QueryAsync<UserPremiumAccess>(
|
||||
$"[{Schema}].[{Table}_ReadPremiumAccessByIds]",
|
||||
new { Ids = ids.ToGuidIdArrayTVP() },
|
||||
commandType: CommandType.StoredProcedure);
|
||||
|
||||
return results.ToList();
|
||||
}
|
||||
}
|
||||
|
||||
public async Task<UserPremiumAccess?> GetPremiumAccessAsync(Guid userId)
|
||||
{
|
||||
var result = await GetPremiumAccessByIdsAsync([userId]);
|
||||
return result.SingleOrDefault();
|
||||
}
|
||||
|
||||
private async Task ProtectDataAndSaveAsync(User user, Func<Task> saveTask)
|
||||
{
|
||||
if (user == null)
|
||||
|
||||
@@ -18,7 +18,7 @@ public class OrganizationEntityTypeConfiguration : IEntityTypeConfiguration<Orga
|
||||
|
||||
NpgsqlIndexBuilderExtensions.IncludeProperties(
|
||||
builder.HasIndex(o => new { o.Id, o.Enabled }),
|
||||
o => o.UseTotp);
|
||||
o => new { o.UseTotp, o.UsersGetPremium });
|
||||
|
||||
builder.ToTable(nameof(Organization));
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
using AutoMapper;
|
||||
using Bit.Core.Billing.Premium.Models;
|
||||
using Bit.Core.KeyManagement.Models.Data;
|
||||
using Bit.Core.KeyManagement.UserKey;
|
||||
using Bit.Core.Models.Data;
|
||||
@@ -350,6 +351,36 @@ public class UserRepository : Repository<Core.Entities.User, User, Guid>, IUserR
|
||||
return result.FirstOrDefault();
|
||||
}
|
||||
|
||||
public async Task<IEnumerable<UserPremiumAccess>> GetPremiumAccessByIdsAsync(IEnumerable<Guid> ids)
|
||||
{
|
||||
using (var scope = ServiceScopeFactory.CreateScope())
|
||||
{
|
||||
var dbContext = GetDatabaseContext(scope);
|
||||
|
||||
var users = await dbContext.Users
|
||||
.Where(x => ids.Contains(x.Id))
|
||||
.Include(u => u.OrganizationUsers)
|
||||
.ThenInclude(ou => ou.Organization)
|
||||
.ToListAsync();
|
||||
|
||||
return users.Select(user => new UserPremiumAccess
|
||||
{
|
||||
Id = user.Id,
|
||||
PersonalPremium = user.Premium,
|
||||
OrganizationPremium = user.OrganizationUsers
|
||||
.Any(ou => ou.Organization != null &&
|
||||
ou.Organization.Enabled == true &&
|
||||
ou.Organization.UsersGetPremium == true)
|
||||
}).ToList();
|
||||
}
|
||||
}
|
||||
|
||||
public async Task<UserPremiumAccess?> GetPremiumAccessAsync(Guid userId)
|
||||
{
|
||||
var result = await GetPremiumAccessByIdsAsync([userId]);
|
||||
return result.FirstOrDefault();
|
||||
}
|
||||
|
||||
public override async Task DeleteAsync(Core.Entities.User user)
|
||||
{
|
||||
using (var scope = ServiceScopeFactory.CreateScope())
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
CREATE PROCEDURE [dbo].[Collection_UpdateWithGroups]
|
||||
@Id UNIQUEIDENTIFIER,
|
||||
@OrganizationId UNIQUEIDENTIFIER,
|
||||
@Name VARCHAR(MAX),
|
||||
@ExternalId NVARCHAR(300),
|
||||
@CreationDate DATETIME2(7),
|
||||
@RevisionDate DATETIME2(7),
|
||||
@Groups AS [dbo].[CollectionAccessSelectionType] READONLY,
|
||||
@DefaultUserCollectionEmail NVARCHAR(256) = NULL,
|
||||
@Type TINYINT = 0
|
||||
AS
|
||||
BEGIN
|
||||
SET NOCOUNT ON
|
||||
|
||||
EXEC [dbo].[Collection_Update] @Id, @OrganizationId, @Name, @ExternalId, @CreationDate, @RevisionDate, @DefaultUserCollectionEmail, @Type
|
||||
|
||||
-- Groups
|
||||
-- Delete groups that are no longer in source
|
||||
DELETE
|
||||
cg
|
||||
FROM
|
||||
[dbo].[CollectionGroup] cg
|
||||
LEFT JOIN
|
||||
@Groups g ON cg.GroupId = g.Id
|
||||
WHERE
|
||||
cg.CollectionId = @Id
|
||||
AND g.Id IS NULL;
|
||||
|
||||
-- Update existing groups
|
||||
UPDATE
|
||||
cg
|
||||
SET
|
||||
cg.ReadOnly = g.ReadOnly,
|
||||
cg.HidePasswords = g.HidePasswords,
|
||||
cg.Manage = g.Manage
|
||||
FROM
|
||||
[dbo].[CollectionGroup] cg
|
||||
INNER JOIN
|
||||
@Groups g ON cg.GroupId = g.Id
|
||||
WHERE
|
||||
cg.CollectionId = @Id
|
||||
AND (
|
||||
cg.ReadOnly != g.ReadOnly
|
||||
OR cg.HidePasswords != g.HidePasswords
|
||||
OR cg.Manage != g.Manage
|
||||
);
|
||||
|
||||
-- Insert new groups
|
||||
INSERT INTO [dbo].[CollectionGroup]
|
||||
(
|
||||
[CollectionId],
|
||||
[GroupId],
|
||||
[ReadOnly],
|
||||
[HidePasswords],
|
||||
[Manage]
|
||||
)
|
||||
SELECT
|
||||
@Id,
|
||||
g.Id,
|
||||
g.ReadOnly,
|
||||
g.HidePasswords,
|
||||
g.Manage
|
||||
FROM
|
||||
@Groups g
|
||||
INNER JOIN
|
||||
[dbo].[Group] grp ON grp.Id = g.Id
|
||||
LEFT JOIN
|
||||
[dbo].[CollectionGroup] cg ON cg.CollectionId = @Id AND cg.GroupId = g.Id
|
||||
WHERE
|
||||
grp.OrganizationId = @OrganizationId
|
||||
AND cg.CollectionId IS NULL;
|
||||
|
||||
EXEC [dbo].[User_BumpAccountRevisionDateByCollectionId] @Id, @OrganizationId
|
||||
END
|
||||
74
src/Sql/dbo/Stored Procedures/Collection_UpdateWithUsers.sql
Normal file
74
src/Sql/dbo/Stored Procedures/Collection_UpdateWithUsers.sql
Normal file
@@ -0,0 +1,74 @@
|
||||
CREATE PROCEDURE [dbo].[Collection_UpdateWithUsers]
|
||||
@Id UNIQUEIDENTIFIER,
|
||||
@OrganizationId UNIQUEIDENTIFIER,
|
||||
@Name VARCHAR(MAX),
|
||||
@ExternalId NVARCHAR(300),
|
||||
@CreationDate DATETIME2(7),
|
||||
@RevisionDate DATETIME2(7),
|
||||
@Users AS [dbo].[CollectionAccessSelectionType] READONLY,
|
||||
@DefaultUserCollectionEmail NVARCHAR(256) = NULL,
|
||||
@Type TINYINT = 0
|
||||
AS
|
||||
BEGIN
|
||||
SET NOCOUNT ON
|
||||
|
||||
EXEC [dbo].[Collection_Update] @Id, @OrganizationId, @Name, @ExternalId, @CreationDate, @RevisionDate, @DefaultUserCollectionEmail, @Type
|
||||
|
||||
-- Users
|
||||
-- Delete users that are no longer in source
|
||||
DELETE
|
||||
cu
|
||||
FROM
|
||||
[dbo].[CollectionUser] cu
|
||||
LEFT JOIN
|
||||
@Users u ON cu.OrganizationUserId = u.Id
|
||||
WHERE
|
||||
cu.CollectionId = @Id
|
||||
AND u.Id IS NULL;
|
||||
|
||||
-- Update existing users
|
||||
UPDATE
|
||||
cu
|
||||
SET
|
||||
cu.ReadOnly = u.ReadOnly,
|
||||
cu.HidePasswords = u.HidePasswords,
|
||||
cu.Manage = u.Manage
|
||||
FROM
|
||||
[dbo].[CollectionUser] cu
|
||||
INNER JOIN
|
||||
@Users u ON cu.OrganizationUserId = u.Id
|
||||
WHERE
|
||||
cu.CollectionId = @Id
|
||||
AND (
|
||||
cu.ReadOnly != u.ReadOnly
|
||||
OR cu.HidePasswords != u.HidePasswords
|
||||
OR cu.Manage != u.Manage
|
||||
);
|
||||
|
||||
-- Insert new users
|
||||
INSERT INTO [dbo].[CollectionUser]
|
||||
(
|
||||
[CollectionId],
|
||||
[OrganizationUserId],
|
||||
[ReadOnly],
|
||||
[HidePasswords],
|
||||
[Manage]
|
||||
)
|
||||
SELECT
|
||||
@Id,
|
||||
u.Id,
|
||||
u.ReadOnly,
|
||||
u.HidePasswords,
|
||||
u.Manage
|
||||
FROM
|
||||
@Users u
|
||||
INNER JOIN
|
||||
[dbo].[OrganizationUser] ou ON ou.Id = u.Id
|
||||
LEFT JOIN
|
||||
[dbo].[CollectionUser] cu ON cu.CollectionId = @Id AND cu.OrganizationUserId = u.Id
|
||||
WHERE
|
||||
ou.OrganizationId = @OrganizationId
|
||||
AND cu.CollectionId IS NULL;
|
||||
|
||||
EXEC [dbo].[User_BumpAccountRevisionDateByCollectionId] @Id, @OrganizationId
|
||||
END
|
||||
@@ -0,0 +1,15 @@
|
||||
CREATE PROCEDURE [dbo].[User_ReadPremiumAccessByIds]
|
||||
@Ids [dbo].[GuidIdArray] READONLY
|
||||
AS
|
||||
BEGIN
|
||||
SET NOCOUNT ON
|
||||
|
||||
SELECT
|
||||
UPA.[Id],
|
||||
UPA.[PersonalPremium],
|
||||
UPA.[OrganizationPremium]
|
||||
FROM
|
||||
[dbo].[UserPremiumAccessView] UPA
|
||||
WHERE
|
||||
UPA.[Id] IN (SELECT [Id] FROM @Ids)
|
||||
END
|
||||
@@ -69,7 +69,7 @@ CREATE TABLE [dbo].[Organization] (
|
||||
GO
|
||||
CREATE NONCLUSTERED INDEX [IX_Organization_Enabled]
|
||||
ON [dbo].[Organization]([Id] ASC, [Enabled] ASC)
|
||||
INCLUDE ([UseTotp]);
|
||||
INCLUDE ([UseTotp], [UsersGetPremium]);
|
||||
|
||||
GO
|
||||
CREATE UNIQUE NONCLUSTERED INDEX [IX_Organization_Identifier]
|
||||
|
||||
21
src/Sql/dbo/Views/UserPremiumAccessView.sql
Normal file
21
src/Sql/dbo/Views/UserPremiumAccessView.sql
Normal file
@@ -0,0 +1,21 @@
|
||||
CREATE VIEW [dbo].[UserPremiumAccessView]
|
||||
AS
|
||||
SELECT
|
||||
U.[Id],
|
||||
U.[Premium] AS [PersonalPremium],
|
||||
CAST(
|
||||
MAX(CASE
|
||||
WHEN O.[Id] IS NOT NULL THEN 1
|
||||
ELSE 0
|
||||
END) AS BIT
|
||||
) AS [OrganizationPremium]
|
||||
FROM
|
||||
[dbo].[User] U
|
||||
LEFT JOIN
|
||||
[dbo].[OrganizationUser] OU ON OU.[UserId] = U.[Id]
|
||||
LEFT JOIN
|
||||
[dbo].[Organization] O ON O.[Id] = OU.[OrganizationId]
|
||||
AND O.[UsersGetPremium] = 1
|
||||
AND O.[Enabled] = 1
|
||||
GROUP BY
|
||||
U.[Id], U.[Premium];
|
||||
@@ -0,0 +1,117 @@
|
||||
using Bit.Api.AdminConsole.Public.Models.Request;
|
||||
using Bit.Api.IntegrationTest.Factories;
|
||||
using Bit.Api.IntegrationTest.Helpers;
|
||||
using Bit.Api.Models.Public.Request;
|
||||
using Bit.Api.Models.Public.Response;
|
||||
using Bit.Core.AdminConsole.Entities;
|
||||
using Bit.Core.AdminConsole.Repositories;
|
||||
using Bit.Core.Billing.Enums;
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Models.Data;
|
||||
using Bit.Core.Platform.Push;
|
||||
using Bit.Core.Repositories;
|
||||
using Bit.Core.Services;
|
||||
using Xunit;
|
||||
|
||||
namespace Bit.Api.IntegrationTest.Controllers.Public;
|
||||
|
||||
public class CollectionsControllerTests : IClassFixture<ApiApplicationFactory>, IAsyncLifetime
|
||||
{
|
||||
|
||||
private readonly HttpClient _client;
|
||||
private readonly ApiApplicationFactory _factory;
|
||||
private readonly LoginHelper _loginHelper;
|
||||
|
||||
private string _ownerEmail = null!;
|
||||
private Organization _organization = null!;
|
||||
|
||||
public CollectionsControllerTests(ApiApplicationFactory factory)
|
||||
{
|
||||
_factory = factory;
|
||||
_factory.SubstituteService<IPushNotificationService>(_ => { });
|
||||
_factory.SubstituteService<IFeatureService>(_ => { });
|
||||
_client = factory.CreateClient();
|
||||
_loginHelper = new LoginHelper(_factory, _client);
|
||||
}
|
||||
|
||||
public async Task InitializeAsync()
|
||||
{
|
||||
_ownerEmail = $"integration-test{Guid.NewGuid()}@bitwarden.com";
|
||||
await _factory.LoginWithNewAccount(_ownerEmail);
|
||||
|
||||
(_organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory,
|
||||
plan: PlanType.EnterpriseAnnually,
|
||||
ownerEmail: _ownerEmail,
|
||||
passwordManagerSeats: 10,
|
||||
paymentMethod: PaymentMethodType.Card);
|
||||
|
||||
await _loginHelper.LoginWithOrganizationApiKeyAsync(_organization.Id);
|
||||
}
|
||||
|
||||
public Task DisposeAsync()
|
||||
{
|
||||
_client.Dispose();
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CreateCollectionWithMultipleUsersAndVariedPermissions_Success()
|
||||
{
|
||||
// Arrange
|
||||
_organization.AllowAdminAccessToAllCollectionItems = true;
|
||||
await _factory.GetService<IOrganizationRepository>().UpsertAsync(_organization);
|
||||
|
||||
var groupRepository = _factory.GetService<IGroupRepository>();
|
||||
var group = await groupRepository.CreateAsync(new Group
|
||||
{
|
||||
OrganizationId = _organization.Id,
|
||||
Name = "CollectionControllerTests.CreateCollectionWithMultipleUsersAndVariedPermissions_Success",
|
||||
ExternalId = $"CollectionControllerTests.CreateCollectionWithMultipleUsersAndVariedPermissions_Success{Guid.NewGuid()}",
|
||||
});
|
||||
|
||||
var (_, user) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(
|
||||
_factory,
|
||||
_organization.Id,
|
||||
OrganizationUserType.User);
|
||||
|
||||
var collection = await OrganizationTestHelpers.CreateCollectionAsync(
|
||||
_factory,
|
||||
_organization.Id,
|
||||
"Shared Collection with a group",
|
||||
externalId: "shared-collection-with-group",
|
||||
groups:
|
||||
[
|
||||
new CollectionAccessSelection { Id = group.Id, ReadOnly = false, HidePasswords = false, Manage = true }
|
||||
],
|
||||
users:
|
||||
[
|
||||
new CollectionAccessSelection { Id = user.Id, ReadOnly = false, HidePasswords = false, Manage = true }
|
||||
]);
|
||||
|
||||
var getCollectionsResponse = await _client.GetFromJsonAsync<ListResponseModel<CollectionResponseModel>>("public/collections");
|
||||
var getCollectionResponse = await _client.GetFromJsonAsync<CollectionResponseModel>($"public/collections/{collection.Id}");
|
||||
|
||||
var firstCollection = getCollectionsResponse.Data.First(x => x.ExternalId == "shared-collection-with-group");
|
||||
|
||||
var update = new CollectionUpdateRequestModel
|
||||
{
|
||||
ExternalId = firstCollection.ExternalId,
|
||||
Groups = firstCollection.Groups?.Select(x => new AssociationWithPermissionsRequestModel
|
||||
{
|
||||
Id = x.Id,
|
||||
ReadOnly = x.ReadOnly,
|
||||
HidePasswords = x.HidePasswords,
|
||||
Manage = x.Manage
|
||||
}),
|
||||
};
|
||||
|
||||
await _client.PutAsJsonAsync($"public/collections/{firstCollection.Id}", update);
|
||||
|
||||
var result = await _factory.GetService<ICollectionRepository>()
|
||||
.GetByIdWithAccessAsync(firstCollection.Id);
|
||||
|
||||
Assert.NotNull(result);
|
||||
Assert.NotEmpty(result.Item2.Groups);
|
||||
Assert.NotEmpty(result.Item2.Users);
|
||||
}
|
||||
}
|
||||
@@ -159,14 +159,16 @@ public static class OrganizationTestHelpers
|
||||
Guid organizationId,
|
||||
string name,
|
||||
IEnumerable<CollectionAccessSelection>? users = null,
|
||||
IEnumerable<CollectionAccessSelection>? groups = null)
|
||||
IEnumerable<CollectionAccessSelection>? groups = null,
|
||||
string? externalId = null)
|
||||
{
|
||||
var collectionRepository = factory.GetService<ICollectionRepository>();
|
||||
var collection = new Collection
|
||||
{
|
||||
OrganizationId = organizationId,
|
||||
Name = name,
|
||||
Type = CollectionType.SharedCollection
|
||||
Type = CollectionType.SharedCollection,
|
||||
ExternalId = externalId
|
||||
};
|
||||
|
||||
await collectionRepository.CreateAsync(collection, groups, users);
|
||||
|
||||
@@ -62,7 +62,7 @@ public class ReconcileAdditionalStorageJobTests
|
||||
|
||||
// Assert
|
||||
_stripeFacade.Received(3).ListSubscriptionsAutoPagingAsync(
|
||||
Arg.Is<SubscriptionListOptions>(o => o.Status == "active"));
|
||||
Arg.Is<SubscriptionListOptions>(o => o.Limit == 100));
|
||||
}
|
||||
|
||||
#endregion
|
||||
@@ -553,6 +553,152 @@ public class ReconcileAdditionalStorageJobTests
|
||||
|
||||
#endregion
|
||||
|
||||
#region Subscription Status Filtering Tests
|
||||
|
||||
[Fact]
|
||||
public async Task Execute_ActiveStatusSubscription_ProcessesSubscription()
|
||||
{
|
||||
// Arrange
|
||||
var context = CreateJobExecutionContext();
|
||||
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
|
||||
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
|
||||
|
||||
var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.Active);
|
||||
|
||||
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
|
||||
.Returns(AsyncEnumerable.Create(subscription));
|
||||
_stripeFacade.UpdateSubscription(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>())
|
||||
.Returns(subscription);
|
||||
|
||||
// Act
|
||||
await _sut.Execute(context);
|
||||
|
||||
// Assert
|
||||
await _stripeFacade.Received(1).UpdateSubscription("sub_123", Arg.Any<SubscriptionUpdateOptions>());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Execute_TrialingStatusSubscription_ProcessesSubscription()
|
||||
{
|
||||
// Arrange
|
||||
var context = CreateJobExecutionContext();
|
||||
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
|
||||
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
|
||||
|
||||
var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.Trialing);
|
||||
|
||||
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
|
||||
.Returns(AsyncEnumerable.Create(subscription));
|
||||
_stripeFacade.UpdateSubscription(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>())
|
||||
.Returns(subscription);
|
||||
|
||||
// Act
|
||||
await _sut.Execute(context);
|
||||
|
||||
// Assert
|
||||
await _stripeFacade.Received(1).UpdateSubscription("sub_123", Arg.Any<SubscriptionUpdateOptions>());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Execute_PastDueStatusSubscription_ProcessesSubscription()
|
||||
{
|
||||
// Arrange
|
||||
var context = CreateJobExecutionContext();
|
||||
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
|
||||
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
|
||||
|
||||
var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.PastDue);
|
||||
|
||||
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
|
||||
.Returns(AsyncEnumerable.Create(subscription));
|
||||
_stripeFacade.UpdateSubscription(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>())
|
||||
.Returns(subscription);
|
||||
|
||||
// Act
|
||||
await _sut.Execute(context);
|
||||
|
||||
// Assert
|
||||
await _stripeFacade.Received(1).UpdateSubscription("sub_123", Arg.Any<SubscriptionUpdateOptions>());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Execute_CanceledStatusSubscription_SkipsSubscription()
|
||||
{
|
||||
// Arrange
|
||||
var context = CreateJobExecutionContext();
|
||||
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
|
||||
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
|
||||
|
||||
var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.Canceled);
|
||||
|
||||
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
|
||||
.Returns(AsyncEnumerable.Create(subscription));
|
||||
|
||||
// Act
|
||||
await _sut.Execute(context);
|
||||
|
||||
// Assert
|
||||
await _stripeFacade.DidNotReceiveWithAnyArgs().UpdateSubscription(null!);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Execute_IncompleteStatusSubscription_SkipsSubscription()
|
||||
{
|
||||
// Arrange
|
||||
var context = CreateJobExecutionContext();
|
||||
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
|
||||
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
|
||||
|
||||
var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.Incomplete);
|
||||
|
||||
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
|
||||
.Returns(AsyncEnumerable.Create(subscription));
|
||||
|
||||
// Act
|
||||
await _sut.Execute(context);
|
||||
|
||||
// Assert
|
||||
await _stripeFacade.DidNotReceiveWithAnyArgs().UpdateSubscription(null!);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Execute_MixedSubscriptionStatuses_OnlyProcessesValidStatuses()
|
||||
{
|
||||
// Arrange
|
||||
var context = CreateJobExecutionContext();
|
||||
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
|
||||
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
|
||||
|
||||
var activeSubscription = CreateSubscription("sub_active", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.Active);
|
||||
var trialingSubscription = CreateSubscription("sub_trialing", "storage-gb-monthly", quantity: 8, status: StripeConstants.SubscriptionStatus.Trialing);
|
||||
var pastDueSubscription = CreateSubscription("sub_pastdue", "storage-gb-monthly", quantity: 6, status: StripeConstants.SubscriptionStatus.PastDue);
|
||||
var canceledSubscription = CreateSubscription("sub_canceled", "storage-gb-monthly", quantity: 5, status: StripeConstants.SubscriptionStatus.Canceled);
|
||||
var incompleteSubscription = CreateSubscription("sub_incomplete", "storage-gb-monthly", quantity: 4, status: StripeConstants.SubscriptionStatus.Incomplete);
|
||||
|
||||
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
|
||||
.Returns(AsyncEnumerable.Create(activeSubscription, trialingSubscription, pastDueSubscription, canceledSubscription, incompleteSubscription));
|
||||
_stripeFacade.UpdateSubscription(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>())
|
||||
.Returns(callInfo => callInfo.Arg<string>() switch
|
||||
{
|
||||
"sub_active" => activeSubscription,
|
||||
"sub_trialing" => trialingSubscription,
|
||||
"sub_pastdue" => pastDueSubscription,
|
||||
_ => null
|
||||
});
|
||||
|
||||
// Act
|
||||
await _sut.Execute(context);
|
||||
|
||||
// Assert
|
||||
await _stripeFacade.Received(1).UpdateSubscription("sub_active", Arg.Any<SubscriptionUpdateOptions>());
|
||||
await _stripeFacade.Received(1).UpdateSubscription("sub_trialing", Arg.Any<SubscriptionUpdateOptions>());
|
||||
await _stripeFacade.Received(1).UpdateSubscription("sub_pastdue", Arg.Any<SubscriptionUpdateOptions>());
|
||||
await _stripeFacade.DidNotReceive().UpdateSubscription("sub_canceled", Arg.Any<SubscriptionUpdateOptions>());
|
||||
await _stripeFacade.DidNotReceive().UpdateSubscription("sub_incomplete", Arg.Any<SubscriptionUpdateOptions>());
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Cancellation Tests
|
||||
|
||||
[Fact]
|
||||
@@ -598,7 +744,8 @@ public class ReconcileAdditionalStorageJobTests
|
||||
string id,
|
||||
string priceId,
|
||||
long? quantity = null,
|
||||
Dictionary<string, string>? metadata = null)
|
||||
Dictionary<string, string>? metadata = null,
|
||||
string status = StripeConstants.SubscriptionStatus.Active)
|
||||
{
|
||||
var price = new Price { Id = priceId };
|
||||
var item = new SubscriptionItem
|
||||
@@ -611,6 +758,7 @@ public class ReconcileAdditionalStorageJobTests
|
||||
return new Subscription
|
||||
{
|
||||
Id = id,
|
||||
Status = status,
|
||||
Metadata = metadata,
|
||||
Items = new StripeList<SubscriptionItem>
|
||||
{
|
||||
|
||||
@@ -200,7 +200,8 @@ public class EventIntegrationServiceCollectionExtensionsTests
|
||||
["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost",
|
||||
["GlobalSettings:EventLogging:RabbitMq:Username"] = "user",
|
||||
["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass",
|
||||
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange"
|
||||
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange",
|
||||
["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration"
|
||||
});
|
||||
|
||||
Assert.True(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings));
|
||||
@@ -214,7 +215,8 @@ public class EventIntegrationServiceCollectionExtensionsTests
|
||||
["GlobalSettings:EventLogging:RabbitMq:HostName"] = null,
|
||||
["GlobalSettings:EventLogging:RabbitMq:Username"] = "user",
|
||||
["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass",
|
||||
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange"
|
||||
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange",
|
||||
["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration"
|
||||
});
|
||||
|
||||
Assert.False(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings));
|
||||
@@ -228,7 +230,8 @@ public class EventIntegrationServiceCollectionExtensionsTests
|
||||
["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost",
|
||||
["GlobalSettings:EventLogging:RabbitMq:Username"] = null,
|
||||
["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass",
|
||||
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange"
|
||||
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange",
|
||||
["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration"
|
||||
});
|
||||
|
||||
Assert.False(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings));
|
||||
@@ -242,21 +245,38 @@ public class EventIntegrationServiceCollectionExtensionsTests
|
||||
["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost",
|
||||
["GlobalSettings:EventLogging:RabbitMq:Username"] = "user",
|
||||
["GlobalSettings:EventLogging:RabbitMq:Password"] = null,
|
||||
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange"
|
||||
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange",
|
||||
["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration"
|
||||
});
|
||||
|
||||
Assert.False(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void IsRabbitMqEnabled_MissingExchangeName_ReturnsFalse()
|
||||
public void IsRabbitMqEnabled_MissingEventExchangeName_ReturnsFalse()
|
||||
{
|
||||
var globalSettings = CreateGlobalSettings(new Dictionary<string, string?>
|
||||
{
|
||||
["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost",
|
||||
["GlobalSettings:EventLogging:RabbitMq:Username"] = "user",
|
||||
["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass",
|
||||
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = null
|
||||
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = null,
|
||||
["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration"
|
||||
});
|
||||
|
||||
Assert.False(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void IsRabbitMqEnabled_MissingIntegrationExchangeName_ReturnsFalse()
|
||||
{
|
||||
var globalSettings = CreateGlobalSettings(new Dictionary<string, string?>
|
||||
{
|
||||
["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost",
|
||||
["GlobalSettings:EventLogging:RabbitMq:Username"] = "user",
|
||||
["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass",
|
||||
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange",
|
||||
["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = null
|
||||
});
|
||||
|
||||
Assert.False(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings));
|
||||
@@ -268,7 +288,8 @@ public class EventIntegrationServiceCollectionExtensionsTests
|
||||
var globalSettings = CreateGlobalSettings(new Dictionary<string, string?>
|
||||
{
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test",
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events"
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events",
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration"
|
||||
});
|
||||
|
||||
Assert.True(EventIntegrationsServiceCollectionExtensions.IsAzureServiceBusEnabled(globalSettings));
|
||||
@@ -280,19 +301,34 @@ public class EventIntegrationServiceCollectionExtensionsTests
|
||||
var globalSettings = CreateGlobalSettings(new Dictionary<string, string?>
|
||||
{
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = null,
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events"
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events",
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration"
|
||||
});
|
||||
|
||||
Assert.False(EventIntegrationsServiceCollectionExtensions.IsAzureServiceBusEnabled(globalSettings));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void IsAzureServiceBusEnabled_MissingTopicName_ReturnsFalse()
|
||||
public void IsAzureServiceBusEnabled_MissingEventTopicName_ReturnsFalse()
|
||||
{
|
||||
var globalSettings = CreateGlobalSettings(new Dictionary<string, string?>
|
||||
{
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test",
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = null
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = null,
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration"
|
||||
});
|
||||
|
||||
Assert.False(EventIntegrationsServiceCollectionExtensions.IsAzureServiceBusEnabled(globalSettings));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void IsAzureServiceBusEnabled_MissingIntegrationTopicName_ReturnsFalse()
|
||||
{
|
||||
var globalSettings = CreateGlobalSettings(new Dictionary<string, string?>
|
||||
{
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test",
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events",
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = null
|
||||
});
|
||||
|
||||
Assert.False(EventIntegrationsServiceCollectionExtensions.IsAzureServiceBusEnabled(globalSettings));
|
||||
@@ -601,7 +637,8 @@ public class EventIntegrationServiceCollectionExtensionsTests
|
||||
["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost",
|
||||
["GlobalSettings:EventLogging:RabbitMq:Username"] = "user",
|
||||
["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass",
|
||||
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange"
|
||||
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange",
|
||||
["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration"
|
||||
});
|
||||
|
||||
// Add prerequisites
|
||||
@@ -624,7 +661,8 @@ public class EventIntegrationServiceCollectionExtensionsTests
|
||||
var globalSettings = CreateGlobalSettings(new Dictionary<string, string?>
|
||||
{
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test",
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events"
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events",
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration"
|
||||
});
|
||||
|
||||
// Add prerequisites
|
||||
@@ -650,8 +688,10 @@ public class EventIntegrationServiceCollectionExtensionsTests
|
||||
["GlobalSettings:EventLogging:RabbitMq:Username"] = "user",
|
||||
["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass",
|
||||
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange",
|
||||
["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration",
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test",
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events"
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events",
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration"
|
||||
});
|
||||
|
||||
// Add prerequisites
|
||||
@@ -694,7 +734,8 @@ public class EventIntegrationServiceCollectionExtensionsTests
|
||||
var globalSettings = CreateGlobalSettings(new Dictionary<string, string?>
|
||||
{
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test",
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events"
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events",
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration"
|
||||
});
|
||||
|
||||
services.AddEventWriteServices(globalSettings);
|
||||
@@ -712,7 +753,8 @@ public class EventIntegrationServiceCollectionExtensionsTests
|
||||
["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost",
|
||||
["GlobalSettings:EventLogging:RabbitMq:Username"] = "user",
|
||||
["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass",
|
||||
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange"
|
||||
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange",
|
||||
["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration"
|
||||
});
|
||||
|
||||
services.AddEventWriteServices(globalSettings);
|
||||
@@ -769,10 +811,12 @@ public class EventIntegrationServiceCollectionExtensionsTests
|
||||
{
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test",
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events",
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration",
|
||||
["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost",
|
||||
["GlobalSettings:EventLogging:RabbitMq:Username"] = "user",
|
||||
["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass",
|
||||
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange"
|
||||
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange",
|
||||
["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration"
|
||||
});
|
||||
|
||||
services.AddEventWriteServices(globalSettings);
|
||||
@@ -789,7 +833,8 @@ public class EventIntegrationServiceCollectionExtensionsTests
|
||||
var globalSettings = CreateGlobalSettings(new Dictionary<string, string?>
|
||||
{
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test",
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events"
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events",
|
||||
["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration"
|
||||
});
|
||||
|
||||
// Add prerequisites
|
||||
@@ -826,7 +871,8 @@ public class EventIntegrationServiceCollectionExtensionsTests
|
||||
["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost",
|
||||
["GlobalSettings:EventLogging:RabbitMq:Username"] = "user",
|
||||
["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass",
|
||||
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange"
|
||||
["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange",
|
||||
["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration"
|
||||
});
|
||||
|
||||
// Add prerequisites
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
using Bit.Core.AdminConsole.Models.Data.EventIntegrations;
|
||||
using Bit.Test.Common.AutoFixture.Attributes;
|
||||
using Xunit;
|
||||
|
||||
namespace Bit.Core.Test.AdminConsole.Models.Data.EventIntegrations;
|
||||
|
||||
public class IntegrationHandlerResultTests
|
||||
{
|
||||
[Theory, BitAutoData]
|
||||
public void Succeed_SetsSuccessTrue_CategoryNull(IntegrationMessage message)
|
||||
{
|
||||
var result = IntegrationHandlerResult.Succeed(message);
|
||||
|
||||
Assert.True(result.Success);
|
||||
Assert.Null(result.Category);
|
||||
Assert.Equal(message, result.Message);
|
||||
Assert.Null(result.FailureReason);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public void Fail_WithCategory_SetsSuccessFalse_CategorySet(IntegrationMessage message)
|
||||
{
|
||||
var category = IntegrationFailureCategory.AuthenticationFailed;
|
||||
var failureReason = "Invalid credentials";
|
||||
|
||||
var result = IntegrationHandlerResult.Fail(message, category, failureReason);
|
||||
|
||||
Assert.False(result.Success);
|
||||
Assert.Equal(category, result.Category);
|
||||
Assert.Equal(failureReason, result.FailureReason);
|
||||
Assert.Equal(message, result.Message);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public void Fail_WithDelayUntil_SetsDelayUntilDate(IntegrationMessage message)
|
||||
{
|
||||
var delayUntil = DateTime.UtcNow.AddMinutes(5);
|
||||
|
||||
var result = IntegrationHandlerResult.Fail(
|
||||
message,
|
||||
IntegrationFailureCategory.RateLimited,
|
||||
"Rate limited",
|
||||
delayUntil
|
||||
);
|
||||
|
||||
Assert.Equal(delayUntil, result.DelayUntilDate);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public void Retryable_RateLimited_ReturnsTrue(IntegrationMessage message)
|
||||
{
|
||||
var result = IntegrationHandlerResult.Fail(
|
||||
message,
|
||||
IntegrationFailureCategory.RateLimited,
|
||||
"Rate limited"
|
||||
);
|
||||
|
||||
Assert.True(result.Retryable);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public void Retryable_TransientError_ReturnsTrue(IntegrationMessage message)
|
||||
{
|
||||
var result = IntegrationHandlerResult.Fail(
|
||||
message,
|
||||
IntegrationFailureCategory.TransientError,
|
||||
"Temporary network issue"
|
||||
);
|
||||
|
||||
Assert.True(result.Retryable);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public void Retryable_AuthenticationFailed_ReturnsFalse(IntegrationMessage message)
|
||||
{
|
||||
var result = IntegrationHandlerResult.Fail(
|
||||
message,
|
||||
IntegrationFailureCategory.AuthenticationFailed,
|
||||
"Invalid token"
|
||||
);
|
||||
|
||||
Assert.False(result.Retryable);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public void Retryable_ConfigurationError_ReturnsFalse(IntegrationMessage message)
|
||||
{
|
||||
var result = IntegrationHandlerResult.Fail(
|
||||
message,
|
||||
IntegrationFailureCategory.ConfigurationError,
|
||||
"Channel not found"
|
||||
);
|
||||
|
||||
Assert.False(result.Retryable);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public void Retryable_ServiceUnavailable_ReturnsTrue(IntegrationMessage message)
|
||||
{
|
||||
var result = IntegrationHandlerResult.Fail(
|
||||
message,
|
||||
IntegrationFailureCategory.ServiceUnavailable,
|
||||
"Service is down"
|
||||
);
|
||||
|
||||
Assert.True(result.Retryable);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public void Retryable_PermanentFailure_ReturnsFalse(IntegrationMessage message)
|
||||
{
|
||||
var result = IntegrationHandlerResult.Fail(
|
||||
message,
|
||||
IntegrationFailureCategory.PermanentFailure,
|
||||
"Permanent failure"
|
||||
);
|
||||
|
||||
Assert.False(result.Retryable);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public void Retryable_SuccessCase_ReturnsFalse(IntegrationMessage message)
|
||||
{
|
||||
var result = IntegrationHandlerResult.Succeed(message);
|
||||
|
||||
Assert.False(result.Retryable);
|
||||
}
|
||||
}
|
||||
@@ -30,7 +30,7 @@ public class OrganizationUpdateCommandTests
|
||||
var organizationBillingService = sutProvider.GetDependency<IOrganizationBillingService>();
|
||||
|
||||
organization.Id = organizationId;
|
||||
organization.GatewayCustomerId = null; // No Stripe customer, so no billing update
|
||||
organization.GatewayCustomerId = null; // No Stripe customer, but billing update is still called
|
||||
|
||||
organizationRepository
|
||||
.GetByIdAsync(organizationId)
|
||||
@@ -61,8 +61,8 @@ public class OrganizationUpdateCommandTests
|
||||
result,
|
||||
EventType.Organization_Updated);
|
||||
await organizationBillingService
|
||||
.DidNotReceiveWithAnyArgs()
|
||||
.UpdateOrganizationNameAndEmail(Arg.Any<Organization>());
|
||||
.Received(1)
|
||||
.UpdateOrganizationNameAndEmail(result);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
@@ -93,7 +93,7 @@ public class OrganizationUpdateCommandTests
|
||||
[Theory]
|
||||
[BitAutoData("")]
|
||||
[BitAutoData((string)null)]
|
||||
public async Task UpdateAsync_WhenGatewayCustomerIdIsNullOrEmpty_SkipsBillingUpdate(
|
||||
public async Task UpdateAsync_WhenGatewayCustomerIdIsNullOrEmpty_CallsBillingUpdateButHandledGracefully(
|
||||
string gatewayCustomerId,
|
||||
Guid organizationId,
|
||||
Organization organization,
|
||||
@@ -133,8 +133,8 @@ public class OrganizationUpdateCommandTests
|
||||
result,
|
||||
EventType.Organization_Updated);
|
||||
await organizationBillingService
|
||||
.DidNotReceiveWithAnyArgs()
|
||||
.UpdateOrganizationNameAndEmail(Arg.Any<Organization>());
|
||||
.Received(1)
|
||||
.UpdateOrganizationNameAndEmail(result);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
|
||||
@@ -78,8 +78,10 @@ public class AzureServiceBusIntegrationListenerServiceTests
|
||||
var sutProvider = GetSutProvider();
|
||||
message.RetryCount = 0;
|
||||
|
||||
var result = new IntegrationHandlerResult(false, message);
|
||||
result.Retryable = false;
|
||||
var result = IntegrationHandlerResult.Fail(
|
||||
message: message,
|
||||
category: IntegrationFailureCategory.AuthenticationFailed, // NOT retryable
|
||||
failureReason: "403");
|
||||
_handler.HandleAsync(Arg.Any<string>()).Returns(result);
|
||||
|
||||
var expected = IntegrationMessage<WebhookIntegrationConfiguration>.FromJson(message.ToJson());
|
||||
@@ -89,6 +91,12 @@ public class AzureServiceBusIntegrationListenerServiceTests
|
||||
|
||||
await _handler.Received(1).HandleAsync(Arg.Is(expected.ToJson()));
|
||||
await _serviceBusService.DidNotReceiveWithAnyArgs().PublishToRetryAsync(Arg.Any<IIntegrationMessage>());
|
||||
_logger.Received().Log(
|
||||
LogLevel.Warning,
|
||||
Arg.Any<EventId>(),
|
||||
Arg.Is<object>(o => (o.ToString() ?? "").Contains("Integration failure - non-recoverable error or max retries exceeded.")),
|
||||
Arg.Any<Exception?>(),
|
||||
Arg.Any<Func<object, Exception?, string>>());
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
@@ -96,9 +104,10 @@ public class AzureServiceBusIntegrationListenerServiceTests
|
||||
{
|
||||
var sutProvider = GetSutProvider();
|
||||
message.RetryCount = _config.MaxRetries;
|
||||
var result = new IntegrationHandlerResult(false, message);
|
||||
result.Retryable = true;
|
||||
|
||||
var result = IntegrationHandlerResult.Fail(
|
||||
message: message,
|
||||
category: IntegrationFailureCategory.TransientError, // Retryable
|
||||
failureReason: "403");
|
||||
_handler.HandleAsync(Arg.Any<string>()).Returns(result);
|
||||
|
||||
var expected = IntegrationMessage<WebhookIntegrationConfiguration>.FromJson(message.ToJson());
|
||||
@@ -108,6 +117,12 @@ public class AzureServiceBusIntegrationListenerServiceTests
|
||||
|
||||
await _handler.Received(1).HandleAsync(Arg.Is(expected.ToJson()));
|
||||
await _serviceBusService.DidNotReceiveWithAnyArgs().PublishToRetryAsync(Arg.Any<IIntegrationMessage>());
|
||||
_logger.Received().Log(
|
||||
LogLevel.Warning,
|
||||
Arg.Any<EventId>(),
|
||||
Arg.Is<object>(o => (o.ToString() ?? "").Contains("Integration failure - non-recoverable error or max retries exceeded.")),
|
||||
Arg.Any<Exception?>(),
|
||||
Arg.Any<Func<object, Exception?, string>>());
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
@@ -116,8 +131,10 @@ public class AzureServiceBusIntegrationListenerServiceTests
|
||||
var sutProvider = GetSutProvider();
|
||||
message.RetryCount = 0;
|
||||
|
||||
var result = new IntegrationHandlerResult(false, message);
|
||||
result.Retryable = true;
|
||||
var result = IntegrationHandlerResult.Fail(
|
||||
message: message,
|
||||
category: IntegrationFailureCategory.TransientError, // Retryable
|
||||
failureReason: "403");
|
||||
_handler.HandleAsync(Arg.Any<string>()).Returns(result);
|
||||
|
||||
var expected = IntegrationMessage<WebhookIntegrationConfiguration>.FromJson(message.ToJson());
|
||||
@@ -133,7 +150,7 @@ public class AzureServiceBusIntegrationListenerServiceTests
|
||||
public async Task HandleMessageAsync_SuccessfulResult_Succeeds(IntegrationMessage<WebhookIntegrationConfiguration> message)
|
||||
{
|
||||
var sutProvider = GetSutProvider();
|
||||
var result = new IntegrationHandlerResult(true, message);
|
||||
var result = IntegrationHandlerResult.Succeed(message);
|
||||
_handler.HandleAsync(Arg.Any<string>()).Returns(result);
|
||||
|
||||
var expected = IntegrationMessage<WebhookIntegrationConfiguration>.FromJson(message.ToJson());
|
||||
@@ -156,7 +173,7 @@ public class AzureServiceBusIntegrationListenerServiceTests
|
||||
_logger.Received(1).Log(
|
||||
LogLevel.Error,
|
||||
Arg.Any<EventId>(),
|
||||
Arg.Any<object>(),
|
||||
Arg.Is<object>(o => (o.ToString() ?? "").Contains("Unhandled error processing ASB message")),
|
||||
Arg.Any<Exception>(),
|
||||
Arg.Any<Func<object, Exception?, string>>());
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ public class DatadogIntegrationHandlerTests
|
||||
|
||||
Assert.True(result.Success);
|
||||
Assert.Equal(result.Message, message);
|
||||
Assert.Empty(result.FailureReason);
|
||||
Assert.Null(result.FailureReason);
|
||||
|
||||
sutProvider.GetDependency<IHttpClientFactory>().Received(1).CreateClient(
|
||||
Arg.Is(AssertHelper.AssertPropertyEqual(DatadogIntegrationHandler.HttpClientName))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
using Bit.Core.AdminConsole.Models.Data.EventIntegrations;
|
||||
using System.Net;
|
||||
using Bit.Core.AdminConsole.Models.Data.EventIntegrations;
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Services;
|
||||
using Xunit;
|
||||
@@ -7,7 +8,6 @@ namespace Bit.Core.Test.Services;
|
||||
|
||||
public class IntegrationHandlerTests
|
||||
{
|
||||
|
||||
[Fact]
|
||||
public async Task HandleAsync_ConvertsJsonToTypedIntegrationMessage()
|
||||
{
|
||||
@@ -33,13 +33,113 @@ public class IntegrationHandlerTests
|
||||
Assert.Equal(expected.IntegrationType, typedResult.IntegrationType);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(HttpStatusCode.Unauthorized)]
|
||||
[InlineData(HttpStatusCode.Forbidden)]
|
||||
public void ClassifyHttpStatusCode_AuthenticationFailed(HttpStatusCode code)
|
||||
{
|
||||
Assert.Equal(
|
||||
IntegrationFailureCategory.AuthenticationFailed,
|
||||
TestIntegrationHandler.Classify(code));
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(HttpStatusCode.NotFound)]
|
||||
[InlineData(HttpStatusCode.Gone)]
|
||||
[InlineData(HttpStatusCode.MovedPermanently)]
|
||||
[InlineData(HttpStatusCode.TemporaryRedirect)]
|
||||
[InlineData(HttpStatusCode.PermanentRedirect)]
|
||||
public void ClassifyHttpStatusCode_ConfigurationError(HttpStatusCode code)
|
||||
{
|
||||
Assert.Equal(
|
||||
IntegrationFailureCategory.ConfigurationError,
|
||||
TestIntegrationHandler.Classify(code));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ClassifyHttpStatusCode_TooManyRequests_IsRateLimited()
|
||||
{
|
||||
Assert.Equal(
|
||||
IntegrationFailureCategory.RateLimited,
|
||||
TestIntegrationHandler.Classify(HttpStatusCode.TooManyRequests));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ClassifyHttpStatusCode_RequestTimeout_IsTransient()
|
||||
{
|
||||
Assert.Equal(
|
||||
IntegrationFailureCategory.TransientError,
|
||||
TestIntegrationHandler.Classify(HttpStatusCode.RequestTimeout));
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(HttpStatusCode.InternalServerError)]
|
||||
[InlineData(HttpStatusCode.BadGateway)]
|
||||
[InlineData(HttpStatusCode.GatewayTimeout)]
|
||||
public void ClassifyHttpStatusCode_Common5xx_AreTransient(HttpStatusCode code)
|
||||
{
|
||||
Assert.Equal(
|
||||
IntegrationFailureCategory.TransientError,
|
||||
TestIntegrationHandler.Classify(code));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ClassifyHttpStatusCode_ServiceUnavailable_IsServiceUnavailable()
|
||||
{
|
||||
Assert.Equal(
|
||||
IntegrationFailureCategory.ServiceUnavailable,
|
||||
TestIntegrationHandler.Classify(HttpStatusCode.ServiceUnavailable));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ClassifyHttpStatusCode_NotImplemented_IsPermanentFailure()
|
||||
{
|
||||
Assert.Equal(
|
||||
IntegrationFailureCategory.PermanentFailure,
|
||||
TestIntegrationHandler.Classify(HttpStatusCode.NotImplemented));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void FClassifyHttpStatusCode_Unhandled3xx_IsConfigurationError()
|
||||
{
|
||||
Assert.Equal(
|
||||
IntegrationFailureCategory.ConfigurationError,
|
||||
TestIntegrationHandler.Classify(HttpStatusCode.Found));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ClassifyHttpStatusCode_Unhandled4xx_IsConfigurationError()
|
||||
{
|
||||
Assert.Equal(
|
||||
IntegrationFailureCategory.ConfigurationError,
|
||||
TestIntegrationHandler.Classify(HttpStatusCode.BadRequest));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ClassifyHttpStatusCode_Unhandled5xx_IsServiceUnavailable()
|
||||
{
|
||||
Assert.Equal(
|
||||
IntegrationFailureCategory.ServiceUnavailable,
|
||||
TestIntegrationHandler.Classify(HttpStatusCode.HttpVersionNotSupported));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ClassifyHttpStatusCode_UnknownCode_DefaultsToServiceUnavailable()
|
||||
{
|
||||
// cast an out-of-range value to ensure default path is stable
|
||||
Assert.Equal(
|
||||
IntegrationFailureCategory.ServiceUnavailable,
|
||||
TestIntegrationHandler.Classify((HttpStatusCode)799));
|
||||
}
|
||||
|
||||
private class TestIntegrationHandler : IntegrationHandlerBase<WebhookIntegrationConfigurationDetails>
|
||||
{
|
||||
public override Task<IntegrationHandlerResult> HandleAsync(
|
||||
IntegrationMessage<WebhookIntegrationConfigurationDetails> message)
|
||||
{
|
||||
var result = new IntegrationHandlerResult(success: true, message: message);
|
||||
return Task.FromResult(result);
|
||||
return Task.FromResult(IntegrationHandlerResult.Succeed(message: message));
|
||||
}
|
||||
|
||||
public static IntegrationFailureCategory Classify(HttpStatusCode code) => ClassifyHttpStatusCode(code);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,8 +86,10 @@ public class RabbitMqIntegrationListenerServiceTests
|
||||
new BasicProperties(),
|
||||
body: Encoding.UTF8.GetBytes(message.ToJson())
|
||||
);
|
||||
var result = new IntegrationHandlerResult(false, message);
|
||||
result.Retryable = false;
|
||||
var result = IntegrationHandlerResult.Fail(
|
||||
message: message,
|
||||
category: IntegrationFailureCategory.AuthenticationFailed, // NOT retryable
|
||||
failureReason: "403");
|
||||
_handler.HandleAsync(Arg.Any<string>()).Returns(result);
|
||||
|
||||
var expected = IntegrationMessage<WebhookIntegrationConfiguration>.FromJson(message.ToJson());
|
||||
@@ -105,7 +107,7 @@ public class RabbitMqIntegrationListenerServiceTests
|
||||
_logger.Received().Log(
|
||||
LogLevel.Warning,
|
||||
Arg.Any<EventId>(),
|
||||
Arg.Is<object>(o => (o.ToString() ?? "").Contains("Non-retryable failure")),
|
||||
Arg.Is<object>(o => (o.ToString() ?? "").Contains("Integration failure - non-retryable.")),
|
||||
Arg.Any<Exception?>(),
|
||||
Arg.Any<Func<object, Exception?, string>>());
|
||||
|
||||
@@ -133,8 +135,10 @@ public class RabbitMqIntegrationListenerServiceTests
|
||||
new BasicProperties(),
|
||||
body: Encoding.UTF8.GetBytes(message.ToJson())
|
||||
);
|
||||
var result = new IntegrationHandlerResult(false, message);
|
||||
result.Retryable = true;
|
||||
var result = IntegrationHandlerResult.Fail(
|
||||
message: message,
|
||||
category: IntegrationFailureCategory.TransientError, // Retryable
|
||||
failureReason: "403");
|
||||
_handler.HandleAsync(Arg.Any<string>()).Returns(result);
|
||||
|
||||
var expected = IntegrationMessage<WebhookIntegrationConfiguration>.FromJson(message.ToJson());
|
||||
@@ -151,7 +155,7 @@ public class RabbitMqIntegrationListenerServiceTests
|
||||
_logger.Received().Log(
|
||||
LogLevel.Warning,
|
||||
Arg.Any<EventId>(),
|
||||
Arg.Is<object>(o => (o.ToString() ?? "").Contains("Max retry attempts reached")),
|
||||
Arg.Is<object>(o => (o.ToString() ?? "").Contains("Integration failure - max retries exceeded.")),
|
||||
Arg.Any<Exception?>(),
|
||||
Arg.Any<Func<object, Exception?, string>>());
|
||||
|
||||
@@ -179,9 +183,10 @@ public class RabbitMqIntegrationListenerServiceTests
|
||||
new BasicProperties(),
|
||||
body: Encoding.UTF8.GetBytes(message.ToJson())
|
||||
);
|
||||
var result = new IntegrationHandlerResult(false, message);
|
||||
result.Retryable = true;
|
||||
result.DelayUntilDate = _now.AddMinutes(1);
|
||||
var result = IntegrationHandlerResult.Fail(
|
||||
message: message,
|
||||
category: IntegrationFailureCategory.TransientError, // Retryable
|
||||
failureReason: "403");
|
||||
_handler.HandleAsync(Arg.Any<string>()).Returns(result);
|
||||
|
||||
var expected = IntegrationMessage<WebhookIntegrationConfiguration>.FromJson(message.ToJson());
|
||||
@@ -220,7 +225,7 @@ public class RabbitMqIntegrationListenerServiceTests
|
||||
new BasicProperties(),
|
||||
body: Encoding.UTF8.GetBytes(message.ToJson())
|
||||
);
|
||||
var result = new IntegrationHandlerResult(true, message);
|
||||
var result = IntegrationHandlerResult.Succeed(message);
|
||||
_handler.HandleAsync(Arg.Any<string>()).Returns(result);
|
||||
|
||||
await sutProvider.Sut.ProcessReceivedMessageAsync(eventArgs, cancellationToken);
|
||||
|
||||
@@ -110,7 +110,7 @@ public class SlackIntegrationHandlerTests
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task HandleAsync_NullResponse_ReturnsNonRetryableFailure()
|
||||
public async Task HandleAsync_NullResponse_ReturnsRetryableFailure()
|
||||
{
|
||||
var sutProvider = GetSutProvider();
|
||||
var message = new IntegrationMessage<SlackIntegrationConfigurationDetails>()
|
||||
@@ -126,7 +126,7 @@ public class SlackIntegrationHandlerTests
|
||||
var result = await sutProvider.Sut.HandleAsync(message);
|
||||
|
||||
Assert.False(result.Success);
|
||||
Assert.False(result.Retryable);
|
||||
Assert.True(result.Retryable); // Null response is classified as TransientError (retryable)
|
||||
Assert.Equal("Slack response was null", result.FailureReason);
|
||||
Assert.Equal(result.Message, message);
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
using Bit.Core.AdminConsole.Models.Data.EventIntegrations;
|
||||
using System.Text.Json;
|
||||
using Bit.Core.AdminConsole.Models.Data.EventIntegrations;
|
||||
using Bit.Core.Services;
|
||||
using Bit.Test.Common.AutoFixture;
|
||||
using Bit.Test.Common.AutoFixture.Attributes;
|
||||
@@ -42,9 +43,77 @@ public class TeamsIntegrationHandlerTests
|
||||
);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task HandleAsync_ArgumentException_ReturnsConfigurationError(IntegrationMessage<TeamsIntegrationConfigurationDetails> message)
|
||||
{
|
||||
var sutProvider = GetSutProvider();
|
||||
message.Configuration = new TeamsIntegrationConfigurationDetails(_channelId, _serviceUrl);
|
||||
|
||||
sutProvider.GetDependency<ITeamsService>()
|
||||
.SendMessageToChannelAsync(Arg.Any<Uri>(), Arg.Any<string>(), Arg.Any<string>())
|
||||
.ThrowsAsync(new ArgumentException("argument error"));
|
||||
var result = await sutProvider.Sut.HandleAsync(message);
|
||||
|
||||
Assert.False(result.Success);
|
||||
Assert.Equal(IntegrationFailureCategory.ConfigurationError, result.Category);
|
||||
Assert.False(result.Retryable);
|
||||
Assert.Equal(result.Message, message);
|
||||
|
||||
await sutProvider.GetDependency<ITeamsService>().Received(1).SendMessageToChannelAsync(
|
||||
Arg.Is(AssertHelper.AssertPropertyEqual(_serviceUrl)),
|
||||
Arg.Is(AssertHelper.AssertPropertyEqual(_channelId)),
|
||||
Arg.Is(AssertHelper.AssertPropertyEqual(message.RenderedTemplate))
|
||||
);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task HandleAsync_HttpExceptionNonRetryable_ReturnsFalseAndNotRetryable(IntegrationMessage<TeamsIntegrationConfigurationDetails> message)
|
||||
public async Task HandleAsync_JsonException_ReturnsPermanentFailure(IntegrationMessage<TeamsIntegrationConfigurationDetails> message)
|
||||
{
|
||||
var sutProvider = GetSutProvider();
|
||||
message.Configuration = new TeamsIntegrationConfigurationDetails(_channelId, _serviceUrl);
|
||||
|
||||
sutProvider.GetDependency<ITeamsService>()
|
||||
.SendMessageToChannelAsync(Arg.Any<Uri>(), Arg.Any<string>(), Arg.Any<string>())
|
||||
.ThrowsAsync(new JsonException("JSON error"));
|
||||
var result = await sutProvider.Sut.HandleAsync(message);
|
||||
|
||||
Assert.False(result.Success);
|
||||
Assert.Equal(IntegrationFailureCategory.PermanentFailure, result.Category);
|
||||
Assert.False(result.Retryable);
|
||||
Assert.Equal(result.Message, message);
|
||||
|
||||
await sutProvider.GetDependency<ITeamsService>().Received(1).SendMessageToChannelAsync(
|
||||
Arg.Is(AssertHelper.AssertPropertyEqual(_serviceUrl)),
|
||||
Arg.Is(AssertHelper.AssertPropertyEqual(_channelId)),
|
||||
Arg.Is(AssertHelper.AssertPropertyEqual(message.RenderedTemplate))
|
||||
);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task HandleAsync_UriFormatException_ReturnsConfigurationError(IntegrationMessage<TeamsIntegrationConfigurationDetails> message)
|
||||
{
|
||||
var sutProvider = GetSutProvider();
|
||||
message.Configuration = new TeamsIntegrationConfigurationDetails(_channelId, _serviceUrl);
|
||||
|
||||
sutProvider.GetDependency<ITeamsService>()
|
||||
.SendMessageToChannelAsync(Arg.Any<Uri>(), Arg.Any<string>(), Arg.Any<string>())
|
||||
.ThrowsAsync(new UriFormatException("Bad URI"));
|
||||
var result = await sutProvider.Sut.HandleAsync(message);
|
||||
|
||||
Assert.False(result.Success);
|
||||
Assert.Equal(IntegrationFailureCategory.ConfigurationError, result.Category);
|
||||
Assert.False(result.Retryable);
|
||||
Assert.Equal(result.Message, message);
|
||||
|
||||
await sutProvider.GetDependency<ITeamsService>().Received(1).SendMessageToChannelAsync(
|
||||
Arg.Is(AssertHelper.AssertPropertyEqual(_serviceUrl)),
|
||||
Arg.Is(AssertHelper.AssertPropertyEqual(_channelId)),
|
||||
Arg.Is(AssertHelper.AssertPropertyEqual(message.RenderedTemplate))
|
||||
);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task HandleAsync_HttpExceptionForbidden_ReturnsAuthenticationFailed(IntegrationMessage<TeamsIntegrationConfigurationDetails> message)
|
||||
{
|
||||
var sutProvider = GetSutProvider();
|
||||
message.Configuration = new TeamsIntegrationConfigurationDetails(_channelId, _serviceUrl);
|
||||
@@ -62,6 +131,7 @@ public class TeamsIntegrationHandlerTests
|
||||
var result = await sutProvider.Sut.HandleAsync(message);
|
||||
|
||||
Assert.False(result.Success);
|
||||
Assert.Equal(IntegrationFailureCategory.AuthenticationFailed, result.Category);
|
||||
Assert.False(result.Retryable);
|
||||
Assert.Equal(result.Message, message);
|
||||
|
||||
@@ -73,7 +143,7 @@ public class TeamsIntegrationHandlerTests
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task HandleAsync_HttpExceptionRetryable_ReturnsFalseAndRetryable(IntegrationMessage<TeamsIntegrationConfigurationDetails> message)
|
||||
public async Task HandleAsync_HttpExceptionTooManyRequests_ReturnsRateLimited(IntegrationMessage<TeamsIntegrationConfigurationDetails> message)
|
||||
{
|
||||
var sutProvider = GetSutProvider();
|
||||
message.Configuration = new TeamsIntegrationConfigurationDetails(_channelId, _serviceUrl);
|
||||
@@ -92,6 +162,7 @@ public class TeamsIntegrationHandlerTests
|
||||
var result = await sutProvider.Sut.HandleAsync(message);
|
||||
|
||||
Assert.False(result.Success);
|
||||
Assert.Equal(IntegrationFailureCategory.RateLimited, result.Category);
|
||||
Assert.True(result.Retryable);
|
||||
Assert.Equal(result.Message, message);
|
||||
|
||||
@@ -103,7 +174,7 @@ public class TeamsIntegrationHandlerTests
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task HandleAsync_UnknownException_ReturnsFalseAndNotRetryable(IntegrationMessage<TeamsIntegrationConfigurationDetails> message)
|
||||
public async Task HandleAsync_UnknownException_ReturnsTransientError(IntegrationMessage<TeamsIntegrationConfigurationDetails> message)
|
||||
{
|
||||
var sutProvider = GetSutProvider();
|
||||
message.Configuration = new TeamsIntegrationConfigurationDetails(_channelId, _serviceUrl);
|
||||
@@ -114,7 +185,8 @@ public class TeamsIntegrationHandlerTests
|
||||
var result = await sutProvider.Sut.HandleAsync(message);
|
||||
|
||||
Assert.False(result.Success);
|
||||
Assert.False(result.Retryable);
|
||||
Assert.Equal(IntegrationFailureCategory.TransientError, result.Category);
|
||||
Assert.True(result.Retryable);
|
||||
Assert.Equal(result.Message, message);
|
||||
|
||||
await sutProvider.GetDependency<ITeamsService>().Received(1).SendMessageToChannelAsync(
|
||||
|
||||
@@ -51,7 +51,7 @@ public class WebhookIntegrationHandlerTests
|
||||
|
||||
Assert.True(result.Success);
|
||||
Assert.Equal(result.Message, message);
|
||||
Assert.Empty(result.FailureReason);
|
||||
Assert.Null(result.FailureReason);
|
||||
|
||||
sutProvider.GetDependency<IHttpClientFactory>().Received(1).CreateClient(
|
||||
Arg.Is(AssertHelper.AssertPropertyEqual(WebhookIntegrationHandler.HttpClientName))
|
||||
@@ -79,7 +79,7 @@ public class WebhookIntegrationHandlerTests
|
||||
|
||||
Assert.True(result.Success);
|
||||
Assert.Equal(result.Message, message);
|
||||
Assert.Empty(result.FailureReason);
|
||||
Assert.Null(result.FailureReason);
|
||||
|
||||
sutProvider.GetDependency<IHttpClientFactory>().Received(1).CreateClient(
|
||||
Arg.Is(AssertHelper.AssertPropertyEqual(WebhookIntegrationHandler.HttpClientName))
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
using Bit.Core.Auth.Enums;
|
||||
using Bit.Core.Auth.Models;
|
||||
using Bit.Core.Auth.UserFeatures.TwoFactorAuth;
|
||||
using Bit.Core.Billing.Premium.Queries;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Exceptions;
|
||||
using Bit.Core.Models.Data;
|
||||
using Bit.Core.Models.Data.Organizations.OrganizationUsers;
|
||||
using Bit.Core.Repositories;
|
||||
using Bit.Core.Services;
|
||||
using Bit.Core.Utilities;
|
||||
using Bit.Test.Common.AutoFixture;
|
||||
using Bit.Test.Common.AutoFixture.Attributes;
|
||||
@@ -404,6 +407,277 @@ public class TwoFactorIsEnabledQueryTests
|
||||
.GetCalculatedPremiumAsync(default);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData((IEnumerable<Guid>)null)]
|
||||
[BitAutoData([])]
|
||||
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithNoUserIds_ReturnsEmpty(
|
||||
IEnumerable<Guid> userIds,
|
||||
SutProvider<TwoFactorIsEnabledQuery> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
sutProvider.GetDependency<IFeatureService>()
|
||||
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
|
||||
.Returns(true);
|
||||
|
||||
// Act
|
||||
var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(userIds);
|
||||
|
||||
// Assert
|
||||
Assert.Empty(result);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData(TwoFactorProviderType.Duo)]
|
||||
[BitAutoData(TwoFactorProviderType.YubiKey)]
|
||||
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithMixedScenarios_ReturnsCorrectResults(
|
||||
TwoFactorProviderType premiumProviderType,
|
||||
SutProvider<TwoFactorIsEnabledQuery> sutProvider,
|
||||
User user1,
|
||||
User user2,
|
||||
User user3)
|
||||
{
|
||||
// Arrange
|
||||
sutProvider.GetDependency<IFeatureService>()
|
||||
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
|
||||
.Returns(true);
|
||||
|
||||
var users = new List<User> { user1, user2, user3 };
|
||||
var userIds = users.Select(u => u.Id).ToList();
|
||||
|
||||
// User 1: Non-premium provider → 2FA enabled
|
||||
user1.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
|
||||
{
|
||||
{ TwoFactorProviderType.Authenticator, new TwoFactorProvider { Enabled = true } }
|
||||
});
|
||||
|
||||
// User 2: Premium provider + has premium → 2FA enabled
|
||||
user2.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
|
||||
{
|
||||
{ premiumProviderType, new TwoFactorProvider { Enabled = true } }
|
||||
});
|
||||
|
||||
// User 3: Premium provider + no premium → 2FA disabled
|
||||
user3.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
|
||||
{
|
||||
{ premiumProviderType, new TwoFactorProvider { Enabled = true } }
|
||||
});
|
||||
|
||||
var premiumStatus = new Dictionary<Guid, bool>
|
||||
{
|
||||
{ user2.Id, true },
|
||||
{ user3.Id, false }
|
||||
};
|
||||
|
||||
sutProvider.GetDependency<IUserRepository>()
|
||||
.GetManyAsync(Arg.Is<IEnumerable<Guid>>(ids => ids.SequenceEqual(userIds)))
|
||||
.Returns(users);
|
||||
|
||||
sutProvider.GetDependency<IHasPremiumAccessQuery>()
|
||||
.HasPremiumAccessAsync(Arg.Is<IEnumerable<Guid>>(ids =>
|
||||
ids.Count() == 2 && ids.Contains(user2.Id) && ids.Contains(user3.Id)))
|
||||
.Returns(premiumStatus);
|
||||
|
||||
// Act
|
||||
var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(userIds);
|
||||
|
||||
// Assert
|
||||
Assert.Contains(result, res => res.userId == user1.Id && res.twoFactorIsEnabled == true); // Non-premium provider
|
||||
Assert.Contains(result, res => res.userId == user2.Id && res.twoFactorIsEnabled == true); // Premium + has premium
|
||||
Assert.Contains(result, res => res.userId == user3.Id && res.twoFactorIsEnabled == false); // Premium + no premium
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData(TwoFactorProviderType.Duo)]
|
||||
[BitAutoData(TwoFactorProviderType.YubiKey)]
|
||||
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_OnlyChecksPremiumAccessForUsersWhoNeedIt(
|
||||
TwoFactorProviderType premiumProviderType,
|
||||
SutProvider<TwoFactorIsEnabledQuery> sutProvider,
|
||||
User user1,
|
||||
User user2,
|
||||
User user3)
|
||||
{
|
||||
// Arrange
|
||||
sutProvider.GetDependency<IFeatureService>()
|
||||
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
|
||||
.Returns(true);
|
||||
|
||||
var users = new List<User> { user1, user2, user3 };
|
||||
var userIds = users.Select(u => u.Id).ToList();
|
||||
|
||||
// User 1: Has non-premium provider - should NOT trigger premium check
|
||||
user1.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
|
||||
{
|
||||
{ TwoFactorProviderType.Authenticator, new TwoFactorProvider { Enabled = true } }
|
||||
});
|
||||
|
||||
// User 2 & 3: Have only premium providers - SHOULD trigger premium check
|
||||
user2.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
|
||||
{
|
||||
{ premiumProviderType, new TwoFactorProvider { Enabled = true } }
|
||||
});
|
||||
user3.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
|
||||
{
|
||||
{ premiumProviderType, new TwoFactorProvider { Enabled = true } }
|
||||
});
|
||||
|
||||
var premiumStatus = new Dictionary<Guid, bool>
|
||||
{
|
||||
{ user2.Id, true },
|
||||
{ user3.Id, false }
|
||||
};
|
||||
|
||||
sutProvider.GetDependency<IUserRepository>()
|
||||
.GetManyAsync(Arg.Is<IEnumerable<Guid>>(ids => ids.SequenceEqual(userIds)))
|
||||
.Returns(users);
|
||||
|
||||
sutProvider.GetDependency<IHasPremiumAccessQuery>()
|
||||
.HasPremiumAccessAsync(Arg.Is<IEnumerable<Guid>>(ids =>
|
||||
ids.Count() == 2 && ids.Contains(user2.Id) && ids.Contains(user3.Id)))
|
||||
.Returns(premiumStatus);
|
||||
|
||||
// Act
|
||||
var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(userIds);
|
||||
|
||||
// Assert - Verify optimization: premium checked ONLY for users 2 and 3 (not user 1)
|
||||
await sutProvider.GetDependency<IHasPremiumAccessQuery>()
|
||||
.Received(1)
|
||||
.HasPremiumAccessAsync(Arg.Is<IEnumerable<Guid>>(ids =>
|
||||
ids.Count() == 2 && ids.Contains(user2.Id) && ids.Contains(user3.Id)));
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithNoUserIds_ReturnsAllTwoFactorDisabled(
|
||||
SutProvider<TwoFactorIsEnabledQuery> sutProvider,
|
||||
List<OrganizationUserUserDetails> users)
|
||||
{
|
||||
// Arrange
|
||||
sutProvider.GetDependency<IFeatureService>()
|
||||
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
|
||||
.Returns(true);
|
||||
|
||||
foreach (var user in users)
|
||||
{
|
||||
user.UserId = null;
|
||||
}
|
||||
|
||||
// Act
|
||||
var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(users);
|
||||
|
||||
// Assert
|
||||
foreach (var user in users)
|
||||
{
|
||||
Assert.Contains(result, res => res.user.Equals(user) && res.twoFactorIsEnabled == false);
|
||||
}
|
||||
|
||||
// No UserIds were supplied so no calls to the UserRepository should have been made
|
||||
await sutProvider.GetDependency<IUserRepository>()
|
||||
.DidNotReceiveWithAnyArgs()
|
||||
.GetManyAsync(default);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData(TwoFactorProviderType.Authenticator, true)] // Non-premium provider
|
||||
[BitAutoData(TwoFactorProviderType.Duo, true)] // Premium provider with premium access
|
||||
[BitAutoData(TwoFactorProviderType.YubiKey, false)] // Premium provider without premium access
|
||||
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_SingleUser_VariousScenarios(
|
||||
TwoFactorProviderType providerType,
|
||||
bool hasPremiumAccess,
|
||||
SutProvider<TwoFactorIsEnabledQuery> sutProvider,
|
||||
User user)
|
||||
{
|
||||
// Arrange
|
||||
sutProvider.GetDependency<IFeatureService>()
|
||||
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
|
||||
.Returns(true);
|
||||
|
||||
user.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
|
||||
{
|
||||
{ providerType, new TwoFactorProvider { Enabled = true } }
|
||||
});
|
||||
|
||||
sutProvider.GetDependency<IHasPremiumAccessQuery>()
|
||||
.HasPremiumAccessAsync(user.Id)
|
||||
.Returns(hasPremiumAccess);
|
||||
|
||||
// Act
|
||||
var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(user);
|
||||
|
||||
// Assert
|
||||
var requiresPremium = TwoFactorProvider.RequiresPremium(providerType);
|
||||
var expectedResult = !requiresPremium || hasPremiumAccess;
|
||||
Assert.Equal(expectedResult, result);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithNoEnabledProviders_ReturnsFalse(
|
||||
SutProvider<TwoFactorIsEnabledQuery> sutProvider,
|
||||
User user)
|
||||
{
|
||||
// Arrange
|
||||
sutProvider.GetDependency<IFeatureService>()
|
||||
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
|
||||
.Returns(true);
|
||||
|
||||
user.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
|
||||
{
|
||||
{ TwoFactorProviderType.Email, new TwoFactorProvider { Enabled = false } }
|
||||
});
|
||||
|
||||
// Act
|
||||
var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(user);
|
||||
|
||||
// Assert
|
||||
Assert.False(result);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithNullProviders_ReturnsFalse(
|
||||
SutProvider<TwoFactorIsEnabledQuery> sutProvider,
|
||||
User user)
|
||||
{
|
||||
// Arrange
|
||||
sutProvider.GetDependency<IFeatureService>()
|
||||
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
|
||||
.Returns(true);
|
||||
|
||||
user.TwoFactorProviders = null;
|
||||
|
||||
// Act
|
||||
var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(user);
|
||||
|
||||
// Assert
|
||||
Assert.False(result);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_UserNotFound_ThrowsNotFoundException(
|
||||
SutProvider<TwoFactorIsEnabledQuery> sutProvider,
|
||||
Guid userId)
|
||||
{
|
||||
// Arrange
|
||||
sutProvider.GetDependency<IFeatureService>()
|
||||
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
|
||||
.Returns(true);
|
||||
|
||||
var testUser = new TestTwoFactorProviderUser
|
||||
{
|
||||
Id = userId,
|
||||
TwoFactorProviders = null
|
||||
};
|
||||
|
||||
sutProvider.GetDependency<IUserRepository>()
|
||||
.GetByIdAsync(userId)
|
||||
.Returns((User)null);
|
||||
|
||||
// Act & Assert
|
||||
await Assert.ThrowsAsync<NotFoundException>(
|
||||
async () => await sutProvider.Sut.TwoFactorIsEnabledAsync(testUser));
|
||||
}
|
||||
|
||||
private class TestTwoFactorProviderUser : ITwoFactorProvidersUser
|
||||
{
|
||||
public Guid? Id { get; set; }
|
||||
@@ -418,10 +692,5 @@ public class TwoFactorIsEnabledQueryTests
|
||||
{
|
||||
return Id;
|
||||
}
|
||||
|
||||
public bool GetPremium()
|
||||
{
|
||||
return Premium;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,234 @@
|
||||
using Bit.Core.Billing.Premium.Models;
|
||||
using Bit.Core.Billing.Premium.Queries;
|
||||
using Bit.Core.Exceptions;
|
||||
using Bit.Core.Repositories;
|
||||
using Bit.Test.Common.AutoFixture;
|
||||
using Bit.Test.Common.AutoFixture.Attributes;
|
||||
using NSubstitute;
|
||||
using Xunit;
|
||||
|
||||
namespace Bit.Core.Test.Billing.Premium.Queries;
|
||||
|
||||
[SutProviderCustomize]
|
||||
public class HasPremiumAccessQueryTests
|
||||
{
|
||||
[Theory, BitAutoData]
|
||||
public async Task HasPremiumAccessAsync_WhenUserHasPersonalPremium_ReturnsTrue(
|
||||
UserPremiumAccess user,
|
||||
SutProvider<HasPremiumAccessQuery> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
user.PersonalPremium = true;
|
||||
user.OrganizationPremium = false;
|
||||
|
||||
sutProvider.GetDependency<IUserRepository>()
|
||||
.GetPremiumAccessAsync(user.Id)
|
||||
.Returns(user);
|
||||
|
||||
// Act
|
||||
var result = await sutProvider.Sut.HasPremiumAccessAsync(user.Id);
|
||||
|
||||
// Assert
|
||||
Assert.True(result);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task HasPremiumAccessAsync_WhenUserHasNoPersonalPremiumButHasOrgPremium_ReturnsTrue(
|
||||
UserPremiumAccess user,
|
||||
SutProvider<HasPremiumAccessQuery> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
user.PersonalPremium = false;
|
||||
user.OrganizationPremium = true; // Has org premium
|
||||
|
||||
sutProvider.GetDependency<IUserRepository>()
|
||||
.GetPremiumAccessAsync(user.Id)
|
||||
.Returns(user);
|
||||
|
||||
// Act
|
||||
var result = await sutProvider.Sut.HasPremiumAccessAsync(user.Id);
|
||||
|
||||
// Assert
|
||||
Assert.True(result);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task HasPremiumAccessAsync_WhenUserHasNoPersonalPremiumAndNoOrgPremium_ReturnsFalse(
|
||||
UserPremiumAccess user,
|
||||
SutProvider<HasPremiumAccessQuery> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
user.PersonalPremium = false;
|
||||
user.OrganizationPremium = false;
|
||||
|
||||
sutProvider.GetDependency<IUserRepository>()
|
||||
.GetPremiumAccessAsync(user.Id)
|
||||
.Returns(user);
|
||||
|
||||
// Act
|
||||
var result = await sutProvider.Sut.HasPremiumAccessAsync(user.Id);
|
||||
|
||||
// Assert
|
||||
Assert.False(result);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task HasPremiumAccessAsync_WhenUserNotFound_ThrowsNotFoundException(
|
||||
Guid userId,
|
||||
SutProvider<HasPremiumAccessQuery> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
sutProvider.GetDependency<IUserRepository>()
|
||||
.GetPremiumAccessAsync(userId)
|
||||
.Returns((UserPremiumAccess?)null);
|
||||
|
||||
// Act & Assert
|
||||
await Assert.ThrowsAsync<NotFoundException>(
|
||||
() => sutProvider.Sut.HasPremiumAccessAsync(userId));
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task HasPremiumFromOrganizationAsync_WhenUserHasNoOrganizations_ReturnsFalse(
|
||||
UserPremiumAccess user,
|
||||
SutProvider<HasPremiumAccessQuery> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
user.PersonalPremium = false;
|
||||
user.OrganizationPremium = false; // No premium from anywhere
|
||||
|
||||
sutProvider.GetDependency<IUserRepository>()
|
||||
.GetPremiumAccessAsync(user.Id)
|
||||
.Returns(user);
|
||||
|
||||
// Act
|
||||
var result = await sutProvider.Sut.HasPremiumFromOrganizationAsync(user.Id);
|
||||
|
||||
// Assert
|
||||
Assert.False(result);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task HasPremiumFromOrganizationAsync_WhenUserHasPremiumFromOrg_ReturnsTrue(
|
||||
UserPremiumAccess user,
|
||||
SutProvider<HasPremiumAccessQuery> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
user.PersonalPremium = false; // No personal premium
|
||||
user.OrganizationPremium = true; // But has premium from org
|
||||
|
||||
sutProvider.GetDependency<IUserRepository>()
|
||||
.GetPremiumAccessAsync(user.Id)
|
||||
.Returns(user);
|
||||
|
||||
// Act
|
||||
var result = await sutProvider.Sut.HasPremiumFromOrganizationAsync(user.Id);
|
||||
|
||||
// Assert
|
||||
Assert.True(result);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task HasPremiumFromOrganizationAsync_WhenUserHasOnlyPersonalPremium_ReturnsFalse(
|
||||
UserPremiumAccess user,
|
||||
SutProvider<HasPremiumAccessQuery> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
user.PersonalPremium = true; // Has personal premium
|
||||
user.OrganizationPremium = false; // Not in any org that grants premium
|
||||
|
||||
sutProvider.GetDependency<IUserRepository>()
|
||||
.GetPremiumAccessAsync(user.Id)
|
||||
.Returns(user);
|
||||
|
||||
// Act
|
||||
var result = await sutProvider.Sut.HasPremiumFromOrganizationAsync(user.Id);
|
||||
|
||||
// Assert
|
||||
Assert.False(result); // Should return false because user is not in an org that grants premium
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task HasPremiumFromOrganizationAsync_WhenUserHasBothPersonalAndOrgPremium_ReturnsTrue(
|
||||
UserPremiumAccess user,
|
||||
SutProvider<HasPremiumAccessQuery> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
user.PersonalPremium = true; // Has personal premium
|
||||
user.OrganizationPremium = true; // Also in an org that grants premium
|
||||
|
||||
sutProvider.GetDependency<IUserRepository>()
|
||||
.GetPremiumAccessAsync(user.Id)
|
||||
.Returns(user);
|
||||
|
||||
// Act
|
||||
var result = await sutProvider.Sut.HasPremiumFromOrganizationAsync(user.Id);
|
||||
|
||||
// Assert
|
||||
Assert.True(result); // Should return true because user IS in an org that grants premium (regardless of personal premium)
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task HasPremiumFromOrganizationAsync_WhenUserNotFound_ThrowsNotFoundException(
|
||||
Guid userId,
|
||||
SutProvider<HasPremiumAccessQuery> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
sutProvider.GetDependency<IUserRepository>()
|
||||
.GetPremiumAccessAsync(userId)
|
||||
.Returns((UserPremiumAccess?)null);
|
||||
|
||||
// Act & Assert
|
||||
await Assert.ThrowsAsync<NotFoundException>(
|
||||
() => sutProvider.Sut.HasPremiumFromOrganizationAsync(userId));
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task HasPremiumAccessAsync_Bulk_WhenEmptyList_ReturnsEmptyDictionary(
|
||||
SutProvider<HasPremiumAccessQuery> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
var userIds = new List<Guid>();
|
||||
|
||||
sutProvider.GetDependency<IUserRepository>()
|
||||
.GetPremiumAccessByIdsAsync(userIds)
|
||||
.Returns(new List<UserPremiumAccess>());
|
||||
|
||||
// Act
|
||||
var result = await sutProvider.Sut.HasPremiumAccessAsync(userIds);
|
||||
|
||||
// Assert
|
||||
Assert.Empty(result);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task HasPremiumAccessAsync_Bulk_ReturnsCorrectStatus(
|
||||
UserPremiumAccess user1,
|
||||
UserPremiumAccess user2,
|
||||
UserPremiumAccess user3,
|
||||
SutProvider<HasPremiumAccessQuery> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
user1.PersonalPremium = true;
|
||||
user1.OrganizationPremium = false;
|
||||
user2.PersonalPremium = false;
|
||||
user2.OrganizationPremium = false;
|
||||
user3.PersonalPremium = false;
|
||||
user3.OrganizationPremium = true;
|
||||
|
||||
var users = new List<UserPremiumAccess> { user1, user2, user3 };
|
||||
var userIds = users.Select(u => u.Id).ToList();
|
||||
|
||||
sutProvider.GetDependency<IUserRepository>()
|
||||
.GetPremiumAccessByIdsAsync(Arg.Is<IEnumerable<Guid>>(ids => ids.SequenceEqual(userIds)))
|
||||
.Returns(users);
|
||||
|
||||
// Act
|
||||
var result = await sutProvider.Sut.HasPremiumAccessAsync(userIds);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(3, result.Count);
|
||||
Assert.True(result[user1.Id]); // Personal premium
|
||||
Assert.False(result[user2.Id]); // No premium
|
||||
Assert.True(result[user3.Id]); // Organization premium
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,4 @@
|
||||
using Bit.Core.AdminConsole.Entities;
|
||||
using Bit.Core.Billing;
|
||||
using Bit.Core.Billing.Constants;
|
||||
using Bit.Core.Billing.Enums;
|
||||
using Bit.Core.Billing.Models.Sales;
|
||||
@@ -391,12 +390,13 @@ public class OrganizationBillingServiceTests
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task UpdateOrganizationNameAndEmail_WhenNameIsLong_TruncatesTo30Characters(
|
||||
public async Task UpdateOrganizationNameAndEmail_WhenNameIsLong_UsesFullName(
|
||||
Organization organization,
|
||||
SutProvider<OrganizationBillingService> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
organization.Name = "This is a very long organization name that exceeds thirty characters";
|
||||
var longName = "This is a very long organization name that exceeds thirty characters";
|
||||
organization.Name = longName;
|
||||
|
||||
CustomerUpdateOptions capturedOptions = null;
|
||||
sutProvider.GetDependency<IStripeAdapter>()
|
||||
@@ -420,14 +420,11 @@ public class OrganizationBillingServiceTests
|
||||
Assert.NotNull(capturedOptions.InvoiceSettings.CustomFields);
|
||||
|
||||
var customField = capturedOptions.InvoiceSettings.CustomFields.First();
|
||||
Assert.Equal(30, customField.Value.Length);
|
||||
|
||||
var expectedCustomFieldDisplayName = "This is a very long organizati";
|
||||
Assert.Equal(expectedCustomFieldDisplayName, customField.Value);
|
||||
Assert.Equal(longName, customField.Value);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task UpdateOrganizationNameAndEmail_WhenGatewayCustomerIdIsNull_ThrowsBillingException(
|
||||
public async Task UpdateOrganizationNameAndEmail_WhenGatewayCustomerIdIsNull_LogsWarningAndReturns(
|
||||
Organization organization,
|
||||
SutProvider<OrganizationBillingService> sutProvider)
|
||||
{
|
||||
@@ -435,15 +432,93 @@ public class OrganizationBillingServiceTests
|
||||
organization.GatewayCustomerId = null;
|
||||
organization.Name = "Test Organization";
|
||||
organization.BillingEmail = "billing@example.com";
|
||||
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
|
||||
|
||||
// Act & Assert
|
||||
var exception = await Assert.ThrowsAsync<BillingException>(
|
||||
() => sutProvider.Sut.UpdateOrganizationNameAndEmail(organization));
|
||||
// Act
|
||||
await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization);
|
||||
|
||||
Assert.Contains("Cannot update an organization in Stripe without a GatewayCustomerId.", exception.Response);
|
||||
// Assert
|
||||
await stripeAdapter.DidNotReceive().UpdateCustomerAsync(
|
||||
Arg.Any<string>(),
|
||||
Arg.Any<CustomerUpdateOptions>());
|
||||
}
|
||||
|
||||
await sutProvider.GetDependency<IStripeAdapter>()
|
||||
.DidNotReceiveWithAnyArgs()
|
||||
.UpdateCustomerAsync(Arg.Any<string>(), Arg.Any<CustomerUpdateOptions>());
|
||||
[Theory, BitAutoData]
|
||||
public async Task UpdateOrganizationNameAndEmail_WhenGatewayCustomerIdIsEmpty_LogsWarningAndReturns(
|
||||
Organization organization,
|
||||
SutProvider<OrganizationBillingService> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
organization.GatewayCustomerId = "";
|
||||
organization.Name = "Test Organization";
|
||||
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
|
||||
|
||||
// Act
|
||||
await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization);
|
||||
|
||||
// Assert
|
||||
await stripeAdapter.DidNotReceive().UpdateCustomerAsync(
|
||||
Arg.Any<string>(),
|
||||
Arg.Any<CustomerUpdateOptions>());
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task UpdateOrganizationNameAndEmail_WhenNameIsNull_LogsWarningAndReturns(
|
||||
Organization organization,
|
||||
SutProvider<OrganizationBillingService> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
organization.Name = null;
|
||||
organization.GatewayCustomerId = "cus_test123";
|
||||
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
|
||||
|
||||
// Act
|
||||
await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization);
|
||||
|
||||
// Assert
|
||||
await stripeAdapter.DidNotReceive().UpdateCustomerAsync(
|
||||
Arg.Any<string>(),
|
||||
Arg.Any<CustomerUpdateOptions>());
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task UpdateOrganizationNameAndEmail_WhenNameIsEmpty_LogsWarningAndReturns(
|
||||
Organization organization,
|
||||
SutProvider<OrganizationBillingService> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
organization.Name = "";
|
||||
organization.GatewayCustomerId = "cus_test123";
|
||||
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
|
||||
|
||||
// Act
|
||||
await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization);
|
||||
|
||||
// Assert
|
||||
await stripeAdapter.DidNotReceive().UpdateCustomerAsync(
|
||||
Arg.Any<string>(),
|
||||
Arg.Any<CustomerUpdateOptions>());
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task UpdateOrganizationNameAndEmail_WhenBillingEmailIsNull_UpdatesWithNull(
|
||||
Organization organization,
|
||||
SutProvider<OrganizationBillingService> sutProvider)
|
||||
{
|
||||
// Arrange
|
||||
organization.Name = "Test Organization";
|
||||
organization.BillingEmail = null;
|
||||
organization.GatewayCustomerId = "cus_test123";
|
||||
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
|
||||
|
||||
// Act
|
||||
await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization);
|
||||
|
||||
// Assert
|
||||
await stripeAdapter.Received(1).UpdateCustomerAsync(
|
||||
organization.GatewayCustomerId,
|
||||
Arg.Is<CustomerUpdateOptions>(options =>
|
||||
options.Email == null &&
|
||||
options.Description == organization.Name));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,7 +81,6 @@ public class IdentityServerTests : IClassFixture<IdentityApplicationFactory>
|
||||
var root = body.RootElement;
|
||||
AssertRefreshTokenExists(root);
|
||||
AssertHelper.AssertJsonProperty(root, "ForcePasswordReset", JsonValueKind.False);
|
||||
AssertHelper.AssertJsonProperty(root, "ResetMasterPassword", JsonValueKind.False);
|
||||
var kdf = AssertHelper.AssertJsonProperty(root, "Kdf", JsonValueKind.Number).GetInt32();
|
||||
Assert.Equal(0, kdf);
|
||||
var kdfIterations = AssertHelper.AssertJsonProperty(root, "KdfIterations", JsonValueKind.Number).GetInt32();
|
||||
|
||||
@@ -144,4 +144,69 @@ public class CollectionRepositoryReplaceTests
|
||||
await userRepository.DeleteAsync(user);
|
||||
await organizationRepository.DeleteAsync(organization);
|
||||
}
|
||||
|
||||
[Theory, DatabaseData]
|
||||
public async Task ReplaceAsync_WhenNotPassingGroupsOrUsers_DoesNotDeleteAccess(
|
||||
IUserRepository userRepository,
|
||||
IOrganizationRepository organizationRepository,
|
||||
IOrganizationUserRepository organizationUserRepository,
|
||||
IGroupRepository groupRepository,
|
||||
ICollectionRepository collectionRepository)
|
||||
{
|
||||
// Arrange
|
||||
var organization = await organizationRepository.CreateTestOrganizationAsync();
|
||||
|
||||
var user1 = await userRepository.CreateTestUserAsync();
|
||||
var orgUser1 = await organizationUserRepository.CreateTestOrganizationUserAsync(organization, user1);
|
||||
|
||||
var user2 = await userRepository.CreateTestUserAsync();
|
||||
var orgUser2 = await organizationUserRepository.CreateTestOrganizationUserAsync(organization, user2);
|
||||
|
||||
var group1 = await groupRepository.CreateTestGroupAsync(organization);
|
||||
var group2 = await groupRepository.CreateTestGroupAsync(organization);
|
||||
|
||||
var collection = new Collection
|
||||
{
|
||||
Name = "Test Collection Name",
|
||||
OrganizationId = organization.Id,
|
||||
};
|
||||
|
||||
await collectionRepository.CreateAsync(collection,
|
||||
[
|
||||
new CollectionAccessSelection { Id = group1.Id, Manage = true, HidePasswords = true, ReadOnly = false, },
|
||||
new CollectionAccessSelection { Id = group2.Id, Manage = false, HidePasswords = false, ReadOnly = true, },
|
||||
],
|
||||
[
|
||||
new CollectionAccessSelection { Id = orgUser1.Id, Manage = true, HidePasswords = false, ReadOnly = true },
|
||||
new CollectionAccessSelection { Id = orgUser2.Id, Manage = false, HidePasswords = true, ReadOnly = false },
|
||||
]
|
||||
);
|
||||
|
||||
// Act
|
||||
collection.Name = "Updated Collection Name";
|
||||
|
||||
await collectionRepository.ReplaceAsync(collection, null, null);
|
||||
|
||||
// Assert
|
||||
var (actualCollection, actualAccess) = await collectionRepository.GetByIdWithAccessAsync(collection.Id);
|
||||
|
||||
Assert.NotNull(actualCollection);
|
||||
Assert.Equal("Updated Collection Name", actualCollection.Name);
|
||||
|
||||
var groups = actualAccess.Groups.ToArray();
|
||||
Assert.Equal(2, groups.Length);
|
||||
Assert.Single(groups, g => g.Id == group1.Id && g.Manage && g.HidePasswords && !g.ReadOnly);
|
||||
Assert.Single(groups, g => g.Id == group2.Id && !g.Manage && !g.HidePasswords && g.ReadOnly);
|
||||
|
||||
var users = actualAccess.Users.ToArray();
|
||||
|
||||
Assert.Equal(2, users.Length);
|
||||
Assert.Single(users, u => u.Id == orgUser1.Id && u.Manage && !u.HidePasswords && u.ReadOnly);
|
||||
Assert.Single(users, u => u.Id == orgUser2.Id && !u.Manage && u.HidePasswords && !u.ReadOnly);
|
||||
|
||||
// Clean up data
|
||||
await userRepository.DeleteAsync(user1);
|
||||
await userRepository.DeleteAsync(user2);
|
||||
await organizationRepository.DeleteAsync(organization);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,4 +179,325 @@ public class UserRepositoryTests
|
||||
Assert.Equal(CollectionType.SharedCollection, updatedCollection2.Type);
|
||||
Assert.Equal(user2.Email, updatedCollection2.DefaultUserCollectionEmail);
|
||||
}
|
||||
|
||||
[Theory, DatabaseData]
|
||||
public async Task GetPremiumAccessAsync_WithPersonalPremium_ReturnsCorrectAccess(
|
||||
IUserRepository userRepository)
|
||||
{
|
||||
// Arrange
|
||||
var user = await userRepository.CreateAsync(new User
|
||||
{
|
||||
Name = "Premium User",
|
||||
Email = $"premium+{Guid.NewGuid()}@example.com",
|
||||
ApiKey = "TEST",
|
||||
SecurityStamp = "stamp",
|
||||
Premium = true
|
||||
});
|
||||
|
||||
// Act
|
||||
var result = await userRepository.GetPremiumAccessAsync(user.Id);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(result);
|
||||
Assert.True(result.PersonalPremium);
|
||||
Assert.False(result.OrganizationPremium);
|
||||
Assert.True(result.HasPremiumAccess);
|
||||
}
|
||||
|
||||
[Theory, DatabaseData]
|
||||
public async Task GetPremiumAccessAsync_WithOrganizationPremium_ReturnsCorrectAccess(
|
||||
IUserRepository userRepository,
|
||||
IOrganizationRepository organizationRepository,
|
||||
IOrganizationUserRepository organizationUserRepository)
|
||||
{
|
||||
// Arrange
|
||||
var user = await userRepository.CreateAsync(new User
|
||||
{
|
||||
Name = "Org User",
|
||||
Email = $"org+{Guid.NewGuid()}@example.com",
|
||||
ApiKey = "TEST",
|
||||
SecurityStamp = "stamp",
|
||||
Premium = false
|
||||
});
|
||||
|
||||
var organization = await organizationRepository.CreateTestOrganizationAsync();
|
||||
await organizationUserRepository.CreateTestOrganizationUserAsync(organization, user);
|
||||
|
||||
// Act
|
||||
var result = await userRepository.GetPremiumAccessAsync(user.Id);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(result);
|
||||
Assert.False(result.PersonalPremium);
|
||||
Assert.True(result.OrganizationPremium);
|
||||
Assert.True(result.HasPremiumAccess);
|
||||
}
|
||||
|
||||
[Theory, DatabaseData]
|
||||
public async Task GetPremiumAccessAsync_WithDisabledOrganization_ReturnsNoOrganizationPremium(
|
||||
IUserRepository userRepository,
|
||||
IOrganizationRepository organizationRepository,
|
||||
IOrganizationUserRepository organizationUserRepository)
|
||||
{
|
||||
// Arrange
|
||||
var user = await userRepository.CreateAsync(new User
|
||||
{
|
||||
Name = "User",
|
||||
Email = $"user+{Guid.NewGuid()}@example.com",
|
||||
ApiKey = "TEST",
|
||||
SecurityStamp = "stamp",
|
||||
Premium = false
|
||||
});
|
||||
|
||||
var organization = await organizationRepository.CreateTestOrganizationAsync();
|
||||
organization.Enabled = false;
|
||||
await organizationRepository.ReplaceAsync(organization);
|
||||
await organizationUserRepository.CreateTestOrganizationUserAsync(organization, user);
|
||||
|
||||
// Act
|
||||
var result = await userRepository.GetPremiumAccessAsync(user.Id);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(result);
|
||||
Assert.False(result.OrganizationPremium);
|
||||
Assert.False(result.HasPremiumAccess);
|
||||
}
|
||||
|
||||
[Theory, DatabaseData]
|
||||
public async Task GetPremiumAccessAsync_WithOrganizationUsersGetPremiumFalse_ReturnsNoOrganizationPremium(
|
||||
IUserRepository userRepository,
|
||||
IOrganizationRepository organizationRepository,
|
||||
IOrganizationUserRepository organizationUserRepository)
|
||||
{
|
||||
// Arrange
|
||||
var user = await userRepository.CreateAsync(new User
|
||||
{
|
||||
Name = "User",
|
||||
Email = $"{Guid.NewGuid()}@example.com",
|
||||
ApiKey = "TEST",
|
||||
SecurityStamp = "stamp",
|
||||
Premium = false
|
||||
});
|
||||
|
||||
var organization = await organizationRepository.CreateTestOrganizationAsync();
|
||||
organization.UsersGetPremium = false;
|
||||
await organizationRepository.ReplaceAsync(organization);
|
||||
await organizationUserRepository.CreateTestOrganizationUserAsync(organization, user);
|
||||
|
||||
// Act
|
||||
var result = await userRepository.GetPremiumAccessAsync(user.Id);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(result);
|
||||
Assert.False(result.OrganizationPremium);
|
||||
Assert.False(result.HasPremiumAccess);
|
||||
}
|
||||
|
||||
[Theory, DatabaseData]
|
||||
public async Task GetPremiumAccessAsync_WithMultipleOrganizations_OneProvidesPremium_ReturnsOrganizationPremium(
|
||||
IUserRepository userRepository,
|
||||
IOrganizationRepository organizationRepository,
|
||||
IOrganizationUserRepository organizationUserRepository)
|
||||
{
|
||||
// Arrange
|
||||
var user = await userRepository.CreateAsync(new User
|
||||
{
|
||||
Name = "User With Premium Org",
|
||||
Email = $"{Guid.NewGuid()}@example.com",
|
||||
ApiKey = "TEST",
|
||||
SecurityStamp = "stamp",
|
||||
Premium = false
|
||||
});
|
||||
|
||||
var orgWithPremium = await organizationRepository.CreateTestOrganizationAsync();
|
||||
await organizationUserRepository.CreateTestOrganizationUserAsync(orgWithPremium, user);
|
||||
|
||||
var orgNoPremium = await organizationRepository.CreateTestOrganizationAsync();
|
||||
orgNoPremium.UsersGetPremium = false;
|
||||
await organizationRepository.ReplaceAsync(orgNoPremium);
|
||||
await organizationUserRepository.CreateTestOrganizationUserAsync(orgNoPremium, user);
|
||||
|
||||
// Act
|
||||
var result = await userRepository.GetPremiumAccessAsync(user.Id);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(result);
|
||||
Assert.False(result.PersonalPremium);
|
||||
Assert.True(result.OrganizationPremium);
|
||||
Assert.True(result.HasPremiumAccess);
|
||||
}
|
||||
|
||||
[Theory, DatabaseData]
|
||||
public async Task GetPremiumAccessAsync_WithMultipleOrganizations_NoneProvidePremium_ReturnsNoOrganizationPremium(
|
||||
IUserRepository userRepository,
|
||||
IOrganizationRepository organizationRepository,
|
||||
IOrganizationUserRepository organizationUserRepository)
|
||||
{
|
||||
// Arrange
|
||||
var user = await userRepository.CreateAsync(new User
|
||||
{
|
||||
Name = "User With No Premium Orgs",
|
||||
Email = $"{Guid.NewGuid()}@example.com",
|
||||
ApiKey = "TEST",
|
||||
SecurityStamp = "stamp",
|
||||
Premium = false
|
||||
});
|
||||
|
||||
var disabledOrg = await organizationRepository.CreateTestOrganizationAsync();
|
||||
disabledOrg.Enabled = false;
|
||||
await organizationRepository.ReplaceAsync(disabledOrg);
|
||||
await organizationUserRepository.CreateTestOrganizationUserAsync(disabledOrg, user);
|
||||
|
||||
var orgNoPremium = await organizationRepository.CreateTestOrganizationAsync();
|
||||
orgNoPremium.UsersGetPremium = false;
|
||||
await organizationRepository.ReplaceAsync(orgNoPremium);
|
||||
await organizationUserRepository.CreateTestOrganizationUserAsync(orgNoPremium, user);
|
||||
|
||||
// Act
|
||||
var result = await userRepository.GetPremiumAccessAsync(user.Id);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(result);
|
||||
Assert.False(result.PersonalPremium);
|
||||
Assert.False(result.OrganizationPremium);
|
||||
Assert.False(result.HasPremiumAccess);
|
||||
}
|
||||
|
||||
[Theory, DatabaseData]
|
||||
public async Task GetPremiumAccessAsync_NonExistentUser_ReturnsNull(
|
||||
IUserRepository userRepository)
|
||||
{
|
||||
// Act
|
||||
var result = await userRepository.GetPremiumAccessAsync(Guid.NewGuid());
|
||||
|
||||
// Assert
|
||||
Assert.Null(result);
|
||||
}
|
||||
|
||||
[Theory, DatabaseData]
|
||||
public async Task GetPremiumAccessByIdsAsync_MultipleUsers_ReturnsCorrectAccessForEach(
|
||||
IUserRepository userRepository,
|
||||
IOrganizationRepository organizationRepository,
|
||||
IOrganizationUserRepository organizationUserRepository)
|
||||
{
|
||||
// Arrange
|
||||
var personalPremiumUser = await userRepository.CreateAsync(new User
|
||||
{
|
||||
Name = "Personal Premium",
|
||||
Email = $"{Guid.NewGuid()}@example.com",
|
||||
ApiKey = "TEST",
|
||||
SecurityStamp = "stamp",
|
||||
Premium = true
|
||||
});
|
||||
|
||||
var orgPremiumUser = await userRepository.CreateAsync(new User
|
||||
{
|
||||
Name = "Org Premium",
|
||||
Email = $"{Guid.NewGuid()}@example.com",
|
||||
ApiKey = "TEST",
|
||||
SecurityStamp = "stamp",
|
||||
Premium = false
|
||||
});
|
||||
|
||||
var bothPremiumUser = await userRepository.CreateAsync(new User
|
||||
{
|
||||
Name = "Both Premium",
|
||||
Email = $"{Guid.NewGuid()}@example.com",
|
||||
ApiKey = "TEST",
|
||||
SecurityStamp = "stamp",
|
||||
Premium = true
|
||||
});
|
||||
|
||||
var noPremiumUser = await userRepository.CreateAsync(new User
|
||||
{
|
||||
Name = "No Premium",
|
||||
Email = $"{Guid.NewGuid()}@example.com",
|
||||
ApiKey = "TEST",
|
||||
SecurityStamp = "stamp",
|
||||
Premium = false
|
||||
});
|
||||
|
||||
var multiOrgUser = await userRepository.CreateAsync(new User
|
||||
{
|
||||
Name = "Multi Org User",
|
||||
Email = $"{Guid.NewGuid()}@example.com",
|
||||
ApiKey = "TEST",
|
||||
SecurityStamp = "stamp",
|
||||
Premium = false
|
||||
});
|
||||
|
||||
var personalPremiumWithDisabledOrg = await userRepository.CreateAsync(new User
|
||||
{
|
||||
Name = "Personal Premium With Disabled Org",
|
||||
Email = $"{Guid.NewGuid()}@example.com",
|
||||
ApiKey = "TEST",
|
||||
SecurityStamp = "stamp",
|
||||
Premium = true
|
||||
});
|
||||
|
||||
var organization = await organizationRepository.CreateTestOrganizationAsync();
|
||||
await organizationUserRepository.CreateTestOrganizationUserAsync(organization, orgPremiumUser);
|
||||
await organizationUserRepository.CreateTestOrganizationUserAsync(organization, bothPremiumUser);
|
||||
await organizationUserRepository.CreateTestOrganizationUserAsync(organization, multiOrgUser);
|
||||
|
||||
var orgWithoutPremium = await organizationRepository.CreateTestOrganizationAsync();
|
||||
orgWithoutPremium.UsersGetPremium = false;
|
||||
await organizationRepository.ReplaceAsync(orgWithoutPremium);
|
||||
await organizationUserRepository.CreateTestOrganizationUserAsync(orgWithoutPremium, multiOrgUser);
|
||||
|
||||
var disabledOrg = await organizationRepository.CreateTestOrganizationAsync();
|
||||
disabledOrg.Enabled = false;
|
||||
await organizationRepository.ReplaceAsync(disabledOrg);
|
||||
await organizationUserRepository.CreateTestOrganizationUserAsync(disabledOrg, personalPremiumWithDisabledOrg);
|
||||
|
||||
// Act
|
||||
var results = await userRepository.GetPremiumAccessByIdsAsync([
|
||||
personalPremiumUser.Id,
|
||||
orgPremiumUser.Id,
|
||||
bothPremiumUser.Id,
|
||||
noPremiumUser.Id,
|
||||
multiOrgUser.Id,
|
||||
personalPremiumWithDisabledOrg.Id
|
||||
]);
|
||||
|
||||
var resultsList = results.ToList();
|
||||
|
||||
// Assert
|
||||
Assert.Equal(6, resultsList.Count);
|
||||
|
||||
var personalResult = resultsList.First(r => r.Id == personalPremiumUser.Id);
|
||||
Assert.True(personalResult.PersonalPremium);
|
||||
Assert.False(personalResult.OrganizationPremium);
|
||||
|
||||
var orgResult = resultsList.First(r => r.Id == orgPremiumUser.Id);
|
||||
Assert.False(orgResult.PersonalPremium);
|
||||
Assert.True(orgResult.OrganizationPremium);
|
||||
|
||||
var bothResult = resultsList.First(r => r.Id == bothPremiumUser.Id);
|
||||
Assert.True(bothResult.PersonalPremium);
|
||||
Assert.True(bothResult.OrganizationPremium);
|
||||
|
||||
var noneResult = resultsList.First(r => r.Id == noPremiumUser.Id);
|
||||
Assert.False(noneResult.PersonalPremium);
|
||||
Assert.False(noneResult.OrganizationPremium);
|
||||
|
||||
var multiResult = resultsList.First(r => r.Id == multiOrgUser.Id);
|
||||
Assert.False(multiResult.PersonalPremium);
|
||||
Assert.True(multiResult.OrganizationPremium);
|
||||
|
||||
var personalWithDisabledOrgResult = resultsList.First(r => r.Id == personalPremiumWithDisabledOrg.Id);
|
||||
Assert.True(personalWithDisabledOrgResult.PersonalPremium);
|
||||
Assert.False(personalWithDisabledOrgResult.OrganizationPremium);
|
||||
}
|
||||
|
||||
[Theory, DatabaseData]
|
||||
public async Task GetPremiumAccessByIdsAsync_EmptyList_ReturnsEmptyResult(
|
||||
IUserRepository userRepository)
|
||||
{
|
||||
// Act
|
||||
var results = await userRepository.GetPremiumAccessByIdsAsync([]);
|
||||
|
||||
// Assert
|
||||
Assert.Empty(results);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,151 @@
|
||||
CREATE OR ALTER PROCEDURE [dbo].[Collection_UpdateWithUsers]
|
||||
@Id UNIQUEIDENTIFIER,
|
||||
@OrganizationId UNIQUEIDENTIFIER,
|
||||
@Name VARCHAR(MAX),
|
||||
@ExternalId NVARCHAR(300),
|
||||
@CreationDate DATETIME2(7),
|
||||
@RevisionDate DATETIME2(7),
|
||||
@Users AS [dbo].[CollectionAccessSelectionType] READONLY,
|
||||
@DefaultUserCollectionEmail NVARCHAR(256) = NULL,
|
||||
@Type TINYINT = 0
|
||||
AS
|
||||
BEGIN
|
||||
SET NOCOUNT ON
|
||||
|
||||
EXEC [dbo].[Collection_Update] @Id, @OrganizationId, @Name, @ExternalId, @CreationDate, @RevisionDate, @DefaultUserCollectionEmail, @Type
|
||||
|
||||
-- Users
|
||||
-- Delete users that are no longer in source
|
||||
DELETE
|
||||
cu
|
||||
FROM
|
||||
[dbo].[CollectionUser] cu
|
||||
LEFT JOIN
|
||||
@Users u ON cu.OrganizationUserId = u.Id
|
||||
WHERE
|
||||
cu.CollectionId = @Id
|
||||
AND u.Id IS NULL;
|
||||
|
||||
-- Update existing users
|
||||
UPDATE
|
||||
cu
|
||||
SET
|
||||
cu.ReadOnly = u.ReadOnly,
|
||||
cu.HidePasswords = u.HidePasswords,
|
||||
cu.Manage = u.Manage
|
||||
FROM
|
||||
[dbo].[CollectionUser] cu
|
||||
INNER JOIN
|
||||
@Users u ON cu.OrganizationUserId = u.Id
|
||||
WHERE
|
||||
cu.CollectionId = @Id
|
||||
AND (
|
||||
cu.ReadOnly != u.ReadOnly
|
||||
OR cu.HidePasswords != u.HidePasswords
|
||||
OR cu.Manage != u.Manage
|
||||
);
|
||||
|
||||
-- Insert new users
|
||||
INSERT INTO [dbo].[CollectionUser]
|
||||
(
|
||||
[CollectionId],
|
||||
[OrganizationUserId],
|
||||
[ReadOnly],
|
||||
[HidePasswords],
|
||||
[Manage]
|
||||
)
|
||||
SELECT
|
||||
@Id,
|
||||
u.Id,
|
||||
u.ReadOnly,
|
||||
u.HidePasswords,
|
||||
u.Manage
|
||||
FROM
|
||||
@Users u
|
||||
INNER JOIN
|
||||
[dbo].[OrganizationUser] ou ON ou.Id = u.Id
|
||||
LEFT JOIN
|
||||
[dbo].[CollectionUser] cu ON cu.CollectionId = @Id AND cu.OrganizationUserId = u.Id
|
||||
WHERE
|
||||
ou.OrganizationId = @OrganizationId
|
||||
AND cu.CollectionId IS NULL;
|
||||
|
||||
EXEC [dbo].[User_BumpAccountRevisionDateByCollectionId] @Id, @OrganizationId
|
||||
END
|
||||
GO
|
||||
|
||||
CREATE OR ALTER PROCEDURE [dbo].[Collection_UpdateWithGroups]
|
||||
@Id UNIQUEIDENTIFIER,
|
||||
@OrganizationId UNIQUEIDENTIFIER,
|
||||
@Name VARCHAR(MAX),
|
||||
@ExternalId NVARCHAR(300),
|
||||
@CreationDate DATETIME2(7),
|
||||
@RevisionDate DATETIME2(7),
|
||||
@Groups AS [dbo].[CollectionAccessSelectionType] READONLY,
|
||||
@DefaultUserCollectionEmail NVARCHAR(256) = NULL,
|
||||
@Type TINYINT = 0
|
||||
AS
|
||||
BEGIN
|
||||
SET NOCOUNT ON
|
||||
|
||||
EXEC [dbo].[Collection_Update] @Id, @OrganizationId, @Name, @ExternalId, @CreationDate, @RevisionDate, @DefaultUserCollectionEmail, @Type
|
||||
|
||||
-- Groups
|
||||
-- Delete groups that are no longer in source
|
||||
DELETE
|
||||
cg
|
||||
FROM
|
||||
[dbo].[CollectionGroup] cg
|
||||
LEFT JOIN
|
||||
@Groups g ON cg.GroupId = g.Id
|
||||
WHERE
|
||||
cg.CollectionId = @Id
|
||||
AND g.Id IS NULL;
|
||||
|
||||
-- Update existing groups
|
||||
UPDATE
|
||||
cg
|
||||
SET
|
||||
cg.ReadOnly = g.ReadOnly,
|
||||
cg.HidePasswords = g.HidePasswords,
|
||||
cg.Manage = g.Manage
|
||||
FROM
|
||||
[dbo].[CollectionGroup] cg
|
||||
INNER JOIN
|
||||
@Groups g ON cg.GroupId = g.Id
|
||||
WHERE
|
||||
cg.CollectionId = @Id
|
||||
AND (
|
||||
cg.ReadOnly != g.ReadOnly
|
||||
OR cg.HidePasswords != g.HidePasswords
|
||||
OR cg.Manage != g.Manage
|
||||
);
|
||||
|
||||
-- Insert new groups
|
||||
INSERT INTO [dbo].[CollectionGroup]
|
||||
(
|
||||
[CollectionId],
|
||||
[GroupId],
|
||||
[ReadOnly],
|
||||
[HidePasswords],
|
||||
[Manage]
|
||||
)
|
||||
SELECT
|
||||
@Id,
|
||||
g.Id,
|
||||
g.ReadOnly,
|
||||
g.HidePasswords,
|
||||
g.Manage
|
||||
FROM
|
||||
@Groups g
|
||||
INNER JOIN
|
||||
[dbo].[Group] grp ON grp.Id = g.Id
|
||||
LEFT JOIN
|
||||
[dbo].[CollectionGroup] cg ON cg.CollectionId = @Id AND cg.GroupId = g.Id
|
||||
WHERE
|
||||
grp.OrganizationId = @OrganizationId
|
||||
AND cg.CollectionId IS NULL;
|
||||
|
||||
EXEC [dbo].[User_BumpAccountRevisionDateByCollectionId] @Id, @OrganizationId
|
||||
END
|
||||
GO
|
||||
@@ -0,0 +1,60 @@
|
||||
-- Add UsersGetPremium to IX_Organization_Enabled index to support premium access queries
|
||||
|
||||
IF EXISTS (
|
||||
SELECT * FROM sys.indexes
|
||||
WHERE name = 'IX_Organization_Enabled'
|
||||
AND object_id = OBJECT_ID('[dbo].[Organization]')
|
||||
)
|
||||
BEGIN
|
||||
CREATE NONCLUSTERED INDEX [IX_Organization_Enabled]
|
||||
ON [dbo].[Organization]([Id] ASC, [Enabled] ASC)
|
||||
INCLUDE ([UseTotp], [UsersGetPremium])
|
||||
WITH (DROP_EXISTING = ON);
|
||||
END
|
||||
ELSE
|
||||
BEGIN
|
||||
CREATE NONCLUSTERED INDEX [IX_Organization_Enabled]
|
||||
ON [dbo].[Organization]([Id] ASC, [Enabled] ASC)
|
||||
INCLUDE ([UseTotp], [UsersGetPremium]);
|
||||
END
|
||||
GO
|
||||
|
||||
CREATE OR ALTER VIEW [dbo].[UserPremiumAccessView]
|
||||
AS
|
||||
SELECT
|
||||
U.[Id],
|
||||
U.[Premium] AS [PersonalPremium],
|
||||
CAST(
|
||||
MAX(CASE
|
||||
WHEN O.[Id] IS NOT NULL THEN 1
|
||||
ELSE 0
|
||||
END) AS BIT
|
||||
) AS [OrganizationPremium]
|
||||
FROM
|
||||
[dbo].[User] U
|
||||
LEFT JOIN
|
||||
[dbo].[OrganizationUser] OU ON OU.[UserId] = U.[Id]
|
||||
LEFT JOIN
|
||||
[dbo].[Organization] O ON O.[Id] = OU.[OrganizationId]
|
||||
AND O.[UsersGetPremium] = 1
|
||||
AND O.[Enabled] = 1
|
||||
GROUP BY
|
||||
U.[Id], U.[Premium];
|
||||
GO
|
||||
|
||||
CREATE OR ALTER PROCEDURE [dbo].[User_ReadPremiumAccessByIds]
|
||||
@Ids [dbo].[GuidIdArray] READONLY
|
||||
AS
|
||||
BEGIN
|
||||
SET NOCOUNT ON
|
||||
|
||||
SELECT
|
||||
UPA.[Id],
|
||||
UPA.[PersonalPremium],
|
||||
UPA.[OrganizationPremium]
|
||||
FROM
|
||||
[dbo].[UserPremiumAccessView] UPA
|
||||
WHERE
|
||||
UPA.[Id] IN (SELECT [Id] FROM @Ids)
|
||||
END
|
||||
GO
|
||||
3443
util/MySqlMigrations/Migrations/20251212171212_OrganizationUsersGetPremiumIndex.Designer.cs
generated
Normal file
3443
util/MySqlMigrations/Migrations/20251212171212_OrganizationUsersGetPremiumIndex.Designer.cs
generated
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,21 @@
|
||||
using Microsoft.EntityFrameworkCore.Migrations;
|
||||
|
||||
#nullable disable
|
||||
|
||||
namespace Bit.MySqlMigrations.Migrations;
|
||||
|
||||
/// <inheritdoc />
|
||||
public partial class OrganizationUsersGetPremiumIndex : Migration
|
||||
{
|
||||
/// <inheritdoc />
|
||||
protected override void Up(MigrationBuilder migrationBuilder)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override void Down(MigrationBuilder migrationBuilder)
|
||||
{
|
||||
|
||||
}
|
||||
}
|
||||
@@ -274,7 +274,7 @@ namespace Bit.MySqlMigrations.Migrations
|
||||
b.HasKey("Id");
|
||||
|
||||
b.HasIndex("Id", "Enabled")
|
||||
.HasAnnotation("Npgsql:IndexInclude", new[] { "UseTotp" });
|
||||
.HasAnnotation("Npgsql:IndexInclude", new[] { "UseTotp", "UsersGetPremium" });
|
||||
|
||||
b.ToTable("Organization", (string)null);
|
||||
});
|
||||
|
||||
3449
util/PostgresMigrations/Migrations/20251212171204_OrganizationUsersGetPremiumIndex.Designer.cs
generated
Normal file
3449
util/PostgresMigrations/Migrations/20251212171204_OrganizationUsersGetPremiumIndex.Designer.cs
generated
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,37 @@
|
||||
using Microsoft.EntityFrameworkCore.Migrations;
|
||||
|
||||
#nullable disable
|
||||
|
||||
namespace Bit.PostgresMigrations.Migrations;
|
||||
|
||||
/// <inheritdoc />
|
||||
public partial class OrganizationUsersGetPremiumIndex : Migration
|
||||
{
|
||||
/// <inheritdoc />
|
||||
protected override void Up(MigrationBuilder migrationBuilder)
|
||||
{
|
||||
migrationBuilder.DropIndex(
|
||||
name: "IX_Organization_Id_Enabled",
|
||||
table: "Organization");
|
||||
|
||||
migrationBuilder.CreateIndex(
|
||||
name: "IX_Organization_Id_Enabled",
|
||||
table: "Organization",
|
||||
columns: new[] { "Id", "Enabled" })
|
||||
.Annotation("Npgsql:IndexInclude", new[] { "UseTotp", "UsersGetPremium" });
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override void Down(MigrationBuilder migrationBuilder)
|
||||
{
|
||||
migrationBuilder.DropIndex(
|
||||
name: "IX_Organization_Id_Enabled",
|
||||
table: "Organization");
|
||||
|
||||
migrationBuilder.CreateIndex(
|
||||
name: "IX_Organization_Id_Enabled",
|
||||
table: "Organization",
|
||||
columns: new[] { "Id", "Enabled" })
|
||||
.Annotation("Npgsql:IndexInclude", new[] { "UseTotp" });
|
||||
}
|
||||
}
|
||||
@@ -277,7 +277,7 @@ namespace Bit.PostgresMigrations.Migrations
|
||||
|
||||
b.HasIndex("Id", "Enabled");
|
||||
|
||||
NpgsqlIndexBuilderExtensions.IncludeProperties(b.HasIndex("Id", "Enabled"), new[] { "UseTotp" });
|
||||
NpgsqlIndexBuilderExtensions.IncludeProperties(b.HasIndex("Id", "Enabled"), new[] { "UseTotp", "UsersGetPremium" });
|
||||
|
||||
b.ToTable("Organization", (string)null);
|
||||
});
|
||||
|
||||
3432
util/SqliteMigrations/Migrations/20251212171156_OrganizationUsersGetPremiumIndex.Designer.cs
generated
Normal file
3432
util/SqliteMigrations/Migrations/20251212171156_OrganizationUsersGetPremiumIndex.Designer.cs
generated
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,21 @@
|
||||
using Microsoft.EntityFrameworkCore.Migrations;
|
||||
|
||||
#nullable disable
|
||||
|
||||
namespace Bit.SqliteMigrations.Migrations;
|
||||
|
||||
/// <inheritdoc />
|
||||
public partial class OrganizationUsersGetPremiumIndex : Migration
|
||||
{
|
||||
/// <inheritdoc />
|
||||
protected override void Up(MigrationBuilder migrationBuilder)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override void Down(MigrationBuilder migrationBuilder)
|
||||
{
|
||||
|
||||
}
|
||||
}
|
||||
@@ -269,7 +269,7 @@ namespace Bit.SqliteMigrations.Migrations
|
||||
b.HasKey("Id");
|
||||
|
||||
b.HasIndex("Id", "Enabled")
|
||||
.HasAnnotation("Npgsql:IndexInclude", new[] { "UseTotp" });
|
||||
.HasAnnotation("Npgsql:IndexInclude", new[] { "UseTotp", "UsersGetPremium" });
|
||||
|
||||
b.ToTable("Organization", (string)null);
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user