1
0
mirror of https://github.com/bitwarden/server synced 2026-02-18 18:33:29 +00:00

Merge remote-tracking branch 'origin' into auth/pm-27084/register-accepts-new-data-types-repush

This commit is contained in:
Patrick Pimentel
2026-01-22 13:20:00 -05:00
67 changed files with 2383 additions and 1249 deletions

View File

@@ -11,6 +11,7 @@ using Bit.Core.Billing.Pricing;
using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Test.Billing.Mocks;
using Bit.Core.Test.Billing.Mocks.Plans;
using Microsoft.Extensions.Logging;
using Newtonsoft.Json.Linq;
@@ -654,6 +655,8 @@ public class SubscriptionUpdatedHandlerTests
var plan = new Enterprise2023Plan(true);
_pricingClient.GetPlanOrThrow(organization.PlanType)
.Returns(plan);
_pricingClient.ListPlans()
.Returns(MockPlans.Plans);
var parsedEvent = new Event
{
@@ -693,6 +696,92 @@ public class SubscriptionUpdatedHandlerTests
await _stripeFacade.Received(1).DeleteCustomerDiscount(subscription.CustomerId);
await _stripeFacade.Received(1).DeleteSubscriptionDiscount(subscription.Id);
}
[Fact]
public async Task
HandleAsync_WhenUpgradingPlan_AndPreviousPlanHasSecretsManagerTrial_AndCurrentPlanHasSecretsManagerTrial_DoesNotRemovePasswordManagerCoupon()
{
// Arrange
var organizationId = Guid.NewGuid();
var subscription = new Subscription
{
Id = "sub_123",
Status = StripeSubscriptionStatus.Active,
CustomerId = "cus_123",
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem
{
CurrentPeriodEnd = DateTime.UtcNow.AddDays(10),
Plan = new Plan { Id = "2023-enterprise-org-seat-annually" }
},
new SubscriptionItem
{
CurrentPeriodEnd = DateTime.UtcNow.AddDays(10),
Plan = new Plan { Id = "secrets-manager-enterprise-seat-annually" }
}
]
},
Customer = new Customer
{
Balance = 0,
Discount = new Discount { Coupon = new Coupon { Id = "sm-standalone" } }
},
Discounts = [new Discount { Coupon = new Coupon { Id = "sm-standalone" } }],
Metadata = new Dictionary<string, string> { { "organizationId", organizationId.ToString() } }
};
// Note: The organization plan is still the previous plan because the subscription is updated before the organization is updated
var organization = new Organization { Id = organizationId, PlanType = PlanType.TeamsAnnually2023 };
var plan = new Teams2023Plan(true);
_pricingClient.GetPlanOrThrow(organization.PlanType)
.Returns(plan);
_pricingClient.ListPlans()
.Returns(MockPlans.Plans);
var parsedEvent = new Event
{
Data = new EventData
{
Object = subscription,
PreviousAttributes = JObject.FromObject(new
{
items = new
{
data = new[]
{
new { plan = new { id = "secrets-manager-teams-seat-annually" } },
}
},
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { Plan = new Stripe.Plan { Id = "secrets-manager-teams-seat-annually" } },
]
}
})
}
};
_stripeEventService.GetSubscription(Arg.Any<Event>(), Arg.Any<bool>(), Arg.Any<List<string>>())
.Returns(subscription);
_stripeEventUtilityService.GetIdsFromMetadata(Arg.Any<Dictionary<string, string>>())
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(organizationId, null, null));
_organizationRepository.GetByIdAsync(organizationId)
.Returns(organization);
// Act
await _sut.HandleAsync(parsedEvent);
// Assert
await _stripeFacade.DidNotReceive().DeleteCustomerDiscount(subscription.CustomerId);
await _stripeFacade.DidNotReceive().DeleteSubscriptionDiscount(subscription.Id);
}
[Theory]
[MemberData(nameof(GetNonActiveSubscriptions))]

View File

@@ -280,7 +280,7 @@ public class UpcomingInvoiceHandlerTests
email.ToEmails.Contains("user@example.com") &&
email.Subject == "Your Bitwarden Premium renewal is updating" &&
email.View.BaseMonthlyRenewalPrice == (plan.Seat.Price / 12).ToString("C", new CultureInfo("en-US")) &&
email.View.DiscountedMonthlyRenewalPrice == (discountedPrice / 12).ToString("C", new CultureInfo("en-US")) &&
email.View.DiscountedAnnualRenewalPrice == discountedPrice.ToString("C", new CultureInfo("en-US")) &&
email.View.DiscountAmount == $"{coupon.PercentOff}%"
));
}
@@ -2436,7 +2436,7 @@ public class UpcomingInvoiceHandlerTests
email.Subject == "Your Bitwarden Premium renewal is updating" &&
email.View.BaseMonthlyRenewalPrice == (plan.Seat.Price / 12).ToString("C", new CultureInfo("en-US")) &&
email.View.DiscountAmount == "30%" &&
email.View.DiscountedMonthlyRenewalPrice == (expectedDiscountedPrice / 12).ToString("C", new CultureInfo("en-US"))
email.View.DiscountedAnnualRenewalPrice == expectedDiscountedPrice.ToString("C", new CultureInfo("en-US"))
));
await _mailService.DidNotReceive().SendInvoiceUpcoming(

View File

@@ -10,7 +10,6 @@ using Bit.Core.AdminConsole.Utilities.v2;
using Bit.Core.AdminConsole.Utilities.v2.Validation;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Models.Data;
using Bit.Core.Platform.Push;
using Bit.Core.Repositories;
using Bit.Core.Services;
@@ -204,14 +203,10 @@ public class AutomaticallyConfirmUsersCommandTests
await sutProvider.GetDependency<ICollectionRepository>()
.Received(1)
.CreateAsync(
Arg.Is<Collection>(c =>
c.OrganizationId == organization.Id &&
c.Name == defaultCollectionName &&
c.Type == CollectionType.DefaultUserCollection),
Arg.Is<IEnumerable<CollectionAccessSelection>>(groups => groups == null),
Arg.Is<IEnumerable<CollectionAccessSelection>>(access =>
access.FirstOrDefault(x => x.Id == organizationUser.Id && x.Manage) != null));
.CreateDefaultCollectionsAsync(
organization.Id,
Arg.Is<IEnumerable<Guid>>(ids => ids.Single() == organizationUser.Id),
defaultCollectionName);
}
[Theory]
@@ -253,9 +248,7 @@ public class AutomaticallyConfirmUsersCommandTests
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceive()
.CreateAsync(Arg.Any<Collection>(),
Arg.Any<IEnumerable<CollectionAccessSelection>>(),
Arg.Any<IEnumerable<CollectionAccessSelection>>());
.CreateDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
}
[Theory]
@@ -291,9 +284,7 @@ public class AutomaticallyConfirmUsersCommandTests
var collectionException = new Exception("Collection creation failed");
sutProvider.GetDependency<ICollectionRepository>()
.CreateAsync(Arg.Any<Collection>(),
Arg.Any<IEnumerable<CollectionAccessSelection>>(),
Arg.Any<IEnumerable<CollectionAccessSelection>>())
.CreateDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>())
.ThrowsAsync(collectionException);
// Act

View File

@@ -13,7 +13,6 @@ using Bit.Core.Billing.Enums;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Models.Data;
using Bit.Core.Models.Data.Organizations.OrganizationUsers;
using Bit.Core.Platform.Push;
using Bit.Core.Repositories;
@@ -493,15 +492,10 @@ public class ConfirmOrganizationUserCommandTests
await sutProvider.GetDependency<ICollectionRepository>()
.Received(1)
.CreateAsync(
Arg.Is<Collection>(c =>
c.Name == collectionName &&
c.OrganizationId == organization.Id &&
c.Type == CollectionType.DefaultUserCollection),
Arg.Any<IEnumerable<CollectionAccessSelection>>(),
Arg.Is<IEnumerable<CollectionAccessSelection>>(cu =>
cu.Single().Id == orgUser.Id &&
cu.Single().Manage));
.CreateDefaultCollectionsAsync(
organization.Id,
Arg.Is<IEnumerable<Guid>>(ids => ids.Single() == orgUser.Id),
collectionName);
}
[Theory, BitAutoData]
@@ -522,7 +516,7 @@ public class ConfirmOrganizationUserCommandTests
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceive()
.UpsertDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
.CreateDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
}
[Theory, BitAutoData]
@@ -539,24 +533,15 @@ public class ConfirmOrganizationUserCommandTests
sutProvider.GetDependency<IOrganizationUserRepository>().GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser });
sutProvider.GetDependency<IUserRepository>().GetManyAsync(default).ReturnsForAnyArgs(new[] { user });
var policyDetails = new PolicyDetails
{
OrganizationId = org.Id,
OrganizationUserId = orgUser.Id,
IsProvider = false,
OrganizationUserStatus = orgUser.Status,
OrganizationUserType = orgUser.Type,
PolicyType = PolicyType.OrganizationDataOwnership
};
sutProvider.GetDependency<IPolicyRequirementQuery>()
.GetAsync<OrganizationDataOwnershipPolicyRequirement>(orgUser.UserId!.Value)
.Returns(new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Disabled, [policyDetails]));
.Returns(new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Disabled, []));
await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, collectionName);
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceive()
.UpsertDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
.CreateDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
}
[Theory, BitAutoData]

View File

@@ -715,6 +715,39 @@ public class RestoreOrganizationUserCommandTests
Arg.Is<OrganizationUserStatusType>(x => x != OrganizationUserStatusType.Revoked));
}
[Theory, BitAutoData]
public async Task RestoreUser_InvitedUserInFreeOrganization_Success(
Organization organization,
[OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner,
[OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser,
SutProvider<RestoreOrganizationUserCommand> sutProvider)
{
organization.PlanType = PlanType.Free;
organizationUser.UserId = null;
organizationUser.Key = null;
organizationUser.Status = OrganizationUserStatusType.Revoked;
RestoreUser_Setup(organization, owner, organizationUser, sutProvider);
sutProvider.GetDependency<IOrganizationRepository>()
.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id).Returns(new OrganizationSeatCounts
{
Sponsored = 0,
Users = 1
});
await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id);
await sutProvider.GetDependency<IOrganizationUserRepository>()
.Received(1)
.RestoreAsync(organizationUser.Id, OrganizationUserStatusType.Invited);
await sutProvider.GetDependency<IEventService>()
.Received(1)
.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Restored);
await sutProvider.GetDependency<IPushNotificationService>()
.DidNotReceiveWithAnyArgs()
.PushSyncOrgKeysAsync(Arg.Any<Guid>());
}
[Theory, BitAutoData]
public async Task RestoreUsers_Success(Organization organization,
[OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner,

View File

@@ -38,7 +38,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceive()
.UpsertDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
.CreateDefaultCollectionsBulkAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
}
[Theory, BitAutoData]
@@ -60,7 +60,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceive()
.UpsertDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
.CreateDefaultCollectionsBulkAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
}
[Theory, BitAutoData]
@@ -86,7 +86,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await collectionRepository
.DidNotReceive()
.UpsertDefaultCollectionsAsync(
.CreateDefaultCollectionsBulkAsync(
Arg.Any<Guid>(),
Arg.Any<IEnumerable<Guid>>(),
Arg.Any<string>());
@@ -172,10 +172,10 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Act
await sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
// Assert
// Assert - Should call with all user IDs (repository does internal filtering)
await collectionRepository
.Received(1)
.UpsertDefaultCollectionsAsync(
.CreateDefaultCollectionsBulkAsync(
policyUpdate.OrganizationId,
Arg.Is<IEnumerable<Guid>>(ids => ids.Count() == 3),
_defaultUserCollectionName);
@@ -210,7 +210,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceive()
.UpsertDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
.CreateDefaultCollectionsBulkAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
}
private static IPolicyRepository ArrangePolicyRepository(IEnumerable<OrganizationPolicyDetails> policyDetails)
@@ -251,7 +251,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceiveWithAnyArgs()
.UpsertDefaultCollectionsAsync(default, default, default);
.CreateDefaultCollectionsBulkAsync(default, default, default);
}
[Theory, BitAutoData]
@@ -273,7 +273,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceiveWithAnyArgs()
.UpsertDefaultCollectionsAsync(default, default, default);
.CreateDefaultCollectionsBulkAsync(default, default, default);
}
[Theory, BitAutoData]
@@ -299,7 +299,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await collectionRepository
.DidNotReceiveWithAnyArgs()
.UpsertDefaultCollectionsAsync(
.CreateDefaultCollectionsBulkAsync(
default,
default,
default);
@@ -336,10 +336,10 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Act
await sut.ExecutePostUpsertSideEffectAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
// Assert
// Assert - Should call with all user IDs (repository does internal filtering)
await collectionRepository
.Received(1)
.UpsertDefaultCollectionsAsync(
.CreateDefaultCollectionsBulkAsync(
policyUpdate.OrganizationId,
Arg.Is<IEnumerable<Guid>>(ids => ids.Count() == 3),
_defaultUserCollectionName);
@@ -367,6 +367,6 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceiveWithAnyArgs()
.UpsertDefaultCollectionsAsync(default, default, default);
.CreateDefaultCollectionsBulkAsync(default, default, default);
}
}

