diff --git a/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs b/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs index 3365e754ca..bf49f144ce 100644 --- a/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs +++ b/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs @@ -340,7 +340,7 @@ public class OrganizationUsersController : Controller [FromBody] OrganizationUserBulkConfirmRequestModel model) { var userId = _userService.GetProperUserId(User); - var results = await _confirmOrganizationUserCommand.ConfirmUsersAsync(orgId, model.ToDictionary(), userId.Value); + var results = await _confirmOrganizationUserCommand.ConfirmUsersAsync(orgId, model.ToDictionary(), userId.Value, model.DefaultUserCollectionName); return new ListResponseModel(results.Select(r => new OrganizationUserBulkResponseModel(r.Item1.Id, r.Item2))); diff --git a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUserRequestModels.cs b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUserRequestModels.cs index b4d3326013..4e0accb9e8 100644 --- a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUserRequestModels.cs +++ b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUserRequestModels.cs @@ -82,6 +82,10 @@ public class OrganizationUserBulkConfirmRequestModel [Required] public IEnumerable Keys { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string DefaultUserCollectionName { get; set; } + public Dictionary ToDictionary() { return Keys.ToDictionary(e => e.Id, e => e.Key); diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs index 6ec69312ad..0baa9c9e3a 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs @@ -11,7 +11,6 @@ using Bit.Core.Billing.Enums; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; -using Bit.Core.Models.Data; using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Services; @@ -67,7 +66,7 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand public async Task ConfirmUserAsync(Guid organizationId, Guid organizationUserId, string key, Guid confirmingUserId, string defaultUserCollectionName = null) { - var result = await ConfirmUsersAsync( + var result = await SaveChangesToDatabaseAsync( organizationId, new Dictionary() { { organizationUserId, key } }, confirmingUserId); @@ -83,12 +82,30 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand throw new BadRequestException(error); } - await HandleConfirmationSideEffectsAsync(organizationId, orgUser, defaultUserCollectionName); + await HandleConfirmationSideEffectsAsync(organizationId, confirmedOrganizationUsers: [orgUser], defaultUserCollectionName); return orgUser; } public async Task>> ConfirmUsersAsync(Guid organizationId, Dictionary keys, + Guid confirmingUserId, string defaultUserCollectionName = null) + { + var result = await SaveChangesToDatabaseAsync(organizationId, keys, confirmingUserId); + + var confirmedOrganizationUsers = result + .Where(r => string.IsNullOrEmpty(r.Item2)) + .Select(r => r.Item1) + .ToList(); + + if (confirmedOrganizationUsers.Count > 0) + { + await HandleConfirmationSideEffectsAsync(organizationId, confirmedOrganizationUsers, defaultUserCollectionName); + } + + return result; + } + + private async Task>> SaveChangesToDatabaseAsync(Guid organizationId, Dictionary keys, Guid confirmingUserId) { var selectedOrganizationUsers = await _organizationUserRepository.GetManyAsync(keys.Keys); @@ -227,17 +244,7 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand .Select(d => d.Id.ToString()); } - private async Task HandleConfirmationSideEffectsAsync(Guid organizationId, OrganizationUser organizationUser, string defaultUserCollectionName) - { - // Create DefaultUserCollection type collection for the user if the OrganizationDataOwnership policy is enabled for the organization - var requiresDefaultCollection = await OrganizationRequiresDefaultCollectionAsync(organizationId, organizationUser.UserId.Value, defaultUserCollectionName); - if (requiresDefaultCollection) - { - await CreateDefaultCollectionAsync(organizationId, organizationUser.Id, defaultUserCollectionName); - } - } - - private async Task OrganizationRequiresDefaultCollectionAsync(Guid organizationId, Guid userId, string defaultUserCollectionName) + private async Task OrganizationRequiresDefaultCollectionAsync(Guid organizationId, string defaultUserCollectionName) { if (!_featureService.IsEnabled(FeatureFlagKeys.CreateDefaultLocation)) { @@ -250,30 +257,29 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand return false; } - var organizationDataOwnershipRequirement = await _policyRequirementQuery.GetAsync(userId); - return organizationDataOwnershipRequirement.RequiresDefaultCollection(organizationId); + var organizationPolicyRequirement = await _policyRequirementQuery.GetByOrganizationAsync(organizationId); + + // Check if the organization requires default collections + return organizationPolicyRequirement.RequiresDefaultCollection(organizationId); } - private async Task CreateDefaultCollectionAsync(Guid organizationId, Guid organizationUserId, string defaultCollectionName) + /// + /// Handles the side effects of confirming an organization user. + /// Creates a default collection for the user if the organization + /// has the OrganizationDataOwnership policy enabled. + /// + /// The organization ID. + /// The confirmed organization users. + /// The encrypted default user collection name. + private async Task HandleConfirmationSideEffectsAsync(Guid organizationId, IEnumerable confirmedOrganizationUsers, string defaultUserCollectionName) { - var collection = new Collection + var requiresDefaultCollections = await OrganizationRequiresDefaultCollectionAsync(organizationId, defaultUserCollectionName); + if (!requiresDefaultCollections) { - OrganizationId = organizationId, - Name = defaultCollectionName, - Type = CollectionType.DefaultUserCollection - }; + return; + } - var userAccess = new List - { - new CollectionAccessSelection - { - Id = organizationUserId, - ReadOnly = false, - HidePasswords = false, - Manage = true - } - }; - - await _collectionRepository.CreateAsync(collection, groups: null, users: userAccess); + var organizationUserIds = confirmedOrganizationUsers.Select(u => u.Id).ToList(); + await _collectionRepository.CreateDefaultCollectionsAsync(organizationId, organizationUserIds, defaultUserCollectionName); } } diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IConfirmOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IConfirmOrganizationUserCommand.cs index cf5999f892..aca4853b66 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IConfirmOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IConfirmOrganizationUserCommand.cs @@ -29,7 +29,8 @@ public interface IConfirmOrganizationUserCommand /// The ID of the organization. /// A dictionary mapping organization user IDs to their encrypted organization keys. /// The ID of the user performing the confirmation. + /// Optional encrypted collection name for creating default collections. /// A list of tuples containing the organization user and an error message (if any). Task>> ConfirmUsersAsync(Guid organizationId, Dictionary keys, - Guid confirmingUserId); + Guid confirmingUserId, string defaultUserCollectionName = null); } diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs index 5736078f22..226347fe29 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs @@ -15,4 +15,14 @@ public interface IPolicyRequirementQuery /// The user that you need to enforce the policy against. /// The IPolicyRequirement that corresponds to the policy you want to enforce. Task GetAsync(Guid userId) where T : IPolicyRequirement; + + /// + /// Get a policy requirement for a specific organization. + /// This returns the policy requirement that represents the policy state for the entire organization. + /// It will always return a value even if there are no policies that should be enforced. + /// This should be used for organization-level policy checks. + /// + /// The organization to check policies for. + /// The IPolicyRequirement that corresponds to the policy you want to enforce. + Task GetByOrganizationAsync(Guid organizationId) where T : IPolicyRequirement; } diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs index de4796d4b5..ba4495224c 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs @@ -1,5 +1,6 @@ #nullable enable +using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Repositories; @@ -27,6 +28,27 @@ public class PolicyRequirementQuery( return requirement; } + public async Task GetByOrganizationAsync(Guid organizationId) where T : IPolicyRequirement + { + var factory = factories.OfType>().SingleOrDefault(); + if (factory is null) + { + throw new NotImplementedException("No Requirement Factory found for " + typeof(T)); + } + + var organizationPolicyDetails = await GetOrganizationPolicyDetails(organizationId, factory.PolicyType); + var filteredPolicies = organizationPolicyDetails + .Cast() + .Where(policyDetails => policyDetails.PolicyType == factory.PolicyType) + .Where(factory.Enforce) + .ToList(); + var requirement = factory.Create(filteredPolicies); + return requirement; + } + private Task> GetPolicyDetails(Guid userId) => policyRepository.GetPolicyDetailsByUserId(userId); + + private async Task> GetOrganizationPolicyDetails(Guid organizationId, PolicyType policyType) + => await policyRepository.GetPolicyDetailsByOrganizationIdAsync(organizationId, policyType); } diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerTests.cs index ca13585017..c60c2049a1 100644 --- a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerTests.cs +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerTests.cs @@ -2,19 +2,34 @@ using Bit.Api.AdminConsole.Models.Request.Organizations; using Bit.Api.IntegrationTest.Factories; using Bit.Api.IntegrationTest.Helpers; +using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Enums; +using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Data; +using Bit.Core.Repositories; +using Bit.Core.Services; +using NSubstitute; using Xunit; namespace Bit.Api.IntegrationTest.AdminConsole.Controllers; public class OrganizationUserControllerTests : IClassFixture, IAsyncLifetime { + private static readonly string _mockEncryptedString = + "2.AOs41Hd8OQiCPXjyJKCiDA==|O6OHgt2U2hJGBSNGnimJmg==|iD33s8B69C8JhYYhSa4V1tArjvLr8eEaGqOV7BRo5Jk="; + + public OrganizationUserControllerTests(ApiApplicationFactory apiFactory) { _factory = apiFactory; + _factory.SubstituteService(featureService => + { + featureService + .IsEnabled(FeatureFlagKeys.CreateDefaultLocation) + .Returns(true); + }); _client = _factory.CreateClient(); _loginHelper = new LoginHelper(_factory, _client); } @@ -93,9 +108,113 @@ public class OrganizationUserControllerTests : IClassFixture new OrganizationUserBulkConfirmRequestModelEntry + { + Id = organizationUser.Id, + Key = string.Format(testKeyFormat, index) + }), + DefaultUserCollectionName = _mockEncryptedString + }; + + var bulkConfirmResponse = await _client.PostAsJsonAsync($"organizations/{_organization.Id}/users/confirm", bulkConfirmModel); + + Assert.Equal(HttpStatusCode.OK, bulkConfirmResponse.StatusCode); + + await VerifyMultipleUsersConfirmedAsync(acceptedUsers.Select((organizationUser, index) => + (organizationUser, string.Format(testKeyFormat, index))).ToList()); + await VerifyMultipleUsersHaveDefaultCollectionsAsync(acceptedUsers); + } + public Task DisposeAsync() { _client.Dispose(); return Task.CompletedTask; } + + private async Task> CreateAcceptedUsersAsync(IEnumerable emails) + { + var acceptedUsers = new List(); + + foreach (var email in emails) + { + await _factory.LoginWithNewAccount(email); + + var acceptedOrgUser = await OrganizationTestHelpers.CreateUserAsync(_factory, _organization.Id, email, + OrganizationUserType.User, userStatusType: OrganizationUserStatusType.Accepted); + + acceptedUsers.Add(acceptedOrgUser); + } + + return acceptedUsers; + } + + private async Task VerifyDefaultCollectionCreatedAsync(OrganizationUser orgUser) + { + var collectionRepository = _factory.GetService(); + var collections = await collectionRepository.GetManyByUserIdAsync(orgUser.UserId!.Value); + Assert.Single(collections); + Assert.Equal(_mockEncryptedString, collections.First().Name); + } + + private async Task VerifyUserConfirmedAsync(OrganizationUser orgUser, string expectedKey) + { + await VerifyMultipleUsersConfirmedAsync(new List<(OrganizationUser orgUser, string key)> { (orgUser, expectedKey) }); + } + + private async Task VerifyMultipleUsersConfirmedAsync(List<(OrganizationUser orgUser, string key)> acceptedOrganizationUsers) + { + var orgUserRepository = _factory.GetService(); + for (int i = 0; i < acceptedOrganizationUsers.Count; i++) + { + var confirmedUser = await orgUserRepository.GetByIdAsync(acceptedOrganizationUsers[i].orgUser.Id); + Assert.Equal(OrganizationUserStatusType.Confirmed, confirmedUser.Status); + Assert.Equal(acceptedOrganizationUsers[i].key, confirmedUser.Key); + } + } + + private async Task VerifyMultipleUsersHaveDefaultCollectionsAsync(List acceptedOrganizationUsers) + { + var collectionRepository = _factory.GetService(); + foreach (var acceptedOrganizationUser in acceptedOrganizationUsers) + { + var collections = await collectionRepository.GetManyByUserIdAsync(acceptedOrganizationUser.UserId!.Value); + Assert.Single(collections); + Assert.Equal(_mockEncryptedString, collections.First().Name); + } + } } diff --git a/test/Api.IntegrationTest/Helpers/OrganizationTestHelpers.cs b/test/Api.IntegrationTest/Helpers/OrganizationTestHelpers.cs index ae4e27267d..fb5c9bbc56 100644 --- a/test/Api.IntegrationTest/Helpers/OrganizationTestHelpers.cs +++ b/test/Api.IntegrationTest/Helpers/OrganizationTestHelpers.cs @@ -1,6 +1,7 @@ using System.Diagnostics; using Bit.Api.IntegrationTest.Factories; using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Enums; @@ -148,4 +149,23 @@ public static class OrganizationTestHelpers await groupRepository.CreateAsync(group, new List()); return group; } + + /// + /// Enables the Organization Data Ownership policy for the specified organization. + /// + public static async Task EnableOrganizationDataOwnershipPolicyAsync( + WebApplicationFactoryBase factory, + Guid organizationId) where T : class + { + var policyRepository = factory.GetService(); + + var policy = new Policy + { + OrganizationId = organizationId, + Type = PolicyType.OrganizationDataOwnership, + Enabled = true + }; + + await policyRepository.CreateAsync(policy); + } } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs index a6709cd10b..b0815d9f35 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs @@ -473,7 +473,7 @@ public class ConfirmOrganizationUserCommandTests sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.CreateDefaultLocation).Returns(true); sutProvider.GetDependency() - .GetAsync(user.Id) + .GetByOrganizationAsync(organization.Id) .Returns(new OrganizationDataOwnershipPolicyRequirement( OrganizationDataOwnershipState.Enabled, [organization.Id])); @@ -482,15 +482,10 @@ public class ConfirmOrganizationUserCommandTests await sutProvider.GetDependency() .Received(1) - .CreateAsync( - Arg.Is(c => c.Name == collectionName && - c.OrganizationId == organization.Id && - c.Type == CollectionType.DefaultUserCollection), - Arg.Is>(groups => groups == null), - Arg.Is>(u => - u.Count() == 1 && - u.First().Id == orgUser.Id && - u.First().Manage == true)); + .CreateDefaultCollectionsAsync( + organization.Id, + Arg.Is>(ids => ids.Contains(orgUser.Id)), + collectionName); } [Theory, BitAutoData] @@ -510,7 +505,7 @@ public class ConfirmOrganizationUserCommandTests sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.CreateDefaultLocation).Returns(true); sutProvider.GetDependency() - .GetAsync(user.Id) + .GetByOrganizationAsync(org.Id) .Returns(new OrganizationDataOwnershipPolicyRequirement( OrganizationDataOwnershipState.Enabled, [org.Id])); @@ -538,7 +533,7 @@ public class ConfirmOrganizationUserCommandTests sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.CreateDefaultLocation).Returns(true); sutProvider.GetDependency() - .GetAsync(user.Id) + .GetByOrganizationAsync(org.Id) .Returns(new OrganizationDataOwnershipPolicyRequirement( OrganizationDataOwnershipState.Enabled, [Guid.NewGuid()])); diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs index 56b6740678..da8f7319d5 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs @@ -79,4 +79,73 @@ public class PolicyRequirementQueryTests Assert.Empty(requirement.Policies); } + + [Theory, BitAutoData] + public async Task GetByOrganizationAsync_IgnoresOtherPolicyTypes(Guid organizationId) + { + var policyRepository = Substitute.For(); + var thisPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = Guid.NewGuid() }; + var otherPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.RequireSso, UserId = Guid.NewGuid() }; + // Force the repository to return both policies even though that is not the expected result + policyRepository.GetPolicyDetailsByOrganizationIdAsync(organizationId, PolicyType.SingleOrg) + .Returns([thisPolicy, otherPolicy]); + + var factory = new TestPolicyRequirementFactory(_ => true); + var sut = new PolicyRequirementQuery(policyRepository, [factory]); + + var requirement = await sut.GetByOrganizationAsync(organizationId); + + await policyRepository.Received(1).GetPolicyDetailsByOrganizationIdAsync(organizationId, PolicyType.SingleOrg); + + Assert.Contains(thisPolicy, requirement.Policies.Cast()); + Assert.DoesNotContain(otherPolicy, requirement.Policies.Cast()); + } + + [Theory, BitAutoData] + public async Task GetByOrganizationAsync_CallsEnforceCallback(Guid organizationId) + { + var policyRepository = Substitute.For(); + var thisPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = Guid.NewGuid() }; + var otherPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = Guid.NewGuid() }; + policyRepository.GetPolicyDetailsByOrganizationIdAsync(organizationId, PolicyType.SingleOrg).Returns([thisPolicy, otherPolicy]); + + var callback = Substitute.For>(); + callback(Arg.Any()).Returns(x => x.Arg() == thisPolicy); + + var factory = new TestPolicyRequirementFactory(callback); + var sut = new PolicyRequirementQuery(policyRepository, [factory]); + + var requirement = await sut.GetByOrganizationAsync(organizationId); + + Assert.Contains(thisPolicy, requirement.Policies.Cast()); + Assert.DoesNotContain(otherPolicy, requirement.Policies.Cast()); + callback.Received()(Arg.Is(p => p == thisPolicy)); + callback.Received()(Arg.Is(p => p == otherPolicy)); + } + + [Theory, BitAutoData] + public async Task GetByOrganizationAsync_ThrowsIfNoFactoryRegistered(Guid organizationId) + { + var policyRepository = Substitute.For(); + var sut = new PolicyRequirementQuery(policyRepository, []); + + var exception = await Assert.ThrowsAsync(() + => sut.GetByOrganizationAsync(organizationId)); + + Assert.Contains("No Requirement Factory found", exception.Message); + } + + [Theory, BitAutoData] + public async Task GetByOrganizationAsync_HandlesNoPolicies(Guid organizationId) + { + var policyRepository = Substitute.For(); + policyRepository.GetPolicyDetailsByOrganizationIdAsync(organizationId, PolicyType.SingleOrg).Returns([]); + + var factory = new TestPolicyRequirementFactory(x => x.IsProvider); + var sut = new PolicyRequirementQuery(policyRepository, [factory]); + + var requirement = await sut.GetByOrganizationAsync(organizationId); + + Assert.Empty(requirement.Policies); + } }