1
0
mirror of https://github.com/bitwarden/server synced 2026-01-14 22:43:19 +00:00

Merge branch 'main' into tools/pm-21918/send-authentication-commands

This commit is contained in:
✨ Audrey ✨
2025-08-25 14:55:45 -04:00
176 changed files with 14667 additions and 1455 deletions

View File

@@ -28,8 +28,6 @@ public class ImportOrganizationUsersAndGroupsCommandTests : IClassFixture<ApiApp
_factory.SubstituteService((IFeatureService featureService)
=>
{
featureService.IsEnabled(FeatureFlagKeys.ImportAsyncRefactor)
.Returns(true);
featureService.IsEnabled(FeatureFlagKeys.DirectoryConnectorPreventUserRemoval)
.Returns(true);
});

View File

@@ -2,7 +2,6 @@
using Bit.Api.Billing.Models.Requests;
using Bit.Api.Billing.Models.Responses;
using Bit.Commercial.Core.Billing.Providers.Services;
using Bit.Core;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Repositories;
@@ -346,9 +345,6 @@ public class ProviderBillingControllerTests
}
};
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.PM21383_GetProviderPriceFromStripe)
.Returns(true);
sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id).Returns(providerPlans);
foreach (var providerPlan in providerPlans)

View File

@@ -22,7 +22,7 @@ namespace Bit.Api.Test.Controllers;
public class CollectionsControllerTests
{
[Theory, BitAutoData]
public async Task Post_Success(Organization organization, CollectionRequestModel collectionRequest,
public async Task Post_Success(Organization organization, CreateCollectionRequestModel collectionRequest,
SutProvider<CollectionsController> sutProvider)
{
Collection ExpectedCollection() => Arg.Is<Collection>(c =>
@@ -46,9 +46,10 @@ public class CollectionsControllerTests
}
[Theory, BitAutoData]
public async Task Put_Success(Collection collection, CollectionRequestModel collectionRequest,
public async Task Put_Success(Collection collection, UpdateCollectionRequestModel collectionRequest,
SutProvider<CollectionsController> sutProvider)
{
collection.DefaultUserCollectionEmail = null;
Collection ExpectedCollection() => Arg.Is<Collection>(c => c.Id == collection.Id &&
c.Name == collectionRequest.Name && c.ExternalId == collectionRequest.ExternalId &&
c.OrganizationId == collection.OrganizationId);
@@ -72,7 +73,7 @@ public class CollectionsControllerTests
}
[Theory, BitAutoData]
public async Task Put_WithNoCollectionPermission_ThrowsNotFound(Collection collection, CollectionRequestModel collectionRequest,
public async Task Put_WithNoCollectionPermission_ThrowsNotFound(Collection collection, UpdateCollectionRequestModel collectionRequest,
SutProvider<CollectionsController> sutProvider)
{
sutProvider.GetDependency<IAuthorizationService>()
@@ -484,4 +485,176 @@ public class CollectionsControllerTests
await sutProvider.GetDependency<IBulkAddCollectionAccessCommand>().DidNotReceiveWithAnyArgs()
.AddAccessAsync(default, default, default);
}
[Theory, BitAutoData]
public async Task Put_With_NonNullName_DoesNotPreserveExistingName(Collection existingCollection, UpdateCollectionRequestModel collectionRequest,
SutProvider<CollectionsController> sutProvider)
{
// Arrange
var newName = "new name";
var originalName = "original name";
existingCollection.Name = originalName;
existingCollection.DefaultUserCollectionEmail = null;
collectionRequest.Name = newName;
sutProvider.GetDependency<ICollectionRepository>()
.GetByIdAsync(existingCollection.Id)
.Returns(existingCollection);
sutProvider.GetDependency<IAuthorizationService>()
.AuthorizeAsync(Arg.Any<ClaimsPrincipal>(),
existingCollection,
Arg.Is<IEnumerable<IAuthorizationRequirement>>(r => r.Contains(BulkCollectionOperations.Update)))
.Returns(AuthorizationResult.Success());
// Act
await sutProvider.Sut.Put(existingCollection.OrganizationId, existingCollection.Id, collectionRequest);
// Assert
await sutProvider.GetDependency<IUpdateCollectionCommand>()
.Received(1)
.UpdateAsync(
Arg.Is<Collection>(c => c.Id == existingCollection.Id && c.Name == newName),
Arg.Any<IEnumerable<CollectionAccessSelection>>(),
Arg.Any<IEnumerable<CollectionAccessSelection>>());
}
[Theory, BitAutoData]
public async Task Put_WithNullName_DoesPreserveExistingName(Collection existingCollection, UpdateCollectionRequestModel collectionRequest,
SutProvider<CollectionsController> sutProvider)
{
// Arrange
var originalName = "original name";
existingCollection.Name = originalName;
existingCollection.DefaultUserCollectionEmail = null;
collectionRequest.Name = null;
sutProvider.GetDependency<ICollectionRepository>()
.GetByIdAsync(existingCollection.Id)
.Returns(existingCollection);
sutProvider.GetDependency<IAuthorizationService>()
.AuthorizeAsync(Arg.Any<ClaimsPrincipal>(),
existingCollection,
Arg.Is<IEnumerable<IAuthorizationRequirement>>(r => r.Contains(BulkCollectionOperations.Update)))
.Returns(AuthorizationResult.Success());
// Act
await sutProvider.Sut.Put(existingCollection.OrganizationId, existingCollection.Id, collectionRequest);
// Assert
await sutProvider.GetDependency<IUpdateCollectionCommand>()
.Received(1)
.UpdateAsync(
Arg.Is<Collection>(c => c.Id == existingCollection.Id && c.Name == originalName),
Arg.Any<IEnumerable<CollectionAccessSelection>>(),
Arg.Any<IEnumerable<CollectionAccessSelection>>());
}
[Theory, BitAutoData]
public async Task Put_WithDefaultUserCollectionEmail_DoesPreserveExistingName(Collection existingCollection, UpdateCollectionRequestModel collectionRequest,
SutProvider<CollectionsController> sutProvider)
{
// Arrange
var originalName = "original name";
var defaultUserCollectionEmail = "user@email.com";
existingCollection.Name = originalName;
existingCollection.DefaultUserCollectionEmail = defaultUserCollectionEmail;
collectionRequest.Name = "new name";
sutProvider.GetDependency<ICollectionRepository>()
.GetByIdAsync(existingCollection.Id)
.Returns(existingCollection);
sutProvider.GetDependency<IAuthorizationService>()
.AuthorizeAsync(Arg.Any<ClaimsPrincipal>(),
existingCollection,
Arg.Is<IEnumerable<IAuthorizationRequirement>>(r => r.Contains(BulkCollectionOperations.Update)))
.Returns(AuthorizationResult.Success());
// Act
await sutProvider.Sut.Put(existingCollection.OrganizationId, existingCollection.Id, collectionRequest);
// Assert
await sutProvider.GetDependency<IUpdateCollectionCommand>()
.Received(1)
.UpdateAsync(
Arg.Is<Collection>(c => c.Id == existingCollection.Id && c.Name == originalName && c.DefaultUserCollectionEmail == defaultUserCollectionEmail),
Arg.Any<IEnumerable<CollectionAccessSelection>>(),
Arg.Any<IEnumerable<CollectionAccessSelection>>());
}
[Theory, BitAutoData]
public async Task Put_WithEmptyName_DoesPreserveExistingName(Collection existingCollection, UpdateCollectionRequestModel collectionRequest,
SutProvider<CollectionsController> sutProvider)
{
// Arrange
var originalName = "original name";
existingCollection.Name = originalName;
existingCollection.DefaultUserCollectionEmail = null;
collectionRequest.Name = ""; // Empty string
sutProvider.GetDependency<ICollectionRepository>()
.GetByIdAsync(existingCollection.Id)
.Returns(existingCollection);
sutProvider.GetDependency<IAuthorizationService>()
.AuthorizeAsync(Arg.Any<ClaimsPrincipal>(),
existingCollection,
Arg.Is<IEnumerable<IAuthorizationRequirement>>(r => r.Contains(BulkCollectionOperations.Update)))
.Returns(AuthorizationResult.Success());
// Act
await sutProvider.Sut.Put(existingCollection.OrganizationId, existingCollection.Id, collectionRequest);
// Assert
await sutProvider.GetDependency<IUpdateCollectionCommand>()
.Received(1)
.UpdateAsync(
Arg.Is<Collection>(c => c.Id == existingCollection.Id && c.Name == originalName),
Arg.Any<IEnumerable<CollectionAccessSelection>>(),
Arg.Any<IEnumerable<CollectionAccessSelection>>());
}
[Theory, BitAutoData]
public async Task Put_WithWhitespaceOnlyName_DoesPreserveExistingName(Collection existingCollection, UpdateCollectionRequestModel collectionRequest,
SutProvider<CollectionsController> sutProvider)
{
// Arrange
var originalName = "original name";
existingCollection.Name = originalName;
existingCollection.DefaultUserCollectionEmail = null;
collectionRequest.Name = " "; // Whitespace only
sutProvider.GetDependency<ICollectionRepository>()
.GetByIdAsync(existingCollection.Id)
.Returns(existingCollection);
sutProvider.GetDependency<IAuthorizationService>()
.AuthorizeAsync(Arg.Any<ClaimsPrincipal>(),
existingCollection,
Arg.Is<IEnumerable<IAuthorizationRequirement>>(r => r.Contains(BulkCollectionOperations.Update)))
.Returns(AuthorizationResult.Success());
// Act
await sutProvider.Sut.Put(existingCollection.OrganizationId, existingCollection.Id, collectionRequest);
// Assert
await sutProvider.GetDependency<IUpdateCollectionCommand>()
.Received(1)
.UpdateAsync(
Arg.Is<Collection>(c => c.Id == existingCollection.Id && c.Name == originalName),
Arg.Any<IEnumerable<CollectionAccessSelection>>(),
Arg.Any<IEnumerable<CollectionAccessSelection>>());
}
}

View File

@@ -317,7 +317,7 @@ public class ProjectsControllerTests
[Theory]
[BitAutoData]
public async Task BulkDeleteProjects_ReturnsAccessDeniedForProjectsWithoutAccess_Success(
SutProvider<ProjectsController> sutProvider, List<Project> data)
SutProvider<ProjectsController> sutProvider, Guid userId, List<Project> data)
{
var ids = data.Select(project => project.Id).ToList();
@@ -333,6 +333,7 @@ public class ProjectsControllerTests
.AuthorizeAsync(Arg.Any<ClaimsPrincipal>(), data.First(),
Arg.Any<IEnumerable<IAuthorizationRequirement>>()).Returns(AuthorizationResult.Failed());
sutProvider.GetDependency<IUserService>().GetProperUserId(default).ReturnsForAnyArgs(userId);
sutProvider.GetDependency<ICurrentContext>().AccessSecretsManager(Arg.Is(organizationId)).ReturnsForAnyArgs(true);
sutProvider.GetDependency<IProjectRepository>().GetManyWithSecretsByIds(Arg.Is(ids)).ReturnsForAnyArgs(data);
var results = await sutProvider.Sut.BulkDeleteAsync(ids);
@@ -346,7 +347,7 @@ public class ProjectsControllerTests
[Theory]
[BitAutoData]
public async Task BulkDeleteProjects_Success(SutProvider<ProjectsController> sutProvider, List<Project> data)
public async Task BulkDeleteProjects_Success(SutProvider<ProjectsController> sutProvider, Guid userId, List<Project> data)
{
var ids = data.Select(project => project.Id).ToList();
var organizationId = data.First().OrganizationId;
@@ -357,7 +358,7 @@ public class ProjectsControllerTests
.AuthorizeAsync(Arg.Any<ClaimsPrincipal>(), project,
Arg.Any<IEnumerable<IAuthorizationRequirement>>()).ReturnsForAnyArgs(AuthorizationResult.Success());
}
sutProvider.GetDependency<IUserService>().GetProperUserId(default).ReturnsForAnyArgs(userId);
sutProvider.GetDependency<IProjectRepository>().GetManyWithSecretsByIds(Arg.Is(ids)).ReturnsForAnyArgs(data);
sutProvider.GetDependency<ICurrentContext>().AccessSecretsManager(Arg.Is(organizationId)).ReturnsForAnyArgs(true);

View File

@@ -27,7 +27,9 @@ using Bit.Test.Common.AutoFixture.Attributes;
using Bit.Test.Common.Fakes;
using NSubstitute;
using NSubstitute.ExceptionExtensions;
using NSubstitute.ReceivedExtensions;
using NSubstitute.ReturnsExtensions;
using Stripe;
using Xunit;
using Organization = Bit.Core.AdminConsole.Entities.Organization;
using OrganizationUser = Bit.Core.Entities.OrganizationUser;
@@ -40,139 +42,7 @@ public class OrganizationServiceTests
{
private readonly IDataProtectorTokenFactory<OrgUserInviteTokenable> _orgUserInviteTokenDataFactory = new FakeDataProtectorTokenFactory<OrgUserInviteTokenable>();
[Theory, PaidOrganizationCustomize, BitAutoData]
public async Task OrgImportCreateNewUsers(SutProvider<OrganizationService> sutProvider, Organization org, List<OrganizationUserUserDetails> existingUsers, List<ImportedOrganizationUser> newUsers)
{
// Setup FakeDataProtectorTokenFactory for creating new tokens - this must come first in order to avoid resetting mocks
sutProvider.SetDependency(_orgUserInviteTokenDataFactory, "orgUserInviteTokenDataFactory");
sutProvider.Create();
org.UseDirectory = true;
org.Seats = 10;
newUsers.Add(new ImportedOrganizationUser
{
Email = existingUsers.First().Email,
ExternalId = existingUsers.First().ExternalId
});
var expectedNewUsersCount = newUsers.Count - 1;
existingUsers.First().Type = OrganizationUserType.Owner;
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(org.Id).Returns(org);
sutProvider.GetDependency<IOrganizationRepository>()
.GetOccupiedSeatCountByOrganizationIdAsync(org.Id).Returns(new OrganizationSeatCounts
{
Sponsored = 0,
Users = 1
});
var organizationUserRepository = sutProvider.GetDependency<IOrganizationUserRepository>();
SetupOrgUserRepositoryCreateManyAsyncMock(organizationUserRepository);
organizationUserRepository.GetManyDetailsByOrganizationAsync(org.Id)
.Returns(existingUsers);
organizationUserRepository.GetCountByOrganizationIdAsync(org.Id)
.Returns(existingUsers.Count);
sutProvider.GetDependency<IHasConfirmedOwnersExceptQuery>()
.HasConfirmedOwnersExceptAsync(org.Id, Arg.Any<IEnumerable<Guid>>())
.Returns(true);
sutProvider.GetDependency<ICurrentContext>().ManageUsers(org.Id).Returns(true);
await sutProvider.Sut.ImportAsync(org.Id, null, newUsers, null, false, EventSystemUser.PublicApi);
await sutProvider.GetDependency<IOrganizationUserRepository>().DidNotReceiveWithAnyArgs()
.UpsertAsync(default);
await sutProvider.GetDependency<IOrganizationUserRepository>().Received(1)
.UpsertManyAsync(Arg.Is<IEnumerable<OrganizationUser>>(users => !users.Any()));
await sutProvider.GetDependency<IOrganizationUserRepository>().DidNotReceiveWithAnyArgs()
.CreateAsync(default);
// Create new users
await sutProvider.GetDependency<IOrganizationUserRepository>().Received(1)
.CreateManyAsync(Arg.Is<IEnumerable<OrganizationUser>>(users => users.Count() == expectedNewUsersCount));
await sutProvider.GetDependency<ISendOrganizationInvitesCommand>().Received(1)
.SendInvitesAsync(
Arg.Is<SendInvitesRequest>(
info => info.Users.Length == expectedNewUsersCount &&
info.Organization == org));
// Send events
await sutProvider.GetDependency<IEventService>().Received(1)
.LogOrganizationUserEventsAsync(Arg.Is<IEnumerable<(OrganizationUser, EventType, EventSystemUser, DateTime?)>>(events =>
events.Count() == expectedNewUsersCount));
}
[Theory, PaidOrganizationCustomize, BitAutoData]
public async Task OrgImportCreateNewUsersAndMarryExistingUser(SutProvider<OrganizationService> sutProvider, Organization org, List<OrganizationUserUserDetails> existingUsers,
List<ImportedOrganizationUser> newUsers)
{
// Setup FakeDataProtectorTokenFactory for creating new tokens - this must come first in order to avoid resetting mocks
sutProvider.SetDependency(_orgUserInviteTokenDataFactory, "orgUserInviteTokenDataFactory");
sutProvider.Create();
org.UseDirectory = true;
org.Seats = newUsers.Count + existingUsers.Count + 1;
var reInvitedUser = existingUsers.First();
reInvitedUser.ExternalId = null;
newUsers.Add(new ImportedOrganizationUser
{
Email = reInvitedUser.Email,
ExternalId = reInvitedUser.Email,
});
var expectedNewUsersCount = newUsers.Count - 1;
sutProvider.GetDependency<IOrganizationRepository>()
.GetOccupiedSeatCountByOrganizationIdAsync(org.Id).Returns(new OrganizationSeatCounts
{
Sponsored = 0,
Users = 1
});
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(org.Id).Returns(org);
sutProvider.GetDependency<IOrganizationUserRepository>().GetManyDetailsByOrganizationAsync(org.Id)
.Returns(existingUsers);
sutProvider.GetDependency<IOrganizationUserRepository>().GetCountByOrganizationIdAsync(org.Id)
.Returns(existingUsers.Count);
sutProvider.GetDependency<IOrganizationUserRepository>().GetByIdAsync(reInvitedUser.Id)
.Returns(new OrganizationUser { Id = reInvitedUser.Id });
var organizationUserRepository = sutProvider.GetDependency<IOrganizationUserRepository>();
sutProvider.GetDependency<IHasConfirmedOwnersExceptQuery>()
.HasConfirmedOwnersExceptAsync(org.Id, Arg.Any<IEnumerable<Guid>>())
.Returns(true);
SetupOrgUserRepositoryCreateManyAsyncMock(organizationUserRepository);
var currentContext = sutProvider.GetDependency<ICurrentContext>();
currentContext.ManageUsers(org.Id).Returns(true);
await sutProvider.Sut.ImportAsync(org.Id, null, newUsers, null, false, EventSystemUser.PublicApi);
await sutProvider.GetDependency<IOrganizationUserRepository>().DidNotReceiveWithAnyArgs()
.UpsertAsync(default);
await sutProvider.GetDependency<IOrganizationUserRepository>().DidNotReceiveWithAnyArgs()
.CreateAsync(default);
await sutProvider.GetDependency<IOrganizationUserRepository>().DidNotReceiveWithAnyArgs()
.CreateAsync(default, default);
// Upserted existing user
await sutProvider.GetDependency<IOrganizationUserRepository>().Received(1)
.UpsertManyAsync(Arg.Is<IEnumerable<OrganizationUser>>(users => users.Count() == 1));
// Created and invited new users
await sutProvider.GetDependency<IOrganizationUserRepository>().Received(1)
.CreateManyAsync(Arg.Is<IEnumerable<OrganizationUser>>(users => users.Count() == expectedNewUsersCount));
await sutProvider.GetDependency<ISendOrganizationInvitesCommand>().Received(1)
.SendInvitesAsync(Arg.Is<SendInvitesRequest>(request =>
request.Users.Length == expectedNewUsersCount &&
request.Organization == org));
// Sent events
await sutProvider.GetDependency<IEventService>().Received(1)
.LogOrganizationUserEventsAsync(Arg.Is<IEnumerable<(OrganizationUser, EventType, EventSystemUser, DateTime?)>>(events =>
events.Count(e => e.Item2 == EventType.OrganizationUser_Invited) == expectedNewUsersCount));
}
[Theory]
[OrganizationInviteCustomize(InviteeUserType = OrganizationUserType.User,
@@ -1235,6 +1105,130 @@ public class OrganizationServiceTests
await sutProvider.Sut.ValidateOrganizationCustomPermissionsEnabledAsync(organization.Id, OrganizationUserType.Custom);
}
[Theory, BitAutoData]
public async Task UpdateAsync_WhenValidOrganization_AndUpdateBillingIsTrue_UpdateStripeCustomerAndOrganization(Organization organization, SutProvider<OrganizationService> sutProvider)
{
// Arrange
var organizationRepository = sutProvider.GetDependency<IOrganizationRepository>();
var applicationCacheService = sutProvider.GetDependency<IApplicationCacheService>();
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var eventService = sutProvider.GetDependency<IEventService>();
var requestOptionsReturned = new CustomerUpdateOptions
{
Email = organization.BillingEmail,
Description = organization.DisplayBusinessName(),
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
// This overwrites the existing custom fields for this organization
CustomFields =
[
new CustomerInvoiceSettingsCustomFieldOptions
{
Name = organization.SubscriberType(),
Value = organization.DisplayName()[..30]
}
]
},
};
organizationRepository
.GetByIdentifierAsync(organization.Identifier!)
.Returns(organization);
// Act
await sutProvider.Sut.UpdateAsync(organization, updateBilling: true);
// Assert
await organizationRepository
.Received(1)
.GetByIdentifierAsync(Arg.Is<string>(id => id == organization.Identifier));
await stripeAdapter
.Received(1)
.CustomerUpdateAsync(
Arg.Is<string>(id => id == organization.GatewayCustomerId),
Arg.Is<CustomerUpdateOptions>(options => options.Email == requestOptionsReturned.Email
&& options.Description == requestOptionsReturned.Description
&& options.InvoiceSettings.CustomFields.First().Name == requestOptionsReturned.InvoiceSettings.CustomFields.First().Name
&& options.InvoiceSettings.CustomFields.First().Value == requestOptionsReturned.InvoiceSettings.CustomFields.First().Value)); ;
await organizationRepository
.Received(1)
.ReplaceAsync(Arg.Is<Organization>(org => org == organization));
await applicationCacheService
.Received(1)
.UpsertOrganizationAbilityAsync(Arg.Is<Organization>(org => org == organization));
await eventService
.Received(1)
.LogOrganizationEventAsync(Arg.Is<Organization>(org => org == organization),
Arg.Is<EventType>(e => e == EventType.Organization_Updated));
}
[Theory, BitAutoData]
public async Task UpdateAsync_WhenValidOrganization_AndUpdateBillingIsFalse_UpdateOrganization(Organization organization, SutProvider<OrganizationService> sutProvider)
{
// Arrange
var organizationRepository = sutProvider.GetDependency<IOrganizationRepository>();
var applicationCacheService = sutProvider.GetDependency<IApplicationCacheService>();
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var eventService = sutProvider.GetDependency<IEventService>();
organizationRepository
.GetByIdentifierAsync(organization.Identifier!)
.Returns(organization);
// Act
await sutProvider.Sut.UpdateAsync(organization, updateBilling: false);
// Assert
await organizationRepository
.Received(1)
.GetByIdentifierAsync(Arg.Is<string>(id => id == organization.Identifier));
await stripeAdapter
.DidNotReceiveWithAnyArgs()
.CustomerUpdateAsync(Arg.Any<string>(), Arg.Any<CustomerUpdateOptions>());
await organizationRepository
.Received(1)
.ReplaceAsync(Arg.Is<Organization>(org => org == organization));
await applicationCacheService
.Received(1)
.UpsertOrganizationAbilityAsync(Arg.Is<Organization>(org => org == organization));
await eventService
.Received(1)
.LogOrganizationEventAsync(Arg.Is<Organization>(org => org == organization),
Arg.Is<EventType>(e => e == EventType.Organization_Updated));
}
[Theory, BitAutoData]
public async Task UpdateAsync_WhenOrganizationHasNoId_ThrowsApplicationException(Organization organization, SutProvider<OrganizationService> sutProvider)
{
// Arrange
organization.Id = Guid.Empty;
// Act/Assert
var exception = await Assert.ThrowsAnyAsync<ApplicationException>(() => sutProvider.Sut.UpdateAsync(organization));
Assert.Equal("Cannot create org this way. Call SignUpAsync.", exception.Message);
}
[Theory, BitAutoData]
public async Task UpdateAsync_WhenIdentifierAlreadyExistsForADifferentOrganization_ThrowsBadRequestException(Organization organization, SutProvider<OrganizationService> sutProvider)
{
// Arrange
var organizationRepository = sutProvider.GetDependency<IOrganizationRepository>();
var differentOrganization = new Organization { Id = Guid.NewGuid() };
organizationRepository
.GetByIdentifierAsync(organization.Identifier!)
.Returns(differentOrganization);
// Act/Assert
var exception = await Assert.ThrowsAnyAsync<BadRequestException>(() => sutProvider.Sut.UpdateAsync(organization));
Assert.Equal("Identifier already in use by another organization.", exception.Message);
await organizationRepository
.Received(1)
.GetByIdentifierAsync(Arg.Is<string>(id => id == organization.Identifier));
}
// Must set real guids in order for dictionary of guids to not throw aggregate exceptions
private void SetupOrgUserRepositoryCreateManyAsyncMock(IOrganizationUserRepository organizationUserRepository)
{

View File

@@ -21,7 +21,7 @@ namespace Bit.Core.Test.Billing.Organizations.Queries;
[SutProviderCustomize]
public class GetOrganizationWarningsQueryTests
{
private static readonly string[] _requiredExpansions = ["customer", "latest_invoice", "test_clock"];
private static readonly string[] _requiredExpansions = ["customer.tax_ids", "latest_invoice", "test_clock"];
[Theory, BitAutoData]
public async Task Run_NoSubscription_NoWarnings(
@@ -130,7 +130,7 @@ public class GetOrganizationWarningsQueryTests
}
[Theory, BitAutoData]
public async Task Run_Has_InactiveSubscriptionWarning_AddPaymentMethodOptionalTrial(
public async Task Run_OrganizationEnabled_NoInactiveSubscriptionWarning(
Organization organization,
SutProvider<GetOrganizationWarningsQuery> sutProvider)
{
@@ -142,7 +142,7 @@ public class GetOrganizationWarningsQueryTests
))
.Returns(new Subscription
{
Status = StripeConstants.SubscriptionStatus.Trialing,
Status = StripeConstants.SubscriptionStatus.Unpaid,
Customer = new Customer
{
InvoiceSettings = new CustomerInvoiceSettings(),
@@ -151,14 +151,10 @@ public class GetOrganizationWarningsQueryTests
});
sutProvider.GetDependency<ICurrentContext>().OrganizationOwner(organization.Id).Returns(true);
sutProvider.GetDependency<ISetupIntentCache>().Get(organization.Id).Returns((string?)null);
var response = await sutProvider.Sut.Run(organization);
Assert.True(response is
{
InactiveSubscription.Resolution: "add_payment_method_optional_trial"
});
Assert.Null(response.InactiveSubscription);
}
[Theory, BitAutoData]

View File

@@ -1695,9 +1695,6 @@ public class SubscriberServiceTests
sutProvider.GetDependency<IStripeAdapter>().SubscriptionGetAsync(Arg.Any<string>())
.Returns(subscription);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true);
await sutProvider.Sut.UpdateTaxInformation(provider, taxInformation);
await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is<CustomerUpdateOptions>(
@@ -1765,4 +1762,142 @@ public class SubscriberServiceTests
}
#endregion
#region IsValidGatewayCustomerIdAsync
[Theory, BitAutoData]
public async Task IsValidGatewayCustomerIdAsync_NullSubscriber_ThrowsArgumentNullException(
SutProvider<SubscriberService> sutProvider)
{
await Assert.ThrowsAsync<ArgumentNullException>(() =>
sutProvider.Sut.IsValidGatewayCustomerIdAsync(null));
}
[Theory, BitAutoData]
public async Task IsValidGatewayCustomerIdAsync_NullGatewayCustomerId_ReturnsTrue(
Organization organization,
SutProvider<SubscriberService> sutProvider)
{
organization.GatewayCustomerId = null;
var result = await sutProvider.Sut.IsValidGatewayCustomerIdAsync(organization);
Assert.True(result);
await sutProvider.GetDependency<IStripeAdapter>().DidNotReceiveWithAnyArgs()
.CustomerGetAsync(Arg.Any<string>());
}
[Theory, BitAutoData]
public async Task IsValidGatewayCustomerIdAsync_EmptyGatewayCustomerId_ReturnsTrue(
Organization organization,
SutProvider<SubscriberService> sutProvider)
{
organization.GatewayCustomerId = "";
var result = await sutProvider.Sut.IsValidGatewayCustomerIdAsync(organization);
Assert.True(result);
await sutProvider.GetDependency<IStripeAdapter>().DidNotReceiveWithAnyArgs()
.CustomerGetAsync(Arg.Any<string>());
}
[Theory, BitAutoData]
public async Task IsValidGatewayCustomerIdAsync_ValidCustomerId_ReturnsTrue(
Organization organization,
SutProvider<SubscriberService> sutProvider)
{
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId).Returns(new Customer());
var result = await sutProvider.Sut.IsValidGatewayCustomerIdAsync(organization);
Assert.True(result);
await stripeAdapter.Received(1).CustomerGetAsync(organization.GatewayCustomerId);
}
[Theory, BitAutoData]
public async Task IsValidGatewayCustomerIdAsync_InvalidCustomerId_ReturnsFalse(
Organization organization,
SutProvider<SubscriberService> sutProvider)
{
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var stripeException = new StripeException { StripeError = new StripeError { Code = "resource_missing" } };
stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId).Throws(stripeException);
var result = await sutProvider.Sut.IsValidGatewayCustomerIdAsync(organization);
Assert.False(result);
await stripeAdapter.Received(1).CustomerGetAsync(organization.GatewayCustomerId);
}
#endregion
#region IsValidGatewaySubscriptionIdAsync
[Theory, BitAutoData]
public async Task IsValidGatewaySubscriptionIdAsync_NullSubscriber_ThrowsArgumentNullException(
SutProvider<SubscriberService> sutProvider)
{
await Assert.ThrowsAsync<ArgumentNullException>(() =>
sutProvider.Sut.IsValidGatewaySubscriptionIdAsync(null));
}
[Theory, BitAutoData]
public async Task IsValidGatewaySubscriptionIdAsync_NullGatewaySubscriptionId_ReturnsTrue(
Organization organization,
SutProvider<SubscriberService> sutProvider)
{
organization.GatewaySubscriptionId = null;
var result = await sutProvider.Sut.IsValidGatewaySubscriptionIdAsync(organization);
Assert.True(result);
await sutProvider.GetDependency<IStripeAdapter>().DidNotReceiveWithAnyArgs()
.SubscriptionGetAsync(Arg.Any<string>());
}
[Theory, BitAutoData]
public async Task IsValidGatewaySubscriptionIdAsync_EmptyGatewaySubscriptionId_ReturnsTrue(
Organization organization,
SutProvider<SubscriberService> sutProvider)
{
organization.GatewaySubscriptionId = "";
var result = await sutProvider.Sut.IsValidGatewaySubscriptionIdAsync(organization);
Assert.True(result);
await sutProvider.GetDependency<IStripeAdapter>().DidNotReceiveWithAnyArgs()
.SubscriptionGetAsync(Arg.Any<string>());
}
[Theory, BitAutoData]
public async Task IsValidGatewaySubscriptionIdAsync_ValidSubscriptionId_ReturnsTrue(
Organization organization,
SutProvider<SubscriberService> sutProvider)
{
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId).Returns(new Subscription());
var result = await sutProvider.Sut.IsValidGatewaySubscriptionIdAsync(organization);
Assert.True(result);
await stripeAdapter.Received(1).SubscriptionGetAsync(organization.GatewaySubscriptionId);
}
[Theory, BitAutoData]
public async Task IsValidGatewaySubscriptionIdAsync_InvalidSubscriptionId_ReturnsFalse(
Organization organization,
SutProvider<SubscriberService> sutProvider)
{
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var stripeException = new StripeException { StripeError = new StripeError { Code = "resource_missing" } };
stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId).Throws(stripeException);
var result = await sutProvider.Sut.IsValidGatewaySubscriptionIdAsync(organization);
Assert.False(result);
await stripeAdapter.Received(1).SubscriptionGetAsync(organization.GatewaySubscriptionId);
}
#endregion
}

View File

@@ -2,10 +2,13 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models.Business;
using Bit.Core.Entities;
using Bit.Core.Models.Mail;
using Bit.Core.Services;
using Bit.Core.Settings;
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Xunit;
@@ -19,17 +22,93 @@ public class HandlebarsMailServiceTests
private readonly GlobalSettings _globalSettings;
private readonly IMailDeliveryService _mailDeliveryService;
private readonly IMailEnqueuingService _mailEnqueuingService;
private readonly IDistributedCache _distributedCache;
public HandlebarsMailServiceTests()
{
_globalSettings = new GlobalSettings();
_mailDeliveryService = Substitute.For<IMailDeliveryService>();
_mailEnqueuingService = Substitute.For<IMailEnqueuingService>();
_distributedCache = Substitute.For<IDistributedCache>();
_sut = new HandlebarsMailService(
_globalSettings,
_mailDeliveryService,
_mailEnqueuingService
_mailEnqueuingService,
_distributedCache
);
}
[Fact]
public async Task SendFailedTwoFactorAttemptEmailAsync_FirstCall_SendsEmail()
{
// Arrange
var email = "test@example.com";
var failedType = TwoFactorProviderType.Email;
var utcNow = DateTime.UtcNow;
var ip = "192.168.1.1";
_distributedCache.GetAsync(Arg.Any<string>()).Returns((byte[])null);
// Act
await _sut.SendFailedTwoFactorAttemptEmailAsync(email, failedType, utcNow, ip);
// Assert
await _mailDeliveryService.Received(1).SendEmailAsync(Arg.Any<MailMessage>());
await _distributedCache.Received(1).SetAsync(
Arg.Is<string>(key => key == $"FailedTwoFactorAttemptEmail_{email}"),
Arg.Any<byte[]>(),
Arg.Any<DistributedCacheEntryOptions>()
);
}
[Fact]
public async Task SendFailedTwoFactorAttemptEmailAsync_SecondCallWithinHour_DoesNotSendEmail()
{
// Arrange
var email = "test@example.com";
var failedType = TwoFactorProviderType.Email;
var utcNow = DateTime.UtcNow;
var ip = "192.168.1.1";
// Simulate cache hit (email was already sent)
_distributedCache.GetAsync(Arg.Any<string>()).Returns([1]);
// Act
await _sut.SendFailedTwoFactorAttemptEmailAsync(email, failedType, utcNow, ip);
// Assert
await _mailDeliveryService.DidNotReceive().SendEmailAsync(Arg.Any<MailMessage>());
await _distributedCache.DidNotReceive().SetAsync(Arg.Any<string>(), Arg.Any<byte[]>(), Arg.Any<DistributedCacheEntryOptions>());
}
[Fact]
public async Task SendFailedTwoFactorAttemptEmailAsync_DifferentEmails_SendsBothEmails()
{
// Arrange
var email1 = "test1@example.com";
var email2 = "test2@example.com";
var failedType = TwoFactorProviderType.Email;
var utcNow = DateTime.UtcNow;
var ip = "192.168.1.1";
_distributedCache.GetAsync(Arg.Any<string>()).Returns((byte[])null);
// Act
await _sut.SendFailedTwoFactorAttemptEmailAsync(email1, failedType, utcNow, ip);
await _sut.SendFailedTwoFactorAttemptEmailAsync(email2, failedType, utcNow, ip);
// Assert
await _mailDeliveryService.Received(2).SendEmailAsync(Arg.Any<MailMessage>());
await _distributedCache.Received(1).SetAsync(
Arg.Is<string>(key => key == $"FailedTwoFactorAttemptEmail_{email1}"),
Arg.Any<byte[]>(),
Arg.Any<DistributedCacheEntryOptions>()
);
await _distributedCache.Received(1).SetAsync(
Arg.Is<string>(key => key == $"FailedTwoFactorAttemptEmail_{email2}"),
Arg.Any<byte[]>(),
Arg.Any<DistributedCacheEntryOptions>()
);
}
@@ -137,8 +216,9 @@ public class HandlebarsMailServiceTests
};
var mailDeliveryService = new MailKitSmtpMailDeliveryService(globalSettings, Substitute.For<ILogger<MailKitSmtpMailDeliveryService>>());
var distributedCache = Substitute.For<IDistributedCache>();
var handlebarsService = new HandlebarsMailService(globalSettings, mailDeliveryService, new BlockingMailEnqueuingService());
var handlebarsService = new HandlebarsMailService(globalSettings, mailDeliveryService, new BlockingMailEnqueuingService(), distributedCache);
var sendMethods = typeof(IMailService).GetMethods(BindingFlags.Public | BindingFlags.Instance)
.Where(m => m.Name.StartsWith("Send") && m.Name != "SendEnqueuedMailMessageAsync");