View File

@@ -1,6 +1,7 @@
using Bit.Core.Billing.Premium.Commands;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Subscriptions.Models;
using Bit.Core.Entities;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture.Attributes;
@@ -8,6 +9,7 @@ using Microsoft.Extensions.Logging;
using NSubstitute;
using Stripe;
using Xunit;
using static Bit.Core.Billing.Constants.StripeConstants;
using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan;
using PremiumPurchasable = Bit.Core.Billing.Pricing.Premium.Purchasable;
@@ -15,6 +17,7 @@ namespace Bit.Core.Test.Billing.Premium.Commands;
public class UpdatePremiumStorageCommandTests
{
private readonly IBraintreeService _braintreeService = Substitute.For<IBraintreeService>();
private readonly IStripeAdapter _stripeAdapter = Substitute.For<IStripeAdapter>();
private readonly IUserService _userService = Substitute.For<IUserService>();
private readonly IPricingClient _pricingClient = Substitute.For<IPricingClient>();
@@ -33,13 +36,14 @@ public class UpdatePremiumStorageCommandTests
_pricingClient.ListPremiumPlans().Returns([premiumPlan]);
_command = new UpdatePremiumStorageCommand(
_braintreeService,
_stripeAdapter,
_userService,
_pricingClient,
Substitute.For<ILogger<UpdatePremiumStorageCommand>>());
}
private Subscription CreateMockSubscription(string subscriptionId, int? storageQuantity = null)
private Subscription CreateMockSubscription(string subscriptionId, int? storageQuantity = null, bool isPayPal = false)
{
var items = new List<SubscriptionItem>
{
@@ -63,9 +67,17 @@ public class UpdatePremiumStorageCommandTests
});
}
var customer = new Customer
{
Id = "cus_123",
Metadata = isPayPal ? new Dictionary<string, string> { { MetadataKeys.BraintreeCustomerId, "braintree_123" } } : new Dictionary<string, string>()
};
return new Subscription
{
Id = subscriptionId,
CustomerId = "cus_123",
Customer = customer,
Items = new StripeList<SubscriptionItem>
{
Data = items
@@ -97,7 +109,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 4);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, -5);
@@ -117,7 +129,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 4);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 100);
@@ -154,7 +166,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 9);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 0);
@@ -176,7 +188,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 4);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 4);
@@ -185,7 +197,7 @@ public class UpdatePremiumStorageCommandTests
Assert.True(result.IsT0);
// Verify subscription was fetched but NOT updated
await _stripeAdapter.Received(1).GetSubscriptionAsync("sub_123");
await _stripeAdapter.Received(1).GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>());
await _stripeAdapter.DidNotReceive().UpdateSubscriptionAsync(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>());
await _userService.DidNotReceive().SaveUserAsync(Arg.Any<User>());
}
@@ -200,7 +212,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 4);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 9);
@@ -233,7 +245,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123");
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 9);
@@ -262,7 +274,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 9);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 2);
@@ -291,7 +303,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 9);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 0);
@@ -320,7 +332,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 4);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 99);
@@ -335,4 +347,200 @@ public class UpdatePremiumStorageCommandTests
await _userService.Received(1).SaveUserAsync(Arg.Is<User>(u => u.MaxStorageGb == 100));
}
[Theory, BitAutoData]
public async Task Run_IncreaseStorage_PayPal_Success(User user)
{
// Arrange
user.Premium = true;
user.MaxStorageGb = 5;
user.Storage = 2L * 1024 * 1024 * 1024;
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 4, isPayPal: true);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
var draftInvoice = new Invoice { Id = "in_draft" };
_stripeAdapter.CreateInvoiceAsync(Arg.Any<InvoiceCreateOptions>()).Returns(draftInvoice);
var finalizedInvoice = new Invoice
{
Id = "in_finalized",
Customer = new Customer { Id = "cus_123" }
};
_stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any<InvoiceFinalizeOptions>()).Returns(finalizedInvoice);
// Act
var result = await _command.Run(user, 9);
// Assert
Assert.True(result.IsT0);
// Verify subscription was updated with CreateProrations
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Items.Count == 1 &&
opts.Items[0].Id == "si_storage" &&
opts.Items[0].Quantity == 9 &&
opts.ProrationBehavior == "create_prorations"));
// Verify draft invoice was created
await _stripeAdapter.Received(1).CreateInvoiceAsync(
Arg.Is<InvoiceCreateOptions>(opts =>
opts.Customer == "cus_123" &&
opts.Subscription == "sub_123" &&
opts.AutoAdvance == false &&
opts.CollectionMethod == "charge_automatically"));
// Verify invoice was finalized
await _stripeAdapter.Received(1).FinalizeInvoiceAsync(
"in_draft",
Arg.Is<InvoiceFinalizeOptions>(opts =>
opts.AutoAdvance == false &&
opts.Expand.Contains("customer")));
// Verify Braintree payment was processed
await _braintreeService.Received(1).PayInvoice(Arg.Any<SubscriberId>(), finalizedInvoice);
// Verify user was saved
await _userService.Received(1).SaveUserAsync(Arg.Is<User>(u =>
u.Id == user.Id &&
u.MaxStorageGb == 10));
}
[Theory, BitAutoData]
public async Task Run_AddStorageFromZero_PayPal_Success(User user)
{
// Arrange
user.Premium = true;
user.MaxStorageGb = 1;
user.Storage = 500L * 1024 * 1024;
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", isPayPal: true);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
var draftInvoice = new Invoice { Id = "in_draft" };
_stripeAdapter.CreateInvoiceAsync(Arg.Any<InvoiceCreateOptions>()).Returns(draftInvoice);
var finalizedInvoice = new Invoice
{
Id = "in_finalized",
Customer = new Customer { Id = "cus_123" }
};
_stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any<InvoiceFinalizeOptions>()).Returns(finalizedInvoice);
// Act
var result = await _command.Run(user, 9);
// Assert
Assert.True(result.IsT0);
// Verify subscription was updated with new storage item
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Items.Count == 1 &&
opts.Items[0].Price == "price_storage" &&
opts.Items[0].Quantity == 9 &&
opts.ProrationBehavior == "create_prorations"));
// Verify invoice creation and payment flow
await _stripeAdapter.Received(1).CreateInvoiceAsync(Arg.Any<InvoiceCreateOptions>());
await _stripeAdapter.Received(1).FinalizeInvoiceAsync("in_draft", Arg.Any<InvoiceFinalizeOptions>());
await _braintreeService.Received(1).PayInvoice(Arg.Any<SubscriberId>(), finalizedInvoice);
await _userService.Received(1).SaveUserAsync(Arg.Is<User>(u => u.MaxStorageGb == 10));
}
[Theory, BitAutoData]
public async Task Run_DecreaseStorage_PayPal_Success(User user)
{
// Arrange
user.Premium = true;
user.MaxStorageGb = 10;
user.Storage = 2L * 1024 * 1024 * 1024;
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 9, isPayPal: true);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
var draftInvoice = new Invoice { Id = "in_draft" };
_stripeAdapter.CreateInvoiceAsync(Arg.Any<InvoiceCreateOptions>()).Returns(draftInvoice);
var finalizedInvoice = new Invoice
{
Id = "in_finalized",
Customer = new Customer { Id = "cus_123" }
};
_stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any<InvoiceFinalizeOptions>()).Returns(finalizedInvoice);
// Act
var result = await _command.Run(user, 2);
// Assert
Assert.True(result.IsT0);
// Verify subscription was updated
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Items.Count == 1 &&
opts.Items[0].Id == "si_storage" &&
opts.Items[0].Quantity == 2 &&
opts.ProrationBehavior == "create_prorations"));
// Verify invoice creation and payment flow
await _stripeAdapter.Received(1).CreateInvoiceAsync(Arg.Any<InvoiceCreateOptions>());
await _stripeAdapter.Received(1).FinalizeInvoiceAsync("in_draft", Arg.Any<InvoiceFinalizeOptions>());
await _braintreeService.Received(1).PayInvoice(Arg.Any<SubscriberId>(), finalizedInvoice);
await _userService.Received(1).SaveUserAsync(Arg.Is<User>(u => u.MaxStorageGb == 3));
}
[Theory, BitAutoData]
public async Task Run_RemoveAllAdditionalStorage_PayPal_Success(User user)
{
// Arrange
user.Premium = true;
user.MaxStorageGb = 10;
user.Storage = 500L * 1024 * 1024;
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 9, isPayPal: true);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
var draftInvoice = new Invoice { Id = "in_draft" };
_stripeAdapter.CreateInvoiceAsync(Arg.Any<InvoiceCreateOptions>()).Returns(draftInvoice);
var finalizedInvoice = new Invoice
{
Id = "in_finalized",
Customer = new Customer { Id = "cus_123" }
};
_stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any<InvoiceFinalizeOptions>()).Returns(finalizedInvoice);
// Act
var result = await _command.Run(user, 0);
// Assert
Assert.True(result.IsT0);
// Verify subscription item was deleted
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Items.Count == 1 &&
opts.Items[0].Id == "si_storage" &&
opts.Items[0].Deleted == true &&
opts.ProrationBehavior == "create_prorations"));
// Verify invoice creation and payment flow
await _stripeAdapter.Received(1).CreateInvoiceAsync(Arg.Any<InvoiceCreateOptions>());
await _stripeAdapter.Received(1).FinalizeInvoiceAsync("in_draft", Arg.Any<InvoiceFinalizeOptions>());
await _braintreeService.Received(1).PayInvoice(Arg.Any<SubscriberId>(), finalizedInvoice);
await _userService.Received(1).SaveUserAsync(Arg.Is<User>(u => u.MaxStorageGb == 1));
}
}

View File

