1
0
mirror of https://github.com/bitwarden/server synced 2026-02-27 01:43:46 +00:00

Merge branch 'main' into auth/pm-30810/http-redirect-cloud

This commit is contained in:
Patrick-Pimentel-Bitwarden
2026-02-10 14:23:16 -05:00
committed by GitHub
459 changed files with 51296 additions and 4110 deletions

View File

@@ -3,6 +3,7 @@ using AutoFixture;
using AutoFixture.Xunit2;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
namespace Bit.Core.Test.AdminConsole.AutoFixture;
@@ -10,19 +11,30 @@ internal class PolicyCustomization : ICustomization
{
public PolicyType Type { get; set; }
public bool Enabled { get; set; }
public string? Data { get; set; }
public PolicyCustomization(PolicyType type, bool enabled)
public PolicyCustomization(PolicyType type, bool enabled, string? data)
{
Type = type;
Enabled = enabled;
Data = data;
}
public void Customize(IFixture fixture)
{
var orgId = Guid.NewGuid();
fixture.Customize<Policy>(composer => composer
.With(o => o.OrganizationId, Guid.NewGuid())
.With(o => o.OrganizationId, orgId)
.With(o => o.Type, Type)
.With(o => o.Enabled, Enabled));
.With(o => o.Enabled, Enabled)
.With(o => o.Data, Data));
fixture.Customize<PolicyStatus>(composer => composer
.With(o => o.OrganizationId, orgId)
.With(o => o.Type, Type)
.With(o => o.Enabled, Enabled)
.With(o => o.Data, Data));
}
}
@@ -30,15 +42,17 @@ public class PolicyAttribute : CustomizeAttribute
{
private readonly PolicyType _type;
private readonly bool _enabled;
private readonly string? _data;
public PolicyAttribute(PolicyType type, bool enabled = true)
public PolicyAttribute(PolicyType type, bool enabled = true, string? data = null)
{
_type = type;
_enabled = enabled;
_data = data;
}
public override ICustomization GetCustomization(ParameterInfo parameter)
{
return new PolicyCustomization(_type, _enabled);
return new PolicyCustomization(_type, _enabled, _data);
}
}

View File

@@ -1,14 +1,16 @@
using AutoFixture;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Platform.Push;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Test.AdminConsole.AutoFixture;
using Bit.Core.Test.AutoFixture.OrganizationUserFixtures;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
@@ -29,11 +31,12 @@ public class AdminRecoverAccountCommandTests
Organization organization,
OrganizationUser organizationUser,
User user,
[Policy(PolicyType.ResetPassword, true)] PolicyStatus policy,
SutProvider<AdminRecoverAccountCommand> sutProvider)
{
// Arrange
SetupValidOrganization(sutProvider, organization);
SetupValidPolicy(sutProvider, organization);
SetupValidPolicy(sutProvider, organization, policy);
SetupValidOrganizationUser(organizationUser, organization.Id);
SetupValidUser(sutProvider, user, organizationUser);
SetupSuccessfulPasswordUpdate(sutProvider, user, newMasterPassword);
@@ -87,25 +90,18 @@ public class AdminRecoverAccountCommandTests
Assert.Equal("Organization does not allow password reset.", exception.Message);
}
public static IEnumerable<object[]> InvalidPolicies => new object[][]
{
[new Policy { Type = PolicyType.ResetPassword, Enabled = false }], [null]
};
[Theory]
[BitMemberAutoData(nameof(InvalidPolicies))]
[BitAutoData]
public async Task RecoverAccountAsync_InvalidPolicy_ThrowsBadRequest(
Policy resetPasswordPolicy,
string newMasterPassword,
string key,
Organization organization,
[Policy(PolicyType.ResetPassword, false)] PolicyStatus policy,
SutProvider<AdminRecoverAccountCommand> sutProvider)
{
// Arrange
SetupValidOrganization(sutProvider, organization);
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword)
.Returns(resetPasswordPolicy);
SetupValidPolicy(sutProvider, organization, policy);
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(() =>
@@ -171,11 +167,12 @@ public class AdminRecoverAccountCommandTests
Organization organization,
string newMasterPassword,
string key,
[Policy(PolicyType.ResetPassword, true)] PolicyStatus policy,
SutProvider<AdminRecoverAccountCommand> sutProvider)
{
// Arrange
SetupValidOrganization(sutProvider, organization);
SetupValidPolicy(sutProvider, organization);
SetupValidPolicy(sutProvider, organization, policy);
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(() =>
@@ -190,11 +187,12 @@ public class AdminRecoverAccountCommandTests
string key,
Organization organization,
OrganizationUser organizationUser,
[Policy(PolicyType.ResetPassword, true)] PolicyStatus policy,
SutProvider<AdminRecoverAccountCommand> sutProvider)
{
// Arrange
SetupValidOrganization(sutProvider, organization);
SetupValidPolicy(sutProvider, organization);
SetupValidPolicy(sutProvider, organization, policy);
SetupValidOrganizationUser(organizationUser, organization.Id);
sutProvider.GetDependency<IUserService>()
.GetUserByIdAsync(organizationUser.UserId!.Value)
@@ -213,11 +211,12 @@ public class AdminRecoverAccountCommandTests
Organization organization,
OrganizationUser organizationUser,
User user,
[Policy(PolicyType.ResetPassword, true)] PolicyStatus policy,
SutProvider<AdminRecoverAccountCommand> sutProvider)
{
// Arrange
SetupValidOrganization(sutProvider, organization);
SetupValidPolicy(sutProvider, organization);
SetupValidPolicy(sutProvider, organization, policy);
SetupValidOrganizationUser(organizationUser, organization.Id);
user.UsesKeyConnector = true;
sutProvider.GetDependency<IUserService>()
@@ -238,11 +237,10 @@ public class AdminRecoverAccountCommandTests
.Returns(organization);
}
private static void SetupValidPolicy(SutProvider<AdminRecoverAccountCommand> sutProvider, Organization organization)
private static void SetupValidPolicy(SutProvider<AdminRecoverAccountCommand> sutProvider, Organization organization, PolicyStatus policy)
{
var policy = new Policy { Type = PolicyType.ResetPassword, Enabled = true };
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword)
sutProvider.GetDependency<IPolicyQuery>()
.RunAsync(organization.Id, PolicyType.ResetPassword)
.Returns(policy);
}

View File

@@ -282,6 +282,7 @@ public class VerifyOrganizationDomainCommandTests
await sutProvider.GetDependency<IMailService>().Received().SendClaimedDomainUserEmailAsync(
Arg.Is<ClaimedUserDomainClaimedEmails>(x =>
x.EmailList.Count(e => e.EndsWith(domain.DomainName)) == mockedUsers.Count &&
x.Organization.Id == organization.Id));
x.Organization.Id == organization.Id &&
x.DomainName == domain.DomainName));
}
}

View File

@@ -7,7 +7,6 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimed
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces;
using Bit.Core.Billing.Enums;
using Bit.Core.Entities;
@@ -120,7 +119,7 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests
[Organization(useAutomaticUserConfirmation: true, planType: PlanType.EnterpriseAnnually)] Organization organization,
[OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser,
User user,
[Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy)
[Policy(PolicyType.AutomaticUserConfirmation)] PolicyStatus autoConfirmPolicy)
{
// Arrange
organizationUser.UserId = user.Id;
@@ -137,8 +136,8 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests
Key = "test-key"
};
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation)
sutProvider.GetDependency<IPolicyQuery>()
.RunAsync(organization.Id, PolicyType.AutomaticUserConfirmation)
.Returns(autoConfirmPolicy);
sutProvider.GetDependency<ITwoFactorIsEnabledQuery>()
@@ -280,7 +279,7 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests
[Organization(useAutomaticUserConfirmation: true)] Organization organization,
[OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser,
Guid userId,
[Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy)
[Policy(PolicyType.AutomaticUserConfirmation)] PolicyStatus autoConfirmPolicy)
{
// Arrange
organizationUser.UserId = userId;
@@ -303,8 +302,8 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests
PolicyType = PolicyType.TwoFactorAuthentication
};
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation)
sutProvider.GetDependency<IPolicyQuery>()
.RunAsync(organization.Id, PolicyType.AutomaticUserConfirmation)
.Returns(autoConfirmPolicy);
sutProvider.GetDependency<ITwoFactorIsEnabledQuery>()
@@ -334,7 +333,7 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests
[Organization(useAutomaticUserConfirmation: true)] Organization organization,
[OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser,
User user,
[Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy)
[Policy(PolicyType.AutomaticUserConfirmation)] PolicyStatus autoConfirmPolicy)
{
// Arrange
organizationUser.UserId = user.Id;
@@ -351,8 +350,8 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests
Key = "test-key"
};
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation)
sutProvider.GetDependency<IPolicyQuery>()
.RunAsync(organization.Id, PolicyType.AutomaticUserConfirmation)
.Returns(autoConfirmPolicy);
sutProvider.GetDependency<ITwoFactorIsEnabledQuery>()
@@ -389,7 +388,7 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests
[Organization(useAutomaticUserConfirmation: true)] Organization organization,
[OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser,
User user,
[Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy)
[Policy(PolicyType.AutomaticUserConfirmation)] PolicyStatus autoConfirmPolicy)
{
// Arrange
organizationUser.UserId = user.Id;
@@ -406,8 +405,8 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests
Key = "test-key"
};
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation)
sutProvider.GetDependency<IPolicyQuery>()
.RunAsync(organization.Id, PolicyType.AutomaticUserConfirmation)
.Returns(autoConfirmPolicy);
sutProvider.GetDependency<ITwoFactorIsEnabledQuery>()
@@ -448,7 +447,7 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests
[Organization(useAutomaticUserConfirmation: true)] Organization organization,
[OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser,
User user,
[Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy)
[Policy(PolicyType.AutomaticUserConfirmation)] PolicyStatus autoConfirmPolicy)
{
// Arrange
organizationUser.UserId = user.Id;
@@ -465,8 +464,8 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests
Key = "test-key"
};
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation)
sutProvider.GetDependency<IPolicyQuery>()
.RunAsync(organization.Id, PolicyType.AutomaticUserConfirmation)
.Returns(autoConfirmPolicy);
sutProvider.GetDependency<ITwoFactorIsEnabledQuery>()
@@ -501,7 +500,8 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests
SutProvider<AutomaticallyConfirmOrganizationUsersValidator> sutProvider,
Organization organization,
[OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser,
Guid userId)
Guid userId,
[Policy(PolicyType.AutomaticUserConfirmation, false)] PolicyStatus policy)
{
// Arrange
organizationUser.UserId = userId;
@@ -518,9 +518,9 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests
Key = "test-key"
};
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation)
.Returns((Policy)null);
sutProvider.GetDependency<IPolicyQuery>()
.RunAsync(organization.Id, PolicyType.AutomaticUserConfirmation)
.Returns(policy);
sutProvider.GetDependency<ITwoFactorIsEnabledQuery>()
.TwoFactorIsEnabledAsync(Arg.Any<IEnumerable<Guid>>())
@@ -545,7 +545,7 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests
[Organization(useAutomaticUserConfirmation: false)] Organization organization,
[OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser,
Guid userId,
[Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy)
[Policy(PolicyType.AutomaticUserConfirmation)] PolicyStatus autoConfirmPolicy)
{
// Arrange
organizationUser.UserId = userId;
@@ -562,8 +562,8 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests
Key = "test-key"
};
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation)
sutProvider.GetDependency<IPolicyQuery>()
.RunAsync(organization.Id, PolicyType.AutomaticUserConfirmation)
.Returns(autoConfirmPolicy);
sutProvider.GetDependency<ITwoFactorIsEnabledQuery>()
@@ -589,7 +589,7 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests
[Organization(useAutomaticUserConfirmation: true)] Organization organization,
[OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser,
User user,
[Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy)
[Policy(PolicyType.AutomaticUserConfirmation)] PolicyStatus autoConfirmPolicy)
{
// Arrange
organizationUser.UserId = user.Id;
@@ -606,8 +606,8 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests
Key = "test-key"
};
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation)
sutProvider.GetDependency<IPolicyQuery>()
.RunAsync(organization.Id, PolicyType.AutomaticUserConfirmation)
.Returns(autoConfirmPolicy);
sutProvider.GetDependency<ITwoFactorIsEnabledQuery>()

View File

@@ -10,7 +10,6 @@ using Bit.Core.AdminConsole.Utilities.v2;
using Bit.Core.AdminConsole.Utilities.v2.Validation;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Models.Data;
using Bit.Core.Platform.Push;
using Bit.Core.Repositories;
using Bit.Core.Services;
@@ -204,14 +203,10 @@ public class AutomaticallyConfirmUsersCommandTests
await sutProvider.GetDependency<ICollectionRepository>()
.Received(1)
.CreateAsync(
Arg.Is<Collection>(c =>
c.OrganizationId == organization.Id &&
c.Name == defaultCollectionName &&
c.Type == CollectionType.DefaultUserCollection),
Arg.Is<IEnumerable<CollectionAccessSelection>>(groups => groups == null),
Arg.Is<IEnumerable<CollectionAccessSelection>>(access =>
access.FirstOrDefault(x => x.Id == organizationUser.Id && x.Manage) != null));
.CreateDefaultCollectionsAsync(
organization.Id,
Arg.Is<IEnumerable<Guid>>(ids => ids.Single() == organizationUser.Id),
defaultCollectionName);
}
[Theory]
@@ -253,9 +248,7 @@ public class AutomaticallyConfirmUsersCommandTests
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceive()
.CreateAsync(Arg.Any<Collection>(),
Arg.Any<IEnumerable<CollectionAccessSelection>>(),
Arg.Any<IEnumerable<CollectionAccessSelection>>());
.CreateDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
}
[Theory]
@@ -291,9 +284,7 @@ public class AutomaticallyConfirmUsersCommandTests
var collectionException = new Exception("Collection creation failed");
sutProvider.GetDependency<ICollectionRepository>()
.CreateAsync(Arg.Any<Collection>(),
Arg.Any<IEnumerable<CollectionAccessSelection>>(),
Arg.Any<IEnumerable<CollectionAccessSelection>>())
.CreateDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>())
.ThrowsAsync(collectionException);
// Act

View File

@@ -13,7 +13,6 @@ using Bit.Core.Billing.Enums;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Models.Data;
using Bit.Core.Models.Data.Organizations.OrganizationUsers;
using Bit.Core.Platform.Push;
using Bit.Core.Repositories;
@@ -493,15 +492,10 @@ public class ConfirmOrganizationUserCommandTests
await sutProvider.GetDependency<ICollectionRepository>()
.Received(1)
.CreateAsync(
Arg.Is<Collection>(c =>
c.Name == collectionName &&
c.OrganizationId == organization.Id &&
c.Type == CollectionType.DefaultUserCollection),
Arg.Any<IEnumerable<CollectionAccessSelection>>(),
Arg.Is<IEnumerable<CollectionAccessSelection>>(cu =>
cu.Single().Id == orgUser.Id &&
cu.Single().Manage));
.CreateDefaultCollectionsAsync(
organization.Id,
Arg.Is<IEnumerable<Guid>>(ids => ids.Single() == orgUser.Id),
collectionName);
}
[Theory, BitAutoData]
@@ -522,7 +516,7 @@ public class ConfirmOrganizationUserCommandTests
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceive()
.UpsertDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
.CreateDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
}
[Theory, BitAutoData]
@@ -539,24 +533,15 @@ public class ConfirmOrganizationUserCommandTests
sutProvider.GetDependency<IOrganizationUserRepository>().GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser });
sutProvider.GetDependency<IUserRepository>().GetManyAsync(default).ReturnsForAnyArgs(new[] { user });
var policyDetails = new PolicyDetails
{
OrganizationId = org.Id,
OrganizationUserId = orgUser.Id,
IsProvider = false,
OrganizationUserStatus = orgUser.Status,
OrganizationUserType = orgUser.Type,
PolicyType = PolicyType.OrganizationDataOwnership
};
sutProvider.GetDependency<IPolicyRequirementQuery>()
.GetAsync<OrganizationDataOwnershipPolicyRequirement>(orgUser.UserId!.Value)
.Returns(new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Disabled, [policyDetails]));
.Returns(new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Disabled, []));
await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, collectionName);
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceive()
.UpsertDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
.CreateDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
}
[Theory, BitAutoData]

View File