View File

@@ -9,7 +9,7 @@ using Bit.Core.Test.AutoFixture.UserFixtures;
using Bit.Core.Utilities;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using IdentityModel;
using Duende.IdentityModel;
using Microsoft.AspNetCore.DataProtection;
using Xunit;

View File

@@ -17,9 +17,9 @@ using Bit.Core.Repositories;
using Bit.Core.Utilities;
using Bit.IntegrationTestCommon.Factories;
using Bit.Test.Common.Helpers;
using Duende.IdentityModel;
using Duende.IdentityServer.Models;
using Duende.IdentityServer.Stores;
using IdentityModel;
using Microsoft.EntityFrameworkCore;
using NSubstitute;
using Xunit;

View File

@@ -17,9 +17,9 @@ using Bit.Core.Utilities;
using Bit.IntegrationTestCommon.Factories;
using Bit.Test.Common.AutoFixture.Attributes;
using Bit.Test.Common.Helpers;
using Duende.IdentityModel;
using Duende.IdentityServer.Models;
using Duende.IdentityServer.Stores;
using IdentityModel;
using LinqToDB;
using Microsoft.Extensions.Caching.Distributed;
using NSubstitute;

View File

@@ -8,6 +8,7 @@ using Bit.Core.Utilities;
using Bit.Identity.IdentityServer.Enums;
using Bit.Identity.IdentityServer.RequestValidators.SendAccess;
using Bit.IntegrationTestCommon.Factories;
using Duende.IdentityModel;
using Duende.IdentityServer.Validation;
using NSubstitute;
using Xunit;
@@ -96,8 +97,8 @@ public class SendAccessGrantValidatorIntegrationTests(IdentityApplicationFactory
}).CreateClient();
var requestBody = new FormUrlEncodedContent([
new KeyValuePair<string, string>("grant_type", CustomGrantTypes.SendAccess),
new KeyValuePair<string, string>("client_id", BitwardenClient.Send)
new KeyValuePair<string, string>(OidcConstants.TokenRequest.GrantType, CustomGrantTypes.SendAccess),
new KeyValuePair<string, string>(OidcConstants.TokenRequest.ClientId, BitwardenClient.Send)
]);
// Act
@@ -105,8 +106,8 @@ public class SendAccessGrantValidatorIntegrationTests(IdentityApplicationFactory
// Assert
var content = await response.Content.ReadAsStringAsync();
Assert.Contains("invalid_request", content);
Assert.Contains("send_id is required", content);
Assert.Contains(OidcConstants.TokenErrors.InvalidRequest, content);
Assert.Contains($"{SendAccessConstants.TokenRequest.SendId} is required", content);
}
[Fact]
@@ -245,16 +246,16 @@ public class SendAccessGrantValidatorIntegrationTests(IdentityApplicationFactory
var sendIdBase64 = CoreHelpers.Base64UrlEncode(sendId.ToByteArray());
var parameters = new List<KeyValuePair<string, string>>
{
new("grant_type", CustomGrantTypes.SendAccess),
new("client_id", BitwardenClient.Send ),
new("scope", ApiScopes.ApiSendAccess),
new(OidcConstants.TokenRequest.GrantType, CustomGrantTypes.SendAccess),
new(OidcConstants.TokenRequest.ClientId, BitwardenClient.Send ),
new(OidcConstants.TokenRequest.Scope, ApiScopes.ApiSendAccess),
new("deviceType", ((int)DeviceType.FirefoxBrowser).ToString()),
new("send_id", sendIdBase64)
new(SendAccessConstants.TokenRequest.SendId, sendIdBase64)
};
if (!string.IsNullOrEmpty(password))
{
parameters.Add(new("password_hash", password));
parameters.Add(new(SendAccessConstants.TokenRequest.ClientB64HashedPassword, password));
}
if (!string.IsNullOrEmpty(emailOtp) && !string.IsNullOrEmpty(sendEmail))

View File

@@ -0,0 +1,209 @@
using Bit.Core.Enums;
using Bit.Core.IdentityServer;
using Bit.Core.KeyManagement.Sends;
using Bit.Core.Services;
using Bit.Core.Tools.Models.Data;
using Bit.Core.Tools.SendFeatures.Queries.Interfaces;
using Bit.Core.Utilities;
using Bit.Identity.IdentityServer.Enums;
using Bit.Identity.IdentityServer.RequestValidators.SendAccess;
using Bit.IntegrationTestCommon.Factories;
using Duende.IdentityModel;
using NSubstitute;
using Xunit;
namespace Bit.Identity.IntegrationTest.RequestValidation;
public class SendPasswordRequestValidatorIntegrationTests : IClassFixture<IdentityApplicationFactory>
{
private readonly IdentityApplicationFactory _factory;
public SendPasswordRequestValidatorIntegrationTests(IdentityApplicationFactory factory)
{
_factory = factory;
}
[Fact]
public async Task SendAccess_PasswordProtectedSend_ValidPassword_ReturnsAccessToken()
{
// Arrange
var sendId = Guid.NewGuid();
var passwordHash = "stored-password-hash";
var clientPasswordHash = "client-password-hash";
var client = _factory.WithWebHostBuilder(builder =>
{
builder.ConfigureServices(services =>
{
// Enable feature flag
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(Arg.Any<string>()).Returns(true);
services.AddSingleton(featureService);
// Mock send authentication query
var sendAuthQuery = Substitute.For<ISendAuthenticationQuery>();
sendAuthQuery.GetAuthenticationMethod(sendId)
.Returns(new ResourcePassword(passwordHash));
services.AddSingleton(sendAuthQuery);
// Mock password hasher to return true for matching passwords
var passwordHasher = Substitute.For<ISendPasswordHasher>();
passwordHasher.PasswordHashMatches(passwordHash, clientPasswordHash)
.Returns(true);
services.AddSingleton(passwordHasher);
});
}).CreateClient();
var requestBody = CreateTokenRequestBody(sendId, clientPasswordHash);
// Act
var response = await client.PostAsync("/connect/token", requestBody);
// Assert
Assert.True(response.IsSuccessStatusCode);
var content = await response.Content.ReadAsStringAsync();
Assert.Contains(OidcConstants.TokenResponse.AccessToken, content);
Assert.Contains("bearer", content.ToLower());
}
[Fact]
public async Task SendAccess_PasswordProtectedSend_InvalidPassword_ReturnsInvalidGrant()
{
// Arrange
var sendId = Guid.NewGuid();
var passwordHash = "stored-password-hash";
var wrongClientPasswordHash = "wrong-client-password-hash";
var client = _factory.WithWebHostBuilder(builder =>
{
builder.ConfigureServices(services =>
{
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(Arg.Any<string>()).Returns(true);
services.AddSingleton(featureService);
var sendAuthQuery = Substitute.For<ISendAuthenticationQuery>();
sendAuthQuery.GetAuthenticationMethod(sendId)
.Returns(new ResourcePassword(passwordHash));
services.AddSingleton(sendAuthQuery);
// Mock password hasher to return false for wrong passwords
var passwordHasher = Substitute.For<ISendPasswordHasher>();
passwordHasher.PasswordHashMatches(passwordHash, wrongClientPasswordHash)
.Returns(false);
services.AddSingleton(passwordHasher);
});
}).CreateClient();
var requestBody = CreateTokenRequestBody(sendId, wrongClientPasswordHash);
// Act
var response = await client.PostAsync("/connect/token", requestBody);
// Assert
var content = await response.Content.ReadAsStringAsync();
Assert.Contains(OidcConstants.TokenErrors.InvalidGrant, content);
Assert.Contains($"{SendAccessConstants.TokenRequest.ClientB64HashedPassword} is invalid", content);
}
[Fact]
public async Task SendAccess_PasswordProtectedSend_MissingPassword_ReturnsInvalidRequest()
{
// Arrange
var sendId = Guid.NewGuid();
var passwordHash = "stored-password-hash";
var client = _factory.WithWebHostBuilder(builder =>
{
builder.ConfigureServices(services =>
{
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(Arg.Any<string>()).Returns(true);
services.AddSingleton(featureService);
var sendAuthQuery = Substitute.For<ISendAuthenticationQuery>();
sendAuthQuery.GetAuthenticationMethod(sendId)
.Returns(new ResourcePassword(passwordHash));
services.AddSingleton(sendAuthQuery);
var passwordHasher = Substitute.For<ISendPasswordHasher>();
services.AddSingleton(passwordHasher);
});
}).CreateClient();
var requestBody = CreateTokenRequestBody(sendId); // No password
// Act
var response = await client.PostAsync("/connect/token", requestBody);
// Assert
var content = await response.Content.ReadAsStringAsync();
Assert.Contains(OidcConstants.TokenErrors.InvalidRequest, content);
Assert.Contains($"{SendAccessConstants.TokenRequest.ClientB64HashedPassword} is required", content);
}
/// <summary>
/// When the password has is empty or whitespace it doesn't get passed to the server when the request is made.
/// This leads to an invalid request error since the absence of the password hash is considered a malformed request.
/// In the case that the passwordB64Hash _is_ empty or whitespace it would be an invalid grant since the request
/// has the correct shape.
/// </summary>
[Fact]
public async Task SendAccess_PasswordProtectedSend_EmptyPassword_ReturnsInvalidRequest()
{
// Arrange
var sendId = Guid.NewGuid();
var passwordHash = "stored-password-hash";
var client = _factory.WithWebHostBuilder(builder =>
{
builder.ConfigureServices(services =>
{
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(Arg.Any<string>()).Returns(true);
services.AddSingleton(featureService);
var sendAuthQuery = Substitute.For<ISendAuthenticationQuery>();
sendAuthQuery.GetAuthenticationMethod(sendId)
.Returns(new ResourcePassword(passwordHash));
services.AddSingleton(sendAuthQuery);
// Mock password hasher to return false for empty passwords
var passwordHasher = Substitute.For<ISendPasswordHasher>();
passwordHasher.PasswordHashMatches(passwordHash, string.Empty)
.Returns(false);
services.AddSingleton(passwordHasher);
});
}).CreateClient();
var requestBody = CreateTokenRequestBody(sendId, string.Empty);
// Act
var response = await client.PostAsync("/connect/token", requestBody);
// Assert
var content = await response.Content.ReadAsStringAsync();
Assert.Contains(OidcConstants.TokenErrors.InvalidRequest, content);
Assert.Contains($"{SendAccessConstants.TokenRequest.ClientB64HashedPassword} is required", content);
}
private static FormUrlEncodedContent CreateTokenRequestBody(Guid sendId, string passwordHash = null)
{
var sendIdBase64 = CoreHelpers.Base64UrlEncode(sendId.ToByteArray());
var parameters = new List<KeyValuePair<string, string>>
{
new(OidcConstants.TokenRequest.GrantType, CustomGrantTypes.SendAccess),
new(OidcConstants.TokenRequest.ClientId, BitwardenClient.Send),
new(SendAccessConstants.TokenRequest.SendId, sendIdBase64),
new(OidcConstants.TokenRequest.Scope, ApiScopes.ApiSendAccess),
new("deviceType", "10")
};
if (passwordHash != null)
{
parameters.Add(new KeyValuePair<string, string>(SendAccessConstants.TokenRequest.ClientB64HashedPassword, passwordHash));
}
return new FormUrlEncodedContent(parameters);
}
}

View File

@@ -1,7 +1,7 @@
using Bit.Core.IdentityServer;
using Bit.Core.Platform.Installations;
using Bit.Identity.IdentityServer.ClientProviders;
using IdentityModel;
using Duende.IdentityModel;
using NSubstitute;
using Xunit;

View File

@@ -1,7 +1,7 @@
using Bit.Core.IdentityServer;
using Bit.Core.Settings;
using Bit.Identity.IdentityServer.ClientProviders;
using IdentityModel;
using Duende.IdentityModel;
using Xunit;
namespace Bit.Identity.Test.IdentityServer.ClientProviders;

View File

@@ -11,9 +11,9 @@ using Bit.Identity.IdentityServer.Enums;
using Bit.Identity.IdentityServer.RequestValidators.SendAccess;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Duende.IdentityModel;
using Duende.IdentityServer.Extensions;
using Duende.IdentityServer.Validation;
using IdentityModel;
using NSubstitute;
using Xunit;
@@ -65,7 +65,7 @@ public class SendAccessGrantValidatorTests
// Assert
Assert.Equal(OidcConstants.TokenErrors.InvalidRequest, context.Result.Error);
Assert.Equal("send_id is required.", context.Result.ErrorDescription);
Assert.Equal($"{SendAccessConstants.TokenRequest.SendId} is required.", context.Result.ErrorDescription);
}
[Theory, BitAutoData]
@@ -84,7 +84,7 @@ public class SendAccessGrantValidatorTests
tokenRequest.Raw = CreateTokenRequestBody(Guid.Empty);
// To preserve the CreateTokenRequestBody method for more general usage we over write the sendId
tokenRequest.Raw.Set("send_id", "invalid-guid-format");
tokenRequest.Raw.Set(SendAccessConstants.TokenRequest.SendId, "invalid-guid-format");
context.Request = tokenRequest;
// Act
@@ -92,7 +92,7 @@ public class SendAccessGrantValidatorTests
// Assert
Assert.Equal(OidcConstants.TokenErrors.InvalidGrant, context.Result.Error);
Assert.Equal("send_id is invalid.", context.Result.ErrorDescription);
Assert.Equal($"{SendAccessConstants.TokenRequest.SendId} is invalid.", context.Result.ErrorDescription);
}
[Theory, BitAutoData]
@@ -111,7 +111,7 @@ public class SendAccessGrantValidatorTests
// Assert
Assert.Equal(OidcConstants.TokenErrors.InvalidGrant, context.Result.Error);
Assert.Equal("send_id is invalid.", context.Result.ErrorDescription);
Assert.Equal($"{SendAccessConstants.TokenRequest.SendId} is invalid.", context.Result.ErrorDescription);
}
[Theory, BitAutoData]
@@ -135,7 +135,7 @@ public class SendAccessGrantValidatorTests
// Assert
Assert.Equal(OidcConstants.TokenErrors.InvalidGrant, context.Result.Error);
Assert.Equal("send_id is invalid.", context.Result.ErrorDescription);
Assert.Equal($"{SendAccessConstants.TokenRequest.SendId} is invalid.", context.Result.ErrorDescription);
}
[Theory, BitAutoData]
@@ -297,37 +297,28 @@ public class SendAccessGrantValidatorTests
var rawRequestParameters = new NameValueCollection
{
{ "grant_type", CustomGrantTypes.SendAccess },
{ "client_id", BitwardenClient.Send },
{ "scope", ApiScopes.ApiSendAccess },
{ OidcConstants.TokenRequest.GrantType, CustomGrantTypes.SendAccess },
{ OidcConstants.TokenRequest.ClientId, BitwardenClient.Send },
{ OidcConstants.TokenRequest.Scope, ApiScopes.ApiSendAccess },
{ "deviceType", ((int)DeviceType.FirefoxBrowser).ToString() },
{ "send_id", sendIdBase64 }
{ SendAccessConstants.TokenRequest.SendId, sendIdBase64 }
};
if (passwordHash != null)
{
rawRequestParameters.Add("password_hash", passwordHash);
rawRequestParameters.Add(SendAccessConstants.TokenRequest.ClientB64HashedPassword, passwordHash);
}
if (sendEmail != null)
{
rawRequestParameters.Add("send_email", sendEmail);
rawRequestParameters.Add(SendAccessConstants.TokenRequest.Email, sendEmail);
}
if (otpCode != null && sendEmail != null)
{
rawRequestParameters.Add("otp_code", otpCode);
rawRequestParameters.Add(SendAccessConstants.TokenRequest.Otp, otpCode);
}
return rawRequestParameters;
}
// we need a list of sendAuthentication methods to test against since we cannot create new objects in the BitAutoData
public static Dictionary<string, SendAuthenticationMethod> SendAuthenticationMethods => new()
{
{ "NeverAuthenticate", new NeverAuthenticate() }, // Send doesn't exist or is deleted
{ "NotAuthenticated", new NotAuthenticated() }, // Public send, no auth needed
// TODO: PM-22675 - {"ResourcePassword", new ResourcePassword("clientHashedPassword")}; // Password protected send
// TODO: PM-22678 - {"EmailOtp", new EmailOtp(["emailOtp@test.dev"]}; // Email + OTP protected send
};
}

