diff --git a/src/Core/Auth/Repositories/IAuthRequestRepository.cs b/src/Core/Auth/Repositories/IAuthRequestRepository.cs index b414b2206b..6662dd15fc 100644 --- a/src/Core/Auth/Repositories/IAuthRequestRepository.cs +++ b/src/Core/Auth/Repositories/IAuthRequestRepository.cs @@ -9,4 +9,5 @@ public interface IAuthRequestRepository : IRepository Task> GetManyByUserIdAsync(Guid userId); Task> GetManyPendingByOrganizationIdAsync(Guid organizationId); Task> GetManyAdminApprovalRequestsByManyIdsAsync(Guid organizationId, IEnumerable ids); + Task UpdateManyAsync(IEnumerable authRequests); } diff --git a/src/Infrastructure.Dapper/Auth/Repositories/AuthRequestRepository.cs b/src/Infrastructure.Dapper/Auth/Repositories/AuthRequestRepository.cs index 67e636b4dd..df68c06d05 100644 --- a/src/Infrastructure.Dapper/Auth/Repositories/AuthRequestRepository.cs +++ b/src/Infrastructure.Dapper/Auth/Repositories/AuthRequestRepository.cs @@ -1,4 +1,5 @@ using System.Data; +using System.Text.Json; using Bit.Core.Auth.Entities; using Bit.Core.Auth.Models.Data; using Bit.Core.Repositories; @@ -74,4 +75,20 @@ public class AuthRequestRepository : Repository, IAuthRequest return results.ToList(); } } + + public async Task UpdateManyAsync(IEnumerable authRequests) + { + if (!authRequests.Any()) + { + return; + } + + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[dbo].[AuthRequest_UpdateMany]", + new { jsonData = JsonSerializer.Serialize(authRequests) }, + commandType: CommandType.StoredProcedure); + } + } } diff --git a/src/Infrastructure.EntityFramework/Auth/Repositories/AuthRequestRepository.cs b/src/Infrastructure.EntityFramework/Auth/Repositories/AuthRequestRepository.cs index af3ae195dc..11e5b3f65c 100644 --- a/src/Infrastructure.EntityFramework/Auth/Repositories/AuthRequestRepository.cs +++ b/src/Infrastructure.EntityFramework/Auth/Repositories/AuthRequestRepository.cs @@ -69,4 +69,29 @@ public class AuthRequestRepository : Repository authRequests) + { + if (!authRequests.Any()) + { + return; + } + + var entities = new List(); + foreach (var authRequest in authRequests) + { + if (!authRequest.Id.Equals(default)) + { + var entity = Mapper.Map(authRequest); + entities.Add(entity); + } + } + + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + dbContext.UpdateRange(entities); + await dbContext.SaveChangesAsync(); + } + } } diff --git a/src/Sql/Auth/dbo/Stored Procedures/AuthRequest_UpdateMany.sql b/src/Sql/Auth/dbo/Stored Procedures/AuthRequest_UpdateMany.sql new file mode 100644 index 0000000000..227abbb3e1 --- /dev/null +++ b/src/Sql/Auth/dbo/Stored Procedures/AuthRequest_UpdateMany.sql @@ -0,0 +1,45 @@ +CREATE PROCEDURE AuthRequest_UpdateMany + @jsonData NVARCHAR(MAX) +AS +BEGIN + UPDATE AR + SET + [Id] = ARI.[Id], + [UserId] = ARI.[UserId], + [Type] = ARI.[Type], + [RequestDeviceIdentifier] = ARI.[RequestDeviceIdentifier], + [RequestDeviceType] = ARI.[RequestDeviceType], + [RequestIpAddress] = ARI.[RequestIpAddress], + [ResponseDeviceId] = ARI.[ResponseDeviceId], + [AccessCode] = ARI.[AccessCode], + [PublicKey] = ARI.[PublicKey], + [Key] = ARI.[Key], + [MasterPasswordHash] = ARI.[MasterPasswordHash], + [Approved] = ARI.[Approved], + [CreationDate] = ARI.[CreationDate], + [ResponseDate] = ARI.[ResponseDate], + [AuthenticationDate] = ARI.[AuthenticationDate], + [OrganizationId] = ARI.[OrganizationId] + FROM + [dbo].[AuthRequest] AR + INNER JOIN + OPENJSON(@jsonData) + WITH ( + Id UNIQUEIDENTIFIER '$.Id', + UserId UNIQUEIDENTIFIER '$.UserId', + Type SMALLINT '$.Type', + RequestDeviceIdentifier NVARCHAR(50) '$.RequestDeviceIdentifier', + RequestDeviceType SMALLINT '$.RequestDeviceType', + RequestIpAddress VARCHAR(50) '$.RequestIpAddress', + ResponseDeviceId UNIQUEIDENTIFIER '$.ResponseDeviceId', + AccessCode VARCHAR(25) '$.AccessCode', + PublicKey VARCHAR(MAX) '$.PublicKey', + [Key] VARCHAR(MAX) '$.Key', + MasterPasswordHash VARCHAR(MAX) '$.MasterPasswordHash', + Approved BIT '$.Approved', + CreationDate DATETIME2 '$.CreationDate', + ResponseDate DATETIME2 '$.ResponseDate', + AuthenticationDate DATETIME2 '$.AuthenticationDate', + OrganizationId UNIQUEIDENTIFIER '$.OrganizationId' + ) ARI ON AR.Id = ARI.Id; +END diff --git a/test/Infrastructure.IntegrationTest/Auth/Repositories/AuthRequestRepositoryTests.cs b/test/Infrastructure.IntegrationTest/Auth/Repositories/AuthRequestRepositoryTests.cs index 6c5bf135e1..835d1da74c 100644 --- a/test/Infrastructure.IntegrationTest/Auth/Repositories/AuthRequestRepositoryTests.cs +++ b/test/Infrastructure.IntegrationTest/Auth/Repositories/AuthRequestRepositoryTests.cs @@ -72,6 +72,113 @@ public class AuthRequestRepositoryTests Assert.Equal(4, numberOfDeleted); } + [DatabaseTheory, DatabaseData] + public async Task UpdateManyAsync_Works( + IAuthRequestRepository authRequestRepository, + IUserRepository userRepository) + { + // Create two distinct real users for foreign key requirements + var user1 = await userRepository.CreateAsync(new User + { + Name = "First Test User", + Email = $"test+{Guid.NewGuid()}@email.com", + ApiKey = "TEST", + SecurityStamp = "stamp", + }); + + var user2 = await userRepository.CreateAsync(new User + { + Name = "Second Test User", + Email = $"test+{Guid.NewGuid()}@email.com", + ApiKey = "TEST", + SecurityStamp = "stamp", + }); + + var user3 = await userRepository.CreateAsync(new User + { + Name = "Third Test User", + Email = $"test+{Guid.NewGuid()}@email.com", + ApiKey = "TEST", + SecurityStamp = "stamp", + }); + + // Create two different and still valid (not expired or responded to) auth requests + var authRequests = new List + { + await authRequestRepository.CreateAsync(CreateAuthRequest(user1.Id, AuthRequestType.AdminApproval, DateTime.UtcNow.AddMinutes(-5))), + await authRequestRepository.CreateAsync(CreateAuthRequest(user3.Id, AuthRequestType.AdminApproval, DateTime.UtcNow.AddMinutes(-7))), + await authRequestRepository.CreateAsync(CreateAuthRequest(user2.Id, AuthRequestType.AdminApproval, DateTime.UtcNow.AddMinutes(-10))), + // This last auth request is not created manually, and will be + // used to make sure entity framework's `UpdateRange` method + // doesn't create requests too. + CreateAuthRequest(user2.Id, AuthRequestType.AdminApproval, DateTime.UtcNow.AddMinutes(-11)) + }; + + // Update some properties on two auth request, but leave the other one + // alone to be a control value + var authRequestToBeUpdated1 = authRequests[0]; + var authRequestToBeUpdated2 = authRequests[1]; + var authRequestNotToBeUpdated = authRequests[2]; + authRequests[0].Approved = true; + authRequests[0].ResponseDate = DateTime.UtcNow.AddMinutes(-1); + authRequests[0].Key = "UPDATED_KEY_1"; + authRequests[0].MasterPasswordHash = "UPDATED_MASTERPASSWORDHASH_1"; + + authRequests[1].Approved = false; + authRequests[1].ResponseDate = DateTime.UtcNow.AddMinutes(-2); + + // Run the method being tested + await authRequestRepository.UpdateManyAsync(authRequests); + + // Define what "Equality" really means in this context + // This includes stripping milliseconds off of dates, because we can't + // reliably compare that deep + static DateTime? TrimMilliseconds(DateTime? dt) + { + if (!dt.HasValue) + { + return null; + } + return new DateTime(dt.Value.Year, dt.Value.Month, dt.Value.Day, dt.Value.Hour, dt.Value.Minute, dt.Value.Second, 0, dt.Value.Kind); + } + + bool AuthRequestEquals(AuthRequest x, AuthRequest y) + { + return + x.Id == y.Id && + x.UserId == y.UserId && + x.Type == y.Type && + x.RequestDeviceIdentifier == y.RequestDeviceIdentifier && + x.RequestDeviceType == y.RequestDeviceType && + x.RequestIpAddress == y.RequestIpAddress && + x.ResponseDeviceId == y.ResponseDeviceId && + x.AccessCode == y.AccessCode && + x.PublicKey == y.PublicKey && + x.Key == y.Key && + x.MasterPasswordHash == y.MasterPasswordHash && + x.Approved == y.Approved && + TrimMilliseconds(x.CreationDate) == TrimMilliseconds(y.CreationDate) && + TrimMilliseconds(x.ResponseDate) == TrimMilliseconds(y.ResponseDate) && + TrimMilliseconds(x.AuthenticationDate) == TrimMilliseconds(y.AuthenticationDate) && + x.OrganizationId == y.OrganizationId; + } + + // Assert that the unchanged auth request is still unchanged + var skippedAuthRequest = await authRequestRepository.GetByIdAsync(authRequestNotToBeUpdated.Id); + Assert.True(AuthRequestEquals(skippedAuthRequest, authRequestNotToBeUpdated)); + + // Assert that the values updated on the changed auth requests were updated, and no others + var updatedAuthRequest1 = await authRequestRepository.GetByIdAsync(authRequestToBeUpdated1.Id); + Assert.True(AuthRequestEquals(authRequestToBeUpdated1, updatedAuthRequest1)); + var updatedAuthRequest2 = await authRequestRepository.GetByIdAsync(authRequestToBeUpdated2.Id); + Assert.True(AuthRequestEquals(authRequestToBeUpdated2, updatedAuthRequest2)); + + // Assert that the auth request we never created is not created by + // the update method. + var uncreatedAuthRequest = await authRequestRepository.GetByIdAsync(authRequests[3].Id); + Assert.Null(uncreatedAuthRequest); + } + private static AuthRequest CreateAuthRequest(Guid userId, AuthRequestType authRequestType, DateTime creationDate, bool? approved = null, DateTime? responseDate = null) { return new AuthRequest diff --git a/util/Migrator/DbScripts/2024-05-05_00_UpdateManyAuthRequests.sql b/util/Migrator/DbScripts/2024-05-05_00_UpdateManyAuthRequests.sql new file mode 100644 index 0000000000..227abbb3e1 --- /dev/null +++ b/util/Migrator/DbScripts/2024-05-05_00_UpdateManyAuthRequests.sql @@ -0,0 +1,45 @@ +CREATE PROCEDURE AuthRequest_UpdateMany + @jsonData NVARCHAR(MAX) +AS +BEGIN + UPDATE AR + SET + [Id] = ARI.[Id], + [UserId] = ARI.[UserId], + [Type] = ARI.[Type], + [RequestDeviceIdentifier] = ARI.[RequestDeviceIdentifier], + [RequestDeviceType] = ARI.[RequestDeviceType], + [RequestIpAddress] = ARI.[RequestIpAddress], + [ResponseDeviceId] = ARI.[ResponseDeviceId], + [AccessCode] = ARI.[AccessCode], + [PublicKey] = ARI.[PublicKey], + [Key] = ARI.[Key], + [MasterPasswordHash] = ARI.[MasterPasswordHash], + [Approved] = ARI.[Approved], + [CreationDate] = ARI.[CreationDate], + [ResponseDate] = ARI.[ResponseDate], + [AuthenticationDate] = ARI.[AuthenticationDate], + [OrganizationId] = ARI.[OrganizationId] + FROM + [dbo].[AuthRequest] AR + INNER JOIN + OPENJSON(@jsonData) + WITH ( + Id UNIQUEIDENTIFIER '$.Id', + UserId UNIQUEIDENTIFIER '$.UserId', + Type SMALLINT '$.Type', + RequestDeviceIdentifier NVARCHAR(50) '$.RequestDeviceIdentifier', + RequestDeviceType SMALLINT '$.RequestDeviceType', + RequestIpAddress VARCHAR(50) '$.RequestIpAddress', + ResponseDeviceId UNIQUEIDENTIFIER '$.ResponseDeviceId', + AccessCode VARCHAR(25) '$.AccessCode', + PublicKey VARCHAR(MAX) '$.PublicKey', + [Key] VARCHAR(MAX) '$.Key', + MasterPasswordHash VARCHAR(MAX) '$.MasterPasswordHash', + Approved BIT '$.Approved', + CreationDate DATETIME2 '$.CreationDate', + ResponseDate DATETIME2 '$.ResponseDate', + AuthenticationDate DATETIME2 '$.AuthenticationDate', + OrganizationId UNIQUEIDENTIFIER '$.OrganizationId' + ) ARI ON AR.Id = ARI.Id; +END