@@ -1,7 +1,9 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Models.Business.Tokenables;
using Bit.Core.Auth.Repositories;
@@ -9,6 +11,7 @@ using Bit.Core.Billing.Enums;
using Bit.Core.Entities;
using Bit.Core.Models.Mail;
using Bit.Core.Services;
using Bit.Core.Test.AdminConsole.AutoFixture;
using Bit.Core.Test.AutoFixture.OrganizationFixtures;
using Bit.Core.Tokens;
using Bit.Test.Common.AutoFixture;
@@ -31,6 +34,7 @@ public class SendOrganizationInvitesCommandTests
Organization organization,
SsoConfig ssoConfig,
OrganizationUser invite,
[Policy(PolicyType.RequireSso, false)] PolicyStatus policy,
SutProvider<SendOrganizationInvitesCommand> sutProvider)
{
// Setup FakeDataProtectorTokenFactory for creating new tokens - this must come first in order to avoid resetting mocks
@@ -45,7 +49,9 @@ public class SendOrganizationInvitesCommandTests
sutProvider.GetDependency<ISsoConfigRepository>().GetByOrganizationIdAsync(organization.Id).Returns(ssoConfig);
// Return null policy to mimic new org that's never turned on the require sso policy
sutProvider.GetDependency<IPolicyRepository>().GetManyByOrganizationIdAsync(organization.Id).ReturnsNull();
sutProvider.GetDependency<IPolicyQuery>()
.RunAsync(organization.Id, PolicyType.RequireSso)
.Returns(policy);
// Mock tokenable factory to return a token that expires in 5 days
sutProvider.GetDependency<IOrgUserInviteTokenableFactory>()

View File

@@ -37,7 +37,7 @@ public class RestoreOrganizationUserCommandTests
Sponsored = 0,
Users = 1
});
await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id);
await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null);
await sutProvider.GetDependency<IOrganizationUserRepository>()
.Received(1)
@@ -81,7 +81,7 @@ public class RestoreOrganizationUserCommandTests
RestoreUser_Setup(organization, owner, organizationUser, sutProvider);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id));
() => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null));
Assert.Contains("you cannot restore yourself", exception.Message.ToLowerInvariant());
@@ -107,7 +107,7 @@ public class RestoreOrganizationUserCommandTests
RestoreUser_Setup(organization, restoringUser, organizationUser, sutProvider);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.RestoreUserAsync(organizationUser, restoringUser.Id));
() => sutProvider.Sut.RestoreUserAsync(organizationUser, restoringUser.Id, null));
Assert.Contains("only owners can restore other owners", exception.Message.ToLowerInvariant());
@@ -133,7 +133,7 @@ public class RestoreOrganizationUserCommandTests
RestoreUser_Setup(organization, owner, organizationUser, sutProvider);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id));
() => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null));
Assert.Contains("already active", exception.Message.ToLowerInvariant());
@@ -172,7 +172,7 @@ public class RestoreOrganizationUserCommandTests
sutProvider.GetDependency<IUserRepository>().GetByIdAsync(organizationUser.UserId.Value).Returns(user);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id));
() => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null));
Assert.Contains("test@bitwarden.com belongs to an organization that doesn't allow them to join multiple organizations", exception.Message.ToLowerInvariant());
@@ -216,7 +216,7 @@ public class RestoreOrganizationUserCommandTests
sutProvider.GetDependency<IUserRepository>().GetByIdAsync(organizationUser.UserId.Value).Returns(user);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id));
() => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null));
Assert.Contains("test@bitwarden.com is not compliant with the two-step login policy", exception.Message.ToLowerInvariant());
@@ -272,7 +272,7 @@ public class RestoreOrganizationUserCommandTests
sutProvider.GetDependency<IUserRepository>().GetByIdAsync(organizationUser.UserId.Value).Returns(user);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id));
() => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null));
Assert.Contains("test@bitwarden.com is not compliant with the two-step login policy", exception.Message.ToLowerInvariant());
@@ -309,7 +309,7 @@ public class RestoreOrganizationUserCommandTests
Sponsored = 0,
Users = 1
});
await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id);
await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null);
await sutProvider.GetDependency<IOrganizationUserRepository>()
.Received(1)
@@ -349,7 +349,7 @@ public class RestoreOrganizationUserCommandTests
}
]));
await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id);
await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null);
await sutProvider.GetDependency<IOrganizationUserRepository>()
.Received(1)
@@ -395,7 +395,7 @@ public class RestoreOrganizationUserCommandTests
sutProvider.GetDependency<IUserRepository>().GetByIdAsync(organizationUser.UserId.Value).Returns(user);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id));
() => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null));
Assert.Contains("test@bitwarden.com is not compliant with the single organization policy", exception.Message.ToLowerInvariant());
@@ -447,7 +447,7 @@ public class RestoreOrganizationUserCommandTests
sutProvider.GetDependency<IUserRepository>().GetByIdAsync(organizationUser.UserId.Value).Returns(user);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id));
() => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null));
Assert.Contains("test@bitwarden.com is not compliant with the single organization and two-step login policy", exception.Message.ToLowerInvariant());
@@ -509,7 +509,7 @@ public class RestoreOrganizationUserCommandTests
sutProvider.GetDependency<IUserRepository>().GetByIdAsync(organizationUser.UserId.Value).Returns(user);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id));
() => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null));
Assert.Contains("test@bitwarden.com is not compliant with the single organization and two-step login policy", exception.Message.ToLowerInvariant());
@@ -548,7 +548,7 @@ public class RestoreOrganizationUserCommandTests
.TwoFactorIsEnabledAsync(Arg.Is<IEnumerable<Guid>>(i => i.Contains(organizationUser.UserId.Value)))
.Returns(new List<(Guid userId, bool twoFactorIsEnabled)> { (organizationUser.UserId.Value, true) });
await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id);
await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null);
await sutProvider.GetDependency<IOrganizationUserRepository>()
.Received(1)
@@ -599,7 +599,7 @@ public class RestoreOrganizationUserCommandTests
.Returns(new List<(Guid userId, bool twoFactorIsEnabled)> { (organizationUser.UserId.Value, true) });
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id));
() => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null));
Assert.Equal("User is an owner/admin of another free organization. Please have them upgrade to a paid plan to restore their account.", exception.Message);
}
@@ -651,7 +651,7 @@ public class RestoreOrganizationUserCommandTests
.TwoFactorIsEnabledAsync(Arg.Is<IEnumerable<Guid>>(i => i.Contains(organizationUser.UserId.Value)))
.Returns(new List<(Guid userId, bool twoFactorIsEnabled)> { (organizationUser.UserId.Value, true) });
await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id);
await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null);
await organizationUserRepository
.Received(1)
@@ -707,7 +707,7 @@ public class RestoreOrganizationUserCommandTests
.TwoFactorIsEnabledAsync(Arg.Is<IEnumerable<Guid>>(i => i.Contains(organizationUser.UserId.Value)))
.Returns(new List<(Guid userId, bool twoFactorIsEnabled)> { (organizationUser.UserId.Value, true) });
await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id);
await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null);
await organizationUserRepository
.Received(1)
@@ -715,6 +715,39 @@ public class RestoreOrganizationUserCommandTests
Arg.Is<OrganizationUserStatusType>(x => x != OrganizationUserStatusType.Revoked));
}
[Theory, BitAutoData]
public async Task RestoreUser_InvitedUserInFreeOrganization_Success(
Organization organization,
[OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner,
[OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser,
SutProvider<RestoreOrganizationUserCommand> sutProvider)
{
organization.PlanType = PlanType.Free;
organizationUser.UserId = null;
organizationUser.Key = null;
organizationUser.Status = OrganizationUserStatusType.Revoked;
RestoreUser_Setup(organization, owner, organizationUser, sutProvider);
sutProvider.GetDependency<IOrganizationRepository>()
.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id).Returns(new OrganizationSeatCounts
{
Sponsored = 0,
Users = 1
});
await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, "");
await sutProvider.GetDependency<IOrganizationUserRepository>()
.Received(1)
.RestoreAsync(organizationUser.Id, OrganizationUserStatusType.Invited);
await sutProvider.GetDependency<IEventService>()
.Received(1)
.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Restored);
await sutProvider.GetDependency<IPushNotificationService>()
.DidNotReceiveWithAnyArgs()
.PushSyncOrgKeysAsync(Arg.Any<Guid>());
}
[Theory, BitAutoData]
public async Task RestoreUsers_Success(Organization organization,
[OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner,
@@ -749,7 +782,7 @@ public class RestoreOrganizationUserCommandTests
});
// Act
var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, new[] { orgUser1.Id, orgUser2.Id }, owner.Id, userService);
var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, new[] { orgUser1.Id, orgUser2.Id }, owner.Id, userService, null);
// Assert
Assert.Equal(2, result.Count);
@@ -810,7 +843,7 @@ public class RestoreOrganizationUserCommandTests
});
// Act
var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id, orgUser2.Id, orgUser3.Id], owner.Id, userService);
var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id, orgUser2.Id, orgUser3.Id], owner.Id, userService, null);
// Assert
Assert.Equal(3, result.Count);
@@ -881,7 +914,7 @@ public class RestoreOrganizationUserCommandTests
});
// Act
var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id, orgUser2.Id, orgUser3.Id], owner.Id, userService);
var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id, orgUser2.Id, orgUser3.Id], owner.Id, userService, null);
// Assert
Assert.Equal(3, result.Count);
@@ -959,7 +992,7 @@ public class RestoreOrganizationUserCommandTests
});
// Act
var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id, orgUser2.Id, orgUser3.Id], owner.Id, userService);
var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id, orgUser2.Id, orgUser3.Id], owner.Id, userService, null);
// Assert
Assert.Equal(3, result.Count);
@@ -1023,7 +1056,7 @@ public class RestoreOrganizationUserCommandTests
});
// Act
var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id], owner.Id, userService);
var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id], owner.Id, userService, null);
// Assert
Assert.Single(result);
@@ -1074,7 +1107,7 @@ public class RestoreOrganizationUserCommandTests
.Returns([new OrganizationUserPolicyDetails { OrganizationId = organization.Id, PolicyType = PolicyType.TwoFactorAuthentication }]);
// Act
var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id], owner.Id, userService);
var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id], owner.Id, userService, null);
Assert.Single(result);
Assert.Equal(string.Empty, result[0].Item2);
@@ -1105,5 +1138,408 @@ public class RestoreOrganizationUserCommandTests
sutProvider.GetDependency<ICurrentContext>().OrganizationOwner(organization.Id).Returns(requestingOrganizationUser != null && requestingOrganizationUser.Type is OrganizationUserType.Owner);
sutProvider.GetDependency<ICurrentContext>().ManageUsers(organization.Id).Returns(requestingOrganizationUser != null && (requestingOrganizationUser.Type is OrganizationUserType.Owner or OrganizationUserType.Admin));
// Setup default disabled OrganizationDataOwnershipPolicyRequirement for any user
sutProvider.GetDependency<IPolicyRequirementQuery>()
.GetAsync<OrganizationDataOwnershipPolicyRequirement>(Arg.Any<Guid>())
.Returns(new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Disabled, []));
}
private static void SetupOrganizationDataOwnershipPolicy(
SutProvider<RestoreOrganizationUserCommand> sutProvider,
Guid userId,
Guid organizationId,
OrganizationUserStatusType orgUserStatus,
bool policyEnabled)
{
var policyDetails = policyEnabled
? new List<PolicyDetails>
{
new()
{
OrganizationId = organizationId,
OrganizationUserId = Guid.NewGuid(),
OrganizationUserStatus = orgUserStatus,
PolicyType = PolicyType.OrganizationDataOwnership
}
}
: new List<PolicyDetails>();
var policyRequirement = new OrganizationDataOwnershipPolicyRequirement(
policyEnabled ? OrganizationDataOwnershipState.Enabled : OrganizationDataOwnershipState.Disabled,
policyDetails);
sutProvider.GetDependency<IPolicyRequirementQuery>()
.GetAsync<OrganizationDataOwnershipPolicyRequirement>(userId)
.Returns(policyRequirement);
}
#region Single User Restore - Default Collection Tests
[Theory, BitAutoData]
public async Task RestoreUser_WithDataOwnershipPolicyEnabled_AndConfirmedUser_CreatesDefaultCollection(
Organization organization,
[OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner,
[OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser,
string defaultCollectionName,
SutProvider<RestoreOrganizationUserCommand> sutProvider)
{
// Arrange
organizationUser.Email = null; // This causes user to restore to Confirmed status
RestoreUser_Setup(organization, owner, organizationUser, sutProvider);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.DefaultUserCollectionRestore)
.Returns(true);
SetupOrganizationDataOwnershipPolicy(
sutProvider,
organizationUser.UserId!.Value,
organization.Id,
OrganizationUserStatusType.Revoked,
policyEnabled: true);
// Act
await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, defaultCollectionName);
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.Received(1)
.CreateDefaultCollectionsAsync(
organization.Id,
Arg.Is<IEnumerable<Guid>>(ids => ids.Single() == organizationUser.Id),
defaultCollectionName);
}
[Theory, BitAutoData]
public async Task RestoreUser_WithDataOwnershipPolicyDisabled_DoesNotCreateDefaultCollection(
Organization organization,
[OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner,
[OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser,
string defaultCollectionName,
SutProvider<RestoreOrganizationUserCommand> sutProvider)
{
// Arrange
organizationUser.Email = null; // This causes user to restore to Confirmed status
RestoreUser_Setup(organization, owner, organizationUser, sutProvider);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.DefaultUserCollectionRestore)
.Returns(true);
SetupOrganizationDataOwnershipPolicy(
sutProvider,
organizationUser.UserId!.Value,
organization.Id,
OrganizationUserStatusType.Revoked,
policyEnabled: false);
// Act
await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, defaultCollectionName);
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceive()
.CreateDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
}
[Theory, BitAutoData]
public async Task RestoreUser_WithNullDefaultCollectionName_DoesNotCreateDefaultCollection(
Organization organization,
[OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner,
[OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser,
SutProvider<RestoreOrganizationUserCommand> sutProvider)
{
// Arrange
organizationUser.Email = null; // This causes user to restore to Confirmed status
RestoreUser_Setup(organization, owner, organizationUser, sutProvider);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.DefaultUserCollectionRestore)
.Returns(true);
SetupOrganizationDataOwnershipPolicy(
sutProvider,
organizationUser.UserId!.Value,
organization.Id,
OrganizationUserStatusType.Revoked,
policyEnabled: true);
// Act
await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null);
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceive()
.CreateDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
}
[Theory]
[BitAutoData("")]
[BitAutoData(" ")]
public async Task RestoreUser_WithEmptyOrWhitespaceDefaultCollectionName_DoesNotCreateDefaultCollection(
string defaultCollectionName,
Organization organization,
[OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner,
[OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser,
SutProvider<RestoreOrganizationUserCommand> sutProvider)
{
// Arrange
organizationUser.Email = null; // This causes user to restore to Confirmed status
RestoreUser_Setup(organization, owner, organizationUser, sutProvider);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.DefaultUserCollectionRestore)
.Returns(true);
SetupOrganizationDataOwnershipPolicy(
sutProvider,
organizationUser.UserId!.Value,
organization.Id,
OrganizationUserStatusType.Revoked,
policyEnabled: true);
// Act
await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, defaultCollectionName);
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceive()
.CreateDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
}
[Theory, BitAutoData]
public async Task RestoreUser_UserRestoredToInvitedStatus_DoesNotCreateDefaultCollection(
Organization organization,
[OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner,
[OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser,
string defaultCollectionName,
SutProvider<RestoreOrganizationUserCommand> sutProvider)
{
// Arrange
organization.PlanType = PlanType.EnterpriseAnnually; // Non-Free plan to avoid ownership check requiring UserId
organizationUser.Email = "test@example.com"; // Non-null email means user restores to Invited status
organizationUser.UserId = null; // User not linked to account yet
organizationUser.Key = null;
RestoreUser_Setup(organization, owner, organizationUser, sutProvider);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.DefaultUserCollectionRestore)
.Returns(true);
// Act
await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, defaultCollectionName);
// Assert - User was restored to Invited status, so no collection should be created
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceive()
.CreateDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
}
[Theory, BitAutoData]
public async Task RestoreUser_WithNoUserId_DoesNotCreateDefaultCollection(
Organization organization,
[OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner,
[OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser,
string defaultCollectionName,
SutProvider<RestoreOrganizationUserCommand> sutProvider)
{
// Arrange
organization.PlanType = PlanType.EnterpriseAnnually; // Non-Free plan to avoid ownership check requiring UserId
organizationUser.UserId = null; // No linked user account
organizationUser.Email = "test@example.com";
organizationUser.Key = null;
RestoreUser_Setup(organization, owner, organizationUser, sutProvider);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.DefaultUserCollectionRestore)
.Returns(true);
// Act
await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, defaultCollectionName);
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceive()
.CreateDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
}
#endregion
#region Bulk User Restore - Default Collection Tests
[Theory, BitAutoData]
public async Task RestoreUsers_Bulk_WithDataOwnershipPolicy_CreatesCollectionsForEligibleUsers(
Organization organization,
[OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner,
[OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser orgUser1,
[OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser orgUser2,
string defaultCollectionName,
SutProvider<RestoreOrganizationUserCommand> sutProvider)
{
// Arrange
RestoreUser_Setup(organization, owner, orgUser1, sutProvider);
var organizationUserRepository = sutProvider.GetDependency<IOrganizationUserRepository>();
var userService = Substitute.For<IUserService>();
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.DefaultUserCollectionRestore)
.Returns(true);
// orgUser1: Will restore to Confirmed (Email = null)
orgUser1.Email = null;
orgUser1.OrganizationId = organization.Id;
// orgUser2: Will restore to Invited (Email not null)
orgUser2.Email = "test@example.com";
orgUser2.UserId = null;
orgUser2.Key = null;
orgUser2.OrganizationId = organization.Id;
organizationUserRepository
.GetManyAsync(Arg.Is<IEnumerable<Guid>>(ids => ids.Contains(orgUser1.Id) && ids.Contains(orgUser2.Id)))
.Returns([orgUser1, orgUser2]);
// Setup bulk policy query - returns org user IDs with policy enabled
sutProvider.GetDependency<IPolicyRequirementQuery>()
.GetManyByOrganizationIdAsync<OrganizationDataOwnershipPolicyRequirement>(organization.Id)
.Returns([orgUser1.Id]);
sutProvider.GetDependency<ITwoFactorIsEnabledQuery>()
.TwoFactorIsEnabledAsync(Arg.Any<IEnumerable<Guid>>())
.Returns(new List<(Guid userId, bool twoFactorIsEnabled)>
{
(orgUser1.UserId!.Value, true)
});
// Act
var result = await sutProvider.Sut.RestoreUsersAsync(
organization.Id,
[orgUser1.Id, orgUser2.Id],
owner.Id,
userService,
defaultCollectionName);
// Assert - Only orgUser1 should have a collection created (Confirmed with policy enabled)
await sutProvider.GetDependency<ICollectionRepository>()
.Received(1)
.CreateDefaultCollectionsAsync(
organization.Id,
Arg.Is<IEnumerable<Guid>>(ids => ids.Single() == orgUser1.Id),
defaultCollectionName);
}
[Theory, BitAutoData]
public async Task RestoreUsers_Bulk_WithMixedPolicyStates_OnlyCreatesForEnabledPolicy(
Organization organization,
[OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner,
[OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser orgUser1,
[OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser orgUser2,
string defaultCollectionName,
SutProvider<RestoreOrganizationUserCommand> sutProvider)
{
// Arrange
RestoreUser_Setup(organization, owner, orgUser1, sutProvider);
var organizationUserRepository = sutProvider.GetDependency<IOrganizationUserRepository>();
var userService = Substitute.For<IUserService>();
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.DefaultUserCollectionRestore)
.Returns(true);
// Both users will restore to Confirmed
orgUser1.Email = null;
orgUser1.OrganizationId = organization.Id;
orgUser2.Email = null;
orgUser2.OrganizationId = organization.Id;
organizationUserRepository
.GetManyAsync(Arg.Is<IEnumerable<Guid>>(ids => ids.Contains(orgUser1.Id) && ids.Contains(orgUser2.Id)))
.Returns([orgUser1, orgUser2]);
// Setup bulk policy query - only orgUser1 has policy enabled
sutProvider.GetDependency<IPolicyRequirementQuery>()
.GetManyByOrganizationIdAsync<OrganizationDataOwnershipPolicyRequirement>(organization.Id)
.Returns([orgUser1.Id]);
sutProvider.GetDependency<ITwoFactorIsEnabledQuery>()
.TwoFactorIsEnabledAsync(Arg.Any<IEnumerable<Guid>>())
.Returns(new List<(Guid userId, bool twoFactorIsEnabled)>
{
(orgUser1.UserId!.Value, true),
(orgUser2.UserId!.Value, true)
});
// Act
var result = await sutProvider.Sut.RestoreUsersAsync(
organization.Id,
[orgUser1.Id, orgUser2.Id],
owner.Id,
userService,
defaultCollectionName);
// Assert - Only orgUser1 should have a collection created (policy enabled)
await sutProvider.GetDependency<ICollectionRepository>()
.Received(1)
.CreateDefaultCollectionsAsync(
organization.Id,
Arg.Is<IEnumerable<Guid>>(ids => ids.Single() == orgUser1.Id),
defaultCollectionName);
}
[Theory, BitAutoData]
public async Task RestoreUsers_Bulk_WithNullCollectionName_DoesNotCreateAnyCollections(
Organization organization,
[OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner,
[OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser orgUser1,
[OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser orgUser2,
SutProvider<RestoreOrganizationUserCommand> sutProvider)
{
// Arrange
RestoreUser_Setup(organization, owner, orgUser1, sutProvider);
var organizationUserRepository = sutProvider.GetDependency<IOrganizationUserRepository>();
var userService = Substitute.For<IUserService>();
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.DefaultUserCollectionRestore)
.Returns(true);
// Both users will restore to Confirmed
orgUser1.Email = null;
orgUser1.OrganizationId = organization.Id;
orgUser2.Email = null;
orgUser2.OrganizationId = organization.Id;
organizationUserRepository
.GetManyAsync(Arg.Is<IEnumerable<Guid>>(ids => ids.Contains(orgUser1.Id) && ids.Contains(orgUser2.Id)))
.Returns([orgUser1, orgUser2]);
// Setup bulk policy query - both users have policy enabled
sutProvider.GetDependency<IPolicyRequirementQuery>()
.GetManyByOrganizationIdAsync<OrganizationDataOwnershipPolicyRequirement>(organization.Id)
.Returns([orgUser1.Id, orgUser2.Id]);
sutProvider.GetDependency<ITwoFactorIsEnabledQuery>()
.TwoFactorIsEnabledAsync(Arg.Any<IEnumerable<Guid>>())
.Returns(new List<(Guid userId, bool twoFactorIsEnabled)>
{
(orgUser1.UserId!.Value, true),
(orgUser2.UserId!.Value, true)
});
// Act
var result = await sutProvider.Sut.RestoreUsersAsync(
organization.Id,
[orgUser1.Id, orgUser2.Id],
owner.Id,
userService,
null); // Null collection name
// Assert - No collections should be created
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceive()
.CreateDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
}
#endregion
}

View File

@@ -38,7 +38,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceive()
.UpsertDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
.CreateDefaultCollectionsBulkAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
}
[Theory, BitAutoData]
@@ -60,7 +60,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceive()
.UpsertDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
.CreateDefaultCollectionsBulkAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
}
[Theory, BitAutoData]
@@ -86,7 +86,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await collectionRepository
.DidNotReceive()
.UpsertDefaultCollectionsAsync(
.CreateDefaultCollectionsBulkAsync(
Arg.Any<Guid>(),
Arg.Any<IEnumerable<Guid>>(),
Arg.Any<string>());
@@ -172,10 +172,10 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Act
await sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
// Assert
// Assert - Should call with all user IDs (repository does internal filtering)
await collectionRepository
.Received(1)
.UpsertDefaultCollectionsAsync(
.CreateDefaultCollectionsBulkAsync(
policyUpdate.OrganizationId,
Arg.Is<IEnumerable<Guid>>(ids => ids.Count() == 3),
_defaultUserCollectionName);
@@ -210,7 +210,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceive()
.UpsertDefaultCollectionsAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
.CreateDefaultCollectionsBulkAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
}
private static IPolicyRepository ArrangePolicyRepository(IEnumerable<OrganizationPolicyDetails> policyDetails)
@@ -251,7 +251,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceiveWithAnyArgs()
.UpsertDefaultCollectionsAsync(default, default, default);
.CreateDefaultCollectionsBulkAsync(default, default, default);
}
[Theory, BitAutoData]
@@ -273,7 +273,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceiveWithAnyArgs()
.UpsertDefaultCollectionsAsync(default, default, default);
.CreateDefaultCollectionsBulkAsync(default, default, default);
}
[Theory, BitAutoData]
@@ -299,7 +299,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await collectionRepository
.DidNotReceiveWithAnyArgs()
.UpsertDefaultCollectionsAsync(
.CreateDefaultCollectionsBulkAsync(
default,
default,
default);
@@ -336,10 +336,10 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Act
await sut.ExecutePostUpsertSideEffectAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
// Assert
// Assert - Should call with all user IDs (repository does internal filtering)
await collectionRepository
.Received(1)
.UpsertDefaultCollectionsAsync(
.CreateDefaultCollectionsBulkAsync(
policyUpdate.OrganizationId,
Arg.Is<IEnumerable<Guid>>(ids => ids.Count() == 3),
_defaultUserCollectionName);
@@ -367,6 +367,6 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceiveWithAnyArgs()
.UpsertDefaultCollectionsAsync(default, default, default);
.CreateDefaultCollectionsBulkAsync(default, default, default);
}
}

View File

@@ -19,12 +19,17 @@ public class PolicyDataValidatorTests
[Fact]
public void ValidateAndSerialize_ValidData_ReturnsSerializedJson()
{
var data = new Dictionary<string, object> { { "minLength", 12 } };
var data = new Dictionary<string, object>
{
{ "minLength", 12 },
{ "minComplexity", 4 }
};
var result = PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword);
Assert.NotNull(result);
Assert.Contains("\"minLength\":12", result);
Assert.Contains("\"minComplexity\":4", result);
}
[Fact]
@@ -56,4 +61,122 @@ public class PolicyDataValidatorTests
Assert.IsType<OrganizationModelOwnershipPolicyModel>(result);
}
[Fact]
public void ValidateAndSerialize_ExcessiveMinLength_ThrowsBadRequestException()
{
var data = new Dictionary<string, object> { { "minLength", 129 } };
var exception = Assert.Throws<BadRequestException>(() =>
PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword));
Assert.Contains("Invalid data for MasterPassword policy", exception.Message);
}
[Fact]
public void ValidateAndSerialize_ExcessiveMinComplexity_ThrowsBadRequestException()
{
var data = new Dictionary<string, object> { { "minComplexity", 5 } };
var exception = Assert.Throws<BadRequestException>(() =>
PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword));
Assert.Contains("Invalid data for MasterPassword policy", exception.Message);
}
[Fact]
public void ValidateAndSerialize_MinLengthAtMinimum_Succeeds()
{
var data = new Dictionary<string, object> { { "minLength", 12 } };
var result = PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword);
Assert.NotNull(result);
Assert.Contains("\"minLength\":12", result);
}
[Fact]
public void ValidateAndSerialize_MinLengthAtMaximum_Succeeds()
{
var data = new Dictionary<string, object> { { "minLength", 128 } };
var result = PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword);
Assert.NotNull(result);
Assert.Contains("\"minLength\":128", result);
}
[Fact]
public void ValidateAndSerialize_MinLengthBelowMinimum_ThrowsBadRequestException()
{
var data = new Dictionary<string, object> { { "minLength", 11 } };
var exception = Assert.Throws<BadRequestException>(() =>
PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword));
Assert.Contains("Invalid data for MasterPassword policy", exception.Message);
}
[Fact]
public void ValidateAndSerialize_MinComplexityAtMinimum_Succeeds()
{
var data = new Dictionary<string, object> { { "minComplexity", 0 } };
var result = PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword);
Assert.NotNull(result);
Assert.Contains("\"minComplexity\":0", result);
}
[Fact]
public void ValidateAndSerialize_MinComplexityAtMaximum_Succeeds()
{
var data = new Dictionary<string, object> { { "minComplexity", 4 } };
var result = PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword);
Assert.NotNull(result);
Assert.Contains("\"minComplexity\":4", result);
}
[Fact]
public void ValidateAndSerialize_MinComplexityBelowMinimum_ThrowsBadRequestException()
{
var data = new Dictionary<string, object> { { "minComplexity", -1 } };
var exception = Assert.Throws<BadRequestException>(() =>
PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword));
Assert.Contains("Invalid data for MasterPassword policy", exception.Message);
}
[Fact]
public void ValidateAndSerialize_NullMinLength_Succeeds()
{
var data = new Dictionary<string, object>
{
{ "minComplexity", 2 }
// minLength is omitted, should be null
};
var result = PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword);
Assert.NotNull(result);
Assert.Contains("\"minComplexity\":2", result);
}
[Fact]
public void ValidateAndSerialize_MultipleInvalidFields_ThrowsBadRequestException()
{
var data = new Dictionary<string, object>
{
{ "minLength", 200 },
{ "minComplexity", 10 }
};
var exception = Assert.Throws<BadRequestException>(() =>
PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword));
Assert.Contains("Invalid data for MasterPassword policy", exception.Message);
}
}

View File

@@ -29,7 +29,9 @@ internal class RegisterFinishRequestModelCustomization : ICustomization
.With(o => o.OrgInviteToken, OrgInviteToken)
.With(o => o.OrgSponsoredFreeFamilyPlanToken, OrgSponsoredFreeFamilyPlanToken)
.With(o => o.AcceptEmergencyAccessInviteToken, AcceptEmergencyAccessInviteToken)
.With(o => o.ProviderInviteToken, ProviderInviteToken));
.With(o => o.ProviderInviteToken, ProviderInviteToken)
.Without(o => o.MasterPasswordAuthentication)
.Without(o => o.MasterPasswordUnlock));
}
}

View File

@@ -1,5 +1,6 @@
using Bit.Core.Auth.Models.Api.Request.Accounts;
using Bit.Core.Enums;
using Bit.Core.KeyManagement.Models.Api.Request;
using Bit.Test.Common.AutoFixture.Attributes;
using Xunit;
@@ -7,6 +8,17 @@ namespace Bit.Core.Test.Auth.Models.Api.Request.Accounts;
public class RegisterFinishRequestModelTests
{
private static List<System.ComponentModel.DataAnnotations.ValidationResult> Validate(RegisterFinishRequestModel model)
{
var results = new List<System.ComponentModel.DataAnnotations.ValidationResult>();
System.ComponentModel.DataAnnotations.Validator.TryValidateObject(
model,
new System.ComponentModel.DataAnnotations.ValidationContext(model),
results,
true);
return results;
}
[Theory]
[BitAutoData]
public void GetTokenType_Returns_EmailVerification(string email, string masterPasswordHash,
@@ -170,4 +182,175 @@ public class RegisterFinishRequestModelTests
Assert.Equal(userAsymmetricKeys.PublicKey, result.PublicKey);
Assert.Equal(userAsymmetricKeys.EncryptedPrivateKey, result.PrivateKey);
}
[Fact]
public void Validate_WhenBothAuthAndRootHashProvidedButNotEqual_ReturnsMismatchError()
{
var model = new RegisterFinishRequestModel
{
Email = "user@example.com",
MasterPasswordHash = "root-hash",
UserAsymmetricKeys = new KeysRequestModel { PublicKey = "pk", EncryptedPrivateKey = "sk" },
// Provide both unlock and authentication with valid KDF so only the mismatch rule fires
MasterPasswordUnlock = new MasterPasswordUnlockDataRequestModel
{
Kdf = new KdfRequestModel { KdfType = KdfType.PBKDF2_SHA256, Iterations = AuthConstants.PBKDF2_ITERATIONS.Default },
MasterKeyWrappedUserKey = "wrapped",
Salt = "salt"
},
MasterPasswordAuthentication = new MasterPasswordAuthenticationDataRequestModel
{
Kdf = new KdfRequestModel { KdfType = KdfType.PBKDF2_SHA256, Iterations = AuthConstants.PBKDF2_ITERATIONS.Default },
MasterPasswordAuthenticationHash = "auth-hash", // different than root
Salt = "salt"
},
// Provide any valid token so we don't fail token validation
EmailVerificationToken = "token"
};
var results = Validate(model);
Assert.Contains(results, r =>
r.ErrorMessage == $"{nameof(MasterPasswordAuthenticationDataRequestModel.MasterPasswordAuthenticationHash)} and root level {nameof(RegisterFinishRequestModel.MasterPasswordHash)} provided and are not equal. Only provide one.");
}
[Fact]
public void Validate_WhenAuthProvidedButUnlockMissing_ReturnsUnlockMissingError()
{
var model = new RegisterFinishRequestModel
{
Email = "user@example.com",
UserAsymmetricKeys = new KeysRequestModel { PublicKey = "pk", EncryptedPrivateKey = "sk" },
MasterPasswordAuthentication = new MasterPasswordAuthenticationDataRequestModel
{
Kdf = new KdfRequestModel { KdfType = KdfType.PBKDF2_SHA256, Iterations = AuthConstants.PBKDF2_ITERATIONS.Default },
MasterPasswordAuthenticationHash = "auth-hash",
Salt = "salt"
},
EmailVerificationToken = "token"
};
var results = Validate(model);
Assert.Contains(results, r => r.ErrorMessage == "MasterPasswordUnlock not found on RequestModel");
}
[Fact]
public void Validate_WhenUnlockProvidedButAuthMissing_ReturnsAuthMissingError()
{
var model = new RegisterFinishRequestModel
{
Email = "user@example.com",
UserAsymmetricKeys = new KeysRequestModel { PublicKey = "pk", EncryptedPrivateKey = "sk" },
MasterPasswordUnlock = new MasterPasswordUnlockDataRequestModel
{
Kdf = new KdfRequestModel { KdfType = KdfType.PBKDF2_SHA256, Iterations = AuthConstants.PBKDF2_ITERATIONS.Default },
MasterKeyWrappedUserKey = "wrapped",
Salt = "salt"
},
EmailVerificationToken = "token"
};
var results = Validate(model);
Assert.Contains(results, r => r.ErrorMessage == "MasterPasswordAuthentication not found on RequestModel");
}
[Fact]
public void Validate_WhenNeitherAuthNorUnlock_AndRootKdfMissing_ReturnsBothRootKdfErrors()
{
var model = new RegisterFinishRequestModel
{
Email = "user@example.com",
UserAsymmetricKeys = new KeysRequestModel { PublicKey = "pk", EncryptedPrivateKey = "sk" },
// No MasterPasswordUnlock, no MasterPasswordAuthentication
// No root Kdf and KdfIterations to trigger both errors
EmailVerificationToken = "token"
};
var results = Validate(model);
Assert.Contains(results, r => r.ErrorMessage == $"{nameof(RegisterFinishRequestModel.Kdf)} not found on RequestModel");
Assert.Contains(results, r => r.ErrorMessage == $"{nameof(RegisterFinishRequestModel.KdfIterations)} not found on RequestModel");
}
[Fact]
public void Validate_WhenAuthAndRootHashBothMissing_ReturnsMissingHashErrorOnly()
{
var model = new RegisterFinishRequestModel
{
Email = "user@example.com",
UserAsymmetricKeys = new KeysRequestModel { PublicKey = "pk", EncryptedPrivateKey = "sk" },
// Both MasterPasswordAuthentication and MasterPasswordHash are missing
MasterPasswordAuthentication = null,
MasterPasswordHash = null,
// Provide valid root KDF to avoid root KDF errors
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = AuthConstants.PBKDF2_ITERATIONS.Default,
EmailVerificationToken = "token" // avoid token error
};
var results = Validate(model);
// Only the new missing hash error should be present
Assert.Single(results);
Assert.Equal($"{nameof(MasterPasswordAuthenticationDataRequestModel.MasterPasswordAuthenticationHash)} and {nameof(RegisterFinishRequestModel.MasterPasswordHash)} not found on request, one needs to be defined.", results[0].ErrorMessage);
Assert.Contains(nameof(MasterPasswordAuthenticationDataRequestModel.MasterPasswordAuthenticationHash), results[0].MemberNames);
Assert.Contains(nameof(RegisterFinishRequestModel.MasterPasswordHash), results[0].MemberNames);
}
[Fact]
public void Validate_WhenAllFieldsValidWithSubModels_IsValid()
{
var model = new RegisterFinishRequestModel
{
Email = "user@example.com",
UserAsymmetricKeys = new KeysRequestModel { PublicKey = "pk", EncryptedPrivateKey = "sk" },
MasterPasswordUnlock = new MasterPasswordUnlockDataRequestModel
{
Kdf = new KdfRequestModel { KdfType = KdfType.PBKDF2_SHA256, Iterations = AuthConstants.PBKDF2_ITERATIONS.Default },
MasterKeyWrappedUserKey = "wrapped",
Salt = "salt"
},
MasterPasswordAuthentication = new MasterPasswordAuthenticationDataRequestModel
{
Kdf = new KdfRequestModel { KdfType = KdfType.PBKDF2_SHA256, Iterations = AuthConstants.PBKDF2_ITERATIONS.Default },
MasterPasswordAuthenticationHash = "auth-hash",
Salt = "salt"
},
EmailVerificationToken = "token"
};
var results = Validate(model);
Assert.Empty(results);
}
[Fact]
public void Validate_WhenNoValidRegistrationTokenProvided_ReturnsTokenErrorOnly()
{
var model = new RegisterFinishRequestModel
{
Email = "user@example.com",
UserAsymmetricKeys = new KeysRequestModel { PublicKey = "pk", EncryptedPrivateKey = "sk" },
MasterPasswordUnlock = new MasterPasswordUnlockDataRequestModel
{
Kdf = new KdfRequestModel { KdfType = KdfType.PBKDF2_SHA256, Iterations = AuthConstants.PBKDF2_ITERATIONS.Default },
MasterKeyWrappedUserKey = "wrapped",
Salt = "salt"
},
MasterPasswordAuthentication = new MasterPasswordAuthenticationDataRequestModel
{
Kdf = new KdfRequestModel { KdfType = KdfType.PBKDF2_SHA256, Iterations = AuthConstants.PBKDF2_ITERATIONS.Default },
MasterPasswordAuthenticationHash = "auth-hash",
Salt = "salt"
}
// No token fields set
};
var results = Validate(model);
Assert.Single(results);
Assert.Equal("No valid registration token provided", results[0].ErrorMessage);
}
}

View File