View File

@@ -0,0 +1,297 @@
using System.Collections.Specialized;
using Bit.Core.Auth.UserFeatures.SendAccess;
using Bit.Core.Enums;
using Bit.Core.Identity;
using Bit.Core.IdentityServer;
using Bit.Core.KeyManagement.Sends;
using Bit.Core.Tools.Models.Data;
using Bit.Core.Utilities;
using Bit.Identity.IdentityServer.Enums;
using Bit.Identity.IdentityServer.RequestValidators.SendAccess;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Duende.IdentityModel;
using Duende.IdentityServer.Validation;
using NSubstitute;
using Xunit;
namespace Bit.Identity.Test.IdentityServer;
[SutProviderCustomize]
public class SendPasswordRequestValidatorTests
{
[Theory, BitAutoData]
public void ValidateSendPassword_MissingPasswordHash_ReturnsInvalidRequest(
SutProvider<SendPasswordRequestValidator> sutProvider,
[AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
ResourcePassword resourcePassword,
Guid sendId)
{
// Arrange
tokenRequest.Raw = CreateValidatedTokenRequest(sendId);
var context = new ExtensionGrantValidationContext
{
Request = tokenRequest
};
// Act
var result = sutProvider.Sut.ValidateSendPassword(context, resourcePassword, sendId);
// Assert
Assert.True(result.IsError);
Assert.Equal(OidcConstants.TokenErrors.InvalidRequest, result.Error);
Assert.Equal($"{SendAccessConstants.TokenRequest.ClientB64HashedPassword} is required.", result.ErrorDescription);
// Verify password hasher was not called
sutProvider.GetDependency<ISendPasswordHasher>()
.DidNotReceive()
.PasswordHashMatches(Arg.Any<string>(), Arg.Any<string>());
}
[Theory, BitAutoData]
public void ValidateSendPassword_PasswordHashMismatch_ReturnsInvalidGrant(
SutProvider<SendPasswordRequestValidator> sutProvider,
[AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
ResourcePassword resourcePassword,
Guid sendId,
string clientPasswordHash)
{
// Arrange
tokenRequest.Raw = CreateValidatedTokenRequest(sendId, clientPasswordHash);
var context = new ExtensionGrantValidationContext
{
Request = tokenRequest
};
sutProvider.GetDependency<ISendPasswordHasher>()
.PasswordHashMatches(resourcePassword.Hash, clientPasswordHash)
.Returns(false);
// Act
var result = sutProvider.Sut.ValidateSendPassword(context, resourcePassword, sendId);
// Assert
Assert.True(result.IsError);
Assert.Equal(OidcConstants.TokenErrors.InvalidGrant, result.Error);
Assert.Equal($"{SendAccessConstants.TokenRequest.ClientB64HashedPassword} is invalid.", result.ErrorDescription);
// Verify password hasher was called with correct parameters
sutProvider.GetDependency<ISendPasswordHasher>()
.Received(1)
.PasswordHashMatches(resourcePassword.Hash, clientPasswordHash);
}
[Theory, BitAutoData]
public void ValidateSendPassword_PasswordHashMatches_ReturnsSuccess(
SutProvider<SendPasswordRequestValidator> sutProvider,
[AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
ResourcePassword resourcePassword,
Guid sendId,
string clientPasswordHash)
{
// Arrange
tokenRequest.Raw = CreateValidatedTokenRequest(sendId, clientPasswordHash);
var context = new ExtensionGrantValidationContext
{
Request = tokenRequest
};
sutProvider.GetDependency<ISendPasswordHasher>()
.PasswordHashMatches(resourcePassword.Hash, clientPasswordHash)
.Returns(true);
// Act
var result = sutProvider.Sut.ValidateSendPassword(context, resourcePassword, sendId);
// Assert
Assert.False(result.IsError);
var sub = result.Subject;
Assert.Equal(sendId, sub.GetSendId());
// Verify claims
Assert.Contains(sub.Claims, c => c.Type == Claims.SendId && c.Value == sendId.ToString());
Assert.Contains(sub.Claims, c => c.Type == Claims.Type && c.Value == IdentityClientType.Send.ToString());
// Verify password hasher was called
sutProvider.GetDependency<ISendPasswordHasher>()
.Received(1)
.PasswordHashMatches(resourcePassword.Hash, clientPasswordHash);
}
[Theory, BitAutoData]
public void ValidateSendPassword_EmptyPasswordHash_CallsPasswordHasher(
SutProvider<SendPasswordRequestValidator> sutProvider,
[AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
ResourcePassword resourcePassword,
Guid sendId)
{
// Arrange
tokenRequest.Raw = CreateValidatedTokenRequest(sendId, string.Empty);
var context = new ExtensionGrantValidationContext
{
Request = tokenRequest
};
sutProvider.GetDependency<ISendPasswordHasher>()
.PasswordHashMatches(resourcePassword.Hash, string.Empty)
.Returns(false);
// Act
var result = sutProvider.Sut.ValidateSendPassword(context, resourcePassword, sendId);
// Assert
Assert.True(result.IsError);
Assert.Equal(OidcConstants.TokenErrors.InvalidGrant, result.Error);
// Verify password hasher was called with empty string
sutProvider.GetDependency<ISendPasswordHasher>()
.Received(1)
.PasswordHashMatches(resourcePassword.Hash, string.Empty);
}
[Theory, BitAutoData]
public void ValidateSendPassword_WhitespacePasswordHash_CallsPasswordHasher(
SutProvider<SendPasswordRequestValidator> sutProvider,
[AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
ResourcePassword resourcePassword,
Guid sendId)
{
// Arrange
var whitespacePassword = " ";
tokenRequest.Raw = CreateValidatedTokenRequest(sendId, whitespacePassword);
var context = new ExtensionGrantValidationContext
{
Request = tokenRequest
};
sutProvider.GetDependency<ISendPasswordHasher>()
.PasswordHashMatches(resourcePassword.Hash, whitespacePassword)
.Returns(false);
// Act
var result = sutProvider.Sut.ValidateSendPassword(context, resourcePassword, sendId);
// Assert
Assert.True(result.IsError);
// Verify password hasher was called with whitespace string
sutProvider.GetDependency<ISendPasswordHasher>()
.Received(1)
.PasswordHashMatches(resourcePassword.Hash, whitespacePassword);
}
[Theory, BitAutoData]
public void ValidateSendPassword_MultiplePasswordHashParameters_ReturnsInvalidGrant(
SutProvider<SendPasswordRequestValidator> sutProvider,
[AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
ResourcePassword resourcePassword,
Guid sendId)
{
// Arrange
var firstPassword = "first-password";
var secondPassword = "second-password";
tokenRequest.Raw = CreateValidatedTokenRequest(sendId, firstPassword, secondPassword);
var context = new ExtensionGrantValidationContext
{
Request = tokenRequest
};
sutProvider.GetDependency<ISendPasswordHasher>()
.PasswordHashMatches(resourcePassword.Hash, firstPassword)
.Returns(true);
// Act
var result = sutProvider.Sut.ValidateSendPassword(context, resourcePassword, sendId);
// Assert
Assert.True(result.IsError);
Assert.Equal(OidcConstants.TokenErrors.InvalidGrant, result.Error);
// Verify password hasher was called with first value
sutProvider.GetDependency<ISendPasswordHasher>()
.Received(1)
.PasswordHashMatches(resourcePassword.Hash, $"{firstPassword},{secondPassword}");
}
[Theory, BitAutoData]
public void ValidateSendPassword_SuccessResult_ContainsCorrectClaims(
SutProvider<SendPasswordRequestValidator> sutProvider,
[AutoFixture.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
ResourcePassword resourcePassword,
Guid sendId,
string clientPasswordHash)
{
// Arrange
tokenRequest.Raw = CreateValidatedTokenRequest(sendId, clientPasswordHash);
var context = new ExtensionGrantValidationContext
{
Request = tokenRequest
};
sutProvider.GetDependency<ISendPasswordHasher>()
.PasswordHashMatches(Arg.Any<string>(), Arg.Any<string>())
.Returns(true);
// Act
var result = sutProvider.Sut.ValidateSendPassword(context, resourcePassword, sendId);
// Assert
Assert.False(result.IsError);
var sub = result.Subject;
var sendIdClaim = sub.Claims.FirstOrDefault(c => c.Type == Claims.SendId);
Assert.NotNull(sendIdClaim);
Assert.Equal(sendId.ToString(), sendIdClaim.Value);
var typeClaim = sub.Claims.FirstOrDefault(c => c.Type == Claims.Type);
Assert.NotNull(typeClaim);
Assert.Equal(IdentityClientType.Send.ToString(), typeClaim.Value);
}
[Fact]
public void Constructor_WithValidParameters_CreatesInstance()
{
// Arrange
var sendPasswordHasher = Substitute.For<ISendPasswordHasher>();
// Act
var validator = new SendPasswordRequestValidator(sendPasswordHasher);
// Assert
Assert.NotNull(validator);
}
private static NameValueCollection CreateValidatedTokenRequest(
Guid sendId,
params string[] passwordHash)
{
var sendIdBase64 = CoreHelpers.Base64UrlEncode(sendId.ToByteArray());
var rawRequestParameters = new NameValueCollection
{
{ OidcConstants.TokenRequest.GrantType, CustomGrantTypes.SendAccess },
{ OidcConstants.TokenRequest.ClientId, BitwardenClient.Send },
{ OidcConstants.TokenRequest.Scope, ApiScopes.ApiSendAccess },
{ "device_type", ((int)DeviceType.FirefoxBrowser).ToString() },
{ SendAccessConstants.TokenRequest.SendId, sendIdBase64 }
};
if (passwordHash != null && passwordHash.Length > 0)
{
foreach (var hash in passwordHash)
{
rawRequestParameters.Add(SendAccessConstants.TokenRequest.ClientB64HashedPassword, hash);
}
}
return rawRequestParameters;
}
}

View File

@@ -57,8 +57,8 @@ public class OrganizationRepositoryTests
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = $"Test Org {id}",
BillingEmail = user1.Email, // TODO: EF does not enforce this being NOT NULl
Plan = "Test", // TODO: EF does not enforce this being NOT NULl
BillingEmail = user1.Email, // TODO: EF does not enforce this being NOT NULL
Plan = "Test", // TODO: EF does not enforce this being NOT NULL
PrivateKey = "privatekey",
});

View File

@@ -28,8 +28,8 @@ public class OrganizationUserRepositoryTests
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
BillingEmail = user.Email, // TODO: EF does not enfore this being NOT NULL
Plan = "Test", // TODO: EF does not enforce this being NOT NULl
BillingEmail = user.Email, // TODO: EF does not enforce this being NOT NULL
Plan = "Test", // TODO: EF does not enforce this being NOT NULL
});
var orgUser = await organizationUserRepository.CreateAsync(new OrganizationUser
@@ -37,6 +37,7 @@ public class OrganizationUserRepositoryTests
OrganizationId = organization.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Confirmed,
Email = user.Email
});
await organizationUserRepository.DeleteAsync(orgUser);
@@ -46,6 +47,171 @@ public class OrganizationUserRepositoryTests
Assert.NotEqual(newUser.AccountRevisionDate, user.AccountRevisionDate);
}
[DatabaseTheory, DatabaseData]
public async Task DeleteManyAsync_Migrates_UserDefaultCollection(IUserRepository userRepository,
ICollectionRepository collectionRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository
)
{
var user1 = await userRepository.CreateAsync(new User
{
Name = "Test User",
Email = $"test+{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
});
var user2 = await userRepository.CreateAsync(new User
{
Name = "Test User",
Email = $"test+{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
});
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
BillingEmail = user1.Email, // TODO: EF does not enforce this being NOT NULL
Plan = "Test", // TODO: EF does not enforce this being NOT NULL
});
var orgUser1 = await organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organization.Id,
UserId = user1.Id,
Status = OrganizationUserStatusType.Confirmed,
Email = user1.Email
});
var orgUser2 = await organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organization.Id,
UserId = user2.Id,
Status = OrganizationUserStatusType.Confirmed,
Email = user2.Email
});
var defaultUserCollection1 = await collectionRepository.CreateAsync(new Collection
{
Name = "Test Collection 1",
Id = user1.Id,
Type = CollectionType.DefaultUserCollection,
OrganizationId = organization.Id
});
var defaultUserCollection2 = await collectionRepository.CreateAsync(new Collection
{
Name = "Test Collection 2",
Id = user2.Id,
Type = CollectionType.DefaultUserCollection,
OrganizationId = organization.Id
});
// Create the CollectionUser entry for the defaultUserCollection
await collectionRepository.UpdateUsersAsync(defaultUserCollection1.Id, new List<CollectionAccessSelection>()
{
new CollectionAccessSelection
{
Id = orgUser1.Id,
HidePasswords = false,
ReadOnly = false,
Manage = true
},
});
await collectionRepository.UpdateUsersAsync(defaultUserCollection2.Id, new List<CollectionAccessSelection>()
{
new CollectionAccessSelection
{
Id = orgUser2.Id,
HidePasswords = false,
ReadOnly = false,
Manage = true
},
});
await organizationUserRepository.DeleteManyAsync(new List<Guid> { orgUser1.Id, orgUser2.Id });
var newUser = await userRepository.GetByIdAsync(user1.Id);
Assert.NotNull(newUser);
Assert.NotEqual(newUser.AccountRevisionDate, user1.AccountRevisionDate);
var updatedCollection1 = await collectionRepository.GetByIdAsync(defaultUserCollection1.Id);
Assert.NotNull(updatedCollection1);
Assert.Equal(CollectionType.SharedCollection, updatedCollection1.Type);
Assert.Equal(user1.Email, updatedCollection1.DefaultUserCollectionEmail);
var updatedCollection2 = await collectionRepository.GetByIdAsync(defaultUserCollection2.Id);
Assert.NotNull(updatedCollection2);
Assert.Equal(CollectionType.SharedCollection, updatedCollection2.Type);
Assert.Equal(user2.Email, updatedCollection2.DefaultUserCollectionEmail);
}
[DatabaseTheory, DatabaseData]
public async Task DeleteAsync_Migrates_UserDefaultCollection(IUserRepository userRepository,
ICollectionRepository collectionRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository
)
{
var user = await userRepository.CreateAsync(new User
{
Name = "Test User",
Email = $"test+{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
});
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
BillingEmail = user.Email, // TODO: EF does not enforce this being NOT NULL
Plan = "Test", // TODO: EF does not enforce this being NOT NULL
});
var orgUser = await organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organization.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Confirmed,
Email = user.Email
});
var defaultUserCollection = await collectionRepository.CreateAsync(new Collection
{
Name = "Test Collection",
Id = user.Id,
Type = CollectionType.DefaultUserCollection,
OrganizationId = organization.Id
});
// Create the CollectionUser entry for the defaultUserCollection
await collectionRepository.UpdateUsersAsync(defaultUserCollection.Id, new List<CollectionAccessSelection>()
{
new CollectionAccessSelection
{
Id = orgUser.Id,
HidePasswords = false,
ReadOnly = false,
Manage = true
},
});
await organizationUserRepository.DeleteAsync(orgUser);
var newUser = await userRepository.GetByIdAsync(user.Id);
Assert.NotNull(newUser);
Assert.NotEqual(newUser.AccountRevisionDate, user.AccountRevisionDate);
var updatedCollection = await collectionRepository.GetByIdAsync(defaultUserCollection.Id);
Assert.NotNull(updatedCollection);
Assert.Equal(CollectionType.SharedCollection, updatedCollection.Type);
Assert.Equal(user.Email, updatedCollection.DefaultUserCollectionEmail);
}
[DatabaseTheory, DatabaseData]
public async Task DeleteManyAsync_Works(IUserRepository userRepository,
IOrganizationRepository organizationRepository,
@@ -70,8 +236,8 @@ public class OrganizationUserRepositoryTests
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
BillingEmail = user1.Email, // TODO: EF does not enforce this being NOT NULl
Plan = "Test", // TODO: EF does not enforce this being NOT NULl
BillingEmail = user1.Email, // TODO: EF does not enforce this being NOT NULL
Plan = "Test", // TODO: EF does not enforce this being NOT NULL
});
var orgUser1 = await organizationUserRepository.CreateAsync(new OrganizationUser
@@ -79,6 +245,7 @@ public class OrganizationUserRepositoryTests
OrganizationId = organization.Id,
UserId = user1.Id,
Status = OrganizationUserStatusType.Confirmed,
Email = user1.Email
});
var orgUser2 = await organizationUserRepository.CreateAsync(new OrganizationUser
@@ -86,6 +253,7 @@ public class OrganizationUserRepositoryTests
OrganizationId = organization.Id,
UserId = user2.Id,
Status = OrganizationUserStatusType.Confirmed,
Email = user2.Email
});
await organizationUserRepository.DeleteManyAsync(new List<Guid>
@@ -135,8 +303,8 @@ public class OrganizationUserRepositoryTests
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
BillingEmail = user1.Email, // TODO: EF does not enforce this being NOT NULl
Plan = "Test", // TODO: EF does not enforce this being NOT NULl
BillingEmail = user1.Email, // TODO: EF does not enforce this being NOT NULL
Plan = "Test", // TODO: EF does not enforce this being NOT NULL
PrivateKey = "privatekey",
});
@@ -291,8 +459,8 @@ public class OrganizationUserRepositoryTests
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
BillingEmail = user1.Email, // TODO: EF does not enforce this being NOT NULl
Plan = "Test", // TODO: EF does not enforce this being NOT NULl
BillingEmail = user1.Email, // TODO: EF does not enforce this being NOT NULL
Plan = "Test", // TODO: EF does not enforce this being NOT NULL
PrivateKey = "privatekey",
});
@@ -354,6 +522,134 @@ public class OrganizationUserRepositoryTests
Assert.Equal(organization.UseAdminSponsoredFamilies, result.UseAdminSponsoredFamilies);
}
[DatabaseTheory, DatabaseData]
public async Task GetManyByOrganizationWithClaimedDomainsAsync_WithVerifiedDomain_WithOneMatchingEmailDomain_ReturnsSingle(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationDomainRepository organizationDomainRepository)
{
var id = Guid.NewGuid();
var domainName = $"{id}.example.com";
var user1 = await userRepository.CreateAsync(new User
{
Name = "Test User 1",
Email = $"test+{id}@{domainName}",
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var user2 = await userRepository.CreateAsync(new User
{
Name = "Test User 2",
Email = $"test+{id}@x-{domainName}", // Different domain
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var user3 = await userRepository.CreateAsync(new User
{
Name = "Test User 2",
Email = $"test+{id}@{domainName}.example.com", // Different domain
ApiKey = "TEST",
SecurityStamp = "stamp",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 1,
KdfMemory = 2,
KdfParallelism = 3
});
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = $"Test Org {id}",
BillingEmail = user1.Email, // TODO: EF does not enforce this being NOT NULL
Plan = "Test", // TODO: EF does not enforce this being NOT NULL
PrivateKey = "privatekey",
UsePolicies = false,
UseSso = false,
UseKeyConnector = false,
UseScim = false,
UseGroups = false,
UseDirectory = false,
UseEvents = false,
UseTotp = false,
Use2fa = false,
UseApi = false,
UseResetPassword = false,
UseSecretsManager = false,
SelfHost = false,
UsersGetPremium = false,
UseCustomPermissions = false,
Enabled = true,
UsePasswordManager = false,
LimitCollectionCreation = false,
LimitCollectionDeletion = false,
LimitItemDeletion = false,
AllowAdminAccessToAllCollectionItems = false,
UseRiskInsights = false,
UseAdminSponsoredFamilies = false
});
var organizationDomain = new OrganizationDomain
{
OrganizationId = organization.Id,
DomainName = domainName,
Txt = "btw+12345",
};
organizationDomain.SetVerifiedDate();
organizationDomain.SetNextRunDate(12);
organizationDomain.SetJobRunCount();
await organizationDomainRepository.CreateAsync(organizationDomain);
var orgUser1 = await organizationUserRepository.CreateAsync(new OrganizationUser
{
Id = CoreHelpers.GenerateComb(),
OrganizationId = organization.Id,
UserId = user1.Id,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.Owner,
ResetPasswordKey = "resetpasswordkey1",
AccessSecretsManager = false
});
await organizationUserRepository.CreateAsync(new OrganizationUser
{
Id = CoreHelpers.GenerateComb(),
OrganizationId = organization.Id,
UserId = user2.Id,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.User,
ResetPasswordKey = "resetpasswordkey1",
AccessSecretsManager = false
});
await organizationUserRepository.CreateAsync(new OrganizationUser
{
Id = CoreHelpers.GenerateComb(),
OrganizationId = organization.Id,
UserId = user3.Id,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.User,
ResetPasswordKey = "resetpasswordkey1",
AccessSecretsManager = false
});
var responseModel = await organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organization.Id);
Assert.NotNull(responseModel);
Assert.Single(responseModel);
Assert.Equal(orgUser1.Id, responseModel.Single().Id);
}
[DatabaseTheory, DatabaseData]
public async Task CreateManyAsync_NoId_Works(IOrganizationRepository organizationRepository,
IUserRepository userRepository,
@@ -369,7 +665,7 @@ public class OrganizationUserRepositoryTests
{
Name = $"test-{Guid.NewGuid()}",
BillingEmail = "billing@example.com", // TODO: EF does not enforce this being NOT NULL
Plan = "Test", // TODO: EF does not enforce this being NOT NULl
Plan = "Test", // TODO: EF does not enforce this being NOT NULL
});
var orgUsers = users.Select(u => new OrganizationUser
@@ -403,7 +699,7 @@ public class OrganizationUserRepositoryTests
{
Name = $"test-{Guid.NewGuid()}",
BillingEmail = "billing@example.com", // TODO: EF does not enforce this being NOT NULL
Plan = "Test", // TODO: EF does not enforce this being NOT NULl
Plan = "Test", // TODO: EF does not enforce this being NOT NULL
});
var orgUsers = users.Select(u => new OrganizationUser
@@ -435,8 +731,8 @@ public class OrganizationUserRepositoryTests
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
BillingEmail = "billing@test.com", // TODO: EF does not enfore this being NOT NULL
Plan = "Test", // TODO: EF does not enforce this being NOT NULl,
BillingEmail = "billing@test.com", // TODO: EF does not enforce this being NOT NULL
Plan = "Test", // TODO: EF does not enforce this being NOT NULL,
CreationDate = requestTime
});
@@ -862,119 +1158,6 @@ public class OrganizationUserRepositoryTests
Assert.DoesNotContain(user1Result.Collections, c => c.Id == defaultUserCollection.Id);
}
[DatabaseTheory, DatabaseData]
public async Task GetManyByOrganizationWithClaimedDomainsAsync_WithVerifiedDomain_WithOneMatchingEmailDomain_ReturnsSingle(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationDomainRepository organizationDomainRepository)
{
var id = Guid.NewGuid();
var domainName = $"{id}.example.com";
var requestTime = DateTime.UtcNow;
var user1 = await userRepository.CreateAsync(new User
{
Id = CoreHelpers.GenerateComb(),
Name = "Test User 1",
Email = $"test+{id}@{domainName}",
ApiKey = "TEST",
SecurityStamp = "stamp",
CreationDate = requestTime,
RevisionDate = requestTime,
AccountRevisionDate = requestTime
});
var user2 = await userRepository.CreateAsync(new User
{
Id = CoreHelpers.GenerateComb(),
Name = "Test User 2",
Email = $"test+{id}@x-{domainName}", // Different domain
ApiKey = "TEST",
SecurityStamp = "stamp",
CreationDate = requestTime,
RevisionDate = requestTime,
AccountRevisionDate = requestTime
});
var user3 = await userRepository.CreateAsync(new User
{
Id = CoreHelpers.GenerateComb(),
Name = "Test User 3",
Email = $"test+{id}@{domainName}.example.com", // Different domain
ApiKey = "TEST",
SecurityStamp = "stamp",
CreationDate = requestTime,
RevisionDate = requestTime,
AccountRevisionDate = requestTime
});
var organization = await organizationRepository.CreateAsync(new Organization
{
Id = CoreHelpers.GenerateComb(),
Name = $"Test Org {id}",
BillingEmail = user1.Email,
Plan = "Test",
Enabled = true,
CreationDate = requestTime,
RevisionDate = requestTime
});
var organizationDomain = new OrganizationDomain
{
Id = CoreHelpers.GenerateComb(),
OrganizationId = organization.Id,
DomainName = domainName,
Txt = "btw+12345",
CreationDate = requestTime
};
organizationDomain.SetNextRunDate(12);
organizationDomain.SetVerifiedDate();
organizationDomain.SetJobRunCount();
await organizationDomainRepository.CreateAsync(organizationDomain);
var orgUser1 = await organizationUserRepository.CreateAsync(new OrganizationUser
{
Id = CoreHelpers.GenerateComb(),
OrganizationId = organization.Id,
UserId = user1.Id,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.Owner,
CreationDate = requestTime,
RevisionDate = requestTime
});
await organizationUserRepository.CreateAsync(new OrganizationUser
{
Id = CoreHelpers.GenerateComb(),
OrganizationId = organization.Id,
UserId = user2.Id,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.User,
CreationDate = requestTime,
RevisionDate = requestTime
});
await organizationUserRepository.CreateAsync(new OrganizationUser
{
Id = CoreHelpers.GenerateComb(),
OrganizationId = organization.Id,
UserId = user3.Id,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.User,
CreationDate = requestTime,
RevisionDate = requestTime
});
var responseModel = await organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organization.Id);
Assert.NotNull(responseModel);
Assert.Single(responseModel);
Assert.Equal(orgUser1.Id, responseModel.Single().Id);
Assert.Equal(user1.Id, responseModel.Single().UserId);
Assert.Equal(organization.Id, responseModel.Single().OrganizationId);
}
[DatabaseTheory, DatabaseData]
public async Task GetManyByOrganizationWithClaimedDomainsAsync_WithNoVerifiedDomain_ReturnsEmpty(
IUserRepository userRepository,
@@ -1039,6 +1222,120 @@ public class OrganizationUserRepositoryTests
Assert.Empty(responseModel);
}
[DatabaseTheory, DatabaseData]
public async Task DeleteAsync_WithNullEmail_DoesNotSetDefaultUserCollectionEmail(IUserRepository userRepository,
ICollectionRepository collectionRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository
)
{
var user = await userRepository.CreateAsync(new User
{
Name = "Test User",
Email = $"test+{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
});
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
BillingEmail = user.Email,
Plan = "Test",
});
var orgUser = await organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organization.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Confirmed,
Email = null
});
var defaultUserCollection = await collectionRepository.CreateAsync(new Collection
{
Name = "Test Collection",
Id = user.Id,
Type = CollectionType.DefaultUserCollection,
OrganizationId = organization.Id
});
await collectionRepository.UpdateUsersAsync(defaultUserCollection.Id, new List<CollectionAccessSelection>()
{
new CollectionAccessSelection
{
Id = orgUser.Id,
HidePasswords = false,
ReadOnly = false,
Manage = true
},
});
await organizationUserRepository.DeleteAsync(orgUser);
var updatedCollection = await collectionRepository.GetByIdAsync(defaultUserCollection.Id);
Assert.NotNull(updatedCollection);
Assert.Equal(CollectionType.SharedCollection, updatedCollection.Type);
Assert.Equal(user.Email, updatedCollection.DefaultUserCollectionEmail);
}
[DatabaseTheory, DatabaseData]
public async Task DeleteAsync_WithEmptyEmail_DoesNotSetDefaultUserCollectionEmail(IUserRepository userRepository,
ICollectionRepository collectionRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository
)
{
var user = await userRepository.CreateAsync(new User
{
Name = "Test User",
Email = $"test+{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
});
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
BillingEmail = user.Email,
Plan = "Test",
});
var orgUser = await organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organization.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Confirmed,
Email = "" // Empty string email
});
var defaultUserCollection = await collectionRepository.CreateAsync(new Collection
{
Name = "Test Collection",
Id = user.Id,
Type = CollectionType.DefaultUserCollection,
OrganizationId = organization.Id
});
await collectionRepository.UpdateUsersAsync(defaultUserCollection.Id, new List<CollectionAccessSelection>()
{
new CollectionAccessSelection
{
Id = orgUser.Id,
HidePasswords = false,
ReadOnly = false,
Manage = true
},
});
await organizationUserRepository.DeleteAsync(orgUser);
var updatedCollection = await collectionRepository.GetByIdAsync(defaultUserCollection.Id);
Assert.NotNull(updatedCollection);
Assert.Equal(CollectionType.SharedCollection, updatedCollection.Type);
Assert.Equal(user.Email, updatedCollection.DefaultUserCollectionEmail);
}
[DatabaseTheory, DatabaseData]
public async Task ReplaceAsync_PreservesDefaultCollections_WhenUpdatingCollectionAccess(
IUserRepository userRepository,

View File

@@ -40,6 +40,10 @@ public class GetPolicyDetailsByOrganizationIdAsyncTests
Assert.True(results.Single().IsProvider);
// Annul
await organizationRepository.DeleteAsync(new Organization { Id = userOrgConnectedDirectly.OrganizationId });
await userRepository.DeleteAsync(user);
async Task ArrangeProvider()
{
var provider = await providerRepository.CreateAsync(new Provider
@@ -86,6 +90,11 @@ public class GetPolicyDetailsByOrganizationIdAsyncTests
Assert.Contains(results, result => result.OrganizationUserId == userOrgConnectedDirectly.Id
&& result.OrganizationId == userOrgConnectedDirectly.OrganizationId);
Assert.DoesNotContain(results, result => result.OrganizationId == notConnectedOrg.Id);
// Annul
await organizationRepository.DeleteAsync(new Organization { Id = userOrgConnectedDirectly.OrganizationId });
await organizationRepository.DeleteAsync(notConnectedOrg);
await userRepository.DeleteAsync(user);
}
[DatabaseTheory, DatabaseData]
@@ -115,6 +124,10 @@ public class GetPolicyDetailsByOrganizationIdAsyncTests
&& result.PolicyType == inputPolicyType);
Assert.DoesNotContain(results, result => result.PolicyType == notInputPolicyType);
// Annul
await organizationRepository.DeleteAsync(new Organization { Id = orgUser.OrganizationId });
await userRepository.DeleteAsync(user);
}
@@ -143,6 +156,12 @@ public class GetPolicyDetailsByOrganizationIdAsyncTests
Assert.Equal(expectedCount, results.Count);
AssertPolicyDetailUserConnections(results, userOrgConnectedDirectly, userOrgConnectedByEmail, userOrgConnectedByUserId);
// Annul
await organizationRepository.DeleteAsync(new Organization() { Id = userOrgConnectedDirectly.OrganizationId });
await organizationRepository.DeleteAsync(new Organization() { Id = userOrgConnectedByEmail.OrganizationId });
await organizationRepository.DeleteAsync(new Organization() { Id = userOrgConnectedByUserId.OrganizationId });
await userRepository.DeleteAsync(user);
}
[DatabaseTheory, DatabaseData]
@@ -167,8 +186,52 @@ public class GetPolicyDetailsByOrganizationIdAsyncTests
// Assert
AssertPolicyDetailUserConnections(results, userOrgConnectedDirectly, userOrgConnectedByEmail, userOrgConnectedByUserId);
// Annul
await organizationRepository.DeleteAsync(new Organization() { Id = userOrgConnectedDirectly.OrganizationId });
await organizationRepository.DeleteAsync(new Organization() { Id = userOrgConnectedByEmail.OrganizationId });
await organizationRepository.DeleteAsync(new Organization() { Id = userOrgConnectedByUserId.OrganizationId });
await userRepository.DeleteAsync(user);
}
[DatabaseTheory, DatabaseData]
public async Task ShouldReturnUserIds(
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository,
IPolicyRepository policyRepository)
{
// Arrange
var user1 = await userRepository.CreateTestUserAsync();
var user2 = await userRepository.CreateTestUserAsync();
const PolicyType policyType = PolicyType.SingleOrg;
var organization = await CreateEnterpriseOrg(organizationRepository);
await policyRepository.CreateAsync(new Policy { OrganizationId = organization.Id, Enabled = true, Type = policyType });
var orgUser1 = await organizationUserRepository.CreateTestOrganizationUserAsync(organization, user1);
var orgUser2 = await organizationUserRepository.CreateTestOrganizationUserAsync(organization, user2);
// Act
var results = (await policyRepository.GetPolicyDetailsByOrganizationIdAsync(organization.Id, policyType)).ToList();
// Assert
Assert.Equal(2, results.Count);
Assert.Contains(results, result => result.OrganizationUserId == orgUser1.Id
&& result.UserId == orgUser1.UserId
&& result.OrganizationId == orgUser1.OrganizationId);
Assert.Contains(results, result => result.OrganizationUserId == orgUser2.Id
&& result.UserId == orgUser2.UserId
&& result.OrganizationId == orgUser2.OrganizationId);
// Annul
await organizationRepository.DeleteAsync(organization);
await userRepository.DeleteManyAsync([user1, user2]);
}
private async Task<OrganizationUser> ArrangeOtherOrgConnectedByUserIdAsync(IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository, IPolicyRepository policyRepository, User user,
PolicyType policyType)