@@ -0,0 +1,120 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Pricing.Premium;
using Bit.Core.Entities;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Tools.Services;
[SutProviderCustomize]
public class SendValidationServiceTests
{
[Theory, BitAutoData]
public async Task StorageRemainingForSendAsync_OrgGrantedPremiumUser_UsesPricingService(
SutProvider<SendValidationService> sutProvider,
Send send,
User user)
{
// Arrange
send.UserId = user.Id;
send.OrganizationId = null;
send.Type = SendType.File;
user.Premium = false;
user.Storage = 1024L * 1024L * 1024L; // 1 GB used
user.EmailVerified = true;
sutProvider.GetDependency<Bit.Core.Settings.GlobalSettings>().SelfHosted = false;
sutProvider.GetDependency<IUserRepository>().GetByIdAsync(user.Id).Returns(user);
sutProvider.GetDependency<IUserService>().CanAccessPremium(user).Returns(true);
var premiumPlan = new Plan
{
Storage = new Purchasable { Provided = 5 }
};
sutProvider.GetDependency<IPricingClient>().GetAvailablePremiumPlan().Returns(premiumPlan);
// Act
var result = await sutProvider.Sut.StorageRemainingForSendAsync(send);
// Assert
await sutProvider.GetDependency<IPricingClient>().Received(1).GetAvailablePremiumPlan();
Assert.True(result > 0);
}
[Theory, BitAutoData]
public async Task StorageRemainingForSendAsync_IndividualPremium_DoesNotCallPricingService(
SutProvider<SendValidationService> sutProvider,
Send send,
User user)
{
// Arrange
send.UserId = user.Id;
send.OrganizationId = null;
send.Type = SendType.File;
user.Premium = true;
user.MaxStorageGb = 10;
user.EmailVerified = true;
sutProvider.GetDependency<IUserRepository>().GetByIdAsync(user.Id).Returns(user);
sutProvider.GetDependency<IUserService>().CanAccessPremium(user).Returns(true);
// Act
var result = await sutProvider.Sut.StorageRemainingForSendAsync(send);
// Assert - should NOT call pricing service for individual premium users
await sutProvider.GetDependency<IPricingClient>().DidNotReceive().GetAvailablePremiumPlan();
}
[Theory, BitAutoData]
public async Task StorageRemainingForSendAsync_SelfHosted_DoesNotCallPricingService(
SutProvider<SendValidationService> sutProvider,
Send send,
User user)
{
// Arrange
send.UserId = user.Id;
send.OrganizationId = null;
send.Type = SendType.File;
user.Premium = false;
user.EmailVerified = true;
sutProvider.GetDependency<Bit.Core.Settings.GlobalSettings>().SelfHosted = true;
sutProvider.GetDependency<IUserRepository>().GetByIdAsync(user.Id).Returns(user);
sutProvider.GetDependency<IUserService>().CanAccessPremium(user).Returns(true);
// Act
var result = await sutProvider.Sut.StorageRemainingForSendAsync(send);
// Assert - should NOT call pricing service for self-hosted
await sutProvider.GetDependency<IPricingClient>().DidNotReceive().GetAvailablePremiumPlan();
}
[Theory, BitAutoData]
public async Task StorageRemainingForSendAsync_OrgSend_DoesNotCallPricingService(
SutProvider<SendValidationService> sutProvider,
Send send,
Organization org)
{
// Arrange
send.UserId = null;
send.OrganizationId = org.Id;
send.Type = SendType.File;
org.MaxStorageGb = 100;
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(org.Id).Returns(org);
// Act
var result = await sutProvider.Sut.StorageRemainingForSendAsync(send);
// Assert - should NOT call pricing service for org sends
await sutProvider.GetDependency<IPricingClient>().DidNotReceive().GetAvailablePremiumPlan();
}
}

View File