@@ -2,9 +2,9 @@
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models.Data;
@@ -13,6 +13,7 @@ using Bit.Core.Auth.Services;
using Bit.Core.Exceptions;
using Bit.Core.Models.Data.Organizations.OrganizationUsers;
using Bit.Core.Repositories;
using Bit.Core.Test.AdminConsole.AutoFixture;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
@@ -163,7 +164,8 @@ public class SsoConfigServiceTests
[Theory, BitAutoData]
public async Task SaveAsync_KeyConnector_SingleOrgNotEnabled_Throws(SutProvider<SsoConfigService> sutProvider,
Organization organization)
Organization organization,
[Policy(PolicyType.SingleOrg, false)] PolicyStatus policy)
{
var utcNow = DateTime.UtcNow;
@@ -180,6 +182,9 @@ public class SsoConfigServiceTests
RevisionDate = utcNow.AddDays(-10),
};
sutProvider.GetDependency<IPolicyQuery>().RunAsync(
Arg.Any<Guid>(), PolicyType.SingleOrg).Returns(policy);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.SaveAsync(ssoConfig, organization));
@@ -191,7 +196,9 @@ public class SsoConfigServiceTests
[Theory, BitAutoData]
public async Task SaveAsync_KeyConnector_SsoPolicyNotEnabled_Throws(SutProvider<SsoConfigService> sutProvider,
Organization organization)
Organization organization,
[Policy(PolicyType.SingleOrg, true)] PolicyStatus singleOrgPolicy,
[Policy(PolicyType.RequireSso, false)] PolicyStatus requireSsoPolicy)
{
var utcNow = DateTime.UtcNow;
@@ -208,11 +215,10 @@ public class SsoConfigServiceTests
RevisionDate = utcNow.AddDays(-10),
};
sutProvider.GetDependency<IPolicyRepository>().GetByOrganizationIdTypeAsync(
Arg.Any<Guid>(), PolicyType.SingleOrg).Returns(new Policy
{
Enabled = true
});
sutProvider.GetDependency<IPolicyQuery>().RunAsync(
Arg.Any<Guid>(), PolicyType.SingleOrg).Returns(singleOrgPolicy);
sutProvider.GetDependency<IPolicyQuery>().RunAsync(
Arg.Any<Guid>(), PolicyType.RequireSso).Returns(requireSsoPolicy);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.SaveAsync(ssoConfig, organization));
@@ -225,7 +231,8 @@ public class SsoConfigServiceTests
[Theory, BitAutoData]
public async Task SaveAsync_KeyConnector_SsoConfigNotEnabled_Throws(SutProvider<SsoConfigService> sutProvider,
Organization organization)
Organization organization,
[Policy(PolicyType.SingleOrg, true)] PolicyStatus policy)
{
var utcNow = DateTime.UtcNow;
@@ -242,11 +249,8 @@ public class SsoConfigServiceTests
RevisionDate = utcNow.AddDays(-10),
};
sutProvider.GetDependency<IPolicyRepository>().GetByOrganizationIdTypeAsync(
Arg.Any<Guid>(), Arg.Any<PolicyType>()).Returns(new Policy
{
Enabled = true
});
sutProvider.GetDependency<IPolicyQuery>().RunAsync(
Arg.Any<Guid>(), Arg.Any<PolicyType>()).Returns(policy);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.SaveAsync(ssoConfig, organization));
@@ -259,7 +263,8 @@ public class SsoConfigServiceTests
[Theory, BitAutoData]
public async Task SaveAsync_KeyConnector_KeyConnectorAbilityNotEnabled_Throws(SutProvider<SsoConfigService> sutProvider,
Organization organization)
Organization organization,
[Policy(PolicyType.SingleOrg, true)] PolicyStatus policy)
{
var utcNow = DateTime.UtcNow;
@@ -277,11 +282,8 @@ public class SsoConfigServiceTests
RevisionDate = utcNow.AddDays(-10),
};
sutProvider.GetDependency<IPolicyRepository>().GetByOrganizationIdTypeAsync(
Arg.Any<Guid>(), Arg.Any<PolicyType>()).Returns(new Policy
{
Enabled = true,
});
sutProvider.GetDependency<IPolicyQuery>().RunAsync(
Arg.Any<Guid>(), Arg.Any<PolicyType>()).Returns(policy);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.SaveAsync(ssoConfig, organization));
@@ -294,7 +296,8 @@ public class SsoConfigServiceTests
[Theory, BitAutoData]
public async Task SaveAsync_KeyConnector_Success(SutProvider<SsoConfigService> sutProvider,
Organization organization)
Organization organization,
[Policy(PolicyType.SingleOrg, true)] PolicyStatus policy)
{
var utcNow = DateTime.UtcNow;
@@ -312,11 +315,8 @@ public class SsoConfigServiceTests
RevisionDate = utcNow.AddDays(-10),
};
sutProvider.GetDependency<IPolicyRepository>().GetByOrganizationIdTypeAsync(
Arg.Any<Guid>(), Arg.Any<PolicyType>()).Returns(new Policy
{
Enabled = true,
});
sutProvider.GetDependency<IPolicyQuery>().RunAsync(
Arg.Any<Guid>(), Arg.Any<PolicyType>()).Returns(policy);
await sutProvider.Sut.SaveAsync(ssoConfig, organization);

View File

@@ -0,0 +1,253 @@
using Bit.Core.Auth.Models.Data;
using Bit.Core.Auth.UserFeatures.EmergencyAccess.Commands;
using Bit.Core.Auth.UserFeatures.EmergencyAccess.Mail;
using Bit.Core.Exceptions;
using Bit.Core.Platform.Mail.Mailer;
using Bit.Core.Repositories;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Auth.UserFeatures.EmergencyAccess;
[SutProviderCustomize]
public class DeleteEmergencyAccessCommandTests
{
/// <summary>
/// Verifies that attempting to delete a non-existent emergency access record
/// throws a <see cref="BadRequestException"/> and does not call delete or send email.
/// </summary>
[Theory, BitAutoData]
public async Task DeleteByIdGrantorIdAsync_EmergencyAccessNotFound_ThrowsBadRequest(
SutProvider<DeleteEmergencyAccessCommand> sutProvider,
Guid emergencyAccessId,
Guid grantorId)
{
sutProvider.GetDependency<IEmergencyAccessRepository>()
.GetDetailsByIdGrantorIdAsync(emergencyAccessId, grantorId)
.Returns((EmergencyAccessDetails)null);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.DeleteByIdGrantorIdAsync(emergencyAccessId, grantorId));
Assert.Contains("Emergency Access not valid.", exception.Message);
await sutProvider.GetDependency<IEmergencyAccessRepository>()
.DidNotReceiveWithAnyArgs()
.DeleteAsync(default);
await sutProvider.GetDependency<IMailer>()
.DidNotReceiveWithAnyArgs()
.SendEmail<EmergencyAccessRemoveGranteesMailView>(default);
}
/// <summary>
/// Verifies successful deletion of an emergency access record by ID and grantor ID,
/// and ensures that a notification email is sent to the grantor.
/// </summary>
[Theory, BitAutoData]
public async Task DeleteByIdGrantorIdAsync_DeletesEmergencyAccessAndSendsEmail(
SutProvider<DeleteEmergencyAccessCommand> sutProvider,
EmergencyAccessDetails emergencyAccessDetails)
{
sutProvider.GetDependency<IEmergencyAccessRepository>()
.GetDetailsByIdGrantorIdAsync(emergencyAccessDetails.Id, emergencyAccessDetails.GrantorId)
.Returns(emergencyAccessDetails);
var result = await sutProvider.Sut.DeleteByIdGrantorIdAsync(emergencyAccessDetails.Id, emergencyAccessDetails.GrantorId);
await sutProvider.GetDependency<IEmergencyAccessRepository>()
.Received(1)
.DeleteManyAsync(Arg.Any<ICollection<Guid>>());
await sutProvider.GetDependency<IMailer>()
.Received(1)
.SendEmail(Arg.Any<EmergencyAccessRemoveGranteesMail>());
}
/// <summary>
/// Verifies that when a grantor has no emergency access records, the method returns
/// an empty collection and does not attempt to delete or send email.
/// </summary>
[Theory, BitAutoData]
public async Task DeleteAllByGrantorIdAsync_NoEmergencyAccessRecords_ReturnsEmptyCollection(
SutProvider<DeleteEmergencyAccessCommand> sutProvider,
Guid grantorId)
{
sutProvider.GetDependency<IEmergencyAccessRepository>()
.GetManyDetailsByGrantorIdAsync(grantorId)
.Returns([]);
var result = await sutProvider.Sut.DeleteAllByGrantorIdAsync(grantorId);
Assert.NotNull(result);
Assert.Empty(result);
await sutProvider.GetDependency<IEmergencyAccessRepository>()
.DidNotReceiveWithAnyArgs()
.DeleteManyAsync(default);
await sutProvider.GetDependency<IMailer>()
.DidNotReceiveWithAnyArgs()
.SendEmail<EmergencyAccessRemoveGranteesMailView>(default);
}
/// <summary>
/// Verifies that when a grantor has multiple emergency access records, all records are deleted,
/// the details are returned, and a single notification email is sent to the grantor.
/// </summary>
[Theory, BitAutoData]
public async Task DeleteAllByGrantorIdAsync_MultipleRecords_DeletesAllReturnsDetailsSendsSingleEmail(
SutProvider<DeleteEmergencyAccessCommand> sutProvider,
EmergencyAccessDetails emergencyAccessDetails1,
EmergencyAccessDetails emergencyAccessDetails2,
EmergencyAccessDetails emergencyAccessDetails3)
{
// Arrange
// link all details to the same grantor
emergencyAccessDetails2.GrantorId = emergencyAccessDetails1.GrantorId;
emergencyAccessDetails2.GrantorEmail = emergencyAccessDetails1.GrantorEmail;
emergencyAccessDetails3.GrantorId = emergencyAccessDetails1.GrantorId;
emergencyAccessDetails3.GrantorEmail = emergencyAccessDetails1.GrantorEmail;
var allDetails = new List<EmergencyAccessDetails>
{
emergencyAccessDetails1,
emergencyAccessDetails2,
emergencyAccessDetails3
};
sutProvider.GetDependency<IEmergencyAccessRepository>()
.GetManyDetailsByGrantorIdAsync(emergencyAccessDetails1.GrantorId)
.Returns(allDetails);
// Act
var result = await sutProvider.Sut.DeleteAllByGrantorIdAsync(emergencyAccessDetails1.GrantorId);
// Assert
Assert.NotNull(result);
Assert.Equal(3, result.Count);
await sutProvider.GetDependency<IEmergencyAccessRepository>()
.Received(1)
.DeleteManyAsync(Arg.Any<ICollection<Guid>>());
await sutProvider.GetDependency<IMailer>()
.Received(1)
.SendEmail(Arg.Any<EmergencyAccessRemoveGranteesMail>());
}
/// <summary>
/// Verifies that when a grantor has a single emergency access record, it is deleted,
/// the details are returned, and a notification email is sent.
/// </summary>
[Theory, BitAutoData]
public async Task DeleteAllByGrantorIdAsync_SingleRecord_DeletesAndReturnsDetailsSendsSingleEmail(
SutProvider<DeleteEmergencyAccessCommand> sutProvider,
EmergencyAccessDetails emergencyAccessDetails,
Guid grantorId)
{
sutProvider.GetDependency<IEmergencyAccessRepository>()
.GetManyDetailsByGrantorIdAsync(grantorId)
.Returns([emergencyAccessDetails]);
var result = await sutProvider.Sut.DeleteAllByGrantorIdAsync(grantorId);
Assert.NotNull(result);
Assert.Single(result);
Assert.Equal(emergencyAccessDetails.Id, result.First().Id);
await sutProvider.GetDependency<IEmergencyAccessRepository>()
.Received(1)
.DeleteManyAsync(Arg.Any<ICollection<Guid>>());
await sutProvider.GetDependency<IMailer>()
.Received(1)
.SendEmail(Arg.Any<EmergencyAccessRemoveGranteesMail>());
}
/// <summary>
/// Verifies that when a grantee has no emergency access records, the method returns
/// an empty collection and does not attempt to delete or send email.
/// </summary>
[Theory, BitAutoData]
public async Task DeleteAllByGranteeIdAsync_NoEmergencyAccessRecords_ReturnsEmptyCollection(
SutProvider<DeleteEmergencyAccessCommand> sutProvider,
Guid granteeId)
{
sutProvider.GetDependency<IEmergencyAccessRepository>()
.GetManyDetailsByGranteeIdAsync(granteeId)
.Returns([]);
var result = await sutProvider.Sut.DeleteAllByGranteeIdAsync(granteeId);
Assert.NotNull(result);
Assert.Empty(result);
await sutProvider.GetDependency<IEmergencyAccessRepository>()
.DidNotReceiveWithAnyArgs()
.DeleteManyAsync(default);
await sutProvider.GetDependency<IMailer>()
.DidNotReceiveWithAnyArgs()
.SendEmail<EmergencyAccessRemoveGranteesMailView>(default);
}
/// <summary>
/// Verifies that when a grantee has a single emergency access record, it is deleted,
/// the details are returned, and a notification email is sent to the grantor.
/// </summary>
[Theory, BitAutoData]
public async Task DeleteAllByGranteeIdAsync_SingleRecord_DeletesAndReturnsDetailsSendsSingleEmail(
SutProvider<DeleteEmergencyAccessCommand> sutProvider,
EmergencyAccessDetails emergencyAccessDetails,
Guid granteeId)
{
sutProvider.GetDependency<IEmergencyAccessRepository>()
.GetManyDetailsByGranteeIdAsync(granteeId)
.Returns([emergencyAccessDetails]);
var result = await sutProvider.Sut.DeleteAllByGranteeIdAsync(granteeId);
Assert.NotNull(result);
Assert.Single(result);
Assert.Equal(emergencyAccessDetails.Id, result.First().Id);
await sutProvider.GetDependency<IEmergencyAccessRepository>()
.Received(1)
.DeleteManyAsync(Arg.Any<ICollection<Guid>>());
await sutProvider.GetDependency<IMailer>()
.Received(1)
.SendEmail(Arg.Any<EmergencyAccessRemoveGranteesMail>());
}
/// <summary>
/// Verifies that when a grantee has multiple emergency access records from different grantors,
/// all records are deleted, the details are returned, and a single notification email is sent
/// to all affected grantors individually.
/// </summary>
[Theory, BitAutoData]
public async Task DeleteAllByGranteeIdAsync_MultipleRecords_DeletesAllReturnsDetailsSendsMultipleEmails(
SutProvider<DeleteEmergencyAccessCommand> sutProvider,
EmergencyAccessDetails emergencyAccessDetails1,
EmergencyAccessDetails emergencyAccessDetails2,
EmergencyAccessDetails emergencyAccessDetails3)
{
// link all details to the same grantee
emergencyAccessDetails2.GranteeId = emergencyAccessDetails1.GranteeId;
emergencyAccessDetails2.GranteeEmail = emergencyAccessDetails1.GranteeEmail;
emergencyAccessDetails3.GranteeId = emergencyAccessDetails1.GranteeId;
emergencyAccessDetails3.GranteeEmail = emergencyAccessDetails1.GranteeEmail;
var allDetails = new List<EmergencyAccessDetails>
{
emergencyAccessDetails1,
emergencyAccessDetails2,
emergencyAccessDetails3
};
sutProvider.GetDependency<IEmergencyAccessRepository>()
.GetManyDetailsByGranteeIdAsync((Guid)emergencyAccessDetails1.GranteeId)
.Returns(allDetails);
var result = await sutProvider.Sut.DeleteAllByGranteeIdAsync((Guid)emergencyAccessDetails1.GranteeId);
Assert.NotNull(result);
Assert.Equal(3, result.Count);
await sutProvider.GetDependency<IEmergencyAccessRepository>()
.Received(1)
.DeleteManyAsync(Arg.Any<ICollection<Guid>>());
await sutProvider.GetDependency<IMailer>()
.Received(allDetails.Count)
.SendEmail(Arg.Any<EmergencyAccessRemoveGranteesMail>());
}
}

View File

@@ -0,0 +1,153 @@
using Bit.Core.Auth.UserFeatures.EmergencyAccess.Mail;
using Bit.Core.Models.Mail;
using Bit.Core.Platform.Mail.Delivery;
using Bit.Core.Platform.Mail.Mailer;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Xunit;
using GlobalSettings = Bit.Core.Settings.GlobalSettings;
namespace Bit.Core.Test.Auth.UserFeatures.EmergencyAccess;
[SutProviderCustomize]
public class EmergencyAccessMailTests
{
// Constant values for all Emergency Access emails
private const string _emergencyAccessHelpUrl = "https://bitwarden.com/help/emergency-access/";
private const string _emergencyAccessMailSubject = "Emergency contacts removed";
/// <summary>
/// Documents how to construct and send the emergency access removal email.
/// 1. Inject IMailer into their command/service
/// 2. Construct EmergencyAccessRemoveGranteesMail as shown below
/// 3. Call mailer.SendEmail(mail)
/// </summary>
[Theory, BitAutoData]
public async Task SendEmergencyAccessRemoveGranteesEmail_SingleGrantee_Success(
string grantorEmail,
string granteeEmail)
{
// Arrange
var logger = Substitute.For<ILogger<HandlebarMailRenderer>>();
var globalSettings = new GlobalSettings { SelfHosted = false };
var deliveryService = Substitute.For<IMailDeliveryService>();
var mailer = new Mailer(
new HandlebarMailRenderer(logger, globalSettings),
deliveryService);
var mail = new EmergencyAccessRemoveGranteesMail
{
ToEmails = [grantorEmail],
View = new EmergencyAccessRemoveGranteesMailView
{
RemovedGranteeEmails = [granteeEmail]
}
};
MailMessage sentMessage = null;
await deliveryService.SendEmailAsync(Arg.Do<MailMessage>(message =>
sentMessage = message
));
// Act
await mailer.SendEmail(mail);
// Assert
Assert.NotNull(sentMessage);
Assert.Contains(grantorEmail, sentMessage.ToEmails);
// Verify the content contains the grantee name
Assert.Contains(granteeEmail, sentMessage.TextContent);
Assert.Contains(granteeEmail, sentMessage.HtmlContent);
}
/// <summary>
/// Documents handling multiple removed grantees in a single email.
/// </summary>
[Theory, BitAutoData]
public async Task SendEmergencyAccessRemoveGranteesEmail_MultipleGrantees_RendersAllNames(
string grantorEmail)
{
// Arrange
var logger = Substitute.For<ILogger<HandlebarMailRenderer>>();
var globalSettings = new GlobalSettings { SelfHosted = false };
var deliveryService = Substitute.For<IMailDeliveryService>();
var mailer = new Mailer(
new HandlebarMailRenderer(logger, globalSettings),
deliveryService);
var granteeEmails = new[] { "Alice@test.dev", "Bob@test.dev", "Carol@test.dev" };
var mail = new EmergencyAccessRemoveGranteesMail
{
ToEmails = [grantorEmail],
View = new EmergencyAccessRemoveGranteesMailView
{
RemovedGranteeEmails = granteeEmails
}
};
MailMessage sentMessage = null;
await deliveryService.SendEmailAsync(Arg.Do<MailMessage>(message =>
sentMessage = message
));
// Act
await mailer.SendEmail(mail);
// Assert - All grantee names should appear in the email
Assert.NotNull(sentMessage);
foreach (var granteeEmail in granteeEmails)
{
Assert.Contains(granteeEmail, sentMessage.TextContent);
Assert.Contains(granteeEmail, sentMessage.HtmlContent);
}
}
/// <summary>
/// Validates the required GranteeNames for the email view model.
/// </summary>
[Theory, BitAutoData]
public void EmergencyAccessRemoveGranteesMailView_GranteeNames_AreRequired(
string grantorEmail)
{
// Arrange - Shows the minimum required to construct the email
var mail = new EmergencyAccessRemoveGranteesMail
{
ToEmails = [grantorEmail], // Required: who to send to
View = new EmergencyAccessRemoveGranteesMailView
{
// Required: at least one removed grantee name
RemovedGranteeEmails = ["Example Grantee"]
}
};
// Assert
Assert.NotNull(mail);
Assert.NotNull(mail.View);
Assert.NotEmpty(mail.View.RemovedGranteeEmails);
}
/// <summary>
/// Ensure consistency with help pages link and email subject.
/// </summary>
/// <param name="grantorEmail"></param>
/// <param name="granteeName"></param>
[Theory, BitAutoData]
public void EmergencyAccessRemoveGranteesMailView_SubjectAndHelpLink_MatchesExpectedValues(string grantorEmail, string granteeName)
{
// Arrange
var mail = new EmergencyAccessRemoveGranteesMail
{
ToEmails = [grantorEmail],
View = new EmergencyAccessRemoveGranteesMailView { RemovedGranteeEmails = [granteeName] }
};
// Assert
Assert.NotNull(mail);
Assert.NotNull(mail.View);
Assert.Equal(_emergencyAccessMailSubject, mail.Subject);
Assert.Equal(_emergencyAccessHelpUrl, EmergencyAccessRemoveGranteesMailView.EmergencyAccessHelpPageUrl);
}
}

View File