View File

@@ -55,7 +55,7 @@ public class UserRepositoryTests
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
BillingEmail = user3.Email, // TODO: EF does not enfore this being NOT NULL
BillingEmail = user3.Email, // TODO: EF does not enforce this being NOT NULL
Plan = "Test", // TODO: EF does not enforce this being NOT NULl
});

View File

@@ -10,129 +10,29 @@ using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Time.Testing;
using Xunit;
using Xunit.Sdk;
using Xunit.v3;
namespace Bit.Infrastructure.IntegrationTest;
public class DatabaseDataAttribute : DataAttribute
{
private static IConfiguration? _cachedConfiguration;
private static IConfiguration GetConfiguration()
{
return _cachedConfiguration ??= new ConfigurationBuilder()
.AddUserSecrets<DatabaseDataAttribute>(optional: true, reloadOnChange: false)
.AddEnvironmentVariables("BW_TEST_")
.AddCommandLine(Environment.GetCommandLineArgs())
.Build();
}
public bool SelfHosted { get; set; }
public bool UseFakeTimeProvider { get; set; }
public string? MigrationName { get; set; }
public override IEnumerable<object[]> GetData(MethodInfo testMethod)
{
var parameters = testMethod.GetParameters();
var config = DatabaseTheoryAttribute.GetConfiguration();
var serviceProviders = GetDatabaseProviders(config);
foreach (var provider in serviceProviders)
{
var objects = new object[parameters.Length];
for (var i = 0; i < parameters.Length; i++)
{
objects[i] = provider.GetRequiredService(parameters[i].ParameterType);
}
yield return objects;
}
}
protected virtual IEnumerable<IServiceProvider> GetDatabaseProviders(IConfiguration config)
{
// This is for the device repository integration testing.
var userRequestExpiration = 15;
var configureLogging = (ILoggingBuilder builder) =>
{
if (!config.GetValue<bool>("Quiet"))
{
builder.AddConfiguration(config);
builder.AddConsole();
builder.AddDebug();
}
};
var databases = config.GetDatabases();
foreach (var database in databases)
{
if (database.Type == SupportedDatabaseProviders.SqlServer && !database.UseEf)
{
var dapperSqlServerCollection = new ServiceCollection();
AddCommonServices(dapperSqlServerCollection, configureLogging);
dapperSqlServerCollection.AddDapperRepositories(SelfHosted);
var globalSettings = new GlobalSettings
{
DatabaseProvider = "sqlServer",
SqlServer = new GlobalSettings.SqlSettings
{
ConnectionString = database.ConnectionString,
},
PasswordlessAuth = new GlobalSettings.PasswordlessAuthSettings
{
UserRequestExpiration = TimeSpan.FromMinutes(userRequestExpiration),
}
};
dapperSqlServerCollection.AddSingleton(globalSettings);
dapperSqlServerCollection.AddSingleton<IGlobalSettings>(globalSettings);
dapperSqlServerCollection.AddSingleton(database);
dapperSqlServerCollection.AddDistributedSqlServerCache(o =>
{
o.ConnectionString = database.ConnectionString;
o.SchemaName = "dbo";
o.TableName = "Cache";
});
if (!string.IsNullOrEmpty(MigrationName))
{
AddSqlMigrationTester(dapperSqlServerCollection, database.ConnectionString, MigrationName);
}
yield return dapperSqlServerCollection.BuildServiceProvider();
}
else
{
var efCollection = new ServiceCollection();
AddCommonServices(efCollection, configureLogging);
efCollection.SetupEntityFramework(database.ConnectionString, database.Type);
efCollection.AddPasswordManagerEFRepositories(SelfHosted);
var globalSettings = new GlobalSettings
{
PasswordlessAuth = new GlobalSettings.PasswordlessAuthSettings
{
UserRequestExpiration = TimeSpan.FromMinutes(userRequestExpiration),
}
};
efCollection.AddSingleton(globalSettings);
efCollection.AddSingleton<IGlobalSettings>(globalSettings);
efCollection.AddSingleton(database);
efCollection.AddSingleton<IDistributedCache, EntityFrameworkCache>();
if (!string.IsNullOrEmpty(MigrationName))
{
AddEfMigrationTester(efCollection, database.Type, MigrationName);
}
yield return efCollection.BuildServiceProvider();
}
}
}
private void AddCommonServices(IServiceCollection services, Action<ILoggingBuilder> configureLogging)
{
services.AddLogging(configureLogging);
services.AddDataProtection();
if (UseFakeTimeProvider)
{
services.AddSingleton<TimeProvider, FakeTimeProvider>();
}
}
private void AddSqlMigrationTester(IServiceCollection services, string connectionString, string migrationName)
{
services.AddSingleton<IMigrationTesterService, SqlMigrationTesterService>(_ => new SqlMigrationTesterService(connectionString, migrationName));
@@ -146,4 +46,171 @@ public class DatabaseDataAttribute : DataAttribute
return new EfMigrationTesterService(dbContext, databaseType, migrationName);
});
}
public override ValueTask<IReadOnlyCollection<ITheoryDataRow>> GetData(MethodInfo testMethod, DisposalTracker disposalTracker)
{
var config = GetConfiguration();
HashSet<SupportedDatabaseProviders> unconfiguredDatabases =
[
SupportedDatabaseProviders.MySql,
SupportedDatabaseProviders.Postgres,
SupportedDatabaseProviders.Sqlite,
SupportedDatabaseProviders.SqlServer
];
var theories = new List<ITheoryDataRow>();
foreach (var database in config.GetDatabases())
{
unconfiguredDatabases.Remove(database.Type);
if (!database.Enabled)
{
var theory = new TheoryDataRow()
.WithSkip("Not-Enabled")
.WithTrait("Database", database.Type.ToString());
theory.Label = database.Type.ToString();
theories.Add(theory);
continue;
}
var services = new ServiceCollection();
AddCommonServices(services);
if (database.Type == SupportedDatabaseProviders.SqlServer && !database.UseEf)
{
// Dapper services
AddDapperServices(services, database);
}
else
{
// Ef services
AddEfServices(services, database);
}
var serviceProvider = services.BuildServiceProvider();
disposalTracker.Add(serviceProvider);
var serviceTheory = new ServiceBasedTheoryDataRow(serviceProvider, testMethod)
.WithTrait("Database", database.Type.ToString())
.WithTrait("ConnectionString", database.ConnectionString);
serviceTheory.Label = database.Type.ToString();
theories.Add(serviceTheory);
}
foreach (var unconfiguredDatabase in unconfiguredDatabases)
{
var theory = new TheoryDataRow()
.WithSkip("Unconfigured")
.WithTrait("Database", unconfiguredDatabase.ToString());
theory.Label = unconfiguredDatabase.ToString();
theories.Add(theory);
}
return new(theories);
}
private void AddCommonServices(IServiceCollection services)
{
// Common services
services.AddDataProtection();
services.AddLogging(logging =>
{
logging.AddProvider(new XUnitLoggerProvider());
});
if (UseFakeTimeProvider)
{
services.AddSingleton<TimeProvider, FakeTimeProvider>();
}
}
private void AddDapperServices(IServiceCollection services, Database database)
{
services.AddDapperRepositories(SelfHosted);
var globalSettings = new GlobalSettings
{
DatabaseProvider = "sqlServer",
SqlServer = new GlobalSettings.SqlSettings
{
ConnectionString = database.ConnectionString,
},
PasswordlessAuth = new GlobalSettings.PasswordlessAuthSettings
{
UserRequestExpiration = TimeSpan.FromMinutes(15),
}
};
services.AddSingleton(globalSettings);
services.AddSingleton<IGlobalSettings>(globalSettings);
services.AddSingleton(database);
services.AddDistributedSqlServerCache(o =>
{
o.ConnectionString = database.ConnectionString;
o.SchemaName = "dbo";
o.TableName = "Cache";
});
if (!string.IsNullOrEmpty(MigrationName))
{
AddSqlMigrationTester(services, database.ConnectionString, MigrationName);
}
}
private void AddEfServices(IServiceCollection services, Database database)
{
services.SetupEntityFramework(database.ConnectionString, database.Type);
services.AddPasswordManagerEFRepositories(SelfHosted);
var globalSettings = new GlobalSettings
{
PasswordlessAuth = new GlobalSettings.PasswordlessAuthSettings
{
UserRequestExpiration = TimeSpan.FromMinutes(15),
},
};
services.AddSingleton(globalSettings);
services.AddSingleton<IGlobalSettings>(globalSettings);
services.AddSingleton(database);
services.AddSingleton<IDistributedCache, EntityFrameworkCache>();
if (!string.IsNullOrEmpty(MigrationName))
{
AddEfMigrationTester(services, database.Type, MigrationName);
}
}
public override bool SupportsDiscoveryEnumeration()
{
return true;
}
private class ServiceBasedTheoryDataRow : TheoryDataRowBase
{
private readonly IServiceProvider _serviceProvider;
private readonly MethodInfo _testMethod;
public ServiceBasedTheoryDataRow(IServiceProvider serviceProvider, MethodInfo testMethod)
{
_serviceProvider = serviceProvider;
_testMethod = testMethod;
}
protected override object?[] GetData()
{
var parameters = _testMethod.GetParameters();
var services = new object?[parameters.Length];
for (var i = 0; i < parameters.Length; i++)
{
var parameter = parameters[i];
// TODO: Could support keyed services/optional/nullable
services[i] = _serviceProvider.GetRequiredService(parameter.ParameterType);
}
return services;
}
}
}