@@ -74,8 +74,7 @@ public class LoggerFactoryExtensionsTests
logger.LogWarning("This is a test");
// Writing to the file is buffered, give it a little time to flush
await Task.Delay(5);
await provider.DisposeAsync();
var logFile = Assert.Single(tempDir.EnumerateFiles("Logs/*.log"));
@@ -90,13 +89,67 @@ public class LoggerFactoryExtensionsTests
logFileContents
);
}
[Fact]
public async Task AddSerilogFileLogging_LegacyConfig_WithLevelCustomization_InfoLogs_DoNotFillUpFile()
{
await AssertSmallFileAsync((tempDir, config) =>
{
config["GlobalSettings:LogDirectory"] = tempDir;
config["Logging:LogLevel:Microsoft.AspNetCore"] = "Warning";
});
}
[Fact]
public async Task AddSerilogFileLogging_NewConfig_WithLevelCustomization_InfoLogs_DoNotFillUpFile()
{
await AssertSmallFileAsync((tempDir, config) =>
{
config["Logging:PathFormat"] = Path.Combine(tempDir, "log.txt");
config["Logging:LogLevel:Microsoft.AspNetCore"] = "Warning";
});
}
private static async Task AssertSmallFileAsync(Action<string, Dictionary<string, string?>> configure)
{
using var tempDir = new TempDirectory();
var config = new Dictionary<string, string?>();
configure(tempDir.Directory, config);
var provider = GetServiceProvider(config, "Production");
var loggerFactory = provider.GetRequiredService<ILoggerFactory>();
var microsoftLogger = loggerFactory.CreateLogger("Microsoft.AspNetCore.Testing");
for (var i = 0; i < 100; i++)
{
microsoftLogger.LogInformation("Tons of useless information");
}
var otherLogger = loggerFactory.CreateLogger("Bitwarden");
for (var i = 0; i < 5; i++)
{
otherLogger.LogInformation("Mildly more useful information but not as frequent.");
}
await provider.DisposeAsync();
var logFiles = Directory.EnumerateFiles(tempDir.Directory, "*.txt", SearchOption.AllDirectories);
var logFile = Assert.Single(logFiles);
using var fr = File.OpenRead(logFile);
Assert.InRange(fr.Length, 0, 1024);
}
private static IEnumerable<ILoggerProvider> GetProviders(Dictionary<string, string?> initialData, string environment = "Production")
{
var provider = GetServiceProvider(initialData, environment);
return provider.GetServices<ILoggerProvider>();
}
private static IServiceProvider GetServiceProvider(Dictionary<string, string?> initialData, string environment)
private static ServiceProvider GetServiceProvider(Dictionary<string, string?> initialData, string environment)
{
var config = new ConfigurationBuilder()
.AddInMemoryCollection(initialData)

View File

@@ -6,6 +6,8 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Pricing.Premium;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
@@ -2228,10 +2230,6 @@ public class CipherServiceTests
.PushSyncCiphersAsync(deletingUserId);
}
[Theory]
[OrganizationCipherCustomize]
[BitAutoData]
@@ -2387,6 +2385,186 @@ public class CipherServiceTests
ids.Count() == cipherIds.Length && ids.All(id => cipherIds.Contains(id))));
}
[Theory, BitAutoData]
public async Task CreateAttachmentAsync_UserWithOrgGrantedPremium_UsesStorageFromPricingClient(
SutProvider<CipherService> sutProvider, CipherDetails cipher, Guid savingUserId)
{
var stream = new MemoryStream(new byte[100]);
var fileName = "test.txt";
var key = "test-key";
// Setup cipher with user ownership
cipher.UserId = savingUserId;
cipher.OrganizationId = null;
// Setup user WITHOUT personal premium (Premium = false), but with org-granted premium access
var user = new User
{
Id = savingUserId,
Premium = false, // User does not have personal premium
MaxStorageGb = null, // No personal storage allocation
Storage = 0 // No storage used yet
};
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(savingUserId)
.Returns(user);
// User has premium access through their organization
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(true);
// Mock GlobalSettings to indicate cloud (not self-hosted)
sutProvider.GetDependency<Bit.Core.Settings.GlobalSettings>().SelfHosted = false;
// Mock the PricingClient to return a premium plan with 1 GB of storage
var premiumPlan = new Plan
{
Name = "Premium",
Available = true,
Seat = new Purchasable { StripePriceId = "price_123", Price = 10, Provided = 1 },
Storage = new Purchasable { StripePriceId = "price_456", Price = 4, Provided = 1 }
};
sutProvider.GetDependency<IPricingClient>()
.GetAvailablePremiumPlan()
.Returns(premiumPlan);
sutProvider.GetDependency<IAttachmentStorageService>()
.UploadNewAttachmentAsync(Arg.Any<Stream>(), cipher, Arg.Any<CipherAttachment.MetaData>())
.Returns(Task.CompletedTask);
sutProvider.GetDependency<IAttachmentStorageService>()
.ValidateFileAsync(cipher, Arg.Any<CipherAttachment.MetaData>(), Arg.Any<long>())
.Returns((true, 100L));
sutProvider.GetDependency<ICipherRepository>()
.UpdateAttachmentAsync(Arg.Any<CipherAttachment>())
.Returns(Task.CompletedTask);
sutProvider.GetDependency<ICipherRepository>()
.ReplaceAsync(Arg.Any<CipherDetails>())
.Returns(Task.CompletedTask);
// Act
await sutProvider.Sut.CreateAttachmentAsync(cipher, stream, fileName, key, 100, savingUserId, false, cipher.RevisionDate);
// Assert - PricingClient was called to get the premium plan storage
await sutProvider.GetDependency<IPricingClient>().Received(1).GetAvailablePremiumPlan();
// Assert - Attachment was uploaded successfully
await sutProvider.GetDependency<IAttachmentStorageService>().Received(1)
.UploadNewAttachmentAsync(Arg.Any<Stream>(), cipher, Arg.Any<CipherAttachment.MetaData>());
}
[Theory, BitAutoData]
public async Task CreateAttachmentAsync_UserWithOrgGrantedPremium_ExceedsStorage_ThrowsBadRequest(
SutProvider<CipherService> sutProvider, CipherDetails cipher, Guid savingUserId)
{
var stream = new MemoryStream(new byte[100]);
var fileName = "test.txt";
var key = "test-key";
// Setup cipher with user ownership
cipher.UserId = savingUserId;
cipher.OrganizationId = null;
// Setup user WITHOUT personal premium, with org-granted access, but storage is full
var user = new User
{
Id = savingUserId,
Premium = false,
MaxStorageGb = null,
Storage = 1073741824 // 1 GB already used (equals the provided storage)
};
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(savingUserId)
.Returns(user);
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(true);
sutProvider.GetDependency<Bit.Core.Settings.GlobalSettings>().SelfHosted = false;
// Premium plan provides 1 GB of storage
var premiumPlan = new Plan
{
Name = "Premium",
Available = true,
Seat = new Purchasable { StripePriceId = "price_123", Price = 10, Provided = 1 },
Storage = new Purchasable { StripePriceId = "price_456", Price = 4, Provided = 1 }
};
sutProvider.GetDependency<IPricingClient>()
.GetAvailablePremiumPlan()
.Returns(premiumPlan);
// Act & Assert - Should throw because storage is full
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.CreateAttachmentAsync(cipher, stream, fileName, key, 100, savingUserId, false, cipher.RevisionDate));
Assert.Contains("Not enough storage available", exception.Message);
}
[Theory, BitAutoData]
public async Task CreateAttachmentAsync_UserWithOrgGrantedPremium_SelfHosted_UsesConstantStorage(
SutProvider<CipherService> sutProvider, CipherDetails cipher, Guid savingUserId)
{
var stream = new MemoryStream(new byte[100]);
var fileName = "test.txt";
var key = "test-key";
// Setup cipher with user ownership
cipher.UserId = savingUserId;
cipher.OrganizationId = null;
// Setup user WITHOUT personal premium, but with org-granted premium access
var user = new User
{
Id = savingUserId,
Premium = false,
MaxStorageGb = null,
Storage = 0
};
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(savingUserId)
.Returns(user);
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(true);
// Mock GlobalSettings to indicate self-hosted
sutProvider.GetDependency<Bit.Core.Settings.GlobalSettings>().SelfHosted = true;
sutProvider.GetDependency<IAttachmentStorageService>()
.UploadNewAttachmentAsync(Arg.Any<Stream>(), cipher, Arg.Any<CipherAttachment.MetaData>())
.Returns(Task.CompletedTask);
sutProvider.GetDependency<IAttachmentStorageService>()
.ValidateFileAsync(cipher, Arg.Any<CipherAttachment.MetaData>(), Arg.Any<long>())
.Returns((true, 100L));
sutProvider.GetDependency<ICipherRepository>()
.UpdateAttachmentAsync(Arg.Any<CipherAttachment>())
.Returns(Task.CompletedTask);
sutProvider.GetDependency<ICipherRepository>()
.ReplaceAsync(Arg.Any<CipherDetails>())
.Returns(Task.CompletedTask);
// Act
await sutProvider.Sut.CreateAttachmentAsync(cipher, stream, fileName, key, 100, savingUserId, false, cipher.RevisionDate);
// Assert - PricingClient should NOT be called for self-hosted
await sutProvider.GetDependency<IPricingClient>().DidNotReceive().GetAvailablePremiumPlan();
// Assert - Attachment was uploaded successfully
await sutProvider.GetDependency<IAttachmentStorageService>().Received(1)
.UploadNewAttachmentAsync(Arg.Any<Stream>(), cipher, Arg.Any<CipherAttachment.MetaData>());
}
private async Task AssertNoActionsAsync(SutProvider<CipherService> sutProvider)
{
await sutProvider.GetDependency<ICipherRepository>().DidNotReceiveWithAnyArgs().GetManyOrganizationDetailsByOrganizationIdAsync(default);

View File

@@ -18,6 +18,7 @@ using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Identity.IdentityServer;
using Bit.Identity.IdentityServer.RequestValidationConstants;
using Bit.Identity.IdentityServer.RequestValidators;
using Bit.Identity.Test.Wrappers;
using Bit.Test.Common.AutoFixture.Attributes;
@@ -130,7 +131,7 @@ public class BaseRequestValidatorTests
var logs = _logger.Collector.GetSnapshot(true);
Assert.Contains(logs,
l => l.Level == LogLevel.Warning && l.Message == "Failed login attempt. Is2FARequest: False IpAddress: ");
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel];
Assert.Equal("Username or password is incorrect. Try again.", errorResponse.Message);
}
@@ -161,7 +162,11 @@ public class BaseRequestValidatorTests
.ValidateRequestDeviceAsync(tokenRequest, requestContext)
.Returns(Task.FromResult(false));
// 5 -> not legacy user
// 5 -> SSO not required
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// 6 -> not legacy user
_userService.IsLegacyUser(Arg.Any<string>())
.Returns(false);
@@ -203,6 +208,11 @@ public class BaseRequestValidatorTests
_userService.IsLegacyUser(Arg.Any<string>())
.Returns(false);
// 6 -> SSO validation passes
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// 7 -> setup user account keys
_userAccountKeysQuery.Run(Arg.Any<User>()).Returns(new UserAccountKeysData
{
PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData(
@@ -262,6 +272,11 @@ public class BaseRequestValidatorTests
_userService.IsLegacyUser(Arg.Any<string>())
.Returns(false);
// 6 -> SSO validation passes
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// 7 -> setup user account keys
_userAccountKeysQuery.Run(Arg.Any<User>()).Returns(new UserAccountKeysData
{
PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData(
@@ -326,6 +341,9 @@ public class BaseRequestValidatorTests
{ "TwoFactorProviders2", new Dictionary<string, object> { { "Email", null } } }
}));
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
@@ -368,6 +386,10 @@ public class BaseRequestValidatorTests
.VerifyTwoFactorAsync(user, null, TwoFactorProviderType.Email, "invalid_token")
.Returns(Task.FromResult(false));
// 5 -> set up SSO required verification to succeed
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
@@ -396,21 +418,25 @@ public class BaseRequestValidatorTests
// 1 -> initial validation passes
_sut.isValid = true;
// 2 -> set up 2FA as required
// 2 -> set up SSO required verification to succeed
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// 3 -> set up 2FA as required
_twoFactorAuthenticationValidator
.RequiresTwoFactorAsync(Arg.Any<User>(), tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(true, null)));
// 3 -> provide invalid remember token (remember token expired)
// 4 -> provide invalid remember token (remember token expired)
tokenRequest.Raw["TwoFactorToken"] = "expired_remember_token";
tokenRequest.Raw["TwoFactorProvider"] = "5"; // Remember provider
// 4 -> set up remember token verification to fail
// 5 -> set up remember token verification to fail
_twoFactorAuthenticationValidator
.VerifyTwoFactorAsync(user, null, TwoFactorProviderType.Remember, "expired_remember_token")
.Returns(Task.FromResult(false));
// 5 -> set up dummy BuildTwoFactorResultAsync
// 6 -> set up dummy BuildTwoFactorResultAsync
var twoFactorResultDict = new Dictionary<string, object>
{
{ "TwoFactorProviders", new[] { "0", "1" } },
@@ -446,6 +472,19 @@ public class BaseRequestValidatorTests
GrantValidationResult grantResult)
{
// Arrange
// SsoRequestValidator sets custom response
requestContext.ValidationErrorResult = new ValidationResult
{
IsError = true,
Error = SsoConstants.RequestErrors.SsoRequired,
ErrorDescription = SsoConstants.RequestErrors.SsoRequiredDescription
};
requestContext.CustomResponse = new Dictionary<string, object>
{
{ CustomResponseConstants.ResponseKeys.ErrorModel, new ErrorResponseModel(SsoConstants.RequestErrors.SsoRequiredDescription) },
};
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
@@ -454,13 +493,17 @@ public class BaseRequestValidatorTests
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(true));
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(false));
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.True(context.GrantResult.IsError);
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
Assert.Equal("SSO authentication is required.", errorResponse.Message);
Assert.NotNull(context.GrantResult.CustomResponse);
var errorResponse = (ErrorResponseModel)context.CustomValidatorRequestContext.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel];
Assert.Equal(SsoConstants.RequestErrors.SsoRequiredDescription, errorResponse.Message);
}
// Test grantTypes with RequireSsoPolicyRequirement when feature flag is enabled
@@ -477,6 +520,20 @@ public class BaseRequestValidatorTests
{
// Arrange
_featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true);
// SsoRequestValidator sets custom response with organization identifier
requestContext.ValidationErrorResult = new ValidationResult
{
IsError = true,
Error = SsoConstants.RequestErrors.SsoRequired,
ErrorDescription = SsoConstants.RequestErrors.SsoRequiredDescription
};
requestContext.CustomResponse = new Dictionary<string, object>
{
{ CustomResponseConstants.ResponseKeys.ErrorModel, new ErrorResponseModel(SsoConstants.RequestErrors.SsoRequiredDescription) },
{ CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier, "test-org-identifier" }
};
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
@@ -485,6 +542,10 @@ public class BaseRequestValidatorTests
var requirement = new RequireSsoPolicyRequirement { SsoRequired = true };
_policyRequirementQuery.GetAsync<RequireSsoPolicyRequirement>(Arg.Any<Guid>()).Returns(requirement);
// Mock the SSO validator to return false
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(false));
// Act
await _sut.ValidateAsync(context);
@@ -492,8 +553,9 @@ public class BaseRequestValidatorTests
await _policyService.DidNotReceive().AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed);
Assert.True(context.GrantResult.IsError);
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
Assert.Equal("SSO authentication is required.", errorResponse.Message);
Assert.NotNull(context.GrantResult.CustomResponse);
var errorResponse = (ErrorResponseModel)context.CustomValidatorRequestContext.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel];
Assert.Equal(SsoConstants.RequestErrors.SsoRequiredDescription, errorResponse.Message);
}
[Theory]
@@ -519,6 +581,10 @@ public class BaseRequestValidatorTests
var requirement = new RequireSsoPolicyRequirement { SsoRequired = false };
_policyRequirementQuery.GetAsync<RequireSsoPolicyRequirement>(Arg.Any<Guid>()).Returns(requirement);
// SSO validation passes
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
_twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
@@ -561,6 +627,11 @@ public class BaseRequestValidatorTests
_policyService.AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(false));
// SSO validation passes
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
_twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
@@ -603,6 +674,10 @@ public class BaseRequestValidatorTests
context.ValidatedTokenRequest.GrantType = grantType;
// SSO validation passes
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
_twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
@@ -652,13 +727,15 @@ public class BaseRequestValidatorTests
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
.Returns(Task.FromResult(true));
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.True(context.GrantResult.IsError);
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel];
var expectedMessage =
"Legacy encryption without a userkey is no longer supported. To recover your account, please contact support";
Assert.Equal(expectedMessage, errorResponse.Message);
@@ -694,6 +771,10 @@ public class BaseRequestValidatorTests
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
// SSO validation passes
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
_twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
@@ -760,6 +841,8 @@ public class BaseRequestValidatorTests
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
.Returns(Task.FromResult(true));
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
@@ -833,6 +916,8 @@ public class BaseRequestValidatorTests
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
.Returns(Task.FromResult(true));
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
@@ -877,6 +962,8 @@ public class BaseRequestValidatorTests
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
.Returns(Task.FromResult(true));
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
@@ -921,6 +1008,8 @@ public class BaseRequestValidatorTests
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
.Returns(Task.FromResult(true));
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
@@ -950,6 +1039,19 @@ public class BaseRequestValidatorTests
GrantValidationResult grantResult)
{
// Arrange
// SsoRequestValidator sets custom response
requestContext.ValidationErrorResult = new ValidationResult
{
IsError = true,
Error = SsoConstants.RequestErrors.SsoRequired,
ErrorDescription = SsoConstants.RequestErrors.SsoRequiredDescription
};
requestContext.CustomResponse = new Dictionary<string, object>
{
{ CustomResponseConstants.ResponseKeys.ErrorModel, new ErrorResponseModel(SsoConstants.RequestErrors.SsoRequiredDescription) },
};
var context = CreateContext(tokenRequest, requestContext, grantResult);
var user = requestContext.User;
@@ -984,12 +1086,12 @@ public class BaseRequestValidatorTests
// Assert
Assert.True(context.GrantResult.IsError, "Authentication should fail - SSO required after recovery");
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
Assert.NotNull(context.GrantResult.CustomResponse);
var errorResponse = (ErrorResponseModel)context.CustomValidatorRequestContext.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel];
// Recovery succeeds, then SSO blocks with descriptive message
Assert.Equal(
"Two-factor recovery has been performed. SSO authentication is required.",
SsoConstants.RequestErrors.SsoRequiredDescription,
errorResponse.Message);
// Verify recovery was marked
@@ -1050,7 +1152,7 @@ public class BaseRequestValidatorTests
// Assert
Assert.True(context.GrantResult.IsError, "Authentication should fail - invalid recovery code");
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel];
// 2FA is checked first (due to recovery code request), fails with 2FA error
Assert.Equal(
@@ -1132,7 +1234,11 @@ public class BaseRequestValidatorTests
_userService.IsLegacyUser(Arg.Any<string>())
.Returns(false);
// 8. Setup user account keys for successful login response
// 8. SSO is not required
_ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// 9. Setup user account keys for successful login response
_userAccountKeysQuery.Run(Arg.Any<User>()).Returns(new UserAccountKeysData
{
PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData(
@@ -1161,179 +1267,18 @@ public class BaseRequestValidatorTests
}
/// <summary>
/// Tests that when RedirectOnSsoRequired is DISABLED, the legacy SSO validation path is used.
/// This validates the deprecated RequireSsoLoginAsync method is called and SSO requirement
/// is checked using the old PolicyService.AnyPoliciesApplicableToUserAsync approach.
/// Tests that when SSO validation returns a custom response, (e.g., with organization identifier),
/// that custom response is properly propagated to the result.
/// </summary>
[Theory]
[BitAutoData]
public async Task ValidateAsync_RedirectOnSsoRequired_Disabled_UsesLegacySsoValidation(
public async Task ValidateAsync_SsoRequired_PropagatesCustomResponse(
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext]
CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(false);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
tokenRequest.GrantType = OidcConstants.GrantTypes.Password;
// SSO is required via legacy path
_policyService.AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.True(context.GrantResult.IsError);
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
Assert.Equal("SSO authentication is required.", errorResponse.Message);
// Verify legacy path was used
await _policyService.Received(1).AnyPoliciesApplicableToUserAsync(
requestContext.User.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed);
// Verify new SsoRequestValidator was NOT called
await _ssoRequestValidator.DidNotReceive().ValidateAsync(
Arg.Any<User>(), Arg.Any<ValidatedTokenRequest>(), Arg.Any<CustomValidatorRequestContext>());
}
/// <summary>
/// Tests that when RedirectOnSsoRequired is ENABLED, the new ISsoRequestValidator is used
/// instead of the legacy RequireSsoLoginAsync method.
/// </summary>
[Theory]
[BitAutoData]
public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_UsesNewSsoRequestValidator(
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext]
CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
tokenRequest.GrantType = OidcConstants.GrantTypes.Password;
// Configure SsoRequestValidator to indicate SSO is required
_ssoRequestValidator.ValidateAsync(
Arg.Any<User>(),
Arg.Any<ValidatedTokenRequest>(),
Arg.Any<CustomValidatorRequestContext>())
.Returns(Task.FromResult(false)); // false = SSO required
// Set up the ValidationErrorResult that SsoRequestValidator would set
requestContext.ValidationErrorResult = new ValidationResult
{
IsError = true,
Error = "sso_required",
ErrorDescription = "SSO authentication is required."
};
requestContext.CustomResponse = new Dictionary<string, object>
{
{ "ErrorModel", new ErrorResponseModel("SSO authentication is required.") }
};
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.True(context.GrantResult.IsError);
// Verify new SsoRequestValidator was called
await _ssoRequestValidator.Received(1).ValidateAsync(
requestContext.User,
tokenRequest,
requestContext);
// Verify legacy path was NOT used
await _policyService.DidNotReceive().AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), Arg.Any<PolicyType>(), Arg.Any<OrganizationUserStatusType>());
}
/// <summary>
/// Tests that when RedirectOnSsoRequired is ENABLED and SSO is NOT required,
/// authentication continues successfully through the new validation path.
/// </summary>
[Theory]
[BitAutoData]
public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_SsoNotRequired_SuccessfulLogin(
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext]
CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
tokenRequest.GrantType = OidcConstants.GrantTypes.Password;
tokenRequest.ClientId = "web";
// SsoRequestValidator returns true (SSO not required)
_ssoRequestValidator.ValidateAsync(
Arg.Any<User>(),
Arg.Any<ValidatedTokenRequest>(),
Arg.Any<CustomValidatorRequestContext>())
.Returns(Task.FromResult(true));
// No 2FA required
_twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
// Device validation passes
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// User is not legacy
_userService.IsLegacyUser(Arg.Any<string>()).Returns(false);
_userAccountKeysQuery.Run(Arg.Any<User>()).Returns(new UserAccountKeysData
{
PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData(
"test-private-key",
"test-public-key"
)
});
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.False(context.GrantResult.IsError);
await _eventService.Received(1).LogUserEventAsync(requestContext.User.Id, EventType.User_LoggedIn);
// Verify new validator was used
await _ssoRequestValidator.Received(1).ValidateAsync(
requestContext.User,
tokenRequest,
requestContext);
}
/// <summary>
/// Tests that when RedirectOnSsoRequired is ENABLED and SSO validation returns a custom response
/// (e.g., with organization identifier), that custom response is properly propagated to the result.
/// </summary>
[Theory]
[BitAutoData]
public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_PropagatesCustomResponse(
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext]
CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true);
_sut.isValid = true;
tokenRequest.GrantType = OidcConstants.GrantTypes.Password;
@@ -1342,13 +1287,13 @@ public class BaseRequestValidatorTests
requestContext.ValidationErrorResult = new ValidationResult
{
IsError = true,
Error = "sso_required",
ErrorDescription = "SSO authentication is required."
Error = SsoConstants.RequestErrors.SsoRequired,
ErrorDescription = SsoConstants.RequestErrors.SsoRequiredDescription
};
requestContext.CustomResponse = new Dictionary<string, object>
{
{ "ErrorModel", new ErrorResponseModel("SSO authentication is required.") },
{ "SsoOrganizationIdentifier", "test-org-identifier" }
{ CustomResponseConstants.ResponseKeys.ErrorModel, new ErrorResponseModel(SsoConstants.RequestErrors.SsoRequiredDescription) },
{ CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier, "test-org-identifier" }
};
var context = CreateContext(tokenRequest, requestContext, grantResult);
@@ -1365,77 +1310,24 @@ public class BaseRequestValidatorTests
// Assert
Assert.True(context.GrantResult.IsError);
Assert.NotNull(context.GrantResult.CustomResponse);
Assert.Contains("SsoOrganizationIdentifier", context.CustomValidatorRequestContext.CustomResponse);
Assert.Contains(CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier, context.CustomValidatorRequestContext.CustomResponse);
Assert.Equal("test-org-identifier",
context.CustomValidatorRequestContext.CustomResponse["SsoOrganizationIdentifier"]);
context.CustomValidatorRequestContext.CustomResponse[CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier]);
}
/// <summary>
/// Tests that when RedirectOnSsoRequired is DISABLED and a user with 2FA recovery completes recovery,
/// but SSO is required, the legacy error message is returned (without the recovery-specific message).
/// </summary>
[Theory]
[BitAutoData]
public async Task ValidateAsync_RedirectOnSsoRequired_Disabled_RecoveryWithSso_LegacyMessage(
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext]
CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(false);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
// Recovery code scenario
tokenRequest.Raw["TwoFactorProvider"] = ((int)TwoFactorProviderType.RecoveryCode).ToString();
tokenRequest.Raw["TwoFactorToken"] = "valid-recovery-code";
// 2FA with recovery
_twoFactorAuthenticationValidator
.RequiresTwoFactorAsync(requestContext.User, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(true, null)));
_twoFactorAuthenticationValidator
.VerifyTwoFactorAsync(requestContext.User, null, TwoFactorProviderType.RecoveryCode, "valid-recovery-code")
.Returns(Task.FromResult(true));
// SSO is required (legacy check)
_policyService.AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.True(context.GrantResult.IsError);
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
// Legacy behavior: recovery-specific message IS shown even without RedirectOnSsoRequired
Assert.Equal("Two-factor recovery has been performed. SSO authentication is required.", errorResponse.Message);
// But legacy validation path was used
await _policyService.Received(1).AnyPoliciesApplicableToUserAsync(
requestContext.User.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed);
}
/// <summary>
/// Tests that when RedirectOnSsoRequired is ENABLED and recovery code is used for SSO-required user,
/// Tests that when a recovery code is used for SSO-required user,
/// the SsoRequestValidator provides the recovery-specific error message.
/// </summary>
[Theory]
[BitAutoData]
public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_RecoveryWithSso_NewValidatorMessage(
public async Task ValidateAsync_RecoveryWithSso_CorrectValidatorMessage(
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext]
CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
@@ -1457,14 +1349,14 @@ public class BaseRequestValidatorTests
requestContext.ValidationErrorResult = new ValidationResult
{
IsError = true,
Error = "sso_required",
ErrorDescription = "Two-factor recovery has been performed. SSO authentication is required."
Error = SsoConstants.RequestErrors.SsoRequired,
ErrorDescription = SsoConstants.RequestErrors.SsoTwoFactorRecoveryDescription
};
requestContext.CustomResponse = new Dictionary<string, object>
{
{
"ErrorModel",
new ErrorResponseModel("Two-factor recovery has been performed. SSO authentication is required.")
CustomResponseConstants.ResponseKeys.ErrorModel,
new ErrorResponseModel(SsoConstants.RequestErrors.SsoTwoFactorRecoveryDescription)
}
};
@@ -1479,18 +1371,8 @@ public class BaseRequestValidatorTests
// Assert
Assert.True(context.GrantResult.IsError);
var errorResponse = (ErrorResponseModel)context.CustomValidatorRequestContext.CustomResponse["ErrorModel"];
Assert.Equal("Two-factor recovery has been performed. SSO authentication is required.", errorResponse.Message);
// Verify new validator was used
await _ssoRequestValidator.Received(1).ValidateAsync(
requestContext.User,
tokenRequest,
Arg.Is<CustomValidatorRequestContext>(ctx => ctx.TwoFactorRecoveryRequested));
// Verify legacy path was NOT used
await _policyService.DidNotReceive().AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), Arg.Any<PolicyType>(), Arg.Any<OrganizationUserStatusType>());
var errorResponse = (ErrorResponseModel)context.CustomValidatorRequestContext.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel];
Assert.Equal(SsoConstants.RequestErrors.SsoTwoFactorRecoveryDescription, errorResponse.Message);
}
private BaseRequestValidationContextFake CreateContext(

View File

@@ -111,15 +111,6 @@ IBaseRequestValidatorTestWrapper
context.GrantResult = new GrantValidationResult(TokenRequestErrors.InvalidGrant, customResponse: customResponse);
}
[Obsolete]
protected override void SetSsoResult(
BaseRequestValidationContextFake context,
Dictionary<string, object> customResponse)
{
context.GrantResult = new GrantValidationResult(
TokenRequestErrors.InvalidGrant, "Sso authentication required.", customResponse);
}
protected override Task SetSuccessResult(
BaseRequestValidationContextFake context,
User user,

View File

@@ -0,0 +1,53 @@
using Bit.Core.Repositories;
using Xunit;
namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.CollectionRepository;
public class CreateDefaultCollectionsBulkAsyncTests
{
[Theory, DatabaseData]
public async Task CreateDefaultCollectionsBulkAsync_CreatesDefaultCollections_Success(
IOrganizationRepository organizationRepository,
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
ICollectionRepository collectionRepository)
{
await CreateDefaultCollectionsSharedTests.CreatesDefaultCollections_Success(
collectionRepository.CreateDefaultCollectionsBulkAsync,
organizationRepository,
userRepository,
organizationUserRepository,
collectionRepository);
}
[Theory, DatabaseData]
public async Task CreateDefaultCollectionsBulkAsync_CreatesForNewUsersOnly_AndIgnoresExistingUsers(
IOrganizationRepository organizationRepository,
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
ICollectionRepository collectionRepository)
{
await CreateDefaultCollectionsSharedTests.CreatesForNewUsersOnly_AndIgnoresExistingUsers(
collectionRepository.CreateDefaultCollectionsBulkAsync,
organizationRepository,
userRepository,
organizationUserRepository,
collectionRepository);
}
[Theory, DatabaseData]
public async Task CreateDefaultCollectionsBulkAsync_IgnoresAllExistingUsers(
IOrganizationRepository organizationRepository,
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
ICollectionRepository collectionRepository)
{
await CreateDefaultCollectionsSharedTests.IgnoresAllExistingUsers(
collectionRepository.CreateDefaultCollectionsBulkAsync,
organizationRepository,
userRepository,
organizationUserRepository,
collectionRepository);
}
}

View File

@@ -6,10 +6,14 @@ using Xunit;
namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.CollectionRepository;
public class UpsertDefaultCollectionsTests
/// <summary>
/// Shared tests for CreateDefaultCollections methods - both bulk and non-bulk implementations,
/// as they share the same behavior. Both test suites call the tests in this class.
/// </summary>
public static class CreateDefaultCollectionsSharedTests
{
[Theory, DatabaseData]
public async Task UpsertDefaultCollectionsAsync_ShouldCreateDefaultCollection_WhenUsersDoNotHaveDefaultCollection(
public static async Task CreatesDefaultCollections_Success(
Func<Guid, IEnumerable<Guid>, string, Task> createDefaultCollectionsFunc,
IOrganizationRepository organizationRepository,
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
@@ -21,14 +25,13 @@ public class UpsertDefaultCollectionsTests
var resultOrganizationUsers = await Task.WhenAll(
CreateUserForOrgAsync(userRepository, organizationUserRepository, organization),
CreateUserForOrgAsync(userRepository, organizationUserRepository, organization)
);
);
var affectedOrgUserIds = resultOrganizationUsers.Select(organizationUser => organizationUser.Id);
var affectedOrgUserIds = resultOrganizationUsers.Select(organizationUser => organizationUser.Id).ToList();
var defaultCollectionName = $"default-name-{organization.Id}";
// Act
await collectionRepository.UpsertDefaultCollectionsAsync(organization.Id, affectedOrgUserIds, defaultCollectionName);
await createDefaultCollectionsFunc(organization.Id, affectedOrgUserIds, defaultCollectionName);
// Assert
await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, resultOrganizationUsers, organization.Id);
@@ -36,8 +39,8 @@ public class UpsertDefaultCollectionsTests
await CleanupAsync(organizationRepository, userRepository, organization, resultOrganizationUsers);
}
[Theory, DatabaseData]
public async Task UpsertDefaultCollectionsAsync_ShouldUpsertCreateDefaultCollection_ForUsersWithAndWithoutDefaultCollectionsExist(
public static async Task CreatesForNewUsersOnly_AndIgnoresExistingUsers(
Func<Guid, IEnumerable<Guid>, string, Task> createDefaultCollectionsFunc,
IOrganizationRepository organizationRepository,
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
@@ -51,31 +54,30 @@ public class UpsertDefaultCollectionsTests
CreateUserForOrgAsync(userRepository, organizationUserRepository, organization)
);
var arrangedOrgUserIds = arrangedOrganizationUsers.Select(organizationUser => organizationUser.Id);
var arrangedOrgUserIds = arrangedOrganizationUsers.Select(organizationUser => organizationUser.Id).ToList();
var defaultCollectionName = $"default-name-{organization.Id}";
await CreateUsersWithExistingDefaultCollectionsAsync(createDefaultCollectionsFunc, collectionRepository, organization.Id, arrangedOrgUserIds, defaultCollectionName, arrangedOrganizationUsers);
await CreateUsersWithExistingDefaultCollectionsAsync(collectionRepository, organization.Id, arrangedOrgUserIds, defaultCollectionName, arrangedOrganizationUsers);
var newOrganizationUsers = new List<OrganizationUser>()
var newOrganizationUsers = new List<OrganizationUser>
{
await CreateUserForOrgAsync(userRepository, organizationUserRepository, organization)
};
var affectedOrgUsers = newOrganizationUsers.Concat(arrangedOrganizationUsers);
var affectedOrgUserIds = affectedOrgUsers.Select(organizationUser => organizationUser.Id);
var affectedOrgUsers = newOrganizationUsers.Concat(arrangedOrganizationUsers).ToList();
var affectedOrgUserIds = affectedOrgUsers.Select(organizationUser => organizationUser.Id).ToList();
// Act
await collectionRepository.UpsertDefaultCollectionsAsync(organization.Id, affectedOrgUserIds, defaultCollectionName);
await createDefaultCollectionsFunc(organization.Id, affectedOrgUserIds, defaultCollectionName);
// Assert
await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, arrangedOrganizationUsers, organization.Id);
await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, affectedOrgUsers, organization.Id);
await CleanupAsync(organizationRepository, userRepository, organization, affectedOrgUsers);
}
[Theory, DatabaseData]
public async Task UpsertDefaultCollectionsAsync_ShouldNotCreateDefaultCollection_WhenUsersAlreadyHaveOne(
public static async Task IgnoresAllExistingUsers(
Func<Guid, IEnumerable<Guid>, string, Task> createDefaultCollectionsFunc,
IOrganizationRepository organizationRepository,
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
@@ -89,26 +91,29 @@ public class UpsertDefaultCollectionsTests
CreateUserForOrgAsync(userRepository, organizationUserRepository, organization)
);
var affectedOrgUserIds = resultOrganizationUsers.Select(organizationUser => organizationUser.Id);
var affectedOrgUserIds = resultOrganizationUsers.Select(organizationUser => organizationUser.Id).ToList();
var defaultCollectionName = $"default-name-{organization.Id}";
await CreateUsersWithExistingDefaultCollectionsAsync(createDefaultCollectionsFunc, collectionRepository, organization.Id, affectedOrgUserIds, defaultCollectionName, resultOrganizationUsers);
await CreateUsersWithExistingDefaultCollectionsAsync(collectionRepository, organization.Id, affectedOrgUserIds, defaultCollectionName, resultOrganizationUsers);
// Act - Try to create again, should silently filter and not create duplicates
await createDefaultCollectionsFunc(organization.Id, affectedOrgUserIds, defaultCollectionName);
// Act
await collectionRepository.UpsertDefaultCollectionsAsync(organization.Id, affectedOrgUserIds, defaultCollectionName);
// Assert
// Assert - Original collections should remain unchanged, still only one per user
await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, resultOrganizationUsers, organization.Id);
await CleanupAsync(organizationRepository, userRepository, organization, resultOrganizationUsers);
}
private static async Task CreateUsersWithExistingDefaultCollectionsAsync(ICollectionRepository collectionRepository,
Guid organizationId, IEnumerable<Guid> affectedOrgUserIds, string defaultCollectionName,
private static async Task CreateUsersWithExistingDefaultCollectionsAsync(
Func<Guid, IEnumerable<Guid>, string, Task> createDefaultCollectionsFunc,
ICollectionRepository collectionRepository,
Guid organizationId,
IEnumerable<Guid> affectedOrgUserIds,
string defaultCollectionName,
OrganizationUser[] resultOrganizationUsers)
{
await collectionRepository.UpsertDefaultCollectionsAsync(organizationId, affectedOrgUserIds, defaultCollectionName);
await createDefaultCollectionsFunc(organizationId, affectedOrgUserIds, defaultCollectionName);
await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, resultOrganizationUsers, organizationId);
}
@@ -131,7 +136,6 @@ public class UpsertDefaultCollectionsTests
private static async Task<OrganizationUser> CreateUserForOrgAsync(IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository, Organization organization)
{
var user = await userRepository.CreateTestUserAsync();
var orgUser = await organizationUserRepository.CreateTestOrganizationUserAsync(organization, user);

View File

@@ -0,0 +1,52 @@
using Bit.Core.Repositories;
using Xunit;
namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.CollectionRepository;
public class CreateDefaultCollectionsAsyncTests
{
[Theory, DatabaseData]
public async Task CreateDefaultCollectionsAsync_CreatesDefaultCollections_Success(
IOrganizationRepository organizationRepository,
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
ICollectionRepository collectionRepository)
{
await CreateDefaultCollectionsSharedTests.CreatesDefaultCollections_Success(
collectionRepository.CreateDefaultCollectionsAsync,
organizationRepository,
userRepository,
organizationUserRepository,
collectionRepository);
}
[Theory, DatabaseData]
public async Task CreateDefaultCollectionsAsync_CreatesForNewUsersOnly_AndIgnoresExistingUsers(
IOrganizationRepository organizationRepository,
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
ICollectionRepository collectionRepository)
{
await CreateDefaultCollectionsSharedTests.CreatesForNewUsersOnly_AndIgnoresExistingUsers(
collectionRepository.CreateDefaultCollectionsAsync,
organizationRepository,
userRepository,
organizationUserRepository,
collectionRepository);
}
[Theory, DatabaseData]
public async Task CreateDefaultCollectionsAsync_IgnoresAllExistingUsers(
IOrganizationRepository organizationRepository,
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
ICollectionRepository collectionRepository)
{
await CreateDefaultCollectionsSharedTests.IgnoresAllExistingUsers(
collectionRepository.CreateDefaultCollectionsAsync,
organizationRepository,
userRepository,
organizationUserRepository,
collectionRepository);
}
}

View File

@@ -0,0 +1,335 @@
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Xunit;
namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.OrganizationRepository;
public class GetByVerifiedUserEmailDomainAsyncTests
{
[Theory, DatabaseData]
public async Task GetByClaimedUserDomainAsync_WithVerifiedDomain_Success(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationDomainRepository organizationDomainRepository)
{
var id = Guid.NewGuid();
var domainName = $"{id}.example.com";
var user1 = await userRepository.CreateAsync(new User
{
Name = "Test User 1",
Email = $"test+{id}@{domainName}",
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var user2 = await userRepository.CreateAsync(new User
{
Name = "Test User 2",
Email = $"test+{id}@x-{domainName}", // Different domain
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var user3 = await userRepository.CreateAsync(new User
{
Name = "Test User 2",
Email = $"test+{id}@{domainName}.example.com", // Different domain
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var organization = await organizationRepository.CreateTestOrganizationAsync();
var organizationDomain = new OrganizationDomain
{
OrganizationId = organization.Id,
DomainName = domainName,
Txt = "btw+12345",
};
organizationDomain.SetVerifiedDate();
organizationDomain.SetNextRunDate(12);
organizationDomain.SetJobRunCount();
await organizationDomainRepository.CreateAsync(organizationDomain);
await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user1);
await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user2);
await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user3);
var user1Response = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user1.Id);
var user2Response = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user2.Id);
var user3Response = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user3.Id);
Assert.NotEmpty(user1Response);
Assert.Equal(organization.Id, user1Response.First().Id);
Assert.Empty(user2Response);
Assert.Empty(user3Response);
}
[Theory, DatabaseData]
public async Task GetByVerifiedUserEmailDomainAsync_WithUnverifiedDomains_ReturnsEmpty(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationDomainRepository organizationDomainRepository)
{
var id = Guid.NewGuid();
var domainName = $"{id}.example.com";
var user = await userRepository.CreateAsync(new User
{
Name = "Test User",
Email = $"test+{id}@{domainName}",
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var organization = await organizationRepository.CreateTestOrganizationAsync();
var organizationDomain = new OrganizationDomain
{
OrganizationId = organization.Id,
DomainName = domainName,
Txt = "btw+12345",
};
organizationDomain.SetNextRunDate(12);
organizationDomain.SetJobRunCount();
await organizationDomainRepository.CreateAsync(organizationDomain);
await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user);
var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user.Id);
Assert.Empty(result);
}
[Theory, DatabaseData]
public async Task GetByVerifiedUserEmailDomainAsync_WithMultipleVerifiedDomains_ReturnsAllMatchingOrganizations(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationDomainRepository organizationDomainRepository)
{
var id = Guid.NewGuid();
var domainName = $"{id}.example.com";
var user = await userRepository.CreateAsync(new User
{
Name = "Test User",
Email = $"test+{id}@{domainName}",
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var organization1 = await organizationRepository.CreateTestOrganizationAsync();
var organization2 = await organizationRepository.CreateTestOrganizationAsync();
var organizationDomain1 = new OrganizationDomain
{
OrganizationId = organization1.Id,
DomainName = domainName,
Txt = "btw+12345",
};
organizationDomain1.SetNextRunDate(12);
organizationDomain1.SetJobRunCount();
organizationDomain1.SetVerifiedDate();
await organizationDomainRepository.CreateAsync(organizationDomain1);
var organizationDomain2 = new OrganizationDomain
{
OrganizationId = organization2.Id,
DomainName = domainName,
Txt = "btw+67890",
};
organizationDomain2.SetNextRunDate(12);
organizationDomain2.SetJobRunCount();
organizationDomain2.SetVerifiedDate();
await organizationDomainRepository.CreateAsync(organizationDomain2);
await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization1, user);
await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization2, user);
var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user.Id);
Assert.Equal(2, result.Count);
Assert.Contains(result, org => org.Id == organization1.Id);
Assert.Contains(result, org => org.Id == organization2.Id);
}
[Theory, DatabaseData]
public async Task GetByVerifiedUserEmailDomainAsync_WithNonExistentUser_ReturnsEmpty(
IOrganizationRepository organizationRepository)
{
var nonExistentUserId = Guid.NewGuid();
var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(nonExistentUserId);
Assert.Empty(result);
}
/// <summary>
/// Tests an edge case where some invited users are created linked to a UserId.
/// This is defective behavior, but will take longer to fix - for now, we are defensive and expressly
/// exclude such users from the results without relying on the inner join only.
/// Invited-revoked users linked to a UserId remain intentionally unhandled for now as they have not caused
/// any issues to date and we want to minimize edge cases.
/// We will fix the underlying issue going forward: https://bitwarden.atlassian.net/browse/PM-22405
/// </summary>
[Theory, DatabaseData]
public async Task GetByVerifiedUserEmailDomainAsync_WithInvitedUserWithUserId_ReturnsEmpty(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationDomainRepository organizationDomainRepository)
{
var id = Guid.NewGuid();
var domainName = $"{id}.example.com";
var user = await userRepository.CreateAsync(new User
{
Name = "Test User",
Email = $"test+{id}@{domainName}",
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var organization = await organizationRepository.CreateTestOrganizationAsync();
var organizationDomain = new OrganizationDomain
{
OrganizationId = organization.Id,
DomainName = domainName,
Txt = "btw+12345",
};
organizationDomain.SetVerifiedDate();
organizationDomain.SetNextRunDate(12);
organizationDomain.SetJobRunCount();
await organizationDomainRepository.CreateAsync(organizationDomain);
// Create invited user with matching email domain but UserId set (edge case)
await organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organization.Id,
UserId = user.Id,
Email = user.Email,
Status = OrganizationUserStatusType.Invited,
});
var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user.Id);
// Invited users should be excluded even if they have UserId set
Assert.Empty(result);
}
[Theory, DatabaseData]
public async Task GetByVerifiedUserEmailDomainAsync_WithAcceptedUser_ReturnsOrganization(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationDomainRepository organizationDomainRepository)
{
var id = Guid.NewGuid();
var domainName = $"{id}.example.com";
var user = await userRepository.CreateAsync(new User
{
Name = "Test User",
Email = $"test+{id}@{domainName}",
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var organization = await organizationRepository.CreateTestOrganizationAsync();
var organizationDomain = new OrganizationDomain
{
OrganizationId = organization.Id,
DomainName = domainName,
Txt = "btw+12345",
};
organizationDomain.SetVerifiedDate();
organizationDomain.SetNextRunDate(12);
organizationDomain.SetJobRunCount();
await organizationDomainRepository.CreateAsync(organizationDomain);
await organizationUserRepository.CreateAcceptedTestOrganizationUserAsync(organization, user);
var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user.Id);
Assert.NotEmpty(result);
Assert.Equal(organization.Id, result.First().Id);
}
[Theory, DatabaseData]
public async Task GetByVerifiedUserEmailDomainAsync_WithRevokedUser_ReturnsOrganization(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationDomainRepository organizationDomainRepository)
{
var id = Guid.NewGuid();
var domainName = $"{id}.example.com";
var user = await userRepository.CreateAsync(new User
{
Name = "Test User",
Email = $"test+{id}@{domainName}",
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var organization = await organizationRepository.CreateTestOrganizationAsync();
var organizationDomain = new OrganizationDomain
{
OrganizationId = organization.Id,
DomainName = domainName,
Txt = "btw+12345",
};
organizationDomain.SetVerifiedDate();
organizationDomain.SetNextRunDate(12);
organizationDomain.SetJobRunCount();
await organizationDomainRepository.CreateAsync(organizationDomain);
await organizationUserRepository.CreateRevokedTestOrganizationUserAsync(organization, user);
var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user.Id);
Assert.NotEmpty(result);
Assert.Equal(organization.Id, result.First().Id);
}
}

View File

@@ -8,254 +8,7 @@ namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories;
public class OrganizationRepositoryTests
{
[DatabaseTheory, DatabaseData]
public async Task GetByClaimedUserDomainAsync_WithVerifiedDomain_Success(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationDomainRepository organizationDomainRepository)
{
var id = Guid.NewGuid();
var domainName = $"{id}.example.com";
var user1 = await userRepository.CreateAsync(new User
{
Name = "Test User 1",
Email = $"test+{id}@{domainName}",
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var user2 = await userRepository.CreateAsync(new User
{
Name = "Test User 2",
Email = $"test+{id}@x-{domainName}", // Different domain
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var user3 = await userRepository.CreateAsync(new User
{
Name = "Test User 2",
Email = $"test+{id}@{domainName}.example.com", // Different domain
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = $"Test Org {id}",
BillingEmail = user1.Email, // TODO: EF does not enforce this being NOT NULL
Plan = "Test", // TODO: EF does not enforce this being NOT NULL
PrivateKey = "privatekey",
});
var organizationDomain = new OrganizationDomain
{
OrganizationId = organization.Id,
DomainName = domainName,
Txt = "btw+12345",
};
organizationDomain.SetVerifiedDate();
organizationDomain.SetNextRunDate(12);
organizationDomain.SetJobRunCount();
await organizationDomainRepository.CreateAsync(organizationDomain);
await organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organization.Id,
UserId = user1.Id,
Status = OrganizationUserStatusType.Confirmed,
ResetPasswordKey = "resetpasswordkey1",
});
await organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organization.Id,
UserId = user2.Id,
Status = OrganizationUserStatusType.Confirmed,
ResetPasswordKey = "resetpasswordkey1",
});
await organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organization.Id,
UserId = user3.Id,
Status = OrganizationUserStatusType.Confirmed,
ResetPasswordKey = "resetpasswordkey1",
});
var user1Response = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user1.Id);
var user2Response = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user2.Id);
var user3Response = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user3.Id);
Assert.NotEmpty(user1Response);
Assert.Equal(organization.Id, user1Response.First().Id);
Assert.Empty(user2Response);
Assert.Empty(user3Response);
}
[DatabaseTheory, DatabaseData]
public async Task GetByVerifiedUserEmailDomainAsync_WithUnverifiedDomains_ReturnsEmpty(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationDomainRepository organizationDomainRepository)
{
var id = Guid.NewGuid();
var domainName = $"{id}.example.com";
var user = await userRepository.CreateAsync(new User
{
Name = "Test User",
Email = $"test+{id}@{domainName}",
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = $"Test Org {id}",
BillingEmail = user.Email,
Plan = "Test",
PrivateKey = "privatekey",
});
var organizationDomain = new OrganizationDomain
{
OrganizationId = organization.Id,
DomainName = domainName,
Txt = "btw+12345",
};
organizationDomain.SetNextRunDate(12);
organizationDomain.SetJobRunCount();
await organizationDomainRepository.CreateAsync(organizationDomain);
await organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organization.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Confirmed,
ResetPasswordKey = "resetpasswordkey",
});
var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user.Id);
Assert.Empty(result);
}
[DatabaseTheory, DatabaseData]
public async Task GetByVerifiedUserEmailDomainAsync_WithMultipleVerifiedDomains_ReturnsAllMatchingOrganizations(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationDomainRepository organizationDomainRepository)
{
var id = Guid.NewGuid();
var domainName = $"{id}.example.com";
var user = await userRepository.CreateAsync(new User
{
Name = "Test User",
Email = $"test+{id}@{domainName}",
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var organization1 = await organizationRepository.CreateAsync(new Organization
{
Name = $"Test Org 1 {id}",
BillingEmail = user.Email,
Plan = "Test",
PrivateKey = "privatekey1",
});
var organization2 = await organizationRepository.CreateAsync(new Organization
{
Name = $"Test Org 2 {id}",
BillingEmail = user.Email,
Plan = "Test",
PrivateKey = "privatekey2",
});
var organizationDomain1 = new OrganizationDomain
{
OrganizationId = organization1.Id,
DomainName = domainName,
Txt = "btw+12345",
};
organizationDomain1.SetNextRunDate(12);
organizationDomain1.SetJobRunCount();
organizationDomain1.SetVerifiedDate();
await organizationDomainRepository.CreateAsync(organizationDomain1);
var organizationDomain2 = new OrganizationDomain
{
OrganizationId = organization2.Id,
DomainName = domainName,
Txt = "btw+67890",
};
organizationDomain2.SetNextRunDate(12);
organizationDomain2.SetJobRunCount();
organizationDomain2.SetVerifiedDate();
await organizationDomainRepository.CreateAsync(organizationDomain2);
await organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organization1.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Confirmed,
ResetPasswordKey = "resetpasswordkey1",
});
await organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organization2.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Confirmed,
ResetPasswordKey = "resetpasswordkey2",
});
var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user.Id);
Assert.Equal(2, result.Count);
Assert.Contains(result, org => org.Id == organization1.Id);
Assert.Contains(result, org => org.Id == organization2.Id);
}
[DatabaseTheory, DatabaseData]
public async Task GetByVerifiedUserEmailDomainAsync_WithNonExistentUser_ReturnsEmpty(
IOrganizationRepository organizationRepository)
{
var nonExistentUserId = Guid.NewGuid();
var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(nonExistentUserId);
Assert.Empty(result);
}
[DatabaseTheory, DatabaseData]
[Theory, DatabaseData]
public async Task GetManyByIdsAsync_ExistingOrganizations_ReturnsOrganizations(IOrganizationRepository organizationRepository)
{
var email = "test@email.com";
@@ -287,7 +40,7 @@ public class OrganizationRepositoryTests
await organizationRepository.DeleteAsync(organization2);
}
[DatabaseTheory, DatabaseData]
[Theory, DatabaseData]
public async Task GetOccupiedSeatCountByOrganizationIdAsync_WithUsersAndSponsorships_ReturnsCorrectCounts(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
@@ -356,7 +109,7 @@ public class OrganizationRepositoryTests
Assert.Equal(4, result.Total); // Total occupied seats
}
[DatabaseTheory, DatabaseData]
[Theory, DatabaseData]
public async Task GetOccupiedSeatCountByOrganizationIdAsync_WithNoUsersOrSponsorships_ReturnsZero(
IOrganizationRepository organizationRepository)
{
@@ -372,7 +125,7 @@ public class OrganizationRepositoryTests
Assert.Equal(0, result.Total);
}
[DatabaseTheory, DatabaseData]
[Theory, DatabaseData]
public async Task GetOccupiedSeatCountByOrganizationIdAsync_WithOnlyRevokedUsers_ReturnsZero(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
@@ -399,7 +152,7 @@ public class OrganizationRepositoryTests
Assert.Equal(0, result.Total);
}
[DatabaseTheory, DatabaseData]
[Theory, DatabaseData]
public async Task GetOccupiedSeatCountByOrganizationIdAsync_WithOnlyExpiredSponsorships_ReturnsZero(
IOrganizationRepository organizationRepository,
IOrganizationSponsorshipRepository organizationSponsorshipRepository)
@@ -424,7 +177,7 @@ public class OrganizationRepositoryTests
Assert.Equal(0, result.Total);
}
[DatabaseTheory, DatabaseData]
[Theory, DatabaseData]
public async Task IncrementSeatCountAsync_IncrementsSeatCount(IOrganizationRepository organizationRepository)
{
var organization = await organizationRepository.CreateTestOrganizationAsync();
@@ -438,7 +191,7 @@ public class OrganizationRepositoryTests
Assert.Equal(8, result.Seats);
}
[DatabaseData, DatabaseTheory]
[DatabaseData, Theory]
public async Task IncrementSeatCountAsync_GivenOrganizationHasNotChangedSeatCountBefore_WhenUpdatingOrgSeats_ThenSubscriptionUpdateIsSaved(
IOrganizationRepository sutRepository)
{
@@ -462,7 +215,7 @@ public class OrganizationRepositoryTests
await sutRepository.DeleteAsync(organization);
}
[DatabaseData, DatabaseTheory]
[DatabaseData, Theory]
public async Task IncrementSeatCountAsync_GivenOrganizationHasChangedSeatCountBeforeAndRecordExists_WhenUpdatingOrgSeats_ThenSubscriptionUpdateIsSaved(
IOrganizationRepository sutRepository)
{
@@ -487,7 +240,7 @@ public class OrganizationRepositoryTests
await sutRepository.DeleteAsync(organization);
}
[DatabaseData, DatabaseTheory]
[DatabaseData, Theory]
public async Task GetOrganizationsForSubscriptionSyncAsync_GivenOrganizationHasChangedSeatCount_WhenGettingOrgsToUpdate_ThenReturnsOrgSubscriptionUpdate(
IOrganizationRepository sutRepository)
{
@@ -510,7 +263,7 @@ public class OrganizationRepositoryTests
await sutRepository.DeleteAsync(organization);
}
[DatabaseData, DatabaseTheory]
[DatabaseData, Theory]
public async Task UpdateSuccessfulOrganizationSyncStatusAsync_GivenOrganizationHasChangedSeatCount_WhenUpdatingStatus_ThenSuccessfullyUpdatesOrgSoItDoesntSync(
IOrganizationRepository sutRepository)
{

View File

@@ -0,0 +1,197 @@
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Xunit;
namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.OrganizationUserRepository;
public class GetManyByOrganizationWithClaimedDomainsAsyncTests
{
[Theory, DatabaseData]
public async Task WithVerifiedDomain_WithOneMatchingEmailDomain_ReturnsSingle(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationDomainRepository organizationDomainRepository)
{
var id = Guid.NewGuid();
var domainName = $"{id}.example.com";
var user1 = await userRepository.CreateAsync(new User
{
Name = "Test User 1",
Email = $"test+{id}@{domainName}",
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var user2 = await userRepository.CreateAsync(new User
{
Name = "Test User 2",
Email = $"test+{id}@x-{domainName}", // Different domain
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var user3 = await userRepository.CreateAsync(new User
{
Name = "Test User 3",
Email = $"test+{id}@{domainName}.example.com", // Different domain
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var organization = await organizationRepository.CreateTestOrganizationAsync();
var organizationDomain = new OrganizationDomain
{
OrganizationId = organization.Id,
DomainName = domainName,
Txt = "btw+12345",
};
organizationDomain.SetVerifiedDate();
organizationDomain.SetNextRunDate(12);
organizationDomain.SetJobRunCount();
await organizationDomainRepository.CreateAsync(organizationDomain);
var orgUser1 = await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user1);
await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user2);
await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user3);
var result = await organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organization.Id);
Assert.NotNull(result);
Assert.Single(result);
Assert.Equal(orgUser1.Id, result.Single().Id);
}
[Theory, DatabaseData]
public async Task WithNoVerifiedDomain_ReturnsEmpty(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationDomainRepository organizationDomainRepository)
{
var id = Guid.NewGuid();
var domainName = $"{id}.example.com";
var user = await userRepository.CreateAsync(new User
{
Name = "Test User",
Email = $"test+{id}@{domainName}",
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var organization = await organizationRepository.CreateTestOrganizationAsync();
// Create domain but do NOT verify it
var organizationDomain = new OrganizationDomain
{
OrganizationId = organization.Id,
DomainName = domainName,
Txt = "btw+12345",
};
organizationDomain.SetNextRunDate(12);
// Note: NOT calling SetVerifiedDate()
await organizationDomainRepository.CreateAsync(organizationDomain);
await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user);
var result = await organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organization.Id);
Assert.NotNull(result);
Assert.Empty(result);
}
/// <summary>
/// Tests an edge case where some invited users are created linked to a UserId.
/// This is defective behavior, but will take longer to fix - for now, we are defensive and expressly
/// exclude such users from the results without relying on the inner join only.
/// Invited-revoked users linked to a UserId remain intentionally unhandled for now as they have not caused
/// any issues to date and we want to minimize edge cases.
/// We will fix the underlying issue going forward: https://bitwarden.atlassian.net/browse/PM-22405
/// </summary>
[Theory, DatabaseData]
public async Task WithVerifiedDomain_ExcludesInvitedUsers(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationDomainRepository organizationDomainRepository)
{
var id = Guid.NewGuid();
var domainName = $"{id}.example.com";
var invitedUser = await userRepository.CreateAsync(new User
{
Name = "Invited User",
Email = $"invited+{id}@{domainName}",
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var confirmedUser = await userRepository.CreateAsync(new User
{
Name = "Confirmed User",
Email = $"confirmed+{id}@{domainName}",
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var organization = await organizationRepository.CreateTestOrganizationAsync();
var organizationDomain = new OrganizationDomain
{
OrganizationId = organization.Id,
DomainName = domainName,
Txt = "btw+12345",
};
organizationDomain.SetVerifiedDate();
organizationDomain.SetNextRunDate(12);
organizationDomain.SetJobRunCount();
await organizationDomainRepository.CreateAsync(organizationDomain);
// Create invited user with UserId set (edge case - should be excluded even with UserId linked)
var invitedOrgUser = await organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organization.Id,
UserId = invitedUser.Id, // Edge case: invited user with UserId set
Email = invitedUser.Email,
Status = OrganizationUserStatusType.Invited,
Type = OrganizationUserType.User
});
// Create confirmed user linked by UserId only (no Email field set)
var confirmedOrgUser = await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, confirmedUser);
var result = await organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organization.Id);
Assert.NotNull(result);
var claimedUser = Assert.Single(result);
Assert.Equal(confirmedOrgUser.Id, claimedUser.Id);
}
}

View File

@@ -599,136 +599,6 @@ public class OrganizationUserRepositoryTests
Assert.Null(orgWithoutSsoDetails.SsoConfig);
}
[DatabaseTheory, DatabaseData]
public async Task GetManyByOrganizationWithClaimedDomainsAsync_WithVerifiedDomain_WithOneMatchingEmailDomain_ReturnsSingle(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationDomainRepository organizationDomainRepository)
{
var id = Guid.NewGuid();
var domainName = $"{id}.example.com";
var user1 = await userRepository.CreateAsync(new User
{
Name = "Test User 1",
Email = $"test+{id}@{domainName}",
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var user2 = await userRepository.CreateAsync(new User
{
Name = "Test User 2",
Email = $"test+{id}@x-{domainName}", // Different domain
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var user3 = await userRepository.CreateAsync(new User
{
Name = "Test User 2",
Email = $"test+{id}@{domainName}.example.com", // Different domain
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = $"Test Org {id}",
BillingEmail = user1.Email, // TODO: EF does not enforce this being NOT NULL
Plan = "Test", // TODO: EF does not enforce this being NOT NULL
PrivateKey = "privatekey",
UsePolicies = false,
UseSso = false,
UseKeyConnector = false,
UseScim = false,
UseGroups = false,
UseDirectory = false,
UseEvents = false,
UseTotp = false,
Use2fa = false,
UseApi = false,
UseResetPassword = false,
UseSecretsManager = false,
SelfHost = false,
UsersGetPremium = false,
UseCustomPermissions = false,
Enabled = true,
UsePasswordManager = false,
LimitCollectionCreation = false,
LimitCollectionDeletion = false,
LimitItemDeletion = false,
AllowAdminAccessToAllCollectionItems = false,
UseRiskInsights = false,
UseAdminSponsoredFamilies = false,
UsePhishingBlocker = false,
UseDisableSmAdsForUsers = false,
});
var organizationDomain = new OrganizationDomain
{
OrganizationId = organization.Id,
DomainName = domainName,
Txt = "btw+12345",
};
organizationDomain.SetVerifiedDate();
organizationDomain.SetNextRunDate(12);
organizationDomain.SetJobRunCount();
await organizationDomainRepository.CreateAsync(organizationDomain);
var orgUser1 = await organizationUserRepository.CreateAsync(new OrganizationUser
{
Id = CoreHelpers.GenerateComb(),
OrganizationId = organization.Id,
UserId = user1.Id,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.Owner,
ResetPasswordKey = "resetpasswordkey1",
AccessSecretsManager = false
});
await organizationUserRepository.CreateAsync(new OrganizationUser
{
Id = CoreHelpers.GenerateComb(),
OrganizationId = organization.Id,
UserId = user2.Id,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.User,
ResetPasswordKey = "resetpasswordkey1",
AccessSecretsManager = false
});
await organizationUserRepository.CreateAsync(new OrganizationUser
{
Id = CoreHelpers.GenerateComb(),
OrganizationId = organization.Id,
UserId = user3.Id,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.User,
ResetPasswordKey = "resetpasswordkey1",
AccessSecretsManager = false
});
var responseModel = await organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organization.Id);
Assert.NotNull(responseModel);
Assert.Single(responseModel);
Assert.Equal(orgUser1.Id, responseModel.Single().Id);
}
[DatabaseTheory, DatabaseData]
public async Task CreateManyAsync_NoId_Works(IOrganizationRepository organizationRepository,
IUserRepository userRepository,
@@ -1237,70 +1107,6 @@ public class OrganizationUserRepositoryTests
Assert.DoesNotContain(user1Result.Collections, c => c.Id == defaultUserCollection.Id);
}
[DatabaseTheory, DatabaseData]
public async Task GetManyByOrganizationWithClaimedDomainsAsync_WithNoVerifiedDomain_ReturnsEmpty(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationDomainRepository organizationDomainRepository)
{
var id = Guid.NewGuid();
var domainName = $"{id}.example.com";
var requestTime = DateTime.UtcNow;
var user1 = await userRepository.CreateAsync(new User
{
Id = CoreHelpers.GenerateComb(),
Name = "Test User 1",
Email = $"test+{id}@{domainName}",
ApiKey = "TEST",
SecurityStamp = "stamp",
CreationDate = requestTime,
RevisionDate = requestTime,
AccountRevisionDate = requestTime
});
var organization = await organizationRepository.CreateAsync(new Organization
{
Id = CoreHelpers.GenerateComb(),
Name = $"Test Org {id}",
BillingEmail = user1.Email,
Plan = "Test",
Enabled = true,
CreationDate = requestTime,
RevisionDate = requestTime
});
// Create domain but do NOT verify it
var organizationDomain = new OrganizationDomain
{
Id = CoreHelpers.GenerateComb(),
OrganizationId = organization.Id,
DomainName = domainName,
Txt = "btw+12345",
CreationDate = requestTime
};
organizationDomain.SetNextRunDate(12);
// Note: NOT calling SetVerifiedDate()
await organizationDomainRepository.CreateAsync(organizationDomain);
await organizationUserRepository.CreateAsync(new OrganizationUser
{
Id = CoreHelpers.GenerateComb(),
OrganizationId = organization.Id,
UserId = user1.Id,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.Owner,
CreationDate = requestTime,
RevisionDate = requestTime
});
var responseModel = await organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organization.Id);
Assert.NotNull(responseModel);
Assert.Empty(responseModel);
}
[DatabaseTheory, DatabaseData]
public async Task DeleteAsync_WithNullEmail_DoesNotSetDefaultUserCollectionEmail(IUserRepository userRepository,
ICollectionRepository collectionRepository,