@@ -1,11 +1,10 @@
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models;
using Bit.Core.Auth.Models.Business.Tokenables;
using Bit.Core.Auth.Models.Data;
using Bit.Core.Auth.Services;
using Bit.Core.Auth.UserFeatures.EmergencyAccess;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
@@ -17,7 +16,7 @@ using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Auth.Services;
namespace Bit.Core.Test.Auth.UserFeatures.EmergencyAccess;
[SutProviderCustomize]
public class EmergencyAccessServiceTests
@@ -68,13 +67,13 @@ public class EmergencyAccessServiceTests
Assert.Equal(EmergencyAccessStatusType.Invited, result.Status);
await sutProvider.GetDependency<IEmergencyAccessRepository>()
.Received(1)
.CreateAsync(Arg.Any<EmergencyAccess>());
.CreateAsync(Arg.Any<Core.Auth.Entities.EmergencyAccess>());
sutProvider.GetDependency<IDataProtectorTokenFactory<EmergencyAccessInviteTokenable>>()
.Received(1)
.Protect(Arg.Any<EmergencyAccessInviteTokenable>());
await sutProvider.GetDependency<IMailService>()
.Received(1)
.SendEmergencyAccessInviteEmailAsync(Arg.Any<EmergencyAccess>(), Arg.Any<string>(), Arg.Any<string>());
.SendEmergencyAccessInviteEmailAsync(Arg.Any<Core.Auth.Entities.EmergencyAccess>(), Arg.Any<string>(), Arg.Any<string>());
}
[Theory, BitAutoData]
@@ -98,7 +97,7 @@ public class EmergencyAccessServiceTests
User invitingUser,
Guid emergencyAccessId)
{
EmergencyAccess emergencyAccess = null;
Core.Auth.Entities.EmergencyAccess emergencyAccess = null;
sutProvider.GetDependency<IEmergencyAccessRepository>()
.GetByIdAsync(Arg.Any<Guid>())
@@ -119,7 +118,7 @@ public class EmergencyAccessServiceTests
User invitingUser,
Guid emergencyAccessId)
{
var emergencyAccess = new EmergencyAccess
var emergencyAccess = new Core.Auth.Entities.EmergencyAccess
{
Status = EmergencyAccessStatusType.Invited,
GrantorId = Guid.NewGuid(),
@@ -148,7 +147,7 @@ public class EmergencyAccessServiceTests
User invitingUser,
Guid emergencyAccessId)
{
var emergencyAccess = new EmergencyAccess
var emergencyAccess = new Core.Auth.Entities.EmergencyAccess
{
Status = statusType,
GrantorId = invitingUser.Id,
@@ -172,7 +171,7 @@ public class EmergencyAccessServiceTests
User invitingUser,
Guid emergencyAccessId)
{
var emergencyAccess = new EmergencyAccess
var emergencyAccess = new Core.Auth.Entities.EmergencyAccess
{
Status = EmergencyAccessStatusType.Invited,
GrantorId = invitingUser.Id,
@@ -194,7 +193,7 @@ public class EmergencyAccessServiceTests
public async Task AcceptUserAsync_EmergencyAccessNull_ThrowsBadRequest(
SutProvider<EmergencyAccessService> sutProvider, User acceptingUser, string token)
{
EmergencyAccess emergencyAccess = null;
Core.Auth.Entities.EmergencyAccess emergencyAccess = null;
sutProvider.GetDependency<IEmergencyAccessRepository>()
.GetByIdAsync(Arg.Any<Guid>())
.Returns(emergencyAccess);
@@ -209,7 +208,7 @@ public class EmergencyAccessServiceTests
public async Task AcceptUserAsync_CannotUnprotectToken_ThrowsBadRequest(
SutProvider<EmergencyAccessService> sutProvider,
User acceptingUser,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
string token)
{
sutProvider.GetDependency<IEmergencyAccessRepository>()
@@ -230,8 +229,8 @@ public class EmergencyAccessServiceTests
public async Task AcceptUserAsync_TokenDataInvalid_ThrowsBadRequest(
SutProvider<EmergencyAccessService> sutProvider,
User acceptingUser,
EmergencyAccess emergencyAccess,
EmergencyAccess wrongEmergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess wrongEmergencyAccess,
string token)
{
sutProvider.GetDependency<IEmergencyAccessRepository>()
@@ -257,7 +256,7 @@ public class EmergencyAccessServiceTests
public async Task AcceptUserAsync_AcceptedStatus_ThrowsBadRequest(
SutProvider<EmergencyAccessService> sutProvider,
User acceptingUser,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
string token)
{
emergencyAccess.Status = EmergencyAccessStatusType.Accepted;
@@ -284,7 +283,7 @@ public class EmergencyAccessServiceTests
public async Task AcceptUserAsync_NotInvitedStatus_ThrowsBadRequest(
SutProvider<EmergencyAccessService> sutProvider,
User acceptingUser,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
string token)
{
emergencyAccess.Status = EmergencyAccessStatusType.Confirmed;
@@ -311,7 +310,7 @@ public class EmergencyAccessServiceTests
public async Task AcceptUserAsync_EmergencyAccessEmailDoesNotMatch_ThrowsBadRequest(
SutProvider<EmergencyAccessService> sutProvider,
User acceptingUser,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
string token)
{
emergencyAccess.Status = EmergencyAccessStatusType.Invited;
@@ -339,7 +338,7 @@ public class EmergencyAccessServiceTests
SutProvider<EmergencyAccessService> sutProvider,
User acceptingUser,
User invitingUser,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
string token)
{
emergencyAccess.Status = EmergencyAccessStatusType.Invited;
@@ -364,7 +363,7 @@ public class EmergencyAccessServiceTests
await sutProvider.GetDependency<IEmergencyAccessRepository>()
.Received(1)
.ReplaceAsync(Arg.Is<EmergencyAccess>(x => x.Status == EmergencyAccessStatusType.Accepted));
.ReplaceAsync(Arg.Is<Core.Auth.Entities.EmergencyAccess>(x => x.Status == EmergencyAccessStatusType.Accepted));
await sutProvider.GetDependency<IMailService>()
.Received(1)
@@ -375,11 +374,11 @@ public class EmergencyAccessServiceTests
public async Task DeleteAsync_EmergencyAccessNull_ThrowsBadRequest(
SutProvider<EmergencyAccessService> sutProvider,
User invitingUser,
EmergencyAccess emergencyAccess)
Core.Auth.Entities.EmergencyAccess emergencyAccess)
{
sutProvider.GetDependency<IEmergencyAccessRepository>()
.GetByIdAsync(Arg.Any<Guid>())
.Returns((EmergencyAccess)null);
.Returns((Core.Auth.Entities.EmergencyAccess)null);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.DeleteAsync(emergencyAccess.Id, invitingUser.Id));
@@ -391,7 +390,7 @@ public class EmergencyAccessServiceTests
public async Task DeleteAsync_EmergencyAccessGrantorIdNotEqual_ThrowsBadRequest(
SutProvider<EmergencyAccessService> sutProvider,
User invitingUser,
EmergencyAccess emergencyAccess)
Core.Auth.Entities.EmergencyAccess emergencyAccess)
{
emergencyAccess.GrantorId = Guid.NewGuid();
sutProvider.GetDependency<IEmergencyAccessRepository>()
@@ -408,7 +407,7 @@ public class EmergencyAccessServiceTests
public async Task DeleteAsync_EmergencyAccessGranteeIdNotEqual_ThrowsBadRequest(
SutProvider<EmergencyAccessService> sutProvider,
User invitingUser,
EmergencyAccess emergencyAccess)
Core.Auth.Entities.EmergencyAccess emergencyAccess)
{
emergencyAccess.GranteeId = Guid.NewGuid();
sutProvider.GetDependency<IEmergencyAccessRepository>()
@@ -425,7 +424,7 @@ public class EmergencyAccessServiceTests
public async Task DeleteAsync_EmergencyAccessIsDeleted_Success(
SutProvider<EmergencyAccessService> sutProvider,
User user,
EmergencyAccess emergencyAccess)
Core.Auth.Entities.EmergencyAccess emergencyAccess)
{
emergencyAccess.GranteeId = user.Id;
emergencyAccess.GrantorId = user.Id;
@@ -443,7 +442,7 @@ public class EmergencyAccessServiceTests
[Theory, BitAutoData]
public async Task ConfirmUserAsync_EmergencyAccessNull_ThrowsBadRequest(
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
string key,
User grantorUser)
{
@@ -451,7 +450,7 @@ public class EmergencyAccessServiceTests
emergencyAccess.Status = EmergencyAccessStatusType.RecoveryInitiated;
sutProvider.GetDependency<IEmergencyAccessRepository>()
.GetByIdAsync(Arg.Any<Guid>())
.Returns((EmergencyAccess)null);
.Returns((Core.Auth.Entities.EmergencyAccess)null);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.ConfirmUserAsync(emergencyAccess.Id, key, grantorUser.Id));
@@ -463,7 +462,7 @@ public class EmergencyAccessServiceTests
[Theory, BitAutoData]
public async Task ConfirmUserAsync_EmergencyAccessStatusIsNotAccepted_ThrowsBadRequest(
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
string key,
User grantorUser)
{
@@ -484,7 +483,7 @@ public class EmergencyAccessServiceTests
[Theory, BitAutoData]
public async Task ConfirmUserAsync_EmergencyAccessGrantorIdNotEqualToConfirmingUserId_ThrowsBadRequest(
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
string key,
User grantorUser)
{
@@ -505,7 +504,7 @@ public class EmergencyAccessServiceTests
SutProvider<EmergencyAccessService> sutProvider, User confirmingUser, string key)
{
confirmingUser.UsesKeyConnector = true;
var emergencyAccess = new EmergencyAccess
var emergencyAccess = new Core.Auth.Entities.EmergencyAccess
{
Status = EmergencyAccessStatusType.Accepted,
GrantorId = confirmingUser.Id,
@@ -530,7 +529,7 @@ public class EmergencyAccessServiceTests
[Theory, BitAutoData]
public async Task ConfirmUserAsync_ConfirmsAndReplacesEmergencyAccess_Success(
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
string key,
User grantorUser,
User granteeUser)
@@ -553,7 +552,7 @@ public class EmergencyAccessServiceTests
await sutProvider.GetDependency<IEmergencyAccessRepository>()
.Received(1)
.ReplaceAsync(Arg.Is<EmergencyAccess>(x => x.Status == EmergencyAccessStatusType.Confirmed));
.ReplaceAsync(Arg.Is<Core.Auth.Entities.EmergencyAccess>(x => x.Status == EmergencyAccessStatusType.Confirmed));
await sutProvider.GetDependency<IMailService>()
.Received(1)
@@ -564,7 +563,7 @@ public class EmergencyAccessServiceTests
public async Task SaveAsync_PremiumCannotUpdate_ThrowsBadRequest(
SutProvider<EmergencyAccessService> sutProvider, User savingUser)
{
var emergencyAccess = new EmergencyAccess
var emergencyAccess = new Core.Auth.Entities.EmergencyAccess
{
Type = EmergencyAccessType.Takeover,
GrantorId = savingUser.Id,
@@ -586,7 +585,7 @@ public class EmergencyAccessServiceTests
SutProvider<EmergencyAccessService> sutProvider, User savingUser)
{
savingUser.Premium = true;
var emergencyAccess = new EmergencyAccess
var emergencyAccess = new Core.Auth.Entities.EmergencyAccess
{
Type = EmergencyAccessType.Takeover,
GrantorId = new Guid(),
@@ -611,7 +610,7 @@ public class EmergencyAccessServiceTests
SutProvider<EmergencyAccessService> sutProvider, User grantorUser)
{
grantorUser.UsesKeyConnector = true;
var emergencyAccess = new EmergencyAccess
var emergencyAccess = new Core.Auth.Entities.EmergencyAccess
{
Type = EmergencyAccessType.Takeover,
GrantorId = grantorUser.Id,
@@ -633,7 +632,7 @@ public class EmergencyAccessServiceTests
SutProvider<EmergencyAccessService> sutProvider, User grantorUser)
{
grantorUser.UsesKeyConnector = true;
var emergencyAccess = new EmergencyAccess
var emergencyAccess = new Core.Auth.Entities.EmergencyAccess
{
Type = EmergencyAccessType.View,
GrantorId = grantorUser.Id,
@@ -655,7 +654,7 @@ public class EmergencyAccessServiceTests
SutProvider<EmergencyAccessService> sutProvider, User grantorUser)
{
grantorUser.UsesKeyConnector = false;
var emergencyAccess = new EmergencyAccess
var emergencyAccess = new Core.Auth.Entities.EmergencyAccess
{
Type = EmergencyAccessType.Takeover,
GrantorId = grantorUser.Id,
@@ -678,7 +677,7 @@ public class EmergencyAccessServiceTests
{
sutProvider.GetDependency<IEmergencyAccessRepository>()
.GetByIdAsync(Arg.Any<Guid>())
.Returns((EmergencyAccess)null);
.Returns((Core.Auth.Entities.EmergencyAccess)null);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.InitiateAsync(new Guid(), initiatingUser));
@@ -692,7 +691,7 @@ public class EmergencyAccessServiceTests
[Theory, BitAutoData]
public async Task InitiateAsync_EmergencyAccessGranteeIdNotEqual_ThrowBadRequest(
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User initiatingUser)
{
emergencyAccess.GranteeId = new Guid();
@@ -712,7 +711,7 @@ public class EmergencyAccessServiceTests
[Theory, BitAutoData]
public async Task InitiateAsync_EmergencyAccessStatusIsNotConfirmed_ThrowBadRequest(
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User initiatingUser)
{
emergencyAccess.GranteeId = initiatingUser.Id;
@@ -735,7 +734,7 @@ public class EmergencyAccessServiceTests
SutProvider<EmergencyAccessService> sutProvider, User initiatingUser, User grantor)
{
grantor.UsesKeyConnector = true;
var emergencyAccess = new EmergencyAccess
var emergencyAccess = new Core.Auth.Entities.EmergencyAccess
{
Status = EmergencyAccessStatusType.Confirmed,
GranteeId = initiatingUser.Id,
@@ -764,7 +763,7 @@ public class EmergencyAccessServiceTests
SutProvider<EmergencyAccessService> sutProvider, User initiatingUser, User grantor)
{
grantor.UsesKeyConnector = true;
var emergencyAccess = new EmergencyAccess
var emergencyAccess = new Core.Auth.Entities.EmergencyAccess
{
Status = EmergencyAccessStatusType.Confirmed,
GranteeId = initiatingUser.Id,
@@ -783,14 +782,14 @@ public class EmergencyAccessServiceTests
await sutProvider.GetDependency<IEmergencyAccessRepository>()
.Received(1)
.ReplaceAsync(Arg.Is<EmergencyAccess>(x => x.Status == EmergencyAccessStatusType.RecoveryInitiated));
.ReplaceAsync(Arg.Is<Core.Auth.Entities.EmergencyAccess>(x => x.Status == EmergencyAccessStatusType.RecoveryInitiated));
}
[Theory, BitAutoData]
public async Task InitiateAsync_RequestIsCorrect_Success(
SutProvider<EmergencyAccessService> sutProvider, User initiatingUser, User grantor)
{
var emergencyAccess = new EmergencyAccess
var emergencyAccess = new Core.Auth.Entities.EmergencyAccess
{
Status = EmergencyAccessStatusType.Confirmed,
GranteeId = initiatingUser.Id,
@@ -809,7 +808,7 @@ public class EmergencyAccessServiceTests
await sutProvider.GetDependency<IEmergencyAccessRepository>()
.Received(1)
.ReplaceAsync(Arg.Is<EmergencyAccess>(x => x.Status == EmergencyAccessStatusType.RecoveryInitiated));
.ReplaceAsync(Arg.Is<Core.Auth.Entities.EmergencyAccess>(x => x.Status == EmergencyAccessStatusType.RecoveryInitiated));
}
[Theory, BitAutoData]
@@ -818,7 +817,7 @@ public class EmergencyAccessServiceTests
{
sutProvider.GetDependency<IEmergencyAccessRepository>()
.GetByIdAsync(Arg.Any<Guid>())
.Returns((EmergencyAccess)null);
.Returns((Core.Auth.Entities.EmergencyAccess)null);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.ApproveAsync(new Guid(), null));
@@ -829,7 +828,7 @@ public class EmergencyAccessServiceTests
[Theory, BitAutoData]
public async Task ApproveAsync_EmergencyAccessGrantorIdNotEquatToApproving_ThrowsBadRequest(
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User grantorUser)
{
emergencyAccess.Status = EmergencyAccessStatusType.RecoveryInitiated;
@@ -851,7 +850,7 @@ public class EmergencyAccessServiceTests
public async Task ApproveAsync_EmergencyAccessStatusNotRecoveryInitiated_ThrowsBadRequest(
EmergencyAccessStatusType statusType,
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User grantorUser)
{
emergencyAccess.GrantorId = grantorUser.Id;
@@ -869,7 +868,7 @@ public class EmergencyAccessServiceTests
[Theory, BitAutoData]
public async Task ApproveAsync_Success(
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User grantorUser,
User granteeUser)
{
@@ -885,20 +884,20 @@ public class EmergencyAccessServiceTests
await sutProvider.Sut.ApproveAsync(emergencyAccess.Id, grantorUser);
await sutProvider.GetDependency<IEmergencyAccessRepository>()
.Received(1)
.ReplaceAsync(Arg.Is<EmergencyAccess>(x => x.Status == EmergencyAccessStatusType.RecoveryApproved));
.ReplaceAsync(Arg.Is<Core.Auth.Entities.EmergencyAccess>(x => x.Status == EmergencyAccessStatusType.RecoveryApproved));
}
[Theory, BitAutoData]
public async Task RejectAsync_EmergencyAccessIdNull_ThrowsBadRequest(
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User GrantorUser)
{
emergencyAccess.GrantorId = GrantorUser.Id;
emergencyAccess.Status = EmergencyAccessStatusType.Accepted;
sutProvider.GetDependency<IEmergencyAccessRepository>()
.GetByIdAsync(Arg.Any<Guid>())
.Returns((EmergencyAccess)null);
.Returns((Core.Auth.Entities.EmergencyAccess)null);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.RejectAsync(emergencyAccess.Id, GrantorUser));
@@ -909,7 +908,7 @@ public class EmergencyAccessServiceTests
[Theory, BitAutoData]
public async Task RejectAsync_EmergencyAccessGrantorIdNotEqualToRequestUser_ThrowsBadRequest(
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User GrantorUser)
{
emergencyAccess.Status = EmergencyAccessStatusType.Accepted;
@@ -930,7 +929,7 @@ public class EmergencyAccessServiceTests
public async Task RejectAsync_EmergencyAccessStatusNotValid_ThrowsBadRequest(
EmergencyAccessStatusType statusType,
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User GrantorUser)
{
emergencyAccess.GrantorId = GrantorUser.Id;
@@ -951,7 +950,7 @@ public class EmergencyAccessServiceTests
public async Task RejectAsync_Success(
EmergencyAccessStatusType statusType,
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User GrantorUser,
User GranteeUser)
{
@@ -968,7 +967,7 @@ public class EmergencyAccessServiceTests
await sutProvider.GetDependency<IEmergencyAccessRepository>()
.Received(1)
.ReplaceAsync(Arg.Is<EmergencyAccess>(x => x.Status == EmergencyAccessStatusType.Confirmed));
.ReplaceAsync(Arg.Is<Core.Auth.Entities.EmergencyAccess>(x => x.Status == EmergencyAccessStatusType.Confirmed));
}
[Theory, BitAutoData]
@@ -977,7 +976,7 @@ public class EmergencyAccessServiceTests
{
sutProvider.GetDependency<IEmergencyAccessRepository>()
.GetByIdAsync(Arg.Any<Guid>())
.Returns((EmergencyAccess)null);
.Returns((Core.Auth.Entities.EmergencyAccess)null);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.GetPoliciesAsync(default, default));
@@ -992,7 +991,7 @@ public class EmergencyAccessServiceTests
public async Task GetPoliciesAsync_RequestNotValidStatusType_ThrowsBadRequest(
EmergencyAccessStatusType statusType,
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User granteeUser)
{
emergencyAccess.GranteeId = granteeUser.Id;
@@ -1010,7 +1009,7 @@ public class EmergencyAccessServiceTests
[Theory, BitAutoData]
public async Task GetPoliciesAsync_RequestNotValidType_ThrowsBadRequest(
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User granteeUser)
{
emergencyAccess.GranteeId = granteeUser.Id;
@@ -1032,7 +1031,7 @@ public class EmergencyAccessServiceTests
public async Task GetPoliciesAsync_OrganizationUserTypeNotOwner_ReturnsNull(
OrganizationUserType userType,
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User granteeUser,
User grantorUser,
OrganizationUser grantorOrganizationUser)
@@ -1062,7 +1061,7 @@ public class EmergencyAccessServiceTests
[Theory, BitAutoData]
public async Task GetPoliciesAsync_OrganizationUserEmpty_ReturnsNull(
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User granteeUser,
User grantorUser)
{
@@ -1090,7 +1089,7 @@ public class EmergencyAccessServiceTests
[Theory, BitAutoData]
public async Task GetPoliciesAsync_ReturnsNotNull(
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User granteeUser,
User grantorUser,
OrganizationUser grantorOrganizationUser)
@@ -1127,7 +1126,7 @@ public class EmergencyAccessServiceTests
{
sutProvider.GetDependency<IEmergencyAccessRepository>()
.GetByIdAsync(Arg.Any<Guid>())
.Returns((EmergencyAccess)null);
.Returns((Core.Auth.Entities.EmergencyAccess)null);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.TakeoverAsync(default, default));
@@ -1138,7 +1137,7 @@ public class EmergencyAccessServiceTests
[Theory, BitAutoData]
public async Task TakeoverAsync_RequestNotValid_GranteeNotEqualToRequestingUser_ThrowsBadRequest(
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User granteeUser)
{
emergencyAccess.Status = EmergencyAccessStatusType.RecoveryApproved;
@@ -1161,7 +1160,7 @@ public class EmergencyAccessServiceTests
public async Task TakeoverAsync_RequestNotValid_StatusType_ThrowsBadRequest(
EmergencyAccessStatusType statusType,
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User granteeUser)
{
emergencyAccess.GranteeId = granteeUser.Id;
@@ -1180,7 +1179,7 @@ public class EmergencyAccessServiceTests
[Theory, BitAutoData]
public async Task TakeoverAsync_RequestNotValid_TypeIsView_ThrowsBadRequest(
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User granteeUser)
{
emergencyAccess.GranteeId = granteeUser.Id;
@@ -1203,7 +1202,7 @@ public class EmergencyAccessServiceTests
User grantor)
{
grantor.UsesKeyConnector = true;
var emergencyAccess = new EmergencyAccess
var emergencyAccess = new Core.Auth.Entities.EmergencyAccess
{
GrantorId = grantor.Id,
GranteeId = granteeUser.Id,
@@ -1232,7 +1231,7 @@ public class EmergencyAccessServiceTests
User grantor)
{
grantor.UsesKeyConnector = false;
var emergencyAccess = new EmergencyAccess
var emergencyAccess = new Core.Auth.Entities.EmergencyAccess
{
GrantorId = grantor.Id,
GranteeId = granteeUser.Id,
@@ -1260,7 +1259,7 @@ public class EmergencyAccessServiceTests
{
sutProvider.GetDependency<IEmergencyAccessRepository>()
.GetByIdAsync(Arg.Any<Guid>())
.Returns((EmergencyAccess)null);
.Returns((Core.Auth.Entities.EmergencyAccess)null);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.PasswordAsync(default, default, default, default));
@@ -1271,7 +1270,7 @@ public class EmergencyAccessServiceTests
[Theory, BitAutoData]
public async Task PasswordAsync_RequestNotValid_GranteeNotEqualToRequestingUser_ThrowsBadRequest(
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User granteeUser)
{
emergencyAccess.Status = EmergencyAccessStatusType.RecoveryApproved;
@@ -1294,7 +1293,7 @@ public class EmergencyAccessServiceTests
public async Task PasswordAsync_RequestNotValid_StatusType_ThrowsBadRequest(
EmergencyAccessStatusType statusType,
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User granteeUser)
{
emergencyAccess.GranteeId = granteeUser.Id;
@@ -1313,7 +1312,7 @@ public class EmergencyAccessServiceTests
[Theory, BitAutoData]
public async Task PasswordAsync_RequestNotValid_TypeIsView_ThrowsBadRequest(
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User granteeUser)
{
emergencyAccess.GranteeId = granteeUser.Id;
@@ -1332,7 +1331,7 @@ public class EmergencyAccessServiceTests
[Theory, BitAutoData]
public async Task PasswordAsync_NonOrgUser_Success(
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User granteeUser,
User grantorUser,
string key,
@@ -1367,7 +1366,7 @@ public class EmergencyAccessServiceTests
public async Task PasswordAsync_OrgUser_NotOrganizationOwner_RemovedFromOrganization_Success(
OrganizationUserType userType,
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User granteeUser,
User grantorUser,
OrganizationUser organizationUser,
@@ -1408,7 +1407,7 @@ public class EmergencyAccessServiceTests
[Theory, BitAutoData]
public async Task PasswordAsync_OrgUser_IsOrganizationOwner_NotRemovedFromOrganization_Success(
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User granteeUser,
User grantorUser,
OrganizationUser organizationUser,
@@ -1459,7 +1458,7 @@ public class EmergencyAccessServiceTests
Enabled = true
}
});
var emergencyAccess = new EmergencyAccess
var emergencyAccess = new Core.Auth.Entities.EmergencyAccess
{
GrantorId = grantor.Id,
GranteeId = requestingUser.Id,
@@ -1484,7 +1483,7 @@ public class EmergencyAccessServiceTests
[Theory, BitAutoData]
public async Task ViewAsync_EmergencyAccessTypeNotView_ThrowsBadRequest(
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User granteeUser)
{
emergencyAccess.GranteeId = granteeUser.Id;
@@ -1500,7 +1499,7 @@ public class EmergencyAccessServiceTests
[Theory, BitAutoData]
public async Task GetAttachmentDownloadAsync_EmergencyAccessTypeNotView_ThrowsBadRequest(
SutProvider<EmergencyAccessService> sutProvider,
EmergencyAccess emergencyAccess,
Core.Auth.Entities.EmergencyAccess emergencyAccess,
User granteeUser)
{
emergencyAccess.GranteeId = granteeUser.Id;

View File

@@ -1,8 +1,8 @@
using System.Text;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.Entities;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models;
using Bit.Core.Auth.Models.Business.Tokenables;
@@ -14,6 +14,7 @@ using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterpri
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Test.AdminConsole.AutoFixture;
using Bit.Core.Tokens;
using Bit.Core.Utilities;
using Bit.Test.Common.AutoFixture;
@@ -23,6 +24,7 @@ using Microsoft.AspNetCore.Identity;
using Microsoft.AspNetCore.WebUtilities;
using NSubstitute;
using Xunit;
using EmergencyAccessEntity = Bit.Core.Auth.Entities.EmergencyAccess;
namespace Bit.Core.Test.Auth.UserFeatures.Registration;
@@ -241,7 +243,8 @@ public class RegisterUserCommandTests
[BitAutoData(true, "sampleInitiationPath")]
[BitAutoData(true, "Secrets Manager trial")]
public async Task RegisterUserViaOrganizationInviteToken_ComplexHappyPath_Succeeds(bool addUserReferenceData, string initiationPath,
SutProvider<RegisterUserCommand> sutProvider, User user, string masterPasswordHash, OrganizationUser orgUser, string orgInviteToken, Guid orgUserId, Policy twoFactorPolicy)
SutProvider<RegisterUserCommand> sutProvider, User user, string masterPasswordHash, OrganizationUser orgUser, string orgInviteToken, Guid orgUserId,
[Policy(PolicyType.TwoFactorAuthentication, true)] PolicyStatus policy)
{
// Arrange
sutProvider.GetDependency<IGlobalSettings>()
@@ -267,10 +270,9 @@ public class RegisterUserCommandTests
.GetByIdAsync(orgUserId)
.Returns(orgUser);
twoFactorPolicy.Enabled = true;
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(orgUser.OrganizationId, PolicyType.TwoFactorAuthentication)
.Returns(twoFactorPolicy);
sutProvider.GetDependency<IPolicyQuery>()
.RunAsync(orgUser.OrganizationId, PolicyType.TwoFactorAuthentication)
.Returns(policy);
sutProvider.GetDependency<IUserService>()
.CreateUserAsync(user, masterPasswordHash)
@@ -286,9 +288,9 @@ public class RegisterUserCommandTests
.Received(1)
.GetByIdAsync(orgUserId);
await sutProvider.GetDependency<IPolicyRepository>()
await sutProvider.GetDependency<IPolicyQuery>()
.Received(1)
.GetByOrganizationIdTypeAsync(orgUser.OrganizationId, PolicyType.TwoFactorAuthentication);
.RunAsync(orgUser.OrganizationId, PolicyType.TwoFactorAuthentication);
sutProvider.GetDependency<IUserService>()
.Received(1)
@@ -431,7 +433,8 @@ public class RegisterUserCommandTests
[Theory]
[BitAutoData]
public async Task RegisterUserViaOrganizationInviteToken_BlockedDomainFromDifferentOrg_ThrowsBadRequestException(
SutProvider<RegisterUserCommand> sutProvider, User user, string masterPasswordHash, OrganizationUser orgUser, string orgInviteToken, Guid orgUserId)
SutProvider<RegisterUserCommand> sutProvider, User user, string masterPasswordHash, OrganizationUser orgUser, string orgInviteToken, Guid orgUserId,
[Policy(PolicyType.TwoFactorAuthentication, false)] PolicyStatus policy)
{
// Arrange
user.Email = "user@blocked-domain.com";
@@ -463,6 +466,10 @@ public class RegisterUserCommandTests
.HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("blocked-domain.com", orgUser.OrganizationId)
.Returns(true);
sutProvider.GetDependency<IPolicyQuery>()
.RunAsync(Arg.Any<Guid>(), PolicyType.TwoFactorAuthentication)
.Returns(policy);
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.RegisterUserViaOrganizationInviteToken(user, masterPasswordHash, orgInviteToken, orgUserId));
@@ -472,7 +479,8 @@ public class RegisterUserCommandTests
[Theory]
[BitAutoData]
public async Task RegisterUserViaOrganizationInviteToken_BlockedDomainFromSameOrg_Succeeds(
SutProvider<RegisterUserCommand> sutProvider, User user, string masterPasswordHash, OrganizationUser orgUser, string orgInviteToken, Guid orgUserId)
SutProvider<RegisterUserCommand> sutProvider, User user, string masterPasswordHash, OrganizationUser orgUser, string orgInviteToken, Guid orgUserId,
[Policy(PolicyType.TwoFactorAuthentication, false)] PolicyStatus policy)
{
// Arrange
user.Email = "user@company-domain.com";
@@ -509,6 +517,10 @@ public class RegisterUserCommandTests
.CreateUserAsync(user, masterPasswordHash)
.Returns(IdentityResult.Success);
sutProvider.GetDependency<IPolicyQuery>()
.RunAsync(Arg.Any<Guid>(), PolicyType.TwoFactorAuthentication)
.Returns(policy);
// Act
var result = await sutProvider.Sut.RegisterUserViaOrganizationInviteToken(user, masterPasswordHash, orgInviteToken, orgUserId);
@@ -726,7 +738,7 @@ public class RegisterUserCommandTests
[BitAutoData]
public async Task RegisterUserViaAcceptEmergencyAccessInviteToken_Succeeds(
SutProvider<RegisterUserCommand> sutProvider, User user, string masterPasswordHash,
EmergencyAccess emergencyAccess, string acceptEmergencyAccessInviteToken, Guid acceptEmergencyAccessId)
EmergencyAccessEntity emergencyAccess, string acceptEmergencyAccessInviteToken, Guid acceptEmergencyAccessId)
{
// Arrange
user.Email = $"test+{Guid.NewGuid()}@example.com";
@@ -767,7 +779,7 @@ public class RegisterUserCommandTests
[Theory]
[BitAutoData]
public async Task RegisterUserViaAcceptEmergencyAccessInviteToken_InvalidToken_ThrowsBadRequestException(SutProvider<RegisterUserCommand> sutProvider, User user,
string masterPasswordHash, EmergencyAccess emergencyAccess, string acceptEmergencyAccessInviteToken, Guid acceptEmergencyAccessId)
string masterPasswordHash, EmergencyAccessEntity emergencyAccess, string acceptEmergencyAccessInviteToken, Guid acceptEmergencyAccessId)
{
// Arrange
user.Email = $"test+{Guid.NewGuid()}@example.com";
@@ -1112,7 +1124,7 @@ public class RegisterUserCommandTests
[BitAutoData]
public async Task RegisterUserViaAcceptEmergencyAccessInviteToken_BlockedDomain_ThrowsBadRequestException(
SutProvider<RegisterUserCommand> sutProvider, User user, string masterPasswordHash,
EmergencyAccess emergencyAccess, string acceptEmergencyAccessInviteToken, Guid acceptEmergencyAccessId)
EmergencyAccessEntity emergencyAccess, string acceptEmergencyAccessInviteToken, Guid acceptEmergencyAccessId)
{
// Arrange
user.Email = "user@blocked-domain.com";
@@ -1245,6 +1257,7 @@ public class RegisterUserCommandTests
OrganizationUser orgUser,
string orgInviteToken,
string masterPasswordHash,
[Policy(PolicyType.TwoFactorAuthentication, false)] PolicyStatus policy,
SutProvider<RegisterUserCommand> sutProvider)
{
// Arrange
@@ -1259,9 +1272,9 @@ public class RegisterUserCommandTests
.GetByIdAsync(orgUser.Id)
.Returns(orgUser);
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(Arg.Any<Guid>(), PolicyType.TwoFactorAuthentication)
.Returns((Policy)null);
sutProvider.GetDependency<IPolicyQuery>()
.RunAsync(Arg.Any<Guid>(), PolicyType.TwoFactorAuthentication)
.Returns(policy);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(orgUser.OrganizationId)
@@ -1331,6 +1344,7 @@ public class RegisterUserCommandTests
OrganizationUser orgUser,
string masterPasswordHash,
string orgInviteToken,
[Policy(PolicyType.TwoFactorAuthentication, false)] PolicyStatus policy,
SutProvider<RegisterUserCommand> sutProvider)
{
// Arrange
@@ -1346,9 +1360,9 @@ public class RegisterUserCommandTests
.GetByIdAsync(orgUser.Id)
.Returns(orgUser);
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(Arg.Any<Guid>(), PolicyType.TwoFactorAuthentication)
.Returns((Policy)null);
sutProvider.GetDependency<IPolicyQuery>()
.RunAsync(Arg.Any<Guid>(), PolicyType.TwoFactorAuthentication)
.Returns(policy);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(orgUser.OrganizationId)

View File

@@ -4,6 +4,7 @@ using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Payment.Commands;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Services;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Test.Billing.Extensions;
using Braintree;
@@ -22,6 +23,7 @@ using static StripeConstants;
public class UpdatePaymentMethodCommandTests
{
private readonly IBraintreeGateway _braintreeGateway = Substitute.For<IBraintreeGateway>();
private readonly IBraintreeService _braintreeService = Substitute.For<IBraintreeService>();
private readonly IGlobalSettings _globalSettings = Substitute.For<IGlobalSettings>();
private readonly ISetupIntentCache _setupIntentCache = Substitute.For<ISetupIntentCache>();
private readonly IStripeAdapter _stripeAdapter = Substitute.For<IStripeAdapter>();
@@ -32,6 +34,7 @@ public class UpdatePaymentMethodCommandTests
{
_command = new UpdatePaymentMethodCommand(
_braintreeGateway,
_braintreeService,
_globalSettings,
Substitute.For<ILogger<UpdatePaymentMethodCommand>>(),
_setupIntentCache,
@@ -375,7 +378,6 @@ public class UpdatePaymentMethodCommandTests
_subscriberService.GetCustomer(organization).Returns(customer);
var customerGateway = Substitute.For<ICustomerGateway>();
var braintreeCustomer = Substitute.For<Braintree.Customer>();
braintreeCustomer.Id.Returns("braintree_customer_id");
var existing = Substitute.For<PayPalAccount>();
@@ -383,7 +385,10 @@ public class UpdatePaymentMethodCommandTests
existing.IsDefault.Returns(true);
existing.Token.Returns("EXISTING");
braintreeCustomer.PaymentMethods.Returns([existing]);
customerGateway.FindAsync("braintree_customer_id").Returns(braintreeCustomer);
_braintreeService.GetCustomer(customer).Returns(braintreeCustomer);
var customerGateway = Substitute.For<ICustomerGateway>();
_braintreeGateway.Customer.Returns(customerGateway);
var paymentMethodGateway = Substitute.For<IPaymentMethodGateway>();
@@ -471,4 +476,75 @@ public class UpdatePaymentMethodCommandTests
Arg.Is<CustomerUpdateOptions>(options =>
options.Metadata[MetadataKeys.BraintreeCustomerId] == "braintree_customer_id"));
}
[Fact]
public async Task Run_PayPal_MissingBraintreeCustomer_CreatesNewBraintreeCustomer_ReturnsMaskedPayPalAccount()
{
var organization = new Organization
{
Id = Guid.NewGuid(),
GatewayCustomerId = "cus_123"
};
var customer = new Customer
{
Address = new Address
{
Country = "US",
PostalCode = "12345"
},
Id = "cus_123",
Metadata = new Dictionary<string, string>
{
[MetadataKeys.BraintreeCustomerId] = "missing_braintree_customer_id"
}
};
_subscriberService.GetCustomer(organization).Returns(customer);
// BraintreeService.GetCustomer returns null when the Braintree customer doesn't exist
_braintreeService.GetCustomer(customer).Returns((Braintree.Customer?)null);
_globalSettings.BaseServiceUri.Returns(new GlobalSettings.BaseServiceUriSettings(new GlobalSettings())
{
CloudRegion = "US"
});
var customerGateway = Substitute.For<ICustomerGateway>();
var braintreeCustomer = Substitute.For<Braintree.Customer>();
braintreeCustomer.Id.Returns("new_braintree_customer_id");
var payPalAccount = Substitute.For<PayPalAccount>();
payPalAccount.Email.Returns("user@gmail.com");
payPalAccount.IsDefault.Returns(true);
payPalAccount.Token.Returns("NONCE");
braintreeCustomer.PaymentMethods.Returns([payPalAccount]);
var createResult = Substitute.For<Result<Braintree.Customer>>();
createResult.Target.Returns(braintreeCustomer);
customerGateway.CreateAsync(Arg.Is<CustomerRequest>(options =>
options.Id.StartsWith(organization.BraintreeCustomerIdPrefix() + organization.Id.ToString("N").ToLower()) &&
options.CustomFields[organization.BraintreeIdField()] == organization.Id.ToString() &&
options.CustomFields[organization.BraintreeCloudRegionField()] == "US" &&
options.Email == organization.BillingEmailAddress() &&
options.PaymentMethodNonce == "TOKEN")).Returns(createResult);
_braintreeGateway.Customer.Returns(customerGateway);
var result = await _command.Run(organization,
new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.PayPal, Token = "TOKEN" },
new BillingAddress { Country = "US", PostalCode = "12345" });
Assert.True(result.IsT0);
var maskedPaymentMethod = result.AsT0;
Assert.True(maskedPaymentMethod.IsT2);
var maskedPayPalAccount = maskedPaymentMethod.AsT2;
Assert.Equal("user@gmail.com", maskedPayPalAccount.Email);
// Verify a new Braintree customer was created (not FindAsync called)
await customerGateway.DidNotReceive().FindAsync(Arg.Any<string>());
await customerGateway.Received(1).CreateAsync(Arg.Any<CustomerRequest>());
// Verify Stripe metadata was updated with the new Braintree customer ID
await _stripeAdapter.Received(1).UpdateCustomerAsync(customer.Id,
Arg.Is<CustomerUpdateOptions>(options =>
options.Metadata[MetadataKeys.BraintreeCustomerId] == "new_braintree_customer_id"));
}
}

View File

@@ -3,9 +3,9 @@ using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Payment.Queries;
using Bit.Core.Billing.Services;
using Bit.Core.Services;
using Bit.Core.Test.Billing.Extensions;
using Braintree;
using Microsoft.Extensions.Logging;
using NSubstitute;
using NSubstitute.ReturnsExtensions;
using Stripe;
@@ -19,7 +19,7 @@ using static StripeConstants;
public class GetPaymentMethodQueryTests
{
private readonly IBraintreeGateway _braintreeGateway = Substitute.For<IBraintreeGateway>();
private readonly IBraintreeService _braintreeService = Substitute.For<IBraintreeService>();
private readonly ISetupIntentCache _setupIntentCache = Substitute.For<ISetupIntentCache>();
private readonly IStripeAdapter _stripeAdapter = Substitute.For<IStripeAdapter>();
private readonly ISubscriberService _subscriberService = Substitute.For<ISubscriberService>();
@@ -28,8 +28,7 @@ public class GetPaymentMethodQueryTests
public GetPaymentMethodQueryTests()
{
_query = new GetPaymentMethodQuery(
_braintreeGateway,
Substitute.For<ILogger<GetPaymentMethodQuery>>(),
_braintreeService,
_setupIntentCache,
_stripeAdapter,
_subscriberService);
@@ -75,6 +74,34 @@ public class GetPaymentMethodQueryTests
Assert.Null(maskedPaymentMethod);
}
[Fact]
public async Task Run_NoPaymentMethod_BraintreeCustomerNotFound_ReturnsNull()
{
var organization = new Organization
{
Id = Guid.NewGuid()
};
var customer = new Customer
{
InvoiceSettings = new CustomerInvoiceSettings(),
Metadata = new Dictionary<string, string>
{
[MetadataKeys.BraintreeCustomerId] = "non_existent_braintree_customer_id"
}
};
_subscriberService.GetCustomer(organization,
Arg.Is<CustomerGetOptions>(options =>
options.HasExpansions("default_source", "invoice_settings.default_payment_method"))).Returns(customer);
_braintreeService.GetCustomer(customer).ReturnsNull();
var maskedPaymentMethod = await _query.Run(organization);
Assert.Null(maskedPaymentMethod);
}
[Fact]
public async Task Run_BankAccount_FromPaymentMethod_ReturnsMaskedBankAccount()
{
@@ -328,14 +355,12 @@ public class GetPaymentMethodQueryTests
Arg.Is<CustomerGetOptions>(options =>
options.HasExpansions("default_source", "invoice_settings.default_payment_method"))).Returns(customer);
var customerGateway = Substitute.For<ICustomerGateway>();
var braintreeCustomer = Substitute.For<Braintree.Customer>();
var payPalAccount = Substitute.For<PayPalAccount>();
payPalAccount.Email.Returns("user@gmail.com");
payPalAccount.IsDefault.Returns(true);
braintreeCustomer.PaymentMethods.Returns([payPalAccount]);
customerGateway.FindAsync("braintree_customer_id").Returns(braintreeCustomer);
_braintreeGateway.Customer.Returns(customerGateway);
_braintreeService.GetCustomer(customer).Returns(braintreeCustomer);
var maskedPaymentMethod = await _query.Run(organization);

View File

@@ -0,0 +1,777 @@
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Premium.Commands;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
using Bit.Core.Test.Billing.Mocks.Plans;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Stripe;
using Xunit;
using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan;
namespace Bit.Core.Test.Billing.Premium.Commands;
public class PreviewPremiumUpgradeProrationCommandTests
{
private readonly ILogger<PreviewPremiumUpgradeProrationCommand> _logger = Substitute.For<ILogger<PreviewPremiumUpgradeProrationCommand>>();
private readonly IPricingClient _pricingClient = Substitute.For<IPricingClient>();
private readonly IStripeAdapter _stripeAdapter = Substitute.For<IStripeAdapter>();
private readonly PreviewPremiumUpgradeProrationCommand _command;
public PreviewPremiumUpgradeProrationCommandTests()
{
_command = new PreviewPremiumUpgradeProrationCommand(
_logger,
_pricingClient,
_stripeAdapter);
}
[Theory, BitAutoData]
public async Task Run_UserWithoutPremium_ReturnsBadRequest(User user, BillingAddress billingAddress)
{
// Arrange
user.Premium = false;
// Act
var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert
Assert.True(result.IsT1);
var badRequest = result.AsT1;
Assert.Equal("User does not have an active Premium subscription.", badRequest.Response);
}
[Theory, BitAutoData]
public async Task Run_UserWithoutGatewaySubscriptionId_ReturnsBadRequest(User user, BillingAddress billingAddress)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = null;
// Act
var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert
Assert.True(result.IsT1);
var badRequest = result.AsT1;
Assert.Equal("User does not have an active Premium subscription.", badRequest.Response);
}
[Theory, BitAutoData]
public async Task Run_ValidUpgrade_ReturnsProrationAmounts(User user, BillingAddress billingAddress)
{
// Arrange - Setup valid Premium user
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
// Setup Premium plans
var premiumPlan = new PremiumPlan
{
Name = "Premium",
Available = true,
LegacyYear = null,
Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "premium-annually",
Price = 10m,
Provided = 1
},
Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "storage-gb-annually",
Price = 4m,
Provided = 1
}
};
var premiumPlans = new List<PremiumPlan> { premiumPlan };
// Setup current Stripe subscription
var now = new DateTime(2026, 1, 1, 0, 0, 0, DateTimeKind.Utc);
var currentPeriodEnd = now.AddMonths(6);
var currentSubscription = new Subscription
{
Id = "sub_123",
Customer = new Customer
{
Id = "cus_123",
Discount = null
},
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new()
{
Id = "si_premium",
Price = new Price { Id = "premium-annually" },
CurrentPeriodEnd = currentPeriodEnd
}
}
}
};
// Setup target organization plan
var targetPlan = new TeamsPlan(isAnnual: true);
// Setup invoice preview response
var invoice = new Invoice
{
Total = 5000, // $50.00
TotalTaxes = new List<InvoiceTotalTax>
{
new() { Amount = 500 } // $5.00
},
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem>
{
new() { Amount = 5000 } // $50.00 for new plan
}
},
PeriodEnd = now
};
// Configure mocks
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(targetPlan);
_stripeAdapter.GetSubscriptionAsync(
"sub_123",
Arg.Any<SubscriptionGetOptions>())
.Returns(currentSubscription);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(invoice);
// Act
var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert
Assert.True(result.IsT0);
var proration = result.AsT0;
Assert.Equal(50.00m, proration.NewPlanProratedAmount);
Assert.Equal(0m, proration.Credit);
Assert.Equal(5.00m, proration.Tax);
Assert.Equal(50.00m, proration.Total);
Assert.Equal(6, proration.NewPlanProratedMonths); // 6 months remaining
}
[Theory, BitAutoData]
public async Task Run_ValidUpgrade_ExtractsProrationCredit(User user, BillingAddress billingAddress)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
var premiumPlan = new PremiumPlan
{
Name = "Premium",
Available = true,
LegacyYear = null,
Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "premium-annually",
Price = 10m,
Provided = 1
},
Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "storage-gb-annually",
Price = 4m,
Provided = 1
}
};
var premiumPlans = new List<PremiumPlan> { premiumPlan };
// Use fixed time to avoid DateTime.UtcNow differences
var now = new DateTime(2026, 1, 1, 0, 0, 0, DateTimeKind.Utc);
var currentPeriodEnd = now.AddDays(45); // 1.5 months ~ 2 months rounded
var currentSubscription = new Subscription
{
Id = "sub_123",
Customer = new Customer { Id = "cus_123", Discount = null },
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new() { Id = "si_premium", Price = new Price { Id = "premium-annually" }, CurrentPeriodEnd = currentPeriodEnd }
}
}
};
var targetPlan = new TeamsPlan(isAnnual: true);
// Invoice with negative line item (proration credit)
var invoice = new Invoice
{
Total = 4000, // $40.00
TotalTaxes = new List<InvoiceTotalTax> { new() { Amount = 400 } }, // $4.00
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem>
{
new() { Amount = -1000 }, // -$10.00 credit from unused Premium
new() { Amount = 5000 } // $50.00 for new plan
}
},
PeriodEnd = now
};
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(targetPlan);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>())
.Returns(currentSubscription);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(invoice);
// Act
var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert
Assert.True(result.IsT0);
var proration = result.AsT0;
Assert.Equal(50.00m, proration.NewPlanProratedAmount);
Assert.Equal(10.00m, proration.Credit); // Proration credit
Assert.Equal(4.00m, proration.Tax);
Assert.Equal(40.00m, proration.Total);
Assert.Equal(2, proration.NewPlanProratedMonths); // 45 days rounds to 2 months
}
[Theory, BitAutoData]
public async Task Run_ValidUpgrade_AlwaysUsesOneSeat(User user, BillingAddress billingAddress)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
var premiumPlan = new PremiumPlan
{
Name = "Premium",
Available = true,
LegacyYear = null,
Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "premium-annually",
Price = 10m,
Provided = 1
},
Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "storage-gb-annually",
Price = 4m,
Provided = 1
}
};
var premiumPlans = new List<PremiumPlan> { premiumPlan };
var currentSubscription = new Subscription
{
Id = "sub_123",
Customer = new Customer { Id = "cus_123", Discount = null },
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new() { Id = "si_premium", Price = new Price { Id = "premium-annually" }, CurrentPeriodEnd = DateTime.UtcNow.AddMonths(1) }
}
}
};
var targetPlan = new TeamsPlan(isAnnual: true);
var invoice = new Invoice
{
Total = 5000,
TotalTaxes = new List<InvoiceTotalTax> { new() { Amount = 500 } },
Lines = new StripeList<InvoiceLineItem> { Data = new List<InvoiceLineItem> { new() { Amount = 5000 } } },
PeriodEnd = DateTime.UtcNow
};
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(targetPlan);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>())
.Returns(currentSubscription);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(invoice);
// Act
await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert - Verify that the subscription item quantity is always 1 and has Id
await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(
Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.SubscriptionDetails.Items.Any(item =>
item.Id == "si_premium" &&
item.Price == targetPlan.PasswordManager.StripeSeatPlanId &&
item.Quantity == 1)));
}
[Theory, BitAutoData]
public async Task Run_ValidUpgrade_DeletesPremiumSubscriptionItems(User user, BillingAddress billingAddress)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
var premiumPlan = new PremiumPlan
{
Name = "Premium",
Available = true,
LegacyYear = null,
Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "premium-annually",
Price = 10m,
Provided = 1
},
Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "storage-gb-annually",
Price = 4m,
Provided = 1
}
};
var premiumPlans = new List<PremiumPlan> { premiumPlan };
var currentSubscription = new Subscription
{
Id = "sub_123",
Customer = new Customer { Id = "cus_123", Discount = null },
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new() { Id = "si_password_manager", Price = new Price { Id = "premium-annually" }, CurrentPeriodEnd = DateTime.UtcNow.AddMonths(1) },
new() { Id = "si_storage", Price = new Price { Id = "storage-gb-annually" }, CurrentPeriodEnd = DateTime.UtcNow.AddMonths(1) }
}
}
};
var targetPlan = new TeamsPlan(isAnnual: true);
var invoice = new Invoice
{
Total = 5000,
TotalTaxes = new List<InvoiceTotalTax> { new() { Amount = 500 } },
Lines = new StripeList<InvoiceLineItem> { Data = new List<InvoiceLineItem> { new() { Amount = 5000 } } },
PeriodEnd = DateTime.UtcNow
};
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(targetPlan);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>())
.Returns(currentSubscription);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(invoice);
// Act
await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert - Verify password manager item is modified and storage item is deleted
await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(
Arg.Is<InvoiceCreatePreviewOptions>(options =>
// Password manager item should be modified to new plan price, not deleted
options.SubscriptionDetails.Items.Any(item =>
item.Id == "si_password_manager" &&
item.Price == targetPlan.PasswordManager.StripeSeatPlanId &&
item.Deleted != true) &&
// Storage item should be deleted
options.SubscriptionDetails.Items.Any(item =>
item.Id == "si_storage" && item.Deleted == true)));
}
[Theory, BitAutoData]
public async Task Run_NonSeatBasedPlan_UsesStripePlanId(User user, BillingAddress billingAddress)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
var premiumPlan = new PremiumPlan
{
Name = "Premium",
Available = true,
LegacyYear = null,
Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "premium-annually",
Price = 10m,
Provided = 1
},
Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "storage-gb-annually",
Price = 4m,
Provided = 1
}
};
var premiumPlans = new List<PremiumPlan> { premiumPlan };
var currentSubscription = new Subscription
{
Id = "sub_123",
Customer = new Customer { Id = "cus_123", Discount = null },
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new() { Id = "si_premium", Price = new Price { Id = "premium-annually" }, CurrentPeriodEnd = DateTime.UtcNow.AddMonths(1) }
}
}
};
var targetPlan = new FamiliesPlan(); // families is non seat based
var invoice = new Invoice
{
Total = 5000,
TotalTaxes = new List<InvoiceTotalTax> { new() { Amount = 500 } },
Lines = new StripeList<InvoiceLineItem> { Data = new List<InvoiceLineItem> { new() { Amount = 5000 } } },
PeriodEnd = DateTime.UtcNow
};
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually).Returns(targetPlan);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>())
.Returns(currentSubscription);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(invoice);
// Act
await _command.Run(user, PlanType.FamiliesAnnually, billingAddress);
// Assert - Verify non-seat-based plan uses StripePlanId with quantity 1 and modifies existing item
await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(
Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.SubscriptionDetails.Items.Any(item =>
item.Id == "si_premium" &&
item.Price == targetPlan.PasswordManager.StripePlanId &&
item.Quantity == 1)));
}
[Theory, BitAutoData]
public async Task Run_ValidUpgrade_CreatesCorrectInvoicePreviewOptions(User user, BillingAddress billingAddress)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
billingAddress.Country = "US";
billingAddress.PostalCode = "12345";
var premiumPlan = new PremiumPlan
{
Name = "Premium",
Available = true,
LegacyYear = null,
Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "premium-annually",
Price = 10m,
Provided = 1
},
Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "storage-gb-annually",
Price = 4m,
Provided = 1
}
};
var premiumPlans = new List<PremiumPlan> { premiumPlan };
var currentSubscription = new Subscription
{
Id = "sub_123",
Customer = new Customer { Id = "cus_123", Discount = null },
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new() { Id = "si_premium", Price = new Price { Id = "premium-annually" }, CurrentPeriodEnd = DateTime.UtcNow.AddMonths(1) }
}
}
};
var targetPlan = new TeamsPlan(isAnnual: true);
var invoice = new Invoice
{
Total = 5000,
TotalTaxes = new List<InvoiceTotalTax> { new() { Amount = 500 } },
Lines = new StripeList<InvoiceLineItem> { Data = new List<InvoiceLineItem> { new() { Amount = 5000 } } },
PeriodEnd = DateTime.UtcNow
};
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(targetPlan);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>())
.Returns(currentSubscription);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(invoice);
// Act
await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert - Verify all invoice preview options are correct
await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(
Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.AutomaticTax.Enabled == true &&
options.Customer == "cus_123" &&
options.Subscription == "sub_123" &&
options.CustomerDetails.Address.Country == "US" &&
options.CustomerDetails.Address.PostalCode == "12345" &&
options.SubscriptionDetails.ProrationBehavior == "always_invoice"));
}
[Theory, BitAutoData]
public async Task Run_SeatBasedPlan_UsesStripeSeatPlanId(User user, BillingAddress billingAddress)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
var premiumPlan = new PremiumPlan
{
Name = "Premium",
Available = true,
LegacyYear = null,
Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "premium-annually",
Price = 10m,
Provided = 1
},
Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "storage-gb-annually",
Price = 4m,
Provided = 1
}
};
var premiumPlans = new List<PremiumPlan> { premiumPlan };
var currentSubscription = new Subscription
{
Id = "sub_123",
Customer = new Customer { Id = "cus_123", Discount = null },
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new() { Id = "si_premium", Price = new Price { Id = "premium-annually" }, CurrentPeriodEnd = DateTime.UtcNow.AddMonths(1) }
}
}
};
// Use Teams which is seat-based
var targetPlan = new TeamsPlan(isAnnual: true);
var invoice = new Invoice
{
Total = 5000,
TotalTaxes = new List<InvoiceTotalTax> { new() { Amount = 500 } },
Lines = new StripeList<InvoiceLineItem> { Data = new List<InvoiceLineItem> { new() { Amount = 5000 } } },
PeriodEnd = DateTime.UtcNow
};
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(targetPlan);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>())
.Returns(currentSubscription);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(invoice);
// Act
await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert - Verify seat-based plan uses StripeSeatPlanId with quantity 1 and modifies existing item
await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(
Arg.Is<InvoiceCreatePreviewOptions>(options =>
options.SubscriptionDetails.Items.Any(item =>
item.Id == "si_premium" &&
item.Price == targetPlan.PasswordManager.StripeSeatPlanId &&
item.Quantity == 1)));
}
[Theory]
[InlineData(0, 1)] // Less than 15 days, minimum 1 month
[InlineData(1, 1)] // 1 day = 1 month minimum
[InlineData(14, 1)] // 14 days = 1 month minimum
[InlineData(15, 1)] // 15 days rounds to 1 month
[InlineData(30, 1)] // 30 days = 1 month
[InlineData(44, 1)] // 44 days rounds to 1 month
[InlineData(45, 2)] // 45 days rounds to 2 months
[InlineData(60, 2)] // 60 days = 2 months
[InlineData(90, 3)] // 90 days = 3 months
[InlineData(180, 6)] // 180 days = 6 months
[InlineData(365, 12)] // 365 days rounds to 12 months
public async Task Run_ValidUpgrade_CalculatesNewPlanProratedMonthsCorrectly(int daysRemaining, int expectedMonths)
{
// Arrange
var user = new User
{
Premium = true,
GatewaySubscriptionId = "sub_123",
GatewayCustomerId = "cus_123"
};
var billingAddress = new Core.Billing.Payment.Models.BillingAddress
{
Country = "US",
PostalCode = "12345"
};
var premiumPlan = new PremiumPlan
{
Name = "Premium",
Available = true,
LegacyYear = null,
Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "premium-annually",
Price = 10m,
Provided = 1
},
Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "storage-gb-annually",
Price = 4m,
Provided = 1
}
};
var premiumPlans = new List<PremiumPlan> { premiumPlan };
// Use fixed time to avoid DateTime.UtcNow differences
var now = new DateTime(2026, 1, 1, 0, 0, 0, DateTimeKind.Utc);
var currentPeriodEnd = now.AddDays(daysRemaining);
var currentSubscription = new Subscription
{
Id = "sub_123",
Customer = new Customer { Id = "cus_123", Discount = null },
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new() { Id = "si_premium", Price = new Price { Id = "premium-annually" }, CurrentPeriodEnd = currentPeriodEnd }
}
}
};
var targetPlan = new TeamsPlan(isAnnual: true);
var invoice = new Invoice
{
Total = 5000,
TotalTaxes = new List<InvoiceTotalTax> { new() { Amount = 500 } },
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Amount = 5000 } }
},
PeriodEnd = now
};
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(targetPlan);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>())
.Returns(currentSubscription);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(invoice);
// Act
var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert
Assert.True(result.IsT0);
var proration = result.AsT0;
Assert.Equal(expectedMonths, proration.NewPlanProratedMonths);
}
[Theory, BitAutoData]
public async Task Run_ValidUpgrade_ReturnsNewPlanProratedAmountCorrectly(User user, BillingAddress billingAddress)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
var premiumPlan = new PremiumPlan
{
Name = "Premium",
Available = true,
LegacyYear = null,
Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "premium-annually",
Price = 10m,
Provided = 1
},
Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "storage-gb-annually",
Price = 4m,
Provided = 1
}
};
var premiumPlans = new List<PremiumPlan> { premiumPlan };
var now = new DateTime(2026, 1, 1, 0, 0, 0, DateTimeKind.Utc);
var currentPeriodEnd = now.AddMonths(3);
var currentSubscription = new Subscription
{
Id = "sub_123",
Customer = new Customer { Id = "cus_123", Discount = null },
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new() { Id = "si_premium", Price = new Price { Id = "premium-annually" }, CurrentPeriodEnd = currentPeriodEnd }
}
}
};
var targetPlan = new TeamsPlan(isAnnual: true);
// Invoice showing new plan cost, credit, and net
var invoice = new Invoice
{
Total = 4500, // $45.00 net after $5 credit
TotalTaxes = new List<InvoiceTotalTax> { new() { Amount = 450 } }, // $4.50
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem>
{
new() { Amount = -500 }, // -$5.00 credit
new() { Amount = 5000 } // $50.00 for new plan
}
},
PeriodEnd = now
};
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(targetPlan);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>())
.Returns(currentSubscription);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(invoice);
// Act
var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress);
// Assert
Assert.True(result.IsT0);
var proration = result.AsT0;
Assert.Equal(50.00m, proration.NewPlanProratedAmount);
Assert.Equal(5.00m, proration.Credit);
Assert.Equal(4.50m, proration.Tax);
Assert.Equal(45.00m, proration.Total);
}
}