View File

@@ -1,32 +1,17 @@
using Microsoft.Extensions.Configuration;
using System.Runtime.CompilerServices;
using Xunit;
namespace Bit.Infrastructure.IntegrationTest;
[Obsolete("This attribute is no longer needed and can be replaced with a [Theory]")]
public class DatabaseTheoryAttribute : TheoryAttribute
{
private static IConfiguration? _cachedConfiguration;
public DatabaseTheoryAttribute()
{
if (!HasAnyDatabaseSetup())
{
Skip = "No databases setup.";
}
}
private static bool HasAnyDatabaseSetup()
public DatabaseTheoryAttribute([CallerFilePath] string? sourceFilePath = null, [CallerLineNumber] int sourceLineNumber = -1) : base(sourceFilePath, sourceLineNumber)
{
var config = GetConfiguration();
return config.GetDatabases().Length > 0;
}
public static IConfiguration GetConfiguration()
{
return _cachedConfiguration ??= new ConfigurationBuilder()
.AddUserSecrets<DatabaseDataAttribute>(optional: true, reloadOnChange: false)
.AddEnvironmentVariables("BW_TEST_")
.AddCommandLine(Environment.GetCommandLineArgs())
.Build();
}
}

View File

@@ -65,7 +65,7 @@ public class DistributedCacheTests
[DatabaseTheory, DatabaseData]
public async Task MultipleWritesOnSameKey_ShouldNotThrow(IDistributedCache cache)
{
await cache.SetAsync("test-duplicate", "some-value"u8.ToArray());
await cache.SetAsync("test-duplicate", "some-value"u8.ToArray());
await cache.SetAsync("test-duplicate", "some-value"u8.ToArray(), TestContext.Current.CancellationToken);
await cache.SetAsync("test-duplicate", "some-value"u8.ToArray(), TestContext.Current.CancellationToken);
}
}

View File

@@ -12,8 +12,8 @@
<PackageReference Include="Microsoft.Extensions.Logging" Version="8.0.1" />
<PackageReference Include="Microsoft.Extensions.TimeProvider.Testing" Version="8.10.0" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNetTestSdkVersion)" />
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitRunnerVisualStudioVersion)">
<PackageReference Include="xunit.v3" Version="3.0.1" />
<PackageReference Include="xunit.runner.visualstudio" Version="3.1.4">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>

View File

@@ -0,0 +1,47 @@
using Microsoft.Extensions.Logging;
using Xunit;
namespace Bit.Infrastructure.IntegrationTest;
public sealed class XUnitLoggerProvider : ILoggerProvider
{
public ILogger CreateLogger(string categoryName)
{
return new XUnitLogger(categoryName);
}
public void Dispose()
{
}
private class XUnitLogger : ILogger
{
private readonly string _categoryName;
public XUnitLogger(string categoryName)
{
_categoryName = categoryName;
}
public IDisposable? BeginScope<TState>(TState state) where TState : notnull
{
return null;
}
public bool IsEnabled(LogLevel logLevel)
{
return true;
}
public void Log<TState>(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func<TState, Exception?, string> formatter)
{
if (TestContext.Current?.TestOutputHelper is not ITestOutputHelper testOutputHelper)
{
return;
}
testOutputHelper.WriteLine($"[{_categoryName}] {formatter(state, exception)}");
}
}
}