View File

@@ -1,6 +1,7 @@
using Bit.Core.Billing.Premium.Commands;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Subscriptions.Models;
using Bit.Core.Entities;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture.Attributes;
@@ -8,6 +9,7 @@ using Microsoft.Extensions.Logging;
using NSubstitute;
using Stripe;
using Xunit;
using static Bit.Core.Billing.Constants.StripeConstants;
using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan;
using PremiumPurchasable = Bit.Core.Billing.Pricing.Premium.Purchasable;
@@ -15,6 +17,7 @@ namespace Bit.Core.Test.Billing.Premium.Commands;
public class UpdatePremiumStorageCommandTests
{
private readonly IBraintreeService _braintreeService = Substitute.For<IBraintreeService>();
private readonly IStripeAdapter _stripeAdapter = Substitute.For<IStripeAdapter>();
private readonly IUserService _userService = Substitute.For<IUserService>();
private readonly IPricingClient _pricingClient = Substitute.For<IPricingClient>();
@@ -33,13 +36,14 @@ public class UpdatePremiumStorageCommandTests
_pricingClient.ListPremiumPlans().Returns([premiumPlan]);
_command = new UpdatePremiumStorageCommand(
_braintreeService,
_stripeAdapter,
_userService,
_pricingClient,
Substitute.For<ILogger<UpdatePremiumStorageCommand>>());
}
private Subscription CreateMockSubscription(string subscriptionId, int? storageQuantity = null)
private Subscription CreateMockSubscription(string subscriptionId, int? storageQuantity = null, bool isPayPal = false)
{
var items = new List<SubscriptionItem>
{
@@ -63,9 +67,17 @@ public class UpdatePremiumStorageCommandTests
});
}
var customer = new Customer
{
Id = "cus_123",
Metadata = isPayPal ? new Dictionary<string, string> { { MetadataKeys.BraintreeCustomerId, "braintree_123" } } : new Dictionary<string, string>()
};
return new Subscription
{
Id = subscriptionId,
CustomerId = "cus_123",
Customer = customer,
Items = new StripeList<SubscriptionItem>
{
Data = items
@@ -97,7 +109,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 4);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, -5);
@@ -117,7 +129,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 4);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 100);
@@ -154,7 +166,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 9);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 0);
@@ -176,7 +188,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 4);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 4);
@@ -185,7 +197,7 @@ public class UpdatePremiumStorageCommandTests
Assert.True(result.IsT0);
// Verify subscription was fetched but NOT updated
await _stripeAdapter.Received(1).GetSubscriptionAsync("sub_123");
await _stripeAdapter.Received(1).GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>());
await _stripeAdapter.DidNotReceive().UpdateSubscriptionAsync(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>());
await _userService.DidNotReceive().SaveUserAsync(Arg.Any<User>());
}
@@ -200,7 +212,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 4);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 9);
@@ -233,7 +245,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123");
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 9);
@@ -262,7 +274,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 9);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 2);
@@ -291,7 +303,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 9);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 0);
@@ -320,7 +332,7 @@ public class UpdatePremiumStorageCommandTests
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 4);
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
// Act
var result = await _command.Run(user, 99);
@@ -335,4 +347,200 @@ public class UpdatePremiumStorageCommandTests
await _userService.Received(1).SaveUserAsync(Arg.Is<User>(u => u.MaxStorageGb == 100));
}
[Theory, BitAutoData]
public async Task Run_IncreaseStorage_PayPal_Success(User user)
{
// Arrange
user.Premium = true;
user.MaxStorageGb = 5;
user.Storage = 2L * 1024 * 1024 * 1024;
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 4, isPayPal: true);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
var draftInvoice = new Invoice { Id = "in_draft" };
_stripeAdapter.CreateInvoiceAsync(Arg.Any<InvoiceCreateOptions>()).Returns(draftInvoice);
var finalizedInvoice = new Invoice
{
Id = "in_finalized",
Customer = new Customer { Id = "cus_123" }
};
_stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any<InvoiceFinalizeOptions>()).Returns(finalizedInvoice);
// Act
var result = await _command.Run(user, 9);
// Assert
Assert.True(result.IsT0);
// Verify subscription was updated with CreateProrations
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Items.Count == 1 &&
opts.Items[0].Id == "si_storage" &&
opts.Items[0].Quantity == 9 &&
opts.ProrationBehavior == "create_prorations"));
// Verify draft invoice was created
await _stripeAdapter.Received(1).CreateInvoiceAsync(
Arg.Is<InvoiceCreateOptions>(opts =>
opts.Customer == "cus_123" &&
opts.Subscription == "sub_123" &&
opts.AutoAdvance == false &&
opts.CollectionMethod == "charge_automatically"));
// Verify invoice was finalized
await _stripeAdapter.Received(1).FinalizeInvoiceAsync(
"in_draft",
Arg.Is<InvoiceFinalizeOptions>(opts =>
opts.AutoAdvance == false &&
opts.Expand.Contains("customer")));
// Verify Braintree payment was processed
await _braintreeService.Received(1).PayInvoice(Arg.Any<SubscriberId>(), finalizedInvoice);
// Verify user was saved
await _userService.Received(1).SaveUserAsync(Arg.Is<User>(u =>
u.Id == user.Id &&
u.MaxStorageGb == 10));
}
[Theory, BitAutoData]
public async Task Run_AddStorageFromZero_PayPal_Success(User user)
{
// Arrange
user.Premium = true;
user.MaxStorageGb = 1;
user.Storage = 500L * 1024 * 1024;
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", isPayPal: true);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
var draftInvoice = new Invoice { Id = "in_draft" };
_stripeAdapter.CreateInvoiceAsync(Arg.Any<InvoiceCreateOptions>()).Returns(draftInvoice);
var finalizedInvoice = new Invoice
{
Id = "in_finalized",
Customer = new Customer { Id = "cus_123" }
};
_stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any<InvoiceFinalizeOptions>()).Returns(finalizedInvoice);
// Act
var result = await _command.Run(user, 9);
// Assert
Assert.True(result.IsT0);
// Verify subscription was updated with new storage item
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Items.Count == 1 &&
opts.Items[0].Price == "price_storage" &&
opts.Items[0].Quantity == 9 &&
opts.ProrationBehavior == "create_prorations"));
// Verify invoice creation and payment flow
await _stripeAdapter.Received(1).CreateInvoiceAsync(Arg.Any<InvoiceCreateOptions>());
await _stripeAdapter.Received(1).FinalizeInvoiceAsync("in_draft", Arg.Any<InvoiceFinalizeOptions>());
await _braintreeService.Received(1).PayInvoice(Arg.Any<SubscriberId>(), finalizedInvoice);
await _userService.Received(1).SaveUserAsync(Arg.Is<User>(u => u.MaxStorageGb == 10));
}
[Theory, BitAutoData]
public async Task Run_DecreaseStorage_PayPal_Success(User user)
{
// Arrange
user.Premium = true;
user.MaxStorageGb = 10;
user.Storage = 2L * 1024 * 1024 * 1024;
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 9, isPayPal: true);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
var draftInvoice = new Invoice { Id = "in_draft" };
_stripeAdapter.CreateInvoiceAsync(Arg.Any<InvoiceCreateOptions>()).Returns(draftInvoice);
var finalizedInvoice = new Invoice
{
Id = "in_finalized",
Customer = new Customer { Id = "cus_123" }
};
_stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any<InvoiceFinalizeOptions>()).Returns(finalizedInvoice);
// Act
var result = await _command.Run(user, 2);
// Assert
Assert.True(result.IsT0);
// Verify subscription was updated
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Items.Count == 1 &&
opts.Items[0].Id == "si_storage" &&
opts.Items[0].Quantity == 2 &&
opts.ProrationBehavior == "create_prorations"));
// Verify invoice creation and payment flow
await _stripeAdapter.Received(1).CreateInvoiceAsync(Arg.Any<InvoiceCreateOptions>());
await _stripeAdapter.Received(1).FinalizeInvoiceAsync("in_draft", Arg.Any<InvoiceFinalizeOptions>());
await _braintreeService.Received(1).PayInvoice(Arg.Any<SubscriberId>(), finalizedInvoice);
await _userService.Received(1).SaveUserAsync(Arg.Is<User>(u => u.MaxStorageGb == 3));
}
[Theory, BitAutoData]
public async Task Run_RemoveAllAdditionalStorage_PayPal_Success(User user)
{
// Arrange
user.Premium = true;
user.MaxStorageGb = 10;
user.Storage = 500L * 1024 * 1024;
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", 9, isPayPal: true);
_stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any<SubscriptionGetOptions>()).Returns(subscription);
var draftInvoice = new Invoice { Id = "in_draft" };
_stripeAdapter.CreateInvoiceAsync(Arg.Any<InvoiceCreateOptions>()).Returns(draftInvoice);
var finalizedInvoice = new Invoice
{
Id = "in_finalized",
Customer = new Customer { Id = "cus_123" }
};
_stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any<InvoiceFinalizeOptions>()).Returns(finalizedInvoice);
// Act
var result = await _command.Run(user, 0);
// Assert
Assert.True(result.IsT0);
// Verify subscription item was deleted
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Items.Count == 1 &&
opts.Items[0].Id == "si_storage" &&
opts.Items[0].Deleted == true &&
opts.ProrationBehavior == "create_prorations"));
// Verify invoice creation and payment flow
await _stripeAdapter.Received(1).CreateInvoiceAsync(Arg.Any<InvoiceCreateOptions>());
await _stripeAdapter.Received(1).FinalizeInvoiceAsync("in_draft", Arg.Any<InvoiceFinalizeOptions>());
await _braintreeService.Received(1).PayInvoice(Arg.Any<SubscriberId>(), finalizedInvoice);
await _userService.Received(1).SaveUserAsync(Arg.Is<User>(u => u.MaxStorageGb == 1));
}
}