View File

@@ -0,0 +1,60 @@
using Bit.Core.Utilities;
using Bit.SharedWeb.Swagger;
using Microsoft.OpenApi.Models;
using Swashbuckle.AspNetCore.SwaggerGen;
namespace SharedWeb.Test;
public class EncryptedStringSchemaFilterTest
{
private class TestClass
{
[EncryptedString]
public string SecretKey { get; set; }
public string Username { get; set; }
[EncryptedString]
public int Wrong { get; set; }
}
[Fact]
public void AnnotatedStringSetsFormat()
{
var schema = new OpenApiSchema
{
Properties = new Dictionary<string, OpenApiSchema> { { "secretKey", new() } }
};
var context = new SchemaFilterContext(typeof(TestClass), null, null, null);
var filter = new EncryptedStringSchemaFilter();
filter.Apply(schema, context);
Assert.Equal("x-enc-string", schema.Properties["secretKey"].Format);
}
[Fact]
public void NonAnnotatedStringIsIgnored()
{
var schema = new OpenApiSchema
{
Properties = new Dictionary<string, OpenApiSchema> { { "username", new() } }
};
var context = new SchemaFilterContext(typeof(TestClass), null, null, null);
var filter = new EncryptedStringSchemaFilter();
filter.Apply(schema, context);
Assert.Null(schema.Properties["username"].Format);
}
[Fact]
public void AnnotatedWrongTypeIsIgnored()
{
var schema = new OpenApiSchema
{
Properties = new Dictionary<string, OpenApiSchema> { { "wrong", new() } }
};
var context = new SchemaFilterContext(typeof(TestClass), null, null, null);
var filter = new EncryptedStringSchemaFilter();
filter.Apply(schema, context);
Assert.Null(schema.Properties["wrong"].Format);
}
}