View File

@@ -37,7 +37,6 @@ public class UpgradePremiumToOrganizationCommandTests
NameLocalizationKey = "";
DescriptionLocalizationKey = "";
CanBeUsedByBusiness = true;
TrialPeriodDays = null;
HasSelfHost = false;
HasPolicies = false;
HasGroups = false;
@@ -86,10 +85,8 @@ public class UpgradePremiumToOrganizationCommandTests
string? stripePlanId = null,
string? stripeSeatPlanId = null,
string? stripePremiumAccessPlanId = null,
string? stripeStoragePlanId = null)
{
return new TestPlan(planType, stripePlanId, stripeSeatPlanId, stripePremiumAccessPlanId, stripeStoragePlanId);
}
string? stripeStoragePlanId = null) =>
new TestPlan(planType, stripePlanId, stripeSeatPlanId, stripePremiumAccessPlanId, stripeStoragePlanId);
private static PremiumPlan CreateTestPremiumPlan(
string seatPriceId = "premium-annually",
@@ -151,6 +148,9 @@ public class UpgradePremiumToOrganizationCommandTests
_applicationCacheService);
}
private static Core.Billing.Payment.Models.BillingAddress CreateTestBillingAddress() =>
new() { Country = "US", PostalCode = "12345" };
[Theory, BitAutoData]
public async Task Run_UserNotPremium_ReturnsBadRequest(User user)
{
@@ -158,7 +158,7 @@ public class UpgradePremiumToOrganizationCommandTests
user.Premium = false;
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT1);
@@ -174,7 +174,7 @@ public class UpgradePremiumToOrganizationCommandTests
user.GatewaySubscriptionId = null;
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT1);
@@ -190,7 +190,7 @@ public class UpgradePremiumToOrganizationCommandTests
user.GatewaySubscriptionId = "";
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT1);
@@ -245,7 +245,7 @@ public class UpgradePremiumToOrganizationCommandTests
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT0);
@@ -253,9 +253,8 @@ public class UpgradePremiumToOrganizationCommandTests
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Items.Count == 2 && // 1 deleted + 1 seat (no storage)
opts.Items.Any(i => i.Deleted == true) &&
opts.Items.Any(i => i.Price == "teams-seat-annually" && i.Quantity == 1)));
opts.Items.Count == 1 && // Only 1 item: modify existing password manager item (no storage to delete)
opts.Items.Any(i => i.Id == "si_premium" && i.Price == "teams-seat-annually" && i.Quantity == 1 && i.Deleted != true)));
await _organizationRepository.Received(1).CreateAsync(Arg.Is<Organization>(o =>
o.Name == "My Organization" &&
@@ -320,7 +319,7 @@ public class UpgradePremiumToOrganizationCommandTests
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
var result = await _command.Run(user, "My Families Org", "encrypted-key", PlanType.FamiliesAnnually);
var result = await _command.Run(user, "My Families Org", "encrypted-key", PlanType.FamiliesAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT0);
@@ -328,9 +327,8 @@ public class UpgradePremiumToOrganizationCommandTests
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Items.Count == 2 && // 1 deleted + 1 plan
opts.Items.Any(i => i.Deleted == true) &&
opts.Items.Any(i => i.Price == "families-plan-annually" && i.Quantity == 1)));
opts.Items.Count == 1 && // Only 1 item: modify existing password manager item (no storage to delete)
opts.Items.Any(i => i.Id == "si_premium" && i.Price == "families-plan-annually" && i.Quantity == 1 && i.Deleted != true)));
await _organizationRepository.Received(1).CreateAsync(Arg.Is<Organization>(o =>
o.Name == "My Families Org"));
@@ -383,7 +381,7 @@ public class UpgradePremiumToOrganizationCommandTests
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT0);
@@ -392,11 +390,6 @@ public class UpgradePremiumToOrganizationCommandTests
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Metadata.ContainsKey(StripeConstants.MetadataKeys.OrganizationId) &&
opts.Metadata.ContainsKey(StripeConstants.MetadataKeys.PreviousPremiumPriceId) &&
opts.Metadata[StripeConstants.MetadataKeys.PreviousPremiumPriceId] == "premium-annually" &&
opts.Metadata.ContainsKey(StripeConstants.MetadataKeys.PreviousPeriodEndDate) &&
opts.Metadata.ContainsKey(StripeConstants.MetadataKeys.PreviousAdditionalStorage) &&
opts.Metadata[StripeConstants.MetadataKeys.PreviousAdditionalStorage] == "0" &&
opts.Metadata.ContainsKey(StripeConstants.MetadataKeys.UserId) &&
opts.Metadata[StripeConstants.MetadataKeys.UserId] == string.Empty)); // Removes userId to unlink from User
}
@@ -453,19 +446,18 @@ public class UpgradePremiumToOrganizationCommandTests
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT0);
// Verify that BOTH legacy items (password manager + storage) are deleted by ID
// Verify that legacy password manager item is modified and legacy storage is deleted
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Items.Count == 3 && // 2 deleted (legacy PM + legacy storage) + 1 new seat
opts.Items.Count(i => i.Deleted == true && i.Id == "si_premium_legacy") == 1 && // Legacy PM deleted
opts.Items.Count(i => i.Deleted == true && i.Id == "si_storage_legacy") == 1 && // Legacy storage deleted
opts.Items.Any(i => i.Price == "teams-seat-annually" && i.Quantity == 1)));
opts.Items.Count == 2 && // 1 modified (legacy PM to new price) + 1 deleted (legacy storage)
opts.Items.Count(i => i.Id == "si_premium_legacy" && i.Price == "teams-seat-annually" && i.Quantity == 1 && i.Deleted != true) == 1 && // Legacy PM modified
opts.Items.Count(i => i.Deleted == true && i.Id == "si_storage_legacy") == 1)); // Legacy storage deleted
}
[Theory, BitAutoData]
@@ -520,20 +512,19 @@ public class UpgradePremiumToOrganizationCommandTests
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT0);
// Verify that ONLY the premium password manager item is deleted (not other products)
// Note: We delete the specific premium item by ID, so other products are untouched
// Verify that ONLY the premium password manager item is modified (not other products)
// Note: We modify the specific premium item by ID, so other products are untouched
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Items.Count == 2 && // 1 deleted (premium password manager) + 1 new seat
opts.Items.Count(i => i.Deleted == true && i.Id == "si_premium") == 1 && // Premium item deleted by ID
opts.Items.Count(i => i.Id == "si_other_product") == 0 && // Other product NOT in update (untouched)
opts.Items.Any(i => i.Price == "teams-seat-annually" && i.Quantity == 1)));
opts.Items.Count == 1 && // Only modify premium password manager item
opts.Items.Count(i => i.Id == "si_premium" && i.Price == "teams-seat-annually" && i.Quantity == 1 && i.Deleted != true) == 1 && // Premium item modified
opts.Items.Count(i => i.Id == "si_other_product") == 0)); // Other product NOT in update (untouched)
}
[Theory, BitAutoData]
@@ -589,7 +580,7 @@ public class UpgradePremiumToOrganizationCommandTests
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT0);
@@ -598,10 +589,8 @@ public class UpgradePremiumToOrganizationCommandTests
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Metadata.ContainsKey(StripeConstants.MetadataKeys.PreviousAdditionalStorage) &&
opts.Metadata[StripeConstants.MetadataKeys.PreviousAdditionalStorage] == "5" &&
opts.Items.Count == 3 && // 2 deleted (premium + storage) + 1 new seat
opts.Items.Count(i => i.Deleted == true) == 2));
opts.Items.Count == 2 && // 1 modified (premium to new price) + 1 deleted (storage)
opts.Items.Count(i => i.Deleted == true) == 1));
}
[Theory, BitAutoData]
@@ -636,11 +625,385 @@ public class UpgradePremiumToOrganizationCommandTests
_pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT1);
var badRequest = result.AsT1;
Assert.Equal("Premium subscription item not found.", badRequest.Response);
Assert.Equal("Premium subscription password manager item not found.", badRequest.Response);
}
[Theory, BitAutoData]
public async Task Run_UpdatesCustomerBillingAddress(User user)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
var mockSubscription = new Subscription
{
Id = "sub_123",
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new SubscriptionItem
{
Id = "si_premium",
Price = new Price { Id = "premium-annually" }
}
}
},
Metadata = new Dictionary<string, string>()
};
var mockPremiumPlans = CreateTestPremiumPlansList();
var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually");
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription);
_pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan);
_stripeAdapter.UpdateSubscriptionAsync(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>()).Returns(mockSubscription);
_stripeAdapter.UpdateCustomerAsync(Arg.Any<string>(), Arg.Any<CustomerUpdateOptions>()).Returns(Task.FromResult(new Customer()));
_organizationRepository.CreateAsync(Arg.Any<Organization>()).Returns(callInfo => Task.FromResult(callInfo.Arg<Organization>()));
_organizationApiKeyRepository.CreateAsync(Arg.Any<OrganizationApiKey>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationApiKey>()));
_organizationUserRepository.CreateAsync(Arg.Any<OrganizationUser>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationUser>()));
_applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any<Organization>()).Returns(Task.CompletedTask);
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
var billingAddress = new Core.Billing.Payment.Models.BillingAddress { Country = "US", PostalCode = "12345" };
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, billingAddress);
// Assert
Assert.True(result.IsT0);
await _stripeAdapter.Received(1).UpdateCustomerAsync(
"cus_123",
Arg.Is<CustomerUpdateOptions>(opts =>
opts.Address.Country == "US" &&
opts.Address.PostalCode == "12345"));
}
[Theory, BitAutoData]
public async Task Run_EnablesAutomaticTaxOnSubscription(User user)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
var mockSubscription = new Subscription
{
Id = "sub_123",
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new SubscriptionItem
{
Id = "si_premium",
Price = new Price { Id = "premium-annually" }
}
}
},
Metadata = new Dictionary<string, string>()
};
var mockPremiumPlans = CreateTestPremiumPlansList();
var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually");
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription);
_pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan);
_stripeAdapter.UpdateSubscriptionAsync(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>()).Returns(mockSubscription);
_stripeAdapter.UpdateCustomerAsync(Arg.Any<string>(), Arg.Any<CustomerUpdateOptions>()).Returns(Task.FromResult(new Customer()));
_organizationRepository.CreateAsync(Arg.Any<Organization>()).Returns(callInfo => Task.FromResult(callInfo.Arg<Organization>()));
_organizationApiKeyRepository.CreateAsync(Arg.Any<OrganizationApiKey>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationApiKey>()));
_organizationUserRepository.CreateAsync(Arg.Any<OrganizationUser>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationUser>()));
_applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any<Organization>()).Returns(Task.CompletedTask);
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT0);
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.AutomaticTax != null &&
opts.AutomaticTax.Enabled == true));
}
[Theory, BitAutoData]
public async Task Run_UsesAlwaysInvoiceProrationBehavior(User user)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
var mockSubscription = new Subscription
{
Id = "sub_123",
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new SubscriptionItem
{
Id = "si_premium",
Price = new Price { Id = "premium-annually" }
}
}
},
Metadata = new Dictionary<string, string>()
};
var mockPremiumPlans = CreateTestPremiumPlansList();
var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually");
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription);
_pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan);
_stripeAdapter.UpdateSubscriptionAsync(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>()).Returns(mockSubscription);
_stripeAdapter.UpdateCustomerAsync(Arg.Any<string>(), Arg.Any<CustomerUpdateOptions>()).Returns(Task.FromResult(new Customer()));
_organizationRepository.CreateAsync(Arg.Any<Organization>()).Returns(callInfo => Task.FromResult(callInfo.Arg<Organization>()));
_organizationApiKeyRepository.CreateAsync(Arg.Any<OrganizationApiKey>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationApiKey>()));
_organizationUserRepository.CreateAsync(Arg.Any<OrganizationUser>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationUser>()));
_applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any<Organization>()).Returns(Task.CompletedTask);
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT0);
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.ProrationBehavior == "always_invoice"));
}
[Theory, BitAutoData]
public async Task Run_ModifiesExistingSubscriptionItem_NotDeleteAndRecreate(User user)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
var mockSubscription = new Subscription
{
Id = "sub_123",
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new SubscriptionItem
{
Id = "si_premium",
Price = new Price { Id = "premium-annually" }
}
}
},
Metadata = new Dictionary<string, string>()
};
var mockPremiumPlans = CreateTestPremiumPlansList();
var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually");
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription);
_pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan);
_stripeAdapter.UpdateSubscriptionAsync(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>()).Returns(mockSubscription);
_stripeAdapter.UpdateCustomerAsync(Arg.Any<string>(), Arg.Any<CustomerUpdateOptions>()).Returns(Task.FromResult(new Customer()));
_organizationRepository.CreateAsync(Arg.Any<Organization>()).Returns(callInfo => Task.FromResult(callInfo.Arg<Organization>()));
_organizationApiKeyRepository.CreateAsync(Arg.Any<OrganizationApiKey>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationApiKey>()));
_organizationUserRepository.CreateAsync(Arg.Any<OrganizationUser>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationUser>()));
_applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any<Organization>()).Returns(Task.CompletedTask);
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT0);
// Verify that the subscription item was modified, not deleted
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
// Should have an item with the original ID being modified
opts.Items.Any(item =>
item.Id == "si_premium" &&
item.Price == "teams-seat-annually" &&
item.Quantity == 1 &&
item.Deleted != true)));
}
[Theory, BitAutoData]
public async Task Run_CreatesOrganizationWithCorrectSettings(User user)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
var mockSubscription = new Subscription
{
Id = "sub_123",
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new SubscriptionItem
{
Id = "si_premium",
Price = new Price { Id = "premium-annually" }
}
}
},
Metadata = new Dictionary<string, string>()
};
var mockPremiumPlans = CreateTestPremiumPlansList();
var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually");
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription);
_pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan);
_stripeAdapter.UpdateSubscriptionAsync(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>()).Returns(mockSubscription);
_stripeAdapter.UpdateCustomerAsync(Arg.Any<string>(), Arg.Any<CustomerUpdateOptions>()).Returns(Task.FromResult(new Customer()));
_organizationRepository.CreateAsync(Arg.Any<Organization>()).Returns(callInfo => Task.FromResult(callInfo.Arg<Organization>()));
_organizationApiKeyRepository.CreateAsync(Arg.Any<OrganizationApiKey>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationApiKey>()));
_organizationUserRepository.CreateAsync(Arg.Any<OrganizationUser>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationUser>()));
_applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any<Organization>()).Returns(Task.CompletedTask);
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT0);
await _organizationRepository.Received(1).CreateAsync(
Arg.Is<Organization>(org =>
org.Name == "My Organization" &&
org.BillingEmail == user.Email &&
org.PlanType == PlanType.TeamsAnnually &&
org.Seats == 1 &&
org.Gateway == GatewayType.Stripe &&
org.GatewayCustomerId == "cus_123" &&
org.GatewaySubscriptionId == "sub_123" &&
org.Enabled == true));
}
[Theory, BitAutoData]
public async Task Run_CreatesOrganizationApiKeyWithCorrectType(User user)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
var mockSubscription = new Subscription
{
Id = "sub_123",
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new SubscriptionItem
{
Id = "si_premium",
Price = new Price { Id = "premium-annually" }
}
}
},
Metadata = new Dictionary<string, string>()
};
var mockPremiumPlans = CreateTestPremiumPlansList();
var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually");
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription);
_pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan);
_stripeAdapter.UpdateSubscriptionAsync(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>()).Returns(mockSubscription);
_stripeAdapter.UpdateCustomerAsync(Arg.Any<string>(), Arg.Any<CustomerUpdateOptions>()).Returns(Task.FromResult(new Customer()));
_organizationRepository.CreateAsync(Arg.Any<Organization>()).Returns(callInfo => Task.FromResult(callInfo.Arg<Organization>()));
_organizationApiKeyRepository.CreateAsync(Arg.Any<OrganizationApiKey>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationApiKey>()));
_organizationUserRepository.CreateAsync(Arg.Any<OrganizationUser>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationUser>()));
_applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any<Organization>()).Returns(Task.CompletedTask);
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT0);
await _organizationApiKeyRepository.Received(1).CreateAsync(
Arg.Is<OrganizationApiKey>(apiKey =>
apiKey.Type == OrganizationApiKeyType.Default &&
!string.IsNullOrEmpty(apiKey.ApiKey)));
}
[Theory, BitAutoData]
public async Task Run_CreatesOrganizationUserAsOwnerWithAllPermissions(User user)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
var mockSubscription = new Subscription
{
Id = "sub_123",
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new SubscriptionItem
{
Id = "si_premium",
Price = new Price { Id = "premium-annually" }
}
}
},
Metadata = new Dictionary<string, string>()
};
var mockPremiumPlans = CreateTestPremiumPlansList();
var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually");
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription);
_pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan);
_stripeAdapter.UpdateSubscriptionAsync(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>()).Returns(mockSubscription);
_stripeAdapter.UpdateCustomerAsync(Arg.Any<string>(), Arg.Any<CustomerUpdateOptions>()).Returns(Task.FromResult(new Customer()));
_organizationRepository.CreateAsync(Arg.Any<Organization>()).Returns(callInfo => Task.FromResult(callInfo.Arg<Organization>()));
_organizationApiKeyRepository.CreateAsync(Arg.Any<OrganizationApiKey>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationApiKey>()));
_organizationUserRepository.CreateAsync(Arg.Any<OrganizationUser>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationUser>()));
_applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any<Organization>()).Returns(Task.CompletedTask);
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress());
// Assert
Assert.True(result.IsT0);
await _organizationUserRepository.Received(1).CreateAsync(
Arg.Is<OrganizationUser>(orgUser =>
orgUser.UserId == user.Id &&
orgUser.Type == OrganizationUserType.Owner &&
orgUser.Status == OrganizationUserStatusType.Confirmed));
}
}

View File

@@ -0,0 +1,109 @@
using System.Text.Json;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Subscriptions.Entities;
using Xunit;
namespace Bit.Core.Test.Billing.Subscriptions.Entities;
public class SubscriptionDiscountTests
{
[Fact]
public void StripeProductIds_CanSerializeToJson()
{
// Arrange
var discount = new SubscriptionDiscount
{
StripeCouponId = "test-coupon",
StripeProductIds = new List<string> { "prod_123", "prod_456" },
Duration = "once",
StartDate = DateTime.UtcNow,
EndDate = DateTime.UtcNow.AddDays(30),
AudienceType = DiscountAudienceType.UserHasNoPreviousSubscriptions
};
// Act
var json = JsonSerializer.Serialize(discount.StripeProductIds);
// Assert
Assert.Equal("[\"prod_123\",\"prod_456\"]", json);
}
[Fact]
public void StripeProductIds_CanDeserializeFromJson()
{
// Arrange
var json = "[\"prod_123\",\"prod_456\"]";
// Act
var result = JsonSerializer.Deserialize<List<string>>(json);
// Assert
Assert.NotNull(result);
Assert.Equal(2, result.Count);
Assert.Contains("prod_123", result);
Assert.Contains("prod_456", result);
}
[Fact]
public void StripeProductIds_HandlesNull()
{
// Arrange
var discount = new SubscriptionDiscount
{
StripeCouponId = "test-coupon",
StripeProductIds = null,
Duration = "once",
StartDate = DateTime.UtcNow,
EndDate = DateTime.UtcNow.AddDays(30),
AudienceType = DiscountAudienceType.UserHasNoPreviousSubscriptions
};
// Act
var json = JsonSerializer.Serialize(discount.StripeProductIds);
// Assert
Assert.Equal("null", json);
}
[Fact]
public void StripeProductIds_HandlesEmptyCollection()
{
// Arrange
var discount = new SubscriptionDiscount
{
StripeCouponId = "test-coupon",
StripeProductIds = new List<string>(),
Duration = "once",
StartDate = DateTime.UtcNow,
EndDate = DateTime.UtcNow.AddDays(30),
AudienceType = DiscountAudienceType.UserHasNoPreviousSubscriptions
};
// Act
var json = JsonSerializer.Serialize(discount.StripeProductIds);
// Assert
Assert.Equal("[]", json);
}
[Fact]
public void Validate_RejectsEndDateBeforeStartDate()
{
// Arrange
var discount = new SubscriptionDiscount
{
StripeCouponId = "test-coupon",
Duration = "once",
StartDate = DateTime.UtcNow.AddDays(30),
EndDate = DateTime.UtcNow, // EndDate before StartDate
AudienceType = DiscountAudienceType.UserHasNoPreviousSubscriptions
};
// Act
var validationResults = discount.Validate(new System.ComponentModel.DataAnnotations.ValidationContext(discount)).ToList();
// Assert
Assert.Single(validationResults);
Assert.Contains("EndDate", validationResults[0].MemberNames);
}
}

View File

@@ -461,6 +461,77 @@ public class GetBitwardenSubscriptionQueryTests
Assert.Equal(PlanCadenceType.Annually, result.Cart.Cadence);
}
[Fact]
public async Task Run_UserOnLegacyPricing_ReturnsCostFromPricingService()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Active, legacyPricing: true);
var premiumPlans = CreatePremiumPlans();
var availablePlan = premiumPlans.First(p => p.Available);
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
var previewInvoice = CreateInvoicePreview(totalTax: 150);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(previewInvoice);
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(availablePlan.Seat.Price, result.Cart.PasswordManager.Seats.Cost);
Assert.Equal(1.50m, result.Cart.EstimatedTax);
}
[Fact]
public async Task Run_UserOnLegacyPricing_CallsPreviewInvoiceWithRebuiltSubscription()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Active, legacyPricing: true);
var premiumPlans = CreatePremiumPlans();
var availablePlan = premiumPlans.First(p => p.Available);
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
var previewInvoice = CreateInvoicePreview();
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(previewInvoice);
await _query.Run(user);
await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(
Arg.Is<InvoiceCreatePreviewOptions>(opts =>
opts.Subscription == null &&
opts.AutomaticTax != null &&
opts.AutomaticTax.Enabled == true &&
opts.SubscriptionDetails != null &&
opts.SubscriptionDetails.Items.Any(i =>
i.Price == availablePlan.Seat.StripePriceId &&
i.Quantity == 1)));
}
[Fact]
public async Task Run_UserOnCurrentPricing_ReturnsCostFromSubscriptionItem()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Active, legacyPricing: false);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(19.80m, result.Cart.PasswordManager.Seats.Cost);
}
#region Helper Methods
private static User CreateUser()
@@ -477,11 +548,14 @@ public class GetBitwardenSubscriptionQueryTests
private static Subscription CreateSubscription(
string status,
bool includeStorage = false,
bool legacyPricing = false,
DateTime? cancelAt = null,
DateTime? canceledAt = null,
string collectionMethod = "charge_automatically")
{
var currentPeriodEnd = DateTime.UtcNow.AddMonths(1);
var seatPriceId = legacyPricing ? "price_legacy_premium_seat" : "price_premium_seat";
var seatUnitAmount = legacyPricing ? 1000 : 1980;
var items = new List<SubscriptionItem>
{
new()
@@ -489,8 +563,8 @@ public class GetBitwardenSubscriptionQueryTests
Id = "si_premium_seat",
Price = new Price
{
Id = "price_premium_seat",
UnitAmountDecimal = 1000,
Id = seatPriceId,
UnitAmountDecimal = seatUnitAmount,
Product = new Product { Id = "prod_premium_seat" }
},
Quantity = 1,
@@ -521,6 +595,7 @@ public class GetBitwardenSubscriptionQueryTests
Id = "sub_test123",
Status = status,
Created = DateTime.UtcNow.AddMonths(-1),
AutomaticTax = new SubscriptionAutomaticTax { Enabled = true },
Customer = new Customer
{
Id = "cus_test123",
@@ -548,6 +623,24 @@ public class GetBitwardenSubscriptionQueryTests
Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "price_premium_seat",
Price = 19.80m,
Provided = 1
},
Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "price_storage",
Price = 4.0m,
Provided = 1
}
},
new()
{
Name = "Premium",
Available = false,
LegacyYear = 2024,
Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "price_legacy_premium_seat",
Price = 10.0m,
Provided = 1
},

View File

@@ -1,4 +1,7 @@
using Bit.Core.Billing.Enums;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Exceptions;
@@ -9,6 +12,7 @@ using Bit.Core.OrganizationFeatures.OrganizationSubscriptions;
using Bit.Core.Repositories;
using Bit.Core.SecretsManager.Repositories;
using Bit.Core.Services;
using Bit.Core.Test.AdminConsole.AutoFixture;
using Bit.Core.Test.AutoFixture.OrganizationFixtures;
using Bit.Core.Test.Billing.Mocks;
using Bit.Test.Common.AutoFixture;
@@ -72,8 +76,12 @@ public class UpgradeOrganizationPlanCommandTests
[Theory]
[FreeOrganizationUpgradeCustomize, BitAutoData]
public async Task UpgradePlan_Passes(Organization organization, OrganizationUpgrade upgrade,
[Policy(PolicyType.ResetPassword, false)] PolicyStatus policy,
SutProvider<UpgradeOrganizationPlanCommand> sutProvider)
{
sutProvider.GetDependency<IPolicyQuery>()
.RunAsync(Arg.Any<Guid>(), Arg.Any<PolicyType>())
.Returns(policy);
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(organization.Id).Returns(organization);
sutProvider.GetDependency<IPricingClient>().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType));
upgrade.AdditionalSmSeats = 10;
@@ -100,6 +108,7 @@ public class UpgradeOrganizationPlanCommandTests
PlanType planType,
Organization organization,
OrganizationUpgrade organizationUpgrade,
[Policy(PolicyType.ResetPassword, false)] PolicyStatus policy,
SutProvider<UpgradeOrganizationPlanCommand> sutProvider)
{
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(organization.Id).Returns(organization);
@@ -116,6 +125,9 @@ public class UpgradeOrganizationPlanCommandTests
organizationUpgrade.Plan = planType;
sutProvider.GetDependency<IPricingClient>().GetPlanOrThrow(organizationUpgrade.Plan).Returns(MockPlans.Get(organizationUpgrade.Plan));
sutProvider.GetDependency<IPolicyQuery>()
.RunAsync(Arg.Any<Guid>(), Arg.Any<PolicyType>())
.Returns(policy);
sutProvider.GetDependency<IOrganizationRepository>()
.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id).Returns(new OrganizationSeatCounts
{
@@ -141,15 +153,20 @@ public class UpgradeOrganizationPlanCommandTests
[BitAutoData(PlanType.TeamsAnnually)]
[BitAutoData(PlanType.TeamsStarter)]
public async Task UpgradePlan_SM_Passes(PlanType planType, Organization organization, OrganizationUpgrade upgrade,
[Policy(PolicyType.ResetPassword, false)] PolicyStatus policy,
SutProvider<UpgradeOrganizationPlanCommand> sutProvider)
{
sutProvider.GetDependency<IPricingClient>().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType));
upgrade.Plan = planType;
sutProvider.GetDependency<IPricingClient>().GetPlanOrThrow(upgrade.Plan).Returns(MockPlans.Get(upgrade.Plan));
var plan = MockPlans.Get(upgrade.Plan);
sutProvider.GetDependency<IPolicyQuery>()
.RunAsync(Arg.Any<Guid>(), Arg.Any<PolicyType>())
.Returns(policy);
sutProvider.GetDependency<IPricingClient>().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType));
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(organization.Id).Returns(organization);
upgrade.AdditionalSeats = 15;
@@ -180,6 +197,7 @@ public class UpgradeOrganizationPlanCommandTests
[BitAutoData(PlanType.TeamsAnnually)]
[BitAutoData(PlanType.TeamsStarter)]
public async Task UpgradePlan_SM_NotEnoughSmSeats_Throws(PlanType planType, Organization organization, OrganizationUpgrade upgrade,
[Policy(PolicyType.ResetPassword, false)] PolicyStatus policy,
SutProvider<UpgradeOrganizationPlanCommand> sutProvider)
{
upgrade.Plan = planType;
@@ -191,6 +209,10 @@ public class UpgradeOrganizationPlanCommandTests
organization.SmSeats = 2;
sutProvider.GetDependency<IPricingClient>().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType));
sutProvider.GetDependency<IPolicyQuery>()
.RunAsync(Arg.Any<Guid>(), Arg.Any<PolicyType>())
.Returns(policy);
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(organization.Id).Returns(organization);
sutProvider.GetDependency<IOrganizationRepository>()
.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id).Returns(new OrganizationSeatCounts
@@ -214,7 +236,9 @@ public class UpgradeOrganizationPlanCommandTests
[BitAutoData(PlanType.TeamsAnnually, 51)]
[BitAutoData(PlanType.TeamsStarter, 51)]
public async Task UpgradePlan_SM_NotEnoughServiceAccounts_Throws(PlanType planType, int currentServiceAccounts,
Organization organization, OrganizationUpgrade upgrade, SutProvider<UpgradeOrganizationPlanCommand> sutProvider)
Organization organization, OrganizationUpgrade upgrade,
[Policy(PolicyType.ResetPassword, false)] PolicyStatus policy,
SutProvider<UpgradeOrganizationPlanCommand> sutProvider)
{
upgrade.Plan = planType;
upgrade.AdditionalSeats = 15;
@@ -226,6 +250,10 @@ public class UpgradeOrganizationPlanCommandTests
organization.SmServiceAccounts = currentServiceAccounts;
sutProvider.GetDependency<IPricingClient>().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType));
sutProvider.GetDependency<IPolicyQuery>()
.RunAsync(Arg.Any<Guid>(), Arg.Any<PolicyType>())
.Returns(policy);
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(organization.Id).Returns(organization);
sutProvider.GetDependency<IOrganizationRepository>()
.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id).Returns(new OrganizationSeatCounts
@@ -251,6 +279,7 @@ public class UpgradeOrganizationPlanCommandTests
OrganizationUpgrade upgrade,
string newPublicKey,
string newPrivateKey,
[Policy(PolicyType.ResetPassword, false)] PolicyStatus policy,
SutProvider<UpgradeOrganizationPlanCommand> sutProvider)
{
organization.PublicKey = null;
@@ -262,6 +291,9 @@ public class UpgradeOrganizationPlanCommandTests
publicKey: newPublicKey);
upgrade.AdditionalSeats = 10;
sutProvider.GetDependency<IPolicyQuery>()
.RunAsync(Arg.Any<Guid>(), Arg.Any<PolicyType>())
.Returns(policy);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
@@ -291,6 +323,7 @@ public class UpgradeOrganizationPlanCommandTests
public async Task UpgradePlan_WhenOrganizationAlreadyHasPublicAndPrivateKeys_DoesNotOverwriteWithNull(
Organization organization,
OrganizationUpgrade upgrade,
[Policy(PolicyType.ResetPassword, false)] PolicyStatus policy,
SutProvider<UpgradeOrganizationPlanCommand> sutProvider)
{
// Arrange
@@ -304,6 +337,9 @@ public class UpgradeOrganizationPlanCommandTests
upgrade.Keys = null;
upgrade.AdditionalSeats = 10;
sutProvider.GetDependency<IPolicyQuery>()
.RunAsync(Arg.Any<Guid>(), Arg.Any<PolicyType>())
.Returns(policy);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
@@ -333,6 +369,7 @@ public class UpgradeOrganizationPlanCommandTests
public async Task UpgradePlan_WhenOrganizationAlreadyHasPublicAndPrivateKeys_DoesNotBackfillWithNewKeys(
Organization organization,
OrganizationUpgrade upgrade,
[Policy(PolicyType.ResetPassword, false)] PolicyStatus policy,
SutProvider<UpgradeOrganizationPlanCommand> sutProvider)
{
// Arrange
@@ -343,6 +380,9 @@ public class UpgradeOrganizationPlanCommandTests
organization.PublicKey = existingPublicKey;
organization.PrivateKey = existingPrivateKey;
sutProvider.GetDependency<IPolicyQuery>()
.RunAsync(Arg.Any<Guid>(), Arg.Any<PolicyType>())
.Returns(policy);
upgrade.Plan = PlanType.TeamsAnnually;
upgrade.Keys = new PublicKeyEncryptionKeyPairData(

View File

@@ -0,0 +1,55 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations;
using Bit.Core.AdminConsole.Repositories;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using NSubstitute.ReturnsExtensions;
using Xunit;
namespace Bit.Core.Test.OrganizationFeatures.Policies;
[SutProviderCustomize]
public class PolicyQueryTests
{
[Theory, BitAutoData]
public async Task RunAsync_WithExistingPolicy_ReturnsPolicy(SutProvider<PolicyQuery> sutProvider,
Policy policy)
{
// Arrange
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(policy.OrganizationId, policy.Type)
.Returns(policy);
// Act
var policyData = await sutProvider.Sut.RunAsync(policy.OrganizationId, policy.Type);
// Assert
Assert.Equal(policy.Data, policyData.Data);
Assert.Equal(policy.Type, policyData.Type);
Assert.Equal(policy.Enabled, policyData.Enabled);
Assert.Equal(policy.OrganizationId, policyData.OrganizationId);
}
[Theory, BitAutoData]
public async Task RunAsync_WithNonExistentPolicy_ReturnsDefaultDisabledPolicy(
SutProvider<PolicyQuery> sutProvider,
Guid organizationId,
PolicyType policyType)
{
// Arrange
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(organizationId, policyType)
.ReturnsNull();
// Act
var policyData = await sutProvider.Sut.RunAsync(organizationId, policyType);
// Assert
Assert.Equal(organizationId, policyData.OrganizationId);
Assert.Equal(policyType, policyData.Type);
Assert.False(policyData.Enabled);
Assert.Null(policyData.Data);
}
}

View File

@@ -0,0 +1,195 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Models.Data.Organizations;
using Bit.Core.Platform.Mail.Delivery;
using Bit.Core.Platform.Mail.Enqueuing;
using Bit.Core.Services.Mail;
using Bit.Core.Settings;
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Platform.Mail;
public class DomainClaimedEmailRenderTest
{
[Fact]
public async Task RenderDomainClaimedEmail_ToVerifyTemplate()
{
var globalSettings = new GlobalSettings
{
Mail = new GlobalSettings.MailSettings
{
ReplyToEmail = "no-reply@bitwarden.com",
Smtp = new GlobalSettings.MailSettings.SmtpSettings
{
Host = "localhost",
Port = 1025,
StartTls = false,
Ssl = false
}
},
SiteName = "Bitwarden"
};
var mailDeliveryService = Substitute.For<IMailDeliveryService>();
var mailEnqueuingService = new BlockingMailEnqueuingService();
var distributedCache = Substitute.For<IDistributedCache>();
var logger = Substitute.For<ILogger<HandlebarsMailService>>();
var mailService = new HandlebarsMailService(
globalSettings,
mailDeliveryService,
mailEnqueuingService,
distributedCache,
logger
);
var organization = new Organization
{
Id = Guid.NewGuid(),
Name = "Acme Corporation"
};
var testEmails = new List<string>
{
"alice@acme.com",
"bob@acme.com",
"charlie@acme.com"
};
var emailList = new ClaimedUserDomainClaimedEmails(
testEmails,
organization,
"acme.com"
);
await mailService.SendClaimedDomainUserEmailAsync(emailList);
await mailDeliveryService.Received(3).SendEmailAsync(Arg.Any<Bit.Core.Models.Mail.MailMessage>());
var calls = mailDeliveryService.ReceivedCalls()
.Where(call => call.GetMethodInfo().Name == "SendEmailAsync")
.ToList();
Assert.Equal(3, calls.Count);
foreach (var call in calls)
{
var mailMessage = call.GetArguments()[0] as Bit.Core.Models.Mail.MailMessage;
Assert.NotNull(mailMessage);
var recipient = mailMessage.ToEmails.First();
Assert.Contains("@acme.com", mailMessage.HtmlContent);
Assert.Contains(recipient, mailMessage.HtmlContent);
Assert.DoesNotContain("[at]", mailMessage.HtmlContent);
Assert.DoesNotContain("[dot]", mailMessage.HtmlContent);
}
}
[Fact(Skip = "For local development - requires MailCatcher at localhost:10250")]
public async Task SendDomainClaimedEmail_ToMailCatcher()
{
var globalSettings = new GlobalSettings
{
Mail = new GlobalSettings.MailSettings
{
ReplyToEmail = "no-reply@bitwarden.com",
Smtp = new GlobalSettings.MailSettings.SmtpSettings
{
Host = "localhost",
Port = 10250,
StartTls = false,
Ssl = false
}
},
SiteName = "Bitwarden"
};
var mailDeliveryLogger = Substitute.For<ILogger<MailKitSmtpMailDeliveryService>>();
var mailDeliveryService = new MailKitSmtpMailDeliveryService(globalSettings, mailDeliveryLogger);
var mailEnqueuingService = new BlockingMailEnqueuingService();
var distributedCache = Substitute.For<IDistributedCache>();
var logger = Substitute.For<ILogger<HandlebarsMailService>>();
var mailService = new HandlebarsMailService(
globalSettings,
mailDeliveryService,
mailEnqueuingService,
distributedCache,
logger
);
var organization = new Organization
{
Id = Guid.NewGuid(),
Name = "Acme Corporation"
};
var testEmails = new List<string>
{
"alice@acme.com",
"bob@acme.com"
};
var emailList = new ClaimedUserDomainClaimedEmails(
testEmails,
organization,
"acme.com"
);
await mailService.SendClaimedDomainUserEmailAsync(emailList);
}
[Fact(Skip = "This test sends actual emails and is for manual template verification only")]
public async Task RenderDomainClaimedEmail_WithSpecialCharacters()
{
var globalSettings = new GlobalSettings
{
Mail = new GlobalSettings.MailSettings
{
Smtp = new GlobalSettings.MailSettings.SmtpSettings
{
Host = "localhost",
Port = 1025,
StartTls = false,
Ssl = false
}
},
SiteName = "Bitwarden"
};
var mailDeliveryService = Substitute.For<IMailDeliveryService>();
var mailEnqueuingService = new BlockingMailEnqueuingService();
var distributedCache = Substitute.For<IDistributedCache>();
var logger = Substitute.For<ILogger<HandlebarsMailService>>();
var mailService = new HandlebarsMailService(
globalSettings,
mailDeliveryService,
mailEnqueuingService,
distributedCache,
logger
);
var organization = new Organization
{
Id = Guid.NewGuid(),
Name = "Test Corp & Co."
};
var testEmails = new List<string>
{
"test.user+tag@example.com"
};
var emailList = new ClaimedUserDomainClaimedEmails(
testEmails,
organization,
"example.com"
);
await mailService.SendClaimedDomainUserEmailAsync(emailList);
}
}

View File

@@ -254,21 +254,6 @@ public class HandlebarsMailServiceTests
}
}
[Fact]
public async Task SendSendEmailOtpEmailAsync_SendsEmail()
{
// Arrange
var email = "test@example.com";
var token = "aToken";
var subject = string.Format("Your Bitwarden Send verification code is {0}", token);
// Act
await _sut.SendSendEmailOtpEmailAsync(email, token, subject);
// Assert
await _mailDeliveryService.Received(1).SendEmailAsync(Arg.Any<MailMessage>());
}
[Fact]
public async Task SendIndividualUserWelcomeEmailAsync_SendsCorrectEmail()
{

View File

@@ -0,0 +1,118 @@
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Services;
using Bit.Core.Services;
using Bit.Core.Settings;
using Braintree;
using Braintree.Exceptions;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Xunit;
using BraintreeService = Bit.Core.Services.Implementations.BraintreeService;
using Customer = Stripe.Customer;
namespace Bit.Core.Test.Services.Implementations;
public class BraintreeServiceTests
{
private readonly ICustomerGateway _customerGateway;
private readonly BraintreeService _sut;
public BraintreeServiceTests()
{
var braintreeGateway = Substitute.For<IBraintreeGateway>();
_customerGateway = Substitute.For<ICustomerGateway>();
braintreeGateway.Customer.Returns(_customerGateway);
var globalSettings = Substitute.For<IGlobalSettings>();
var logger = Substitute.For<ILogger<BraintreeService>>();
var mailService = Substitute.For<IMailService>();
var stripeAdapter = Substitute.For<IStripeAdapter>();
_sut = new BraintreeService(
braintreeGateway,
globalSettings,
logger,
mailService,
stripeAdapter);
}
#region GetCustomer
[Fact]
public async Task GetCustomer_NoBraintreeCustomerIdInMetadata_ReturnsNull()
{
// Arrange
var stripeCustomer = new Customer
{
Id = "cus_123",
Metadata = new Dictionary<string, string>()
};
// Act
var result = await _sut.GetCustomer(stripeCustomer);
// Assert
Assert.Null(result);
await _customerGateway.DidNotReceiveWithAnyArgs().FindAsync(Arg.Any<string>());
}
[Fact]
public async Task GetCustomer_BraintreeCustomerFound_ReturnsCustomer()
{
// Arrange
const string braintreeCustomerId = "bt_customer_123";
var stripeCustomer = new Customer
{
Id = "cus_123",
Metadata = new Dictionary<string, string>
{
[StripeConstants.MetadataKeys.BraintreeCustomerId] = braintreeCustomerId
}
};
var braintreeCustomer = Substitute.For<Braintree.Customer>();
_customerGateway
.FindAsync(braintreeCustomerId)
.Returns(braintreeCustomer);
// Act
var result = await _sut.GetCustomer(stripeCustomer);
// Assert
Assert.NotNull(result);
Assert.Same(braintreeCustomer, result);
await _customerGateway.Received(1).FindAsync(braintreeCustomerId);
}
[Fact]
public async Task GetCustomer_BraintreeCustomerNotFound_LogsWarningAndReturnsNull()
{
// Arrange
const string braintreeCustomerId = "bt_non_existent_customer";
var stripeCustomer = new Customer
{
Id = "cus_123",
Metadata = new Dictionary<string, string>
{
[StripeConstants.MetadataKeys.BraintreeCustomerId] = braintreeCustomerId
}
};
_customerGateway
.FindAsync(braintreeCustomerId)
.Returns<Braintree.Customer>(_ => throw new NotFoundException());
// Act
var result = await _sut.GetCustomer(stripeCustomer);
// Assert
Assert.Null(result);
await _customerGateway.Received(1).FindAsync(braintreeCustomerId);
}
#endregion
}

View File

@@ -135,6 +135,43 @@ public class ImportCiphersAsyncCommandTests
Assert.Equal("You cannot import items into your personal vault because you are a member of an organization which forbids it.", exception.Message);
}
[Theory, BitAutoData]
public async Task ImportIntoIndividualVaultAsync_FavoriteCiphers_PersistsFavoriteInfo(
Guid importingUserId,
List<CipherDetails> ciphers,
SutProvider<ImportCiphersCommand> sutProvider
)
{
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PolicyRequirements)
.Returns(true);
sutProvider.GetDependency<IPolicyRequirementQuery>()
.GetAsync<OrganizationDataOwnershipPolicyRequirement>(importingUserId)
.Returns(new OrganizationDataOwnershipPolicyRequirement(
OrganizationDataOwnershipState.Disabled,
[]));
sutProvider.GetDependency<IFolderRepository>()
.GetManyByUserIdAsync(importingUserId)
.Returns(new List<Folder>());
var folders = new List<Folder>();
var folderRelationships = new List<KeyValuePair<int, int>>();
ciphers.ForEach(c =>
{
c.UserId = importingUserId;
c.Favorite = true;
});
await sutProvider.Sut.ImportIntoIndividualVaultAsync(folders, ciphers, folderRelationships, importingUserId);
await sutProvider.GetDependency<ICipherRepository>()
.Received(1)
.CreateAsync(importingUserId, Arg.Is<IEnumerable<Cipher>>(ciphers => ciphers.All(c => c.Favorites == $"{{\"{importingUserId.ToString().ToUpperInvariant()}\":true}}")), Arg.Any<List<Folder>>());
}
[Theory, BitAutoData]
public async Task ImportIntoOrganizationalVaultAsync_Success(
Organization organization,
@@ -289,4 +326,101 @@ public class ImportCiphersAsyncCommandTests
await sutProvider.GetDependency<IPushNotificationService>().Received(1).PushSyncVaultAsync(importingUserId);
}
[Theory, BitAutoData]
public async Task ImportIntoIndividualVaultAsync_WithArchivedCiphers_PreservesArchiveStatus(
Guid importingUserId,
List<CipherDetails> ciphers,
SutProvider<ImportCiphersCommand> sutProvider)
{
var archivedDate = DateTime.UtcNow.AddDays(-1);
ciphers[0].UserId = importingUserId;
ciphers[0].ArchivedDate = archivedDate;
sutProvider.GetDependency<IPolicyService>()
.AnyPoliciesApplicableToUserAsync(importingUserId, PolicyType.OrganizationDataOwnership)
.Returns(false);
sutProvider.GetDependency<IFolderRepository>()
.GetManyByUserIdAsync(importingUserId)
.Returns(new List<Folder>());
var folders = new List<Folder>();
var folderRelationships = new List<KeyValuePair<int, int>>();
await sutProvider.Sut.ImportIntoIndividualVaultAsync(folders, ciphers, folderRelationships, importingUserId);
await sutProvider.GetDependency<ICipherRepository>()
.Received(1)
.CreateAsync(importingUserId,
Arg.Is<List<CipherDetails>>(c =>
c[0].Archives != null &&
c[0].Archives.Contains(importingUserId.ToString().ToUpperInvariant()) &&
c[0].Archives.Contains(archivedDate.ToString("yyyy-MM-ddTHH:mm:ss.fffffffZ"))),
Arg.Any<List<Folder>>());
}
/*
* Archive functionality is a per-user function. When importing archived ciphers into an organization vault,
* the Archives field should be set for the importing user only. This allows the importing user to see
* items as archived, while other organization members will not see them as archived.
*/
[Theory, BitAutoData]
public async Task ImportIntoOrganizationalVaultAsync_WithArchivedCiphers_SetsArchivesForImportingUserOnly(
Organization organization,
Guid importingUserId,
OrganizationUser importingOrganizationUser,
List<Collection> collections,
List<CipherDetails> ciphers,
SutProvider<ImportCiphersCommand> sutProvider)
{
var archivedDate = DateTime.UtcNow.AddDays(-1);
organization.MaxCollections = null;
importingOrganizationUser.OrganizationId = organization.Id;
foreach (var collection in collections)
{
collection.OrganizationId = organization.Id;
}
foreach (var cipher in ciphers)
{
cipher.OrganizationId = organization.Id;
}
ciphers[0].ArchivedDate = archivedDate;
ciphers[0].Archives = null;
KeyValuePair<int, int>[] collectionRelationships = {
new(0, 0),
new(1, 1),
new(2, 2)
};
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetByOrganizationAsync(organization.Id, importingUserId)
.Returns(importingOrganizationUser);
sutProvider.GetDependency<ICollectionRepository>()
.GetManyByOrganizationIdAsync(organization.Id)
.Returns(new List<Collection>());
await sutProvider.Sut.ImportIntoOrganizationalVaultAsync(collections, ciphers, collectionRelationships, importingUserId);
await sutProvider.GetDependency<ICipherRepository>()
.Received(1)
.CreateAsync(
Arg.Is<List<CipherDetails>>(c =>
c[0].ArchivedDate == archivedDate &&
c[0].Archives != null &&
c[0].Archives.Contains(importingUserId.ToString().ToUpperInvariant()) &&
c[0].Archives.Contains(archivedDate.ToString("yyyy-MM-ddTHH:mm:ss.fffffffZ"))),
Arg.Any<IEnumerable<Collection>>(),
Arg.Any<IEnumerable<CollectionCipher>>(),
Arg.Any<IEnumerable<CollectionUser>>());
}
}

View File

@@ -11,6 +11,7 @@ using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Data;
using Bit.Core.Tools.Repositories;
using Bit.Core.Tools.SendFeatures.Commands;
using Bit.Core.Tools.SendFeatures.Commands.Interfaces;
using Bit.Core.Tools.Services;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.Extensions.Logging;
@@ -28,7 +29,6 @@ public class NonAnonymousSendCommandTests
private readonly ISendRepository _sendRepository;
private readonly ISendFileStorageService _sendFileStorageService;
private readonly IPushNotificationService _pushNotificationService;
private readonly ISendAuthorizationService _sendAuthorizationService;
private readonly ISendValidationService _sendValidationService;
private readonly IFeatureService _featureService;
private readonly ICurrentContext _currentContext;
@@ -42,7 +42,6 @@ public class NonAnonymousSendCommandTests
_sendRepository = Substitute.For<ISendRepository>();
_sendFileStorageService = Substitute.For<ISendFileStorageService>();
_pushNotificationService = Substitute.For<IPushNotificationService>();
_sendAuthorizationService = Substitute.For<ISendAuthorizationService>();
_featureService = Substitute.For<IFeatureService>();
_sendValidationService = Substitute.For<ISendValidationService>();
_currentContext = Substitute.For<ICurrentContext>();
@@ -53,7 +52,6 @@ public class NonAnonymousSendCommandTests
_sendRepository,
_sendFileStorageService,
_pushNotificationService,
_sendAuthorizationService,
_sendValidationService,
_sendCoreHelperService,
_logger
@@ -1093,4 +1091,329 @@ public class NonAnonymousSendCommandTests
Assert.Equal("File received does not match expected file length.", exception.Message);
}
[Fact]
public async Task GetSendFileDownloadUrlAsync_WithTextSend_ThrowsBadRequest()
{
// Arrange
var send = new Send
{
Id = Guid.NewGuid(),
Type = SendType.Text,
UserId = Guid.NewGuid()
};
var fileId = "somefile123";
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(() =>
_nonAnonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId));
Assert.Equal("Can only get a download URL for a file type of Send", exception.Message);
// Verify no storage service methods were called
await _sendFileStorageService.DidNotReceive()
.GetSendFileDownloadUrlAsync(Arg.Any<Send>(), Arg.Any<string>());
}
[Fact]
public async Task GetSendFileDownloadUrlAsync_WithDisabledSend_ReturnsDenied()
{
// Arrange
var fileId = "file123";
var send = new Send
{
Id = Guid.NewGuid(),
Type = SendType.File,
UserId = Guid.NewGuid(),
Disabled = true,
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = null,
AccessCount = 0,
MaxAccessCount = null
};
// Act
var (url, result) = await _nonAnonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId);
// Assert
Assert.Null(url);
Assert.Equal(SendAccessResult.Denied, result);
// Verify no repository updates occurred
await _sendRepository.DidNotReceive().ReplaceAsync(Arg.Any<Send>());
await _pushNotificationService.DidNotReceive().PushSyncSendUpdateAsync(Arg.Any<Send>());
await _sendFileStorageService.DidNotReceive()
.GetSendFileDownloadUrlAsync(Arg.Any<Send>(), Arg.Any<string>());
}
[Fact]
public async Task GetSendFileDownloadUrlAsync_WithMaxAccessCountReached_ReturnsDenied()
{
// Arrange
var fileId = "file123";
var send = new Send
{
Id = Guid.NewGuid(),
Type = SendType.File,
UserId = Guid.NewGuid(),
Disabled = false,
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = null,
AccessCount = 5,
MaxAccessCount = 5
};
// Act
var (url, result) = await _nonAnonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId);
// Assert
Assert.Null(url);
Assert.Equal(SendAccessResult.Denied, result);
// Verify no repository updates occurred
await _sendRepository.DidNotReceive().ReplaceAsync(Arg.Any<Send>());
await _pushNotificationService.DidNotReceive().PushSyncSendUpdateAsync(Arg.Any<Send>());
await _sendFileStorageService.DidNotReceive()
.GetSendFileDownloadUrlAsync(Arg.Any<Send>(), Arg.Any<string>());
}
[Fact]
public async Task GetSendFileDownloadUrlAsync_WithExpiredSend_ReturnsDenied()
{
// Arrange
var fileId = "file123";
var send = new Send
{
Id = Guid.NewGuid(),
Type = SendType.File,
UserId = Guid.NewGuid(),
Disabled = false,
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = DateTime.UtcNow.AddDays(-1), // Expired yesterday
AccessCount = 0,
MaxAccessCount = null
};
// Act
var (url, result) = await _nonAnonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId);
// Assert
Assert.Null(url);
Assert.Equal(SendAccessResult.Denied, result);
// Verify no repository updates occurred
await _sendRepository.DidNotReceive().ReplaceAsync(Arg.Any<Send>());
await _pushNotificationService.DidNotReceive().PushSyncSendUpdateAsync(Arg.Any<Send>());
await _sendFileStorageService.DidNotReceive()
.GetSendFileDownloadUrlAsync(Arg.Any<Send>(), Arg.Any<string>());
}
[Fact]
public async Task GetSendFileDownloadUrlAsync_WithDeletionDatePassed_ReturnsDenied()
{
// Arrange
var fileId = "file123";
var send = new Send
{
Id = Guid.NewGuid(),
Type = SendType.File,
UserId = Guid.NewGuid(),
Disabled = false,
DeletionDate = DateTime.UtcNow.AddDays(-1), // Deletion date has passed
ExpirationDate = null,
AccessCount = 0,
MaxAccessCount = null
};
// Act
var (url, result) = await _nonAnonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId);
// Assert
Assert.Null(url);
Assert.Equal(SendAccessResult.Denied, result);
// Verify no repository updates occurred
await _sendRepository.DidNotReceive().ReplaceAsync(Arg.Any<Send>());
await _pushNotificationService.DidNotReceive().PushSyncSendUpdateAsync(Arg.Any<Send>());
await _sendFileStorageService.DidNotReceive()
.GetSendFileDownloadUrlAsync(Arg.Any<Send>(), Arg.Any<string>());
}
[Fact]
public async Task GetSendFileDownloadUrlAsync_WithValidSend_ReturnsUrlAndIncrementsAccessCount()
{
// Arrange
var fileId = "file123";
var expectedUrl = "https://download.example.com/file123";
var send = new Send
{
Id = Guid.NewGuid(),
Type = SendType.File,
UserId = Guid.NewGuid(),
Disabled = false,
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = null,
AccessCount = 3,
MaxAccessCount = 10
};
_sendFileStorageService.GetSendFileDownloadUrlAsync(send, fileId).Returns(expectedUrl);
// Act
var (url, result) = await _nonAnonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId);
// Assert
Assert.Equal(expectedUrl, url);
Assert.Equal(SendAccessResult.Granted, result);
// Verify access count was incremented
Assert.Equal(4, send.AccessCount);
// Verify repository was updated
await _sendRepository.Received(1).ReplaceAsync(send);
await _pushNotificationService.Received(1).PushSyncSendUpdateAsync(send);
// Verify file storage service was called
await _sendFileStorageService.Received(1).GetSendFileDownloadUrlAsync(send, fileId);
}
[Fact]
public void SendCanBeAccessed_WithDisabledSend_ReturnsFalse()
{
// Arrange
var send = new Send
{
Disabled = true,
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = null,
AccessCount = 0,
MaxAccessCount = null
};
// Act
var result = INonAnonymousSendCommand.SendCanBeAccessed(send);
// Assert
Assert.False(result);
}
[Fact]
public void SendCanBeAccessed_WithMaxAccessCountReached_ReturnsFalse()
{
// Arrange
var send = new Send
{
Disabled = false,
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = null,
AccessCount = 10,
MaxAccessCount = 10
};
// Act
var result = INonAnonymousSendCommand.SendCanBeAccessed(send);
// Assert
Assert.False(result);
}
[Fact]
public void SendCanBeAccessed_WithExpiredSend_ReturnsFalse()
{
// Arrange
var send = new Send
{
Disabled = false,
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = DateTime.UtcNow.AddDays(-1),
AccessCount = 0,
MaxAccessCount = null
};
// Act
var result = INonAnonymousSendCommand.SendCanBeAccessed(send);
// Assert
Assert.False(result);
}
[Fact]
public void SendCanBeAccessed_WithDeletionDatePassed_ReturnsFalse()
{
// Arrange
var send = new Send
{
Disabled = false,
DeletionDate = DateTime.UtcNow.AddDays(-1),
ExpirationDate = null,
AccessCount = 0,
MaxAccessCount = null
};
// Act
var result = INonAnonymousSendCommand.SendCanBeAccessed(send);
// Assert
Assert.False(result);
}
[Fact]
public void SendCanBeAccessed_WithValidSend_ReturnsTrue()
{
// Arrange
var send = new Send
{
Disabled = false,
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = DateTime.UtcNow.AddDays(7),
AccessCount = 5,
MaxAccessCount = 10
};
// Act
var result = INonAnonymousSendCommand.SendCanBeAccessed(send);
// Assert
Assert.True(result);
}
[Fact]
public void SendCanBeAccessed_WithNullMaxAccessCount_ReturnsTrue()
{
// Arrange
var send = new Send
{
Disabled = false,
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = null,
AccessCount = 100,
MaxAccessCount = null
};
// Act
var result = INonAnonymousSendCommand.SendCanBeAccessed(send);
// Assert
Assert.True(result);
}
[Fact]
public void SendCanBeAccessed_WithNullExpirationDate_ReturnsTrue()
{
// Arrange
var send = new Send
{
Disabled = false,
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = null,
AccessCount = 0,
MaxAccessCount = 10
};
// Act
var result = INonAnonymousSendCommand.SendCanBeAccessed(send);
// Assert
Assert.True(result);
}
}