View File

@@ -0,0 +1,41 @@
using Bit.SharedWeb.Swagger;
using Microsoft.OpenApi.Any;
using Microsoft.OpenApi.Models;
using Swashbuckle.AspNetCore.SwaggerGen;
namespace SharedWeb.Test;
public class EnumSchemaFilterTest
{
private enum TestEnum
{
First,
Second,
Third
}
[Fact]
public void SetsEnumVarNamesExtension()
{
var schema = new OpenApiSchema();
var context = new SchemaFilterContext(typeof(TestEnum), null, null, null);
var filter = new EnumSchemaFilter();
filter.Apply(schema, context);
Assert.True(schema.Extensions.ContainsKey("x-enum-varnames"));
var extensions = schema.Extensions["x-enum-varnames"] as OpenApiArray;
Assert.NotNull(extensions);
Assert.Equal(["First", "Second", "Third"], extensions.Select(x => ((OpenApiString)x).Value));
}
[Fact]
public void DoesNotSetExtensionForNonEnum()
{
var schema = new OpenApiSchema();
var context = new SchemaFilterContext(typeof(string), null, null, null);
var filter = new EnumSchemaFilter();
filter.Apply(schema, context);
Assert.False(schema.Extensions.ContainsKey("x-enum-varnames"));
}
}

View File

@@ -0,0 +1,23 @@
using Bit.SharedWeb.Swagger;
using Microsoft.OpenApi.Models;
using Swashbuckle.AspNetCore.SwaggerGen;
namespace SharedWeb.Test;
public class GitCommitDocumentFilterTest
{
[Fact]
public void AddsGitCommitExtensionIfAvailable()
{
var doc = new OpenApiDocument();
var context = new DocumentFilterContext(null, null, null);
var filter = new GitCommitDocumentFilter();
filter.Apply(doc, context);
Assert.True(doc.Extensions.ContainsKey("x-git-commit"));
var ext = doc.Extensions["x-git-commit"] as Microsoft.OpenApi.Any.OpenApiString;
Assert.NotNull(ext);
Assert.False(string.IsNullOrEmpty(ext.Value));
}
}

View File

@@ -0,0 +1 @@
global using Xunit;

View File

@@ -0,0 +1,22 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<IsPackable>false</IsPackable>
<RootNamespace>SharedWeb.Test</RootNamespace>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="coverlet.collector" Version="$(CoverletCollectorVersion)">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNetTestSdkVersion)" />
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.visualstudio"
Version="$(XUnitRunnerVisualStudioVersion)">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers</IncludeAssets>
</PackageReference>
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\SharedWeb\SharedWeb.csproj" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,33 @@
using Bit.SharedWeb.Swagger;
using Microsoft.OpenApi.Models;
using Swashbuckle.AspNetCore.SwaggerGen;
namespace SharedWeb.Test;
public class SourceFileLineOperationFilterTest
{
private class DummyController
{
public void DummyMethod() { }
}
[Fact]
public void AddsSourceFileAndLineExtensionsIfAvailable()
{
var methodInfo = typeof(DummyController).GetMethod(nameof(DummyController.DummyMethod));
var operation = new OpenApiOperation();
var context = new OperationFilterContext(null, null, null, methodInfo);
var filter = new SourceFileLineOperationFilter();
filter.Apply(operation, context);
Assert.True(operation.Extensions.ContainsKey("x-source-file"));
Assert.True(operation.Extensions.ContainsKey("x-source-line"));
var fileExt = operation.Extensions["x-source-file"] as Microsoft.OpenApi.Any.OpenApiString;
var lineExt = operation.Extensions["x-source-line"] as Microsoft.OpenApi.Any.OpenApiInteger;
Assert.NotNull(fileExt);
Assert.NotNull(lineExt);
Assert.Equal(11, lineExt.Value);
Assert.StartsWith("test/SharedWeb.Test/", fileExt.Value);
}
}