View File

@@ -43,7 +43,7 @@ public class SendAuthenticationQueryTests
}
[Theory]
[MemberData(nameof(EmailParsingTestCases))]
[MemberData(nameof(EmailsParsingTestCases))]
public async Task GetAuthenticationMethod_WithEmails_ParsesEmailsCorrectly(string emailString, string[] expectedEmails)
{
// Arrange
@@ -56,7 +56,7 @@ public class SendAuthenticationQueryTests
// Assert
var emailOtp = Assert.IsType<EmailOtp>(result);
Assert.Equal(expectedEmails, emailOtp.Emails);
Assert.Equal(expectedEmails, emailOtp.emails);
}
[Fact]
@@ -64,7 +64,7 @@ public class SendAuthenticationQueryTests
{
// Arrange
var sendId = Guid.NewGuid();
var send = CreateSend(accessCount: 0, maxAccessCount: 10, emails: "test@example.com", password: "hashedpassword", AuthType.Email);
var send = CreateSend(accessCount: 0, maxAccessCount: 10, emails: "person@company.com", password: "hashedpassword", AuthType.Email);
_sendRepository.GetByIdAsync(sendId).Returns(send);
// Act
@@ -108,18 +108,201 @@ public class SendAuthenticationQueryTests
yield return new object[] { null, typeof(NeverAuthenticate) };
yield return new object[] { CreateSend(accessCount: 5, maxAccessCount: 5, emails: null, password: null, AuthType.None), typeof(NeverAuthenticate) };
yield return new object[] { CreateSend(accessCount: 6, maxAccessCount: 5, emails: null, password: null, AuthType.None), typeof(NeverAuthenticate) };
yield return new object[] { CreateSend(accessCount: 0, maxAccessCount: 10, emails: "test@example.com", password: null, AuthType.Email), typeof(EmailOtp) };
yield return new object[] { CreateSend(accessCount: 0, maxAccessCount: 10, emails: "person@company.com", password: null, AuthType.Email), typeof(EmailOtp) };
yield return new object[] { CreateSend(accessCount: 0, maxAccessCount: 10, emails: null, password: "hashedpassword", AuthType.Password), typeof(ResourcePassword) };
yield return new object[] { CreateSend(accessCount: 0, maxAccessCount: 10, emails: null, password: null, AuthType.None), typeof(NotAuthenticated) };
}
public static IEnumerable<object[]> EmailParsingTestCases()
[Fact]
public async Task GetAuthenticationMethod_WithDisabledSend_ReturnsNeverAuthenticate()
{
yield return new object[] { "test@example.com", new[] { "test@example.com" } };
yield return new object[] { "test1@example.com,test2@example.com", new[] { "test1@example.com", "test2@example.com" } };
yield return new object[] { " test@example.com , other@example.com ", new[] { "test@example.com", "other@example.com" } };
yield return new object[] { "test@example.com,,other@example.com", new[] { "test@example.com", "other@example.com" } };
yield return new object[] { " , test@example.com, ,other@example.com, ", new[] { "test@example.com", "other@example.com" } };
// Arrange
var sendId = Guid.NewGuid();
var send = new Send
{
Id = sendId,
AccessCount = 0,
MaxAccessCount = 10,
Emails = "person@company.com",
Password = null,
AuthType = AuthType.Email,
Disabled = true,
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = null
};
_sendRepository.GetByIdAsync(sendId).Returns(send);
// Act
var result = await _sendAuthenticationQuery.GetAuthenticationMethod(sendId);
// Assert
Assert.IsType<NeverAuthenticate>(result);
}
[Fact]
public async Task GetAuthenticationMethod_WithExpiredSend_ReturnsNeverAuthenticate()
{
// Arrange
var sendId = Guid.NewGuid();
var send = new Send
{
Id = sendId,
AccessCount = 0,
MaxAccessCount = 10,
Emails = "person@company.com",
Password = null,
AuthType = AuthType.Email,
Disabled = false,
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = DateTime.UtcNow.AddDays(-1) // Expired yesterday
};
_sendRepository.GetByIdAsync(sendId).Returns(send);
// Act
var result = await _sendAuthenticationQuery.GetAuthenticationMethod(sendId);
// Assert
Assert.IsType<NeverAuthenticate>(result);
}
[Fact]
public async Task GetAuthenticationMethod_WithDeletionDatePassed_ReturnsNeverAuthenticate()
{
// Arrange
var sendId = Guid.NewGuid();
var send = new Send
{
Id = sendId,
AccessCount = 0,
MaxAccessCount = 10,
Emails = "person@company.com",
Password = null,
AuthType = AuthType.Email,
Disabled = false,
DeletionDate = DateTime.UtcNow.AddDays(-1), // Should have been deleted yesterday
ExpirationDate = null
};
_sendRepository.GetByIdAsync(sendId).Returns(send);
// Act
var result = await _sendAuthenticationQuery.GetAuthenticationMethod(sendId);
// Assert
Assert.IsType<NeverAuthenticate>(result);
}
[Fact]
public async Task GetAuthenticationMethod_WithDeletionDateEqualToNow_ReturnsNeverAuthenticate()
{
// Arrange
var sendId = Guid.NewGuid();
var now = DateTime.UtcNow;
var send = new Send
{
Id = sendId,
AccessCount = 0,
MaxAccessCount = 10,
Emails = "person@company.com",
Password = null,
AuthType = AuthType.Email,
Disabled = false,
DeletionDate = now, // DeletionDate <= DateTime.UtcNow
ExpirationDate = null
};
_sendRepository.GetByIdAsync(sendId).Returns(send);
// Act
var result = await _sendAuthenticationQuery.GetAuthenticationMethod(sendId);
// Assert
Assert.IsType<NeverAuthenticate>(result);
}
[Fact]
public async Task GetAuthenticationMethod_WithAccessCountEqualToMaxAccessCount_ReturnsNeverAuthenticate()
{
// Arrange
var sendId = Guid.NewGuid();
var send = new Send
{
Id = sendId,
AccessCount = 5,
MaxAccessCount = 5,
Emails = "person@company.com",
Password = null,
AuthType = AuthType.Email,
Disabled = false,
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = null
};
_sendRepository.GetByIdAsync(sendId).Returns(send);
// Act
var result = await _sendAuthenticationQuery.GetAuthenticationMethod(sendId);
// Assert
Assert.IsType<NeverAuthenticate>(result);
}
[Fact]
public async Task GetAuthenticationMethod_WithNullMaxAccessCount_DoesNotRestrictAccess()
{
// Arrange
var sendId = Guid.NewGuid();
var send = new Send
{
Id = sendId,
AccessCount = 1000,
MaxAccessCount = null, // No limit
Emails = "person@company.com",
Password = null,
AuthType = AuthType.Email,
Disabled = false,
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = null
};
_sendRepository.GetByIdAsync(sendId).Returns(send);
// Act
var result = await _sendAuthenticationQuery.GetAuthenticationMethod(sendId);
// Assert
Assert.IsType<EmailOtp>(result);
}
[Fact]
public async Task GetAuthenticationMethod_WithNullExpirationDate_DoesNotExpire()
{
// Arrange
var sendId = Guid.NewGuid();
var send = new Send
{
Id = sendId,
AccessCount = 0,
MaxAccessCount = 10,
Emails = "person@company.com",
Password = null,
AuthType = AuthType.Email,
Disabled = false,
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = null // No expiration
};
_sendRepository.GetByIdAsync(sendId).Returns(send);
// Act
var result = await _sendAuthenticationQuery.GetAuthenticationMethod(sendId);
// Assert
Assert.IsType<EmailOtp>(result);
}
public static IEnumerable<object[]> EmailsParsingTestCases()
{
yield return new object[] { "person@company.com", new[] { "person@company.com" } };
yield return new object[] { "person1@company.com,person2@company.com", new[] { "person1@company.com", "person2@company.com" } };
yield return new object[] { " person1@company.com , person2@company.com ", new[] { "person1@company.com", "person2@company.com" } };
yield return new object[] { "person1@company.com,,person2@company.com", new[] { "person1@company.com", "person2@company.com" } };
yield return new object[] { " , person1@company.com, ,person2@company.com, ", new[] { "person1@company.com", "person2@company.com" } };
}
private static Send CreateSend(int accessCount, int? maxAccessCount, string? emails, string? password, AuthType? authType)
@@ -131,7 +314,10 @@ public class SendAuthenticationQueryTests
MaxAccessCount = maxAccessCount,
Emails = emails,
Password = password,
AuthType = authType
AuthType = authType,
Disabled = false,
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = null
};
}
}

View File

@@ -0,0 +1,120 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Pricing.Premium;
using Bit.Core.Entities;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Tools.Services;
[SutProviderCustomize]
public class SendValidationServiceTests
{
[Theory, BitAutoData]
public async Task StorageRemainingForSendAsync_OrgGrantedPremiumUser_UsesPricingService(
SutProvider<SendValidationService> sutProvider,
Send send,
User user)
{
// Arrange
send.UserId = user.Id;
send.OrganizationId = null;
send.Type = SendType.File;
user.Premium = false;
user.Storage = 1024L * 1024L * 1024L; // 1 GB used
user.EmailVerified = true;
sutProvider.GetDependency<Bit.Core.Settings.GlobalSettings>().SelfHosted = false;
sutProvider.GetDependency<IUserRepository>().GetByIdAsync(user.Id).Returns(user);
sutProvider.GetDependency<IUserService>().CanAccessPremium(user).Returns(true);
var premiumPlan = new Plan
{
Storage = new Purchasable { Provided = 5 }
};
sutProvider.GetDependency<IPricingClient>().GetAvailablePremiumPlan().Returns(premiumPlan);
// Act
var result = await sutProvider.Sut.StorageRemainingForSendAsync(send);
// Assert
await sutProvider.GetDependency<IPricingClient>().Received(1).GetAvailablePremiumPlan();
Assert.True(result > 0);
}
[Theory, BitAutoData]
public async Task StorageRemainingForSendAsync_IndividualPremium_DoesNotCallPricingService(
SutProvider<SendValidationService> sutProvider,
Send send,
User user)
{
// Arrange
send.UserId = user.Id;
send.OrganizationId = null;
send.Type = SendType.File;
user.Premium = true;
user.MaxStorageGb = 10;
user.EmailVerified = true;
sutProvider.GetDependency<IUserRepository>().GetByIdAsync(user.Id).Returns(user);
sutProvider.GetDependency<IUserService>().CanAccessPremium(user).Returns(true);
// Act
var result = await sutProvider.Sut.StorageRemainingForSendAsync(send);
// Assert - should NOT call pricing service for individual premium users
await sutProvider.GetDependency<IPricingClient>().DidNotReceive().GetAvailablePremiumPlan();
}
[Theory, BitAutoData]
public async Task StorageRemainingForSendAsync_SelfHosted_DoesNotCallPricingService(
SutProvider<SendValidationService> sutProvider,
Send send,
User user)
{
// Arrange
send.UserId = user.Id;
send.OrganizationId = null;
send.Type = SendType.File;
user.Premium = false;
user.EmailVerified = true;
sutProvider.GetDependency<Bit.Core.Settings.GlobalSettings>().SelfHosted = true;
sutProvider.GetDependency<IUserRepository>().GetByIdAsync(user.Id).Returns(user);
sutProvider.GetDependency<IUserService>().CanAccessPremium(user).Returns(true);
// Act
var result = await sutProvider.Sut.StorageRemainingForSendAsync(send);
// Assert - should NOT call pricing service for self-hosted
await sutProvider.GetDependency<IPricingClient>().DidNotReceive().GetAvailablePremiumPlan();
}
[Theory, BitAutoData]
public async Task StorageRemainingForSendAsync_OrgSend_DoesNotCallPricingService(
SutProvider<SendValidationService> sutProvider,
Send send,
Organization org)
{
// Arrange
send.UserId = null;
send.OrganizationId = org.Id;
send.Type = SendType.File;
org.MaxStorageGb = 100;
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(org.Id).Returns(org);
// Act
var result = await sutProvider.Sut.StorageRemainingForSendAsync(send);
// Assert - should NOT call pricing service for org sends
await sutProvider.GetDependency<IPricingClient>().DidNotReceive().GetAvailablePremiumPlan();
}
}

View File

@@ -0,0 +1,84 @@
using Bit.Core.Utilities;
using Xunit;
namespace Bit.Core.Test.Utilities;
public class DomainNameValidatorAttributeTests
{
[Theory]
[InlineData("example.com")] // basic domain
[InlineData("sub.example.com")] // subdomain
[InlineData("sub.sub2.example.com")] // multiple subdomains
[InlineData("example-dash.com")] // domain with dash
[InlineData("123example.com")] // domain starting with number
[InlineData("example123.com")] // domain with numbers
[InlineData("e.com")] // short domain
[InlineData("very-long-subdomain-name.example.com")] // long subdomain
[InlineData("wörldé.com")] // unicode domain (IDN)
public void IsValid_ReturnsTrueWhenValid(string domainName)
{
var sut = new DomainNameValidatorAttribute();
var actual = sut.IsValid(domainName);
Assert.True(actual);
}
[Theory]
[InlineData("<script>alert('xss')</script>")] // XSS attempt
[InlineData("example.com<script>")] // XSS suffix
[InlineData("<img src=x>")] // HTML tag
[InlineData("example.com\t")] // trailing tab
[InlineData("\texample.com")] // leading tab
[InlineData("exam\tple.com")] // middle tab
[InlineData("example.com\n")] // newline
[InlineData("example.com\r")] // carriage return
[InlineData("example.com\b")] // backspace
[InlineData("exam ple.com")] // space in domain
[InlineData("example.com ")] // trailing space (after trim, becomes valid, but with space it's invalid)
[InlineData(" example.com")] // leading space (after trim, becomes valid, but with space it's invalid)
[InlineData("example&.com")] // ampersand
[InlineData("example'.com")] // single quote
[InlineData("example\".com")] // double quote
[InlineData(".example.com")] // starts with dot
[InlineData("example.com.")] // ends with dot
[InlineData("example..com")] // double dot
[InlineData("-example.com")] // starts with dash
[InlineData("example-.com")] // label ends with dash
[InlineData("")] // empty string
[InlineData(" ")] // whitespace only
[InlineData("http://example.com")] // URL scheme
[InlineData("example.com/path")] // path component
[InlineData("user@example.com")] // email format
public void IsValid_ReturnsFalseWhenInvalid(string domainName)
{
var sut = new DomainNameValidatorAttribute();
var actual = sut.IsValid(domainName);
Assert.False(actual);
}
[Fact]
public void IsValid_ReturnsTrueWhenNull()
{
var sut = new DomainNameValidatorAttribute();
var actual = sut.IsValid(null);
// Null validation should be handled by [Required] attribute
Assert.True(actual);
}
[Fact]
public void IsValid_ReturnsFalseWhenTooLong()
{
var sut = new DomainNameValidatorAttribute();
// Create a domain name longer than 253 characters
var longDomain = new string('a', 250) + ".com";
var actual = sut.IsValid(longDomain);
Assert.False(actual);
}
}

View File

@@ -98,7 +98,7 @@ public class EnumerationProtectionHelpersTests
var hmacKey = RandomNumberGenerator.GetBytes(32);
var salt1 = "user1@example.com";
var salt2 = "user2@example.com";
var range = 100;
var range = 10_000;
// Act
var result1 = EnumerationProtectionHelpers.GetIndexForInputHash(hmacKey, salt1, range);
@@ -117,7 +117,7 @@ public class EnumerationProtectionHelpersTests
var hmacKey1 = RandomNumberGenerator.GetBytes(32);
var hmacKey2 = RandomNumberGenerator.GetBytes(32);
var salt = "test@example.com";
var range = 100;
var range = 10_000;
// Act
var result1 = EnumerationProtectionHelpers.GetIndexForInputHash(hmacKey1, salt, range);

View File

@@ -74,8 +74,7 @@ public class LoggerFactoryExtensionsTests
logger.LogWarning("This is a test");
// Writing to the file is buffered, give it a little time to flush
await Task.Delay(5);
await provider.DisposeAsync();
var logFile = Assert.Single(tempDir.EnumerateFiles("Logs/*.log"));
@@ -90,13 +89,67 @@ public class LoggerFactoryExtensionsTests
logFileContents
);
}
[Fact]
public async Task AddSerilogFileLogging_LegacyConfig_WithLevelCustomization_InfoLogs_DoNotFillUpFile()
{
await AssertSmallFileAsync((tempDir, config) =>
{
config["GlobalSettings:LogDirectory"] = tempDir;
config["Logging:LogLevel:Microsoft.AspNetCore"] = "Warning";
});
}
[Fact]
public async Task AddSerilogFileLogging_NewConfig_WithLevelCustomization_InfoLogs_DoNotFillUpFile()
{
await AssertSmallFileAsync((tempDir, config) =>
{
config["Logging:PathFormat"] = Path.Combine(tempDir, "log.txt");
config["Logging:LogLevel:Microsoft.AspNetCore"] = "Warning";
});
}
private static async Task AssertSmallFileAsync(Action<string, Dictionary<string, string?>> configure)
{
using var tempDir = new TempDirectory();
var config = new Dictionary<string, string?>();
configure(tempDir.Directory, config);
var provider = GetServiceProvider(config, "Production");
var loggerFactory = provider.GetRequiredService<ILoggerFactory>();
var microsoftLogger = loggerFactory.CreateLogger("Microsoft.AspNetCore.Testing");
for (var i = 0; i < 100; i++)
{
microsoftLogger.LogInformation("Tons of useless information");
}
var otherLogger = loggerFactory.CreateLogger("Bitwarden");
for (var i = 0; i < 5; i++)
{
otherLogger.LogInformation("Mildly more useful information but not as frequent.");
}
await provider.DisposeAsync();
var logFiles = Directory.EnumerateFiles(tempDir.Directory, "*.txt", SearchOption.AllDirectories);
var logFile = Assert.Single(logFiles);
using var fr = File.OpenRead(logFile);
Assert.InRange(fr.Length, 0, 1024);
}
private static IEnumerable<ILoggerProvider> GetProviders(Dictionary<string, string?> initialData, string environment = "Production")
{
var provider = GetServiceProvider(initialData, environment);
return provider.GetServices<ILoggerProvider>();
}
private static IServiceProvider GetServiceProvider(Dictionary<string, string?> initialData, string environment)
private static ServiceProvider GetServiceProvider(Dictionary<string, string?> initialData, string environment)
{
var config = new ConfigurationBuilder()
.AddInMemoryCollection(initialData)

View File

@@ -6,6 +6,8 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Pricing.Premium;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
@@ -2228,10 +2230,6 @@ public class CipherServiceTests
.PushSyncCiphersAsync(deletingUserId);
}
[Theory]
[OrganizationCipherCustomize]
[BitAutoData]
@@ -2387,6 +2385,186 @@ public class CipherServiceTests
ids.Count() == cipherIds.Length && ids.All(id => cipherIds.Contains(id))));
}
[Theory, BitAutoData]
public async Task CreateAttachmentAsync_UserWithOrgGrantedPremium_UsesStorageFromPricingClient(
SutProvider<CipherService> sutProvider, CipherDetails cipher, Guid savingUserId)
{
var stream = new MemoryStream(new byte[100]);
var fileName = "test.txt";
var key = "test-key";
// Setup cipher with user ownership
cipher.UserId = savingUserId;
cipher.OrganizationId = null;
// Setup user WITHOUT personal premium (Premium = false), but with org-granted premium access
var user = new User
{
Id = savingUserId,
Premium = false, // User does not have personal premium
MaxStorageGb = null, // No personal storage allocation
Storage = 0 // No storage used yet
};
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(savingUserId)
.Returns(user);
// User has premium access through their organization
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(true);
// Mock GlobalSettings to indicate cloud (not self-hosted)
sutProvider.GetDependency<Bit.Core.Settings.GlobalSettings>().SelfHosted = false;
// Mock the PricingClient to return a premium plan with 1 GB of storage
var premiumPlan = new Plan
{
Name = "Premium",
Available = true,
Seat = new Purchasable { StripePriceId = "price_123", Price = 10, Provided = 1 },
Storage = new Purchasable { StripePriceId = "price_456", Price = 4, Provided = 1 }
};
sutProvider.GetDependency<IPricingClient>()
.GetAvailablePremiumPlan()
.Returns(premiumPlan);
sutProvider.GetDependency<IAttachmentStorageService>()
.UploadNewAttachmentAsync(Arg.Any<Stream>(), cipher, Arg.Any<CipherAttachment.MetaData>())
.Returns(Task.CompletedTask);
sutProvider.GetDependency<IAttachmentStorageService>()
.ValidateFileAsync(cipher, Arg.Any<CipherAttachment.MetaData>(), Arg.Any<long>())
.Returns((true, 100L));
sutProvider.GetDependency<ICipherRepository>()
.UpdateAttachmentAsync(Arg.Any<CipherAttachment>())
.Returns(Task.CompletedTask);
sutProvider.GetDependency<ICipherRepository>()
.ReplaceAsync(Arg.Any<CipherDetails>())
.Returns(Task.CompletedTask);
// Act
await sutProvider.Sut.CreateAttachmentAsync(cipher, stream, fileName, key, 100, savingUserId, false, cipher.RevisionDate);
// Assert - PricingClient was called to get the premium plan storage
await sutProvider.GetDependency<IPricingClient>().Received(1).GetAvailablePremiumPlan();
// Assert - Attachment was uploaded successfully
await sutProvider.GetDependency<IAttachmentStorageService>().Received(1)
.UploadNewAttachmentAsync(Arg.Any<Stream>(), cipher, Arg.Any<CipherAttachment.MetaData>());
}
[Theory, BitAutoData]
public async Task CreateAttachmentAsync_UserWithOrgGrantedPremium_ExceedsStorage_ThrowsBadRequest(
SutProvider<CipherService> sutProvider, CipherDetails cipher, Guid savingUserId)
{
var stream = new MemoryStream(new byte[100]);
var fileName = "test.txt";
var key = "test-key";
// Setup cipher with user ownership
cipher.UserId = savingUserId;
cipher.OrganizationId = null;
// Setup user WITHOUT personal premium, with org-granted access, but storage is full
var user = new User
{
Id = savingUserId,
Premium = false,
MaxStorageGb = null,
Storage = 1073741824 // 1 GB already used (equals the provided storage)
};
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(savingUserId)
.Returns(user);
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(true);
sutProvider.GetDependency<Bit.Core.Settings.GlobalSettings>().SelfHosted = false;
// Premium plan provides 1 GB of storage
var premiumPlan = new Plan
{
Name = "Premium",
Available = true,
Seat = new Purchasable { StripePriceId = "price_123", Price = 10, Provided = 1 },
Storage = new Purchasable { StripePriceId = "price_456", Price = 4, Provided = 1 }
};
sutProvider.GetDependency<IPricingClient>()
.GetAvailablePremiumPlan()
.Returns(premiumPlan);
// Act & Assert - Should throw because storage is full
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.CreateAttachmentAsync(cipher, stream, fileName, key, 100, savingUserId, false, cipher.RevisionDate));
Assert.Contains("Not enough storage available", exception.Message);
}
[Theory, BitAutoData]
public async Task CreateAttachmentAsync_UserWithOrgGrantedPremium_SelfHosted_UsesConstantStorage(
SutProvider<CipherService> sutProvider, CipherDetails cipher, Guid savingUserId)
{
var stream = new MemoryStream(new byte[100]);
var fileName = "test.txt";
var key = "test-key";
// Setup cipher with user ownership
cipher.UserId = savingUserId;
cipher.OrganizationId = null;
// Setup user WITHOUT personal premium, but with org-granted premium access
var user = new User
{
Id = savingUserId,
Premium = false,
MaxStorageGb = null,
Storage = 0
};
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(savingUserId)
.Returns(user);
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(true);
// Mock GlobalSettings to indicate self-hosted
sutProvider.GetDependency<Bit.Core.Settings.GlobalSettings>().SelfHosted = true;
sutProvider.GetDependency<IAttachmentStorageService>()
.UploadNewAttachmentAsync(Arg.Any<Stream>(), cipher, Arg.Any<CipherAttachment.MetaData>())
.Returns(Task.CompletedTask);
sutProvider.GetDependency<IAttachmentStorageService>()
.ValidateFileAsync(cipher, Arg.Any<CipherAttachment.MetaData>(), Arg.Any<long>())
.Returns((true, 100L));
sutProvider.GetDependency<ICipherRepository>()
.UpdateAttachmentAsync(Arg.Any<CipherAttachment>())
.Returns(Task.CompletedTask);
sutProvider.GetDependency<ICipherRepository>()
.ReplaceAsync(Arg.Any<CipherDetails>())
.Returns(Task.CompletedTask);
// Act
await sutProvider.Sut.CreateAttachmentAsync(cipher, stream, fileName, key, 100, savingUserId, false, cipher.RevisionDate);
// Assert - PricingClient should NOT be called for self-hosted
await sutProvider.GetDependency<IPricingClient>().DidNotReceive().GetAvailablePremiumPlan();
// Assert - Attachment was uploaded successfully
await sutProvider.GetDependency<IAttachmentStorageService>().Received(1)
.UploadNewAttachmentAsync(Arg.Any<Stream>(), cipher, Arg.Any<CipherAttachment.MetaData>());
}
private async Task AssertNoActionsAsync(SutProvider<CipherService> sutProvider)
{
await sutProvider.GetDependency<ICipherRepository>().DidNotReceiveWithAnyArgs().GetManyOrganizationDetailsByOrganizationIdAsync(default);