1
0
mirror of https://github.com/bitwarden/server synced 2025-12-24 12:13:17 +00:00

Merge branch 'main' into dbops/dbops-31/csv-import

This commit is contained in:
mkincaid-bw
2025-11-07 15:57:56 -08:00
committed by GitHub
38 changed files with 1990 additions and 256 deletions

View File

@@ -41,6 +41,10 @@
matchUpdateTypes: ["patch"],
dependencyDashboardApproval: false,
},
{
matchSourceUrls: ["https://github.com/bitwarden/sdk-internal"],
groupName: "sdk-internal",
},
{
matchManagers: ["dockerfile", "docker-compose"],
commitMessagePrefix: "[deps] BRE:",

View File

@@ -1,7 +1,4 @@
// FIXME: Update this file to be null safe and then delete the line below
#nullable disable
using System.Security.Claims;
using System.Security.Claims;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
@@ -167,6 +164,8 @@ public class AccountController : Controller
{
var context = await _interaction.GetAuthorizationContextAsync(returnUrl);
// FIXME: Update this file to be null safe and then delete the line below
#nullable disable
if (!context.Parameters.AllKeys.Contains("domain_hint") ||
string.IsNullOrWhiteSpace(context.Parameters["domain_hint"]))
{
@@ -182,6 +181,7 @@ public class AccountController : Controller
var domainHint = context.Parameters["domain_hint"];
var organization = await _organizationRepository.GetByIdentifierAsync(domainHint);
#nullable restore
if (organization == null)
{
@@ -263,30 +263,33 @@ public class AccountController : Controller
// See if the user has logged in with this SSO provider before and has already been provisioned.
// This is signified by the user existing in the User table and the SSOUser table for the SSO provider they're using.
var (user, provider, providerUserId, claims, ssoConfigData) = await FindUserFromExternalProviderAsync(result);
var (possibleSsoLinkedUser, provider, providerUserId, claims, ssoConfigData) = await FindUserFromExternalProviderAsync(result);
// We will look these up as required (lazy resolution) to avoid multiple DB hits.
Organization organization = null;
OrganizationUser orgUser = null;
Organization? organization = null;
OrganizationUser? orgUser = null;
// The user has not authenticated with this SSO provider before.
// They could have an existing Bitwarden account in the User table though.
if (user == null)
if (possibleSsoLinkedUser == null)
{
// FIXME: Update this file to be null safe and then delete the line below
#nullable disable
// If we're manually linking to SSO, the user's external identifier will be passed as query string parameter.
var userIdentifier = result.Properties.Items.Keys.Contains("user_identifier")
? result.Properties.Items["user_identifier"]
: null;
var (provisionedUser, foundOrganization, foundOrCreatedOrgUser) =
await AutoProvisionUserAsync(
var (resolvedUser, foundOrganization, foundOrCreatedOrgUser) =
await CreateUserAndOrgUserConditionallyAsync(
provider,
providerUserId,
claims,
userIdentifier,
ssoConfigData);
#nullable restore
user = provisionedUser;
possibleSsoLinkedUser = resolvedUser;
if (preventOrgUserLoginIfStatusInvalid)
{
@@ -297,9 +300,10 @@ public class AccountController : Controller
if (preventOrgUserLoginIfStatusInvalid)
{
if (user == null) throw new Exception(_i18nService.T("UserShouldBeFound"));
User resolvedSsoLinkedUser = possibleSsoLinkedUser
?? throw new Exception(_i18nService.T("UserShouldBeFound"));
await PreventOrgUserLoginIfStatusInvalidAsync(organization, provider, orgUser, user);
await PreventOrgUserLoginIfStatusInvalidAsync(organization, provider, orgUser, resolvedSsoLinkedUser);
// This allows us to collect any additional claims or properties
// for the specific protocols used and store them in the local auth cookie.
@@ -314,19 +318,20 @@ public class AccountController : Controller
// Issue authentication cookie for user
await HttpContext.SignInAsync(
new IdentityServerUser(user.Id.ToString())
new IdentityServerUser(resolvedSsoLinkedUser.Id.ToString())
{
DisplayName = user.Email,
DisplayName = resolvedSsoLinkedUser.Email,
IdentityProvider = provider,
AdditionalClaims = additionalLocalClaims.ToArray()
}, localSignInProps);
}
else
{
// PM-24579: remove this else block with feature flag removal.
// Either the user already authenticated with the SSO provider, or we've just provisioned them.
// Either way, we have associated the SSO login with a Bitwarden user.
// We will now sign the Bitwarden user in.
if (user != null)
if (possibleSsoLinkedUser != null)
{
// This allows us to collect any additional claims or properties
// for the specific protocols used and store them in the local auth cookie.
@@ -341,9 +346,9 @@ public class AccountController : Controller
// Issue authentication cookie for user
await HttpContext.SignInAsync(
new IdentityServerUser(user.Id.ToString())
new IdentityServerUser(possibleSsoLinkedUser.Id.ToString())
{
DisplayName = user.Email,
DisplayName = possibleSsoLinkedUser.Email,
IdentityProvider = provider,
AdditionalClaims = additionalLocalClaims.ToArray()
}, localSignInProps);
@@ -353,8 +358,11 @@ public class AccountController : Controller
// Delete temporary cookie used during external authentication
await HttpContext.SignOutAsync(AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme);
// FIXME: Update this file to be null safe and then delete the line below
#nullable disable
// Retrieve return URL
var returnUrl = result.Properties.Items["return_url"] ?? "~/";
#nullable restore
// Check if external login is in the context of an OIDC request
var context = await _interaction.GetAuthorizationContextAsync(returnUrl);
@@ -373,6 +381,8 @@ public class AccountController : Controller
return Redirect(returnUrl);
}
// FIXME: Update this file to be null safe and then delete the line below
#nullable disable
[HttpGet]
public async Task<IActionResult> LogoutAsync(string logoutId)
{
@@ -407,15 +417,22 @@ public class AccountController : Controller
return Redirect("~/");
}
}
#nullable restore
/// <summary>
/// Attempts to map the external identity to a Bitwarden user, through the SsoUser table, which holds the `externalId`.
/// The claims on the external identity are used to determine an `externalId`, and that is used to find the appropriate `SsoUser` and `User` records.
/// </summary>
private async Task<(User user, string provider, string providerUserId, IEnumerable<Claim> claims,
SsoConfigurationData config)>
FindUserFromExternalProviderAsync(AuthenticateResult result)
private async Task<(
User? possibleSsoUser,
string provider,
string providerUserId,
IEnumerable<Claim> claims,
SsoConfigurationData config
)> FindUserFromExternalProviderAsync(AuthenticateResult result)
{
// FIXME: Update this file to be null safe and then delete the line below
#nullable disable
var provider = result.Properties.Items["scheme"];
var orgId = new Guid(provider);
var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(orgId);
@@ -458,6 +475,7 @@ public class AccountController : Controller
externalUser.FindFirst("upn") ??
externalUser.FindFirst("eppn") ??
throw new Exception(_i18nService.T("UnknownUserId"));
#nullable restore
// Remove the user id claim so we don't include it as an extra claim if/when we provision the user
var claims = externalUser.Claims.ToList();
@@ -466,13 +484,15 @@ public class AccountController : Controller
// find external user
var providerUserId = userIdClaim.Value;
var user = await _userRepository.GetBySsoUserAsync(providerUserId, orgId);
var possibleSsoUser = await _userRepository.GetBySsoUserAsync(providerUserId, orgId);
return (user, provider, providerUserId, claims, ssoConfigData);
return (possibleSsoUser, provider, providerUserId, claims, ssoConfigData);
}
/// <summary>
/// Provision an SSO-linked Bitwarden user.
/// This function seeks to set up the org user record or create a new user record based on the conditions
/// below.
///
/// This handles three different scenarios:
/// 1. Creating an SsoUser link for an existing User and OrganizationUser
/// - User is a member of the organization, but hasn't authenticated with the org's SSO provider before.
@@ -488,8 +508,7 @@ public class AccountController : Controller
/// <param name="ssoConfigData">The SSO configuration for the organization.</param>
/// <returns>Guaranteed to return the user to sign in as well as the found organization and org user.</returns>
/// <exception cref="Exception">An exception if the user cannot be provisioned as requested.</exception>
private async Task<(User user, Organization foundOrganization, OrganizationUser foundOrgUser)>
AutoProvisionUserAsync(
private async Task<(User resolvedUser, Organization foundOrganization, OrganizationUser foundOrgUser)> CreateUserAndOrgUserConditionallyAsync(
string provider,
string providerUserId,
IEnumerable<Claim> claims,
@@ -497,10 +516,11 @@ public class AccountController : Controller
SsoConfigurationData ssoConfigData
)
{
// Try to get the email from the claims as we don't know if we have a user record yet.
var name = GetName(claims, ssoConfigData.GetAdditionalNameClaimTypes());
var email = TryGetEmailAddress(claims, ssoConfigData, providerUserId);
User existingUser = null;
User? possibleExistingUser;
if (string.IsNullOrWhiteSpace(userIdentifier))
{
if (string.IsNullOrWhiteSpace(email))
@@ -508,51 +528,74 @@ public class AccountController : Controller
throw new Exception(_i18nService.T("CannotFindEmailClaim"));
}
existingUser = await _userRepository.GetByEmailAsync(email);
possibleExistingUser = await _userRepository.GetByEmailAsync(email);
}
else
{
existingUser = await GetUserFromManualLinkingDataAsync(userIdentifier);
possibleExistingUser = await GetUserFromManualLinkingDataAsync(userIdentifier);
}
// Try to find the org (we error if we can't find an org)
var organization = await TryGetOrganizationByProviderAsync(provider);
// Find the org (we error if we can't find an org because no org is not valid)
var organization = await GetOrganizationByProviderAsync(provider);
// Try to find an org user (null org user possible and valid here)
var orgUser = await TryGetOrganizationUserByUserAndOrgOrEmail(existingUser, organization.Id, email);
var possibleOrgUser = await GetOrganizationUserByUserAndOrgIdOrEmailAsync(possibleExistingUser, organization.Id, email);
//----------------------------------------------------
// Scenario 1: We've found the user in the User table
//----------------------------------------------------
if (existingUser != null)
if (possibleExistingUser != null)
{
if (existingUser.UsesKeyConnector &&
(orgUser == null || orgUser.Status == OrganizationUserStatusType.Invited))
User guaranteedExistingUser = possibleExistingUser;
if (guaranteedExistingUser.UsesKeyConnector &&
(possibleOrgUser == null || possibleOrgUser.Status == OrganizationUserStatusType.Invited))
{
throw new Exception(_i18nService.T("UserAlreadyExistsKeyConnector"));
}
// If the user already exists in Bitwarden, we require that the user already be in the org,
// and that they are either Accepted or Confirmed.
if (orgUser == null)
OrganizationUser guaranteedOrgUser = possibleOrgUser ?? throw new Exception(_i18nService.T("UserAlreadyExistsInviteProcess"));
/*
* ----------------------------------------------------
* Critical Code Check Here
*
* We want to ensure a user is not in the invited state
* explicitly. User's in the invited state should not
* be able to authenticate via SSO.
*
* See internal doc called "Added Context for SSO Login
* Flows" for further details.
* ----------------------------------------------------
*/
if (guaranteedOrgUser.Status == OrganizationUserStatusType.Invited)
{
// Org User is not created - no invite has been sent
throw new Exception(_i18nService.T("UserAlreadyExistsInviteProcess"));
// Org User is invited must accept via email first
throw new Exception(
_i18nService.T("AcceptInviteBeforeUsingSSO", organization.DisplayName()));
}
EnsureAcceptedOrConfirmedOrgUserStatus(orgUser.Status, organization.DisplayName());
// If the user already exists in Bitwarden, we require that the user already be in the org,
// and that they are either Accepted or Confirmed.
EnforceAllowedOrgUserStatus(
guaranteedOrgUser.Status,
allowedStatuses: [
OrganizationUserStatusType.Accepted,
OrganizationUserStatusType.Confirmed
],
organization.DisplayName());
// Since we're in the auto-provisioning logic, this means that the user exists, but they have not
// authenticated with the org's SSO provider before now (otherwise we wouldn't be auto-provisioning them).
// We've verified that the user is Accepted or Confnirmed, so we can create an SsoUser link and proceed
// with authentication.
await CreateSsoUserRecordAsync(providerUserId, existingUser.Id, organization.Id, orgUser);
await CreateSsoUserRecordAsync(providerUserId, guaranteedExistingUser.Id, organization.Id, guaranteedOrgUser);
return (existingUser, organization, orgUser);
return (guaranteedExistingUser, organization, guaranteedOrgUser);
}
// Before any user creation - if Org User doesn't exist at this point - make sure there are enough seats to add one
if (orgUser == null && organization.Seats.HasValue)
if (possibleOrgUser == null && organization.Seats.HasValue)
{
var occupiedSeats =
await _organizationRepository.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id);
@@ -584,6 +627,11 @@ public class AccountController : Controller
}
// If the email domain is verified, we can mark the email as verified
if (string.IsNullOrWhiteSpace(email))
{
throw new Exception(_i18nService.T("CannotFindEmailClaim"));
}
var emailVerified = false;
var emailDomain = CoreHelpers.GetEmailDomain(email);
if (!string.IsNullOrWhiteSpace(emailDomain))
@@ -596,29 +644,29 @@ public class AccountController : Controller
//--------------------------------------------------
// Scenarios 2 and 3: We need to register a new user
//--------------------------------------------------
var user = new User
var newUser = new User
{
Name = name,
Email = email,
EmailVerified = emailVerified,
ApiKey = CoreHelpers.SecureRandomString(30)
};
await _registerUserCommand.RegisterUser(user);
await _registerUserCommand.RegisterUser(newUser);
// If the organization has 2fa policy enabled, make sure to default jit user 2fa to email
var twoFactorPolicy =
await _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.TwoFactorAuthentication);
if (twoFactorPolicy != null && twoFactorPolicy.Enabled)
{
user.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
newUser.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
{
[TwoFactorProviderType.Email] = new TwoFactorProvider
{
MetaData = new Dictionary<string, object> { ["Email"] = user.Email.ToLowerInvariant() },
MetaData = new Dictionary<string, object> { ["Email"] = newUser.Email.ToLowerInvariant() },
Enabled = true
}
});
await _userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.Email);
await _userService.UpdateTwoFactorProviderAsync(newUser, TwoFactorProviderType.Email);
}
//-----------------------------------------------------------------
@@ -626,16 +674,16 @@ public class AccountController : Controller
// This means that an invitation was not sent for this user and we
// need to establish their invited status now.
//-----------------------------------------------------------------
if (orgUser == null)
if (possibleOrgUser == null)
{
orgUser = new OrganizationUser
possibleOrgUser = new OrganizationUser
{
OrganizationId = organization.Id,
UserId = user.Id,
UserId = newUser.Id,
Type = OrganizationUserType.User,
Status = OrganizationUserStatusType.Invited
};
await _organizationUserRepository.CreateAsync(orgUser);
await _organizationUserRepository.CreateAsync(possibleOrgUser);
}
//-----------------------------------------------------------------
@@ -645,14 +693,14 @@ public class AccountController : Controller
//-----------------------------------------------------------------
else
{
orgUser.UserId = user.Id;
await _organizationUserRepository.ReplaceAsync(orgUser);
possibleOrgUser.UserId = newUser.Id;
await _organizationUserRepository.ReplaceAsync(possibleOrgUser);
}
// Create the SsoUser record to link the user to the SSO provider.
await CreateSsoUserRecordAsync(providerUserId, user.Id, organization.Id, orgUser);
await CreateSsoUserRecordAsync(providerUserId, newUser.Id, organization.Id, possibleOrgUser);
return (user, organization, orgUser);
return (newUser, organization, possibleOrgUser);
}
/// <summary>
@@ -666,23 +714,31 @@ public class AccountController : Controller
/// <exception cref="Exception">Thrown if the organization cannot be resolved from provider;
/// the organization user cannot be found; or the organization user status is not allowed.</exception>
private async Task PreventOrgUserLoginIfStatusInvalidAsync(
Organization organization,
Organization? organization,
string provider,
OrganizationUser orgUser,
OrganizationUser? orgUser,
User user)
{
// Lazily get organization if not already known
organization ??= await TryGetOrganizationByProviderAsync(provider);
organization ??= await GetOrganizationByProviderAsync(provider);
// Lazily get the org user if not already known
orgUser ??= await TryGetOrganizationUserByUserAndOrgOrEmail(
orgUser ??= await GetOrganizationUserByUserAndOrgIdOrEmailAsync(
user,
organization.Id,
user.Email);
if (orgUser != null)
{
EnsureAcceptedOrConfirmedOrgUserStatus(orgUser.Status, organization.DisplayName());
// Invited is allowed at this point because we know the user is trying to accept an org invite.
EnforceAllowedOrgUserStatus(
orgUser.Status,
allowedStatuses: [
OrganizationUserStatusType.Invited,
OrganizationUserStatusType.Accepted,
OrganizationUserStatusType.Confirmed,
],
organization.DisplayName());
}
else
{
@@ -690,9 +746,9 @@ public class AccountController : Controller
}
}
private async Task<User> GetUserFromManualLinkingDataAsync(string userIdentifier)
private async Task<User?> GetUserFromManualLinkingDataAsync(string userIdentifier)
{
User user = null;
User? user = null;
var split = userIdentifier.Split(",");
if (split.Length < 2)
{
@@ -728,7 +784,7 @@ public class AccountController : Controller
/// </summary>
/// <param name="provider">Org id string from SSO scheme property</param>
/// <exception cref="Exception">Errors if the provider string is not a valid org id guid or if the org cannot be found by the id.</exception>
private async Task<Organization> TryGetOrganizationByProviderAsync(string provider)
private async Task<Organization> GetOrganizationByProviderAsync(string provider)
{
if (!Guid.TryParse(provider, out var organizationId))
{
@@ -755,12 +811,12 @@ public class AccountController : Controller
/// <param name="organizationId">Organization id from the provider data.</param>
/// <param name="email">Email to use as a fallback in case of an invited user not in the Org Users
/// table yet.</param>
private async Task<OrganizationUser> TryGetOrganizationUserByUserAndOrgOrEmail(
User user,
private async Task<OrganizationUser?> GetOrganizationUserByUserAndOrgIdOrEmailAsync(
User? user,
Guid organizationId,
string email)
string? email)
{
OrganizationUser orgUser = null;
OrganizationUser? orgUser = null;
// Try to find OrgUser via existing User Id.
// This covers any OrganizationUser state after they have accepted an invite.
@@ -772,44 +828,40 @@ public class AccountController : Controller
// If no Org User found by Existing User Id - search all the organization's users via email.
// This covers users who are Invited but haven't accepted their invite yet.
orgUser ??= await _organizationUserRepository.GetByOrganizationEmailAsync(organizationId, email);
if (email != null)
{
orgUser ??= await _organizationUserRepository.GetByOrganizationEmailAsync(organizationId, email);
}
return orgUser;
}
private void EnsureAcceptedOrConfirmedOrgUserStatus(
OrganizationUserStatusType status,
string organizationDisplayName)
private void EnforceAllowedOrgUserStatus(
OrganizationUserStatusType statusToCheckAgainst,
OrganizationUserStatusType[] allowedStatuses,
string organizationDisplayNameForLogging)
{
// The only permissible org user statuses allowed.
OrganizationUserStatusType[] allowedStatuses =
[OrganizationUserStatusType.Accepted, OrganizationUserStatusType.Confirmed];
// if this status is one of the allowed ones, just return
if (allowedStatuses.Contains(status))
if (allowedStatuses.Contains(statusToCheckAgainst))
{
return;
}
// otherwise throw the appropriate exception
switch (status)
switch (statusToCheckAgainst)
{
case OrganizationUserStatusType.Invited:
// Org User is invited must accept via email first
throw new Exception(
_i18nService.T("AcceptInviteBeforeUsingSSO", organizationDisplayName));
case OrganizationUserStatusType.Revoked:
// Revoked users may not be (auto)provisioned
throw new Exception(
_i18nService.T("OrganizationUserAccessRevoked", organizationDisplayName));
_i18nService.T("OrganizationUserAccessRevoked", organizationDisplayNameForLogging));
default:
// anything else is “unknown”
throw new Exception(
_i18nService.T("OrganizationUserUnknownStatus", organizationDisplayName));
_i18nService.T("OrganizationUserUnknownStatus", organizationDisplayNameForLogging));
}
}
private IActionResult InvalidJson(string errorMessageKey, Exception ex = null)
private IActionResult InvalidJson(string errorMessageKey, Exception? ex = null)
{
Response.StatusCode = ex == null ? 400 : 500;
return Json(new ErrorResponseModel(_i18nService.T(errorMessageKey))
@@ -820,7 +872,7 @@ public class AccountController : Controller
});
}
private string TryGetEmailAddressFromClaims(IEnumerable<Claim> claims, IEnumerable<string> additionalClaimTypes)
private string? TryGetEmailAddressFromClaims(IEnumerable<Claim> claims, IEnumerable<string> additionalClaimTypes)
{
var filteredClaims = claims.Where(c => !string.IsNullOrWhiteSpace(c.Value) && c.Value.Contains("@"));
@@ -842,6 +894,8 @@ public class AccountController : Controller
return null;
}
// FIXME: Update this file to be null safe and then delete the line below
#nullable disable
private string GetName(IEnumerable<Claim> claims, IEnumerable<string> additionalClaimTypes)
{
var filteredClaims = claims.Where(c => !string.IsNullOrWhiteSpace(c.Value));
@@ -865,6 +919,7 @@ public class AccountController : Controller
return null;
}
#nullable restore
private async Task CreateSsoUserRecordAsync(string providerUserId, Guid userId, Guid orgId,
OrganizationUser orgUser)
@@ -886,6 +941,8 @@ public class AccountController : Controller
await _ssoUserRepository.CreateAsync(ssoUser);
}
// FIXME: Update this file to be null safe and then delete the line below
#nullable disable
private void ProcessLoginCallback(AuthenticateResult externalResult,
List<Claim> localClaims, AuthenticationProperties localSignInProps)
{
@@ -936,12 +993,13 @@ public class AccountController : Controller
return (logoutId, logout?.PostLogoutRedirectUri, externalAuthenticationScheme);
}
#nullable restore
/**
* Tries to get a user's email from the claims and SSO configuration data or the provider user id if
* the claims email extraction returns null.
*/
private string TryGetEmailAddress(
private string? TryGetEmailAddress(
IEnumerable<Claim> claims,
SsoConfigurationData config,
string providerUserId)

View File

@@ -74,17 +74,6 @@ public class AccountControllerTest
return resolvedAuthService;
}
private static void InvokeEnsureOrgUserStatusAllowed(
AccountController controller,
OrganizationUserStatusType status)
{
var method = typeof(AccountController).GetMethod(
"EnsureAcceptedOrConfirmedOrgUserStatus",
BindingFlags.Instance | BindingFlags.NonPublic);
Assert.NotNull(method);
method.Invoke(controller, [status, "Org"]);
}
private static AuthenticateResult BuildSuccessfulExternalAuth(Guid orgId, string providerUserId, string email)
{
var claims = new[]
@@ -241,82 +230,6 @@ public class AccountControllerTest
return counts;
}
[Theory, BitAutoData]
public void EnsureOrgUserStatusAllowed_AllowsAcceptedAndConfirmed(
SutProvider<AccountController> sutProvider)
{
// Arrange
sutProvider.GetDependency<II18nService>()
.T(Arg.Any<string>(), Arg.Any<object?[]>())
.Returns(ci => (string)ci[0]!);
// Act
var ex1 = Record.Exception(() =>
InvokeEnsureOrgUserStatusAllowed(sutProvider.Sut, OrganizationUserStatusType.Accepted));
var ex2 = Record.Exception(() =>
InvokeEnsureOrgUserStatusAllowed(sutProvider.Sut, OrganizationUserStatusType.Confirmed));
// Assert
Assert.Null(ex1);
Assert.Null(ex2);
}
[Theory, BitAutoData]
public void EnsureOrgUserStatusAllowed_Invited_ThrowsAcceptInvite(
SutProvider<AccountController> sutProvider)
{
// Arrange
sutProvider.GetDependency<II18nService>()
.T(Arg.Any<string>(), Arg.Any<object?[]>())
.Returns(ci => (string)ci[0]!);
// Act
var ex = Assert.Throws<TargetInvocationException>(() =>
InvokeEnsureOrgUserStatusAllowed(sutProvider.Sut, OrganizationUserStatusType.Invited));
// Assert
Assert.IsType<Exception>(ex.InnerException);
Assert.Equal("AcceptInviteBeforeUsingSSO", ex.InnerException!.Message);
}
[Theory, BitAutoData]
public void EnsureOrgUserStatusAllowed_Revoked_ThrowsAccessRevoked(
SutProvider<AccountController> sutProvider)
{
// Arrange
sutProvider.GetDependency<II18nService>()
.T(Arg.Any<string>(), Arg.Any<object?[]>())
.Returns(ci => (string)ci[0]!);
// Act
var ex = Assert.Throws<TargetInvocationException>(() =>
InvokeEnsureOrgUserStatusAllowed(sutProvider.Sut, OrganizationUserStatusType.Revoked));
// Assert
Assert.IsType<Exception>(ex.InnerException);
Assert.Equal("OrganizationUserAccessRevoked", ex.InnerException!.Message);
}
[Theory, BitAutoData]
public void EnsureOrgUserStatusAllowed_UnknownStatus_ThrowsUnknown(
SutProvider<AccountController> sutProvider)
{
// Arrange
sutProvider.GetDependency<II18nService>()
.T(Arg.Any<string>(), Arg.Any<object?[]>())
.Returns(ci => (string)ci[0]!);
var unknown = (OrganizationUserStatusType)999;
// Act
var ex = Assert.Throws<TargetInvocationException>(() =>
InvokeEnsureOrgUserStatusAllowed(sutProvider.Sut, unknown));
// Assert
Assert.IsType<Exception>(ex.InnerException);
Assert.Equal("OrganizationUserUnknownStatus", ex.InnerException!.Message);
}
[Theory, BitAutoData]
public async Task ExternalCallback_PreventNonCompliantTrue_ExistingUser_NoOrgUser_ThrowsCouldNotFindOrganizationUser(
SutProvider<AccountController> sutProvider)
@@ -357,7 +270,7 @@ public class AccountControllerTest
}
[Theory, BitAutoData]
public async Task ExternalCallback_PreventNonCompliantTrue_ExistingUser_OrgUserInvited_ThrowsAcceptInvite(
public async Task ExternalCallback_PreventNonCompliantTrue_ExistingUser_OrgUserInvited_AllowsLogin(
SutProvider<AccountController> sutProvider)
{
// Arrange
@@ -374,7 +287,7 @@ public class AccountControllerTest
};
var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, user.Email!);
SetupHttpContextWithAuth(sutProvider, authResult);
var authService = SetupHttpContextWithAuth(sutProvider, authResult);
sutProvider.GetDependency<II18nService>()
.T(Arg.Any<string>(), Arg.Any<object?[]>())
@@ -392,9 +305,23 @@ public class AccountControllerTest
sutProvider.GetDependency<IIdentityServerInteractionService>()
.GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null);
// Act + Assert
var ex = await Assert.ThrowsAsync<Exception>(() => sutProvider.Sut.ExternalCallback());
Assert.Equal("AcceptInviteBeforeUsingSSO", ex.Message);
// Act
var result = await sutProvider.Sut.ExternalCallback();
// Assert
var redirect = Assert.IsType<RedirectResult>(result);
Assert.Equal("~/", redirect.Url);
await authService.Received().SignInAsync(
Arg.Any<HttpContext>(),
Arg.Any<string?>(),
Arg.Any<ClaimsPrincipal>(),
Arg.Any<AuthenticationProperties>());
await authService.Received().SignOutAsync(
Arg.Any<HttpContext>(),
AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme,
Arg.Any<AuthenticationProperties>());
}
[Theory, BitAutoData]
@@ -930,13 +857,13 @@ public class AccountControllerTest
}
[Theory, BitAutoData]
public async Task AutoProvisionUserAsync_WithExistingAcceptedUser_CreatesSsoLinkAndReturnsUser(
public async Task CreateUserAndOrgUserConditionallyAsync_WithExistingAcceptedUser_CreatesSsoLinkAndReturnsUser(
SutProvider<AccountController> sutProvider)
{
// Arrange
var orgId = Guid.NewGuid();
var providerUserId = "ext-456";
var email = "jit@example.com";
var providerUserId = "provider-user-id";
var email = "user@example.com";
var existingUser = new User { Id = Guid.NewGuid(), Email = email };
var organization = new Organization { Id = orgId, Name = "Org" };
var orgUser = new OrganizationUser
@@ -965,12 +892,12 @@ public class AccountControllerTest
var config = new SsoConfigurationData();
var method = typeof(AccountController).GetMethod(
"AutoProvisionUserAsync",
"CreateUserAndOrgUserConditionallyAsync",
BindingFlags.Instance | BindingFlags.NonPublic);
Assert.NotNull(method);
// Act
var task = (Task<(User user, Organization organization, OrganizationUser orgUser)>)method!.Invoke(sutProvider.Sut, new object[]
var task = (Task<(User user, Organization organization, OrganizationUser orgUser)>)method.Invoke(sutProvider.Sut, new object[]
{
orgId.ToString(),
providerUserId,
@@ -992,6 +919,61 @@ public class AccountControllerTest
EventType.OrganizationUser_FirstSsoLogin);
}
[Theory, BitAutoData]
public async Task CreateUserAndOrgUserConditionallyAsync_WithExistingInvitedUser_ThrowsAcceptInviteBeforeUsingSSO(
SutProvider<AccountController> sutProvider)
{
// Arrange
var orgId = Guid.NewGuid();
var providerUserId = "provider-user-id";
var email = "user@example.com";
var existingUser = new User { Id = Guid.NewGuid(), Email = email, UsesKeyConnector = false };
var organization = new Organization { Id = orgId, Name = "Org" };
var orgUser = new OrganizationUser
{
OrganizationId = orgId,
UserId = existingUser.Id,
Status = OrganizationUserStatusType.Invited,
Type = OrganizationUserType.User
};
// i18n returns the key so we can assert on message contents
sutProvider.GetDependency<II18nService>()
.T(Arg.Any<string>(), Arg.Any<object?[]>())
.Returns(ci => (string)ci[0]!);
// Arrange repository expectations for the flow
sutProvider.GetDependency<IUserRepository>().GetByEmailAsync(email).Returns(existingUser);
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(orgId).Returns(organization);
sutProvider.GetDependency<IOrganizationUserRepository>().GetManyByUserAsync(existingUser.Id)
.Returns(new List<OrganizationUser> { orgUser });
var claims = new[]
{
new Claim(JwtClaimTypes.Email, email),
new Claim(JwtClaimTypes.Name, "Invited User")
} as IEnumerable<Claim>;
var config = new SsoConfigurationData();
var method = typeof(AccountController).GetMethod(
"CreateUserAndOrgUserConditionallyAsync",
BindingFlags.Instance | BindingFlags.NonPublic);
Assert.NotNull(method);
// Act + Assert
var task = (Task<(User user, Organization organization, OrganizationUser orgUser)>)method.Invoke(sutProvider.Sut, new object[]
{
orgId.ToString(),
providerUserId,
claims,
null!,
config
})!;
var ex = await Assert.ThrowsAsync<Exception>(async () => await task);
Assert.Equal("AcceptInviteBeforeUsingSSO", ex.Message);
}
/// <summary>
/// PM-24579: Temporary comparison test to ensure the feature flag ON does not
/// regress lookup counts compared to OFF. When removing the flag, delete this

View File

@@ -12,6 +12,7 @@ using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.Models.Business.Tokenables;
using Bit.Core.Context;
@@ -41,8 +42,9 @@ public class PoliciesController : Controller
private readonly IDataProtectorTokenFactory<OrgUserInviteTokenable> _orgUserInviteTokenDataFactory;
private readonly IPolicyRepository _policyRepository;
private readonly IUserService _userService;
private readonly IFeatureService _featureService;
private readonly ISavePolicyCommand _savePolicyCommand;
private readonly IVNextSavePolicyCommand _vNextSavePolicyCommand;
public PoliciesController(IPolicyRepository policyRepository,
IOrganizationUserRepository organizationUserRepository,
@@ -53,7 +55,9 @@ public class PoliciesController : Controller
IDataProtectorTokenFactory<OrgUserInviteTokenable> orgUserInviteTokenDataFactory,
IOrganizationHasVerifiedDomainsQuery organizationHasVerifiedDomainsQuery,
IOrganizationRepository organizationRepository,
ISavePolicyCommand savePolicyCommand)
IFeatureService featureService,
ISavePolicyCommand savePolicyCommand,
IVNextSavePolicyCommand vNextSavePolicyCommand)
{
_policyRepository = policyRepository;
_organizationUserRepository = organizationUserRepository;
@@ -65,7 +69,9 @@ public class PoliciesController : Controller
_organizationRepository = organizationRepository;
_orgUserInviteTokenDataFactory = orgUserInviteTokenDataFactory;
_organizationHasVerifiedDomainsQuery = organizationHasVerifiedDomainsQuery;
_featureService = featureService;
_savePolicyCommand = savePolicyCommand;
_vNextSavePolicyCommand = vNextSavePolicyCommand;
}
[HttpGet("{type}")]
@@ -221,7 +227,9 @@ public class PoliciesController : Controller
{
var savePolicyRequest = await model.ToSavePolicyModelAsync(orgId, _currentContext);
var policy = await _savePolicyCommand.VNextSaveAsync(savePolicyRequest);
var policy = _featureService.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor) ?
await _vNextSavePolicyCommand.SaveAsync(savePolicyRequest) :
await _savePolicyCommand.VNextSaveAsync(savePolicyRequest);
return new PolicyResponseModel(policy);
}

View File

@@ -5,11 +5,15 @@ using System.Net;
using Bit.Api.AdminConsole.Public.Models.Request;
using Bit.Api.AdminConsole.Public.Models.Response;
using Bit.Api.Models.Public.Response;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Context;
using Bit.Core.Services;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;
@@ -22,18 +26,24 @@ public class PoliciesController : Controller
private readonly IPolicyRepository _policyRepository;
private readonly IPolicyService _policyService;
private readonly ICurrentContext _currentContext;
private readonly IFeatureService _featureService;
private readonly ISavePolicyCommand _savePolicyCommand;
private readonly IVNextSavePolicyCommand _vNextSavePolicyCommand;
public PoliciesController(
IPolicyRepository policyRepository,
IPolicyService policyService,
ICurrentContext currentContext,
ISavePolicyCommand savePolicyCommand)
IFeatureService featureService,
ISavePolicyCommand savePolicyCommand,
IVNextSavePolicyCommand vNextSavePolicyCommand)
{
_policyRepository = policyRepository;
_policyService = policyService;
_currentContext = currentContext;
_featureService = featureService;
_savePolicyCommand = savePolicyCommand;
_vNextSavePolicyCommand = vNextSavePolicyCommand;
}
/// <summary>
@@ -87,8 +97,17 @@ public class PoliciesController : Controller
[ProducesResponseType((int)HttpStatusCode.NotFound)]
public async Task<IActionResult> Put(PolicyType type, [FromBody] PolicyUpdateRequestModel model)
{
var policyUpdate = model.ToPolicyUpdate(_currentContext.OrganizationId!.Value, type);
var policy = await _savePolicyCommand.SaveAsync(policyUpdate);
Policy policy;
if (_featureService.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor))
{
var savePolicyModel = model.ToSavePolicyModel(_currentContext.OrganizationId!.Value, type);
policy = await _vNextSavePolicyCommand.SaveAsync(savePolicyModel);
}
else
{
var policyUpdate = model.ToPolicyUpdate(_currentContext.OrganizationId!.Value, type);
policy = await _savePolicyCommand.SaveAsync(policyUpdate);
}
var response = new PolicyResponseModel(policy);
return new JsonResult(response);

View File

@@ -8,6 +8,8 @@ namespace Bit.Api.AdminConsole.Public.Models.Request;
public class PolicyUpdateRequestModel : PolicyBaseModel
{
public Dictionary<string, object>? Metadata { get; set; }
public PolicyUpdate ToPolicyUpdate(Guid organizationId, PolicyType type)
{
var serializedData = PolicyDataValidator.ValidateAndSerialize(Data, type);
@@ -21,4 +23,22 @@ public class PolicyUpdateRequestModel : PolicyBaseModel
PerformedBy = new SystemUser(EventSystemUser.PublicApi)
};
}
public SavePolicyModel ToSavePolicyModel(Guid organizationId, PolicyType type)
{
var serializedData = PolicyDataValidator.ValidateAndSerialize(Data, type);
var policyUpdate = new PolicyUpdate
{
Type = type,
OrganizationId = organizationId,
Data = serializedData,
Enabled = Enabled.GetValueOrDefault()
};
var performedBy = new SystemUser(EventSystemUser.PublicApi);
var metadata = PolicyDataValidator.ValidateAndDeserializeMetadata(Metadata, type);
return new SavePolicyModel(policyUpdate, performedBy, metadata);
}
}

View File

@@ -2,18 +2,22 @@
#nullable disable
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Payment.Queries;
using Bit.Core.Billing.Pricing;
using Bit.Core.Entities;
using Bit.Core.Models.Mail.UpdatedInvoiceIncoming;
using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces;
using Bit.Core.Platform.Mail.Mailer;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Stripe;
using static Bit.Core.Billing.Constants.StripeConstants;
using Event = Stripe.Event;
namespace Bit.Billing.Services.Implementations;
@@ -29,7 +33,9 @@ public class UpcomingInvoiceHandler(
IStripeEventService stripeEventService,
IStripeEventUtilityService stripeEventUtilityService,
IUserRepository userRepository,
IValidateSponsorshipCommand validateSponsorshipCommand)
IValidateSponsorshipCommand validateSponsorshipCommand,
IMailer mailer,
IFeatureService featureService)
: IUpcomingInvoiceHandler
{
public async Task HandleAsync(Event parsedEvent)
@@ -37,7 +43,8 @@ public class UpcomingInvoiceHandler(
var invoice = await stripeEventService.GetInvoice(parsedEvent);
var customer =
await stripeFacade.GetCustomer(invoice.CustomerId, new CustomerGetOptions { Expand = ["subscriptions", "tax", "tax_ids"] });
await stripeFacade.GetCustomer(invoice.CustomerId,
new CustomerGetOptions { Expand = ["subscriptions", "tax", "tax_ids"] });
var subscription = customer.Subscriptions.FirstOrDefault();
@@ -68,7 +75,8 @@ public class UpcomingInvoiceHandler(
if (stripeEventUtilityService.IsSponsoredSubscription(subscription))
{
var sponsorshipIsValid = await validateSponsorshipCommand.ValidateSponsorshipAsync(organizationId.Value);
var sponsorshipIsValid =
await validateSponsorshipCommand.ValidateSponsorshipAsync(organizationId.Value);
if (!sponsorshipIsValid)
{
@@ -122,9 +130,17 @@ public class UpcomingInvoiceHandler(
}
}
var milestone2Feature = featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2);
if (milestone2Feature)
{
await UpdateSubscriptionItemPriceIdAsync(parsedEvent, subscription, user);
}
if (user.Premium)
{
await SendUpcomingInvoiceEmailsAsync(new List<string> { user.Email }, invoice);
await (milestone2Feature
? SendUpdatedUpcomingInvoiceEmailsAsync(new List<string> { user.Email })
: SendUpcomingInvoiceEmailsAsync(new List<string> { user.Email }, invoice));
}
}
else if (providerId.HasValue)
@@ -142,6 +158,39 @@ public class UpcomingInvoiceHandler(
}
}
private async Task UpdateSubscriptionItemPriceIdAsync(Event parsedEvent, Subscription subscription, User user)
{
var pricingItem =
subscription.Items.FirstOrDefault(i => i.Price.Id == Prices.PremiumAnnually);
if (pricingItem != null)
{
try
{
var plan = await pricingClient.GetAvailablePremiumPlan();
await stripeFacade.UpdateSubscription(subscription.Id,
new SubscriptionUpdateOptions
{
Items =
[
new SubscriptionItemOptions { Id = pricingItem.Id, Price = plan.Seat.StripePriceId }
],
Discounts =
[
new SubscriptionDiscountOptions { Coupon = CouponIDs.Milestone2SubscriptionDiscount }
]
});
}
catch (Exception exception)
{
logger.LogError(
exception,
"Failed to update user's ({UserID}) subscription price id while processing event with ID {EventID}",
user.Id,
parsedEvent.Id);
}
}
}
private async Task SendUpcomingInvoiceEmailsAsync(IEnumerable<string> emails, Invoice invoice)
{
var validEmails = emails.Where(e => !string.IsNullOrEmpty(e));
@@ -159,7 +208,19 @@ public class UpcomingInvoiceHandler(
}
}
private async Task SendProviderUpcomingInvoiceEmailsAsync(IEnumerable<string> emails, Invoice invoice, Subscription subscription, Guid providerId)
private async Task SendUpdatedUpcomingInvoiceEmailsAsync(IEnumerable<string> emails)
{
var validEmails = emails.Where(e => !string.IsNullOrEmpty(e));
var updatedUpcomingEmail = new UpdatedInvoiceUpcomingMail
{
ToEmails = validEmails,
View = new UpdatedInvoiceUpcomingView()
};
await mailer.SendEmail(updatedUpcomingEmail);
}
private async Task SendProviderUpcomingInvoiceEmailsAsync(IEnumerable<string> emails, Invoice invoice,
Subscription subscription, Guid providerId)
{
var validEmails = emails.Where(e => !string.IsNullOrEmpty(e));
@@ -205,12 +266,12 @@ public class UpcomingInvoiceHandler(
organization.PlanType.GetProductTier() != ProductTierType.Families &&
customer.Address.Country != Core.Constants.CountryAbbreviations.UnitedStates;
if (nonUSBusinessUse && customer.TaxExempt != StripeConstants.TaxExempt.Reverse)
if (nonUSBusinessUse && customer.TaxExempt != TaxExempt.Reverse)
{
try
{
await stripeFacade.UpdateCustomer(subscription.CustomerId,
new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse });
new CustomerUpdateOptions { TaxExempt = TaxExempt.Reverse });
}
catch (Exception exception)
{
@@ -250,12 +311,12 @@ public class UpcomingInvoiceHandler(
string eventId)
{
if (customer.Address.Country != Core.Constants.CountryAbbreviations.UnitedStates &&
customer.TaxExempt != StripeConstants.TaxExempt.Reverse)
customer.TaxExempt != TaxExempt.Reverse)
{
try
{
await stripeFacade.UpdateCustomer(subscription.CustomerId,
new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse });
new CustomerUpdateOptions { TaxExempt = TaxExempt.Reverse });
}
catch (Exception exception)
{

View File

@@ -6,6 +6,7 @@ using Bit.Core.AdminConsole.Models.Data;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
using Bit.Core.Context;
using Bit.Core.Entities;
using Bit.Core.Enums;
@@ -24,7 +25,9 @@ public class VerifyOrganizationDomainCommand(
IEventService eventService,
IGlobalSettings globalSettings,
ICurrentContext currentContext,
IFeatureService featureService,
ISavePolicyCommand savePolicyCommand,
IVNextSavePolicyCommand vNextSavePolicyCommand,
IMailService mailService,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository,
@@ -131,15 +134,26 @@ public class VerifyOrganizationDomainCommand(
await SendVerifiedDomainUserEmailAsync(domain);
}
private async Task EnableSingleOrganizationPolicyAsync(Guid organizationId, IActingUser actingUser) =>
await savePolicyCommand.SaveAsync(
new PolicyUpdate
{
OrganizationId = organizationId,
Type = PolicyType.SingleOrg,
Enabled = true,
PerformedBy = actingUser
});
private async Task EnableSingleOrganizationPolicyAsync(Guid organizationId, IActingUser actingUser)
{
var policyUpdate = new PolicyUpdate
{
OrganizationId = organizationId,
Type = PolicyType.SingleOrg,
Enabled = true,
PerformedBy = actingUser
};
if (featureService.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor))
{
var savePolicyModel = new SavePolicyModel(policyUpdate, actingUser);
await vNextSavePolicyCommand.SaveAsync(savePolicyModel);
}
else
{
await savePolicyCommand.SaveAsync(policyUpdate);
}
}
private async Task SendVerifiedDomainUserEmailAsync(OrganizationDomain domain)
{

View File

@@ -5,4 +5,18 @@ namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
public record SavePolicyModel(PolicyUpdate PolicyUpdate, IActingUser? PerformedBy, IPolicyMetadataModel Metadata)
{
public SavePolicyModel(PolicyUpdate PolicyUpdate)
: this(PolicyUpdate, null, new EmptyMetadataModel())
{
}
public SavePolicyModel(PolicyUpdate PolicyUpdate, IActingUser performedBy)
: this(PolicyUpdate, performedBy, new EmptyMetadataModel())
{
}
public SavePolicyModel(PolicyUpdate PolicyUpdate, IPolicyMetadataModel metadata)
: this(PolicyUpdate, null, metadata)
{
}
}

View File

@@ -3,9 +3,11 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Enums;
@@ -24,7 +26,9 @@ public class SsoConfigService : ISsoConfigService
private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IEventService _eventService;
private readonly IFeatureService _featureService;
private readonly ISavePolicyCommand _savePolicyCommand;
private readonly IVNextSavePolicyCommand _vNextSavePolicyCommand;
public SsoConfigService(
ISsoConfigRepository ssoConfigRepository,
@@ -32,14 +36,18 @@ public class SsoConfigService : ISsoConfigService
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IEventService eventService,
ISavePolicyCommand savePolicyCommand)
IFeatureService featureService,
ISavePolicyCommand savePolicyCommand,
IVNextSavePolicyCommand vNextSavePolicyCommand)
{
_ssoConfigRepository = ssoConfigRepository;
_policyRepository = policyRepository;
_organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository;
_eventService = eventService;
_featureService = featureService;
_savePolicyCommand = savePolicyCommand;
_vNextSavePolicyCommand = vNextSavePolicyCommand;
}
public async Task SaveAsync(SsoConfig config, Organization organization)
@@ -67,13 +75,12 @@ public class SsoConfigService : ISsoConfigService
// Automatically enable account recovery, SSO required, and single org policies if trusted device encryption is selected
if (config.GetData().MemberDecryptionType == MemberDecryptionType.TrustedDeviceEncryption)
{
await _savePolicyCommand.SaveAsync(new()
var singleOrgPolicy = new PolicyUpdate
{
OrganizationId = config.OrganizationId,
Type = PolicyType.SingleOrg,
Enabled = true
});
};
var resetPasswordPolicy = new PolicyUpdate
{
@@ -82,14 +89,27 @@ public class SsoConfigService : ISsoConfigService
Enabled = true,
};
resetPasswordPolicy.SetDataModel(new ResetPasswordDataModel { AutoEnrollEnabled = true });
await _savePolicyCommand.SaveAsync(resetPasswordPolicy);
await _savePolicyCommand.SaveAsync(new()
var requireSsoPolicy = new PolicyUpdate
{
OrganizationId = config.OrganizationId,
Type = PolicyType.RequireSso,
Enabled = true
});
};
if (_featureService.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor))
{
var performedBy = new SystemUser(EventSystemUser.Unknown);
await _vNextSavePolicyCommand.SaveAsync(new SavePolicyModel(singleOrgPolicy, performedBy));
await _vNextSavePolicyCommand.SaveAsync(new SavePolicyModel(resetPasswordPolicy, performedBy));
await _vNextSavePolicyCommand.SaveAsync(new SavePolicyModel(requireSsoPolicy, performedBy));
}
else
{
await _savePolicyCommand.SaveAsync(singleOrgPolicy);
await _savePolicyCommand.SaveAsync(resetPasswordPolicy);
await _savePolicyCommand.SaveAsync(requireSsoPolicy);
}
}
await LogEventsAsync(config, oldConfig);

View File

@@ -22,6 +22,7 @@ public static class StripeConstants
{
public const string LegacyMSPDiscount = "msp-discount-35";
public const string SecretsManagerStandalone = "sm-standalone";
public const string Milestone2SubscriptionDiscount = "cm3nHfO1";
public static class MSPDiscounts
{

View File

@@ -138,11 +138,16 @@ public static class FeatureFlagKeys
public const string PolicyRequirements = "pm-14439-policy-requirements";
public const string ScimInviteUserOptimization = "pm-16811-optimize-invite-user-flow-to-fail-fast";
public const string EventBasedOrganizationIntegrations = "event-based-organization-integrations";
public const string SeparateCustomRolePermissions = "pm-19917-separate-custom-role-permissions";
public const string CreateDefaultLocation = "pm-19467-create-default-location";
public const string AutomaticConfirmUsers = "pm-19934-auto-confirm-organization-users";
public const string PM23845_VNextApplicationCache = "pm-24957-refactor-memory-application-cache";
public const string AccountRecoveryCommand = "pm-25581-prevent-provider-account-recovery";
public const string PolicyValidatorsRefactor = "pm-26423-refactor-policy-side-effects";
/* Architecture */
public const string DesktopMigrationMilestone1 = "desktop-ui-migration-milestone-1";
public const string DesktopMigrationMilestone2 = "desktop-ui-migration-milestone-2";
public const string DesktopMigrationMilestone3 = "desktop-ui-migration-milestone-3";
/* Auth Team */
public const string TwoFactorExtensionDataPersistence = "pm-9115-two-factor-extension-data-persistence";

View File

@@ -0,0 +1,27 @@
<mjml>
<mj-head>
<mj-include path="../components/head.mjml" />
</mj-head>
<mj-body background-color="#f6f6f6">
<mj-include path="../components/logo.mjml" />
<mj-wrapper
background-color="#fff"
border="1px solid #e9e9e9"
css-class="border-fix"
padding="0"
>
<mj-section>
<mj-column>
<mj-text>
Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nunc semper sapien non sem tincidunt pretium ut vitae tortor. Mauris mattis id arcu in dictum. Vivamus tempor maximus elit id convallis. Pellentesque ligula nisl, bibendum eu maximus sit amet, rutrum efficitur tortor. Cras non dignissim leo, eget gravida odio. Nullam tincidunt porta fermentum. Fusce sit amet sagittis nunc.
</mj-text>
</mj-column>
</mj-section>
</mj-wrapper>
<mj-include path="../components/footer.mjml" />
</mj-body>
</mjml>

View File

@@ -0,0 +1,10 @@
using Bit.Core.Platform.Mail.Mailer;
namespace Bit.Core.Models.Mail.UpdatedInvoiceIncoming;
public class UpdatedInvoiceUpcomingView : BaseMailView;
public class UpdatedInvoiceUpcomingMail : BaseMail<UpdatedInvoiceUpcomingView>
{
public override string Subject { get => "Your Subscription Will Renew Soon"; }
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,3 @@
{{#>BasicTextLayout}}
Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nunc semper sapien non sem tincidunt pretium ut vitae tortor. Mauris mattis id arcu in dictum. Vivamus tempor maximus elit id convallis. Pellentesque ligula nisl, bibendum eu maximus sit amet, rutrum efficitur tortor. Cras non dignissim leo, eget gravida odio. Nullam tincidunt porta fermentum. Fusce sit amet sagittis nunc.
{{/BasicTextLayout}}

View File

@@ -641,7 +641,7 @@ public class StripePaymentService : IPaymentService
}
var subscription = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId,
new SubscriptionGetOptions { Expand = ["customer", "discounts", "test_clock"] });
new SubscriptionGetOptions { Expand = ["customer.discount.coupon.applies_to", "discounts.coupon.applies_to", "test_clock"] });
subscriptionInfo.Subscription = new SubscriptionInfo.BillingSubscription(subscription);

View File

@@ -35,4 +35,10 @@ public interface ISecurityTaskRepository : IRepository<SecurityTask, Guid>
/// <param name="organizationId">The id of the organization</param>
/// <returns>A collection of security task metrics</returns>
Task<SecurityTaskMetrics> GetTaskMetricsAsync(Guid organizationId);
/// <summary>
/// Marks all tasks associated with the respective ciphers as complete.
/// </summary>
/// <param name="cipherIds">Collection of cipher IDs</param>
Task MarkAsCompleteByCipherIds(IEnumerable<Guid> cipherIds);
}

View File

@@ -33,6 +33,7 @@ public class CipherService : ICipherService
private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly ICollectionCipherRepository _collectionCipherRepository;
private readonly ISecurityTaskRepository _securityTaskRepository;
private readonly IPushNotificationService _pushService;
private readonly IAttachmentStorageService _attachmentStorageService;
private readonly IEventService _eventService;
@@ -53,6 +54,7 @@ public class CipherService : ICipherService
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
ICollectionCipherRepository collectionCipherRepository,
ISecurityTaskRepository securityTaskRepository,
IPushNotificationService pushService,
IAttachmentStorageService attachmentStorageService,
IEventService eventService,
@@ -71,6 +73,7 @@ public class CipherService : ICipherService
_organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository;
_collectionCipherRepository = collectionCipherRepository;
_securityTaskRepository = securityTaskRepository;
_pushService = pushService;
_attachmentStorageService = attachmentStorageService;
_eventService = eventService;
@@ -724,6 +727,7 @@ public class CipherService : ICipherService
cipherDetails.ArchivedDate = null;
}
await _securityTaskRepository.MarkAsCompleteByCipherIds([cipherDetails.Id]);
await _cipherRepository.UpsertAsync(cipherDetails);
await _eventService.LogCipherEventAsync(cipherDetails, EventType.Cipher_SoftDeleted);
@@ -750,6 +754,8 @@ public class CipherService : ICipherService
await _cipherRepository.SoftDeleteAsync(deletingCiphers.Select(c => c.Id), deletingUserId);
}
await _securityTaskRepository.MarkAsCompleteByCipherIds(deletingCiphers.Select(c => c.Id));
var events = deletingCiphers.Select(c =>
new Tuple<Cipher, EventType, DateTime?>(c, EventType.Cipher_SoftDeleted, null));
foreach (var eventsBatch in events.Chunk(100))

View File

@@ -85,4 +85,19 @@ public class SecurityTaskRepository : Repository<SecurityTask, Guid>, ISecurityT
return tasksList;
}
/// <inheritdoc />
public async Task MarkAsCompleteByCipherIds(IEnumerable<Guid> cipherIds)
{
if (!cipherIds.Any())
{
return;
}
await using var connection = new SqlConnection(ConnectionString);
await connection.ExecuteAsync(
$"[{Schema}].[SecurityTask_MarkCompleteByCipherIds]",
new { CipherIds = cipherIds.ToGuidIdArrayTVP() },
commandType: CommandType.StoredProcedure);
}
}

View File

@@ -96,4 +96,24 @@ public class SecurityTaskRepository : Repository<Core.Vault.Entities.SecurityTas
return metrics ?? new Core.Vault.Entities.SecurityTaskMetrics(0, 0);
}
/// <inheritdoc />
public async Task MarkAsCompleteByCipherIds(IEnumerable<Guid> cipherIds)
{
if (!cipherIds.Any())
{
return;
}
using var scope = ServiceScopeFactory.CreateScope();
var dbContext = GetDatabaseContext(scope);
var cipherIdsList = cipherIds.ToList();
await dbContext.SecurityTasks
.Where(st => st.CipherId.HasValue && cipherIdsList.Contains(st.CipherId.Value) && st.Status != SecurityTaskStatus.Completed)
.ExecuteUpdateAsync(st => st
.SetProperty(s => s.Status, SecurityTaskStatus.Completed)
.SetProperty(s => s.RevisionDate, DateTime.UtcNow));
}
}

View File

@@ -0,0 +1,15 @@
CREATE PROCEDURE [dbo].[SecurityTask_MarkCompleteByCipherIds]
@CipherIds AS [dbo].[GuidIdArray] READONLY
AS
BEGIN
SET NOCOUNT ON
UPDATE
[dbo].[SecurityTask]
SET
[Status] = 1, -- completed
[RevisionDate] = SYSUTCDATETIME()
WHERE
[CipherId] IN (SELECT [Id] FROM @CipherIds)
AND [Status] <> 1 -- Not already completed
END

View File

@@ -0,0 +1,87 @@
using Bit.Api.AdminConsole.Public.Controllers;
using Bit.Api.AdminConsole.Public.Models.Request;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
using Bit.Core.Context;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using Xunit;
namespace Bit.Api.Test.AdminConsole.Public.Controllers;
[ControllerCustomize(typeof(PoliciesController))]
[SutProviderCustomize]
public class PoliciesControllerTests
{
[Theory]
[BitAutoData]
public async Task Put_WhenPolicyValidatorsRefactorEnabled_UsesVNextSavePolicyCommand(
Guid organizationId,
PolicyType policyType,
PolicyUpdateRequestModel model,
Policy policy,
SutProvider<PoliciesController> sutProvider)
{
// Arrange
policy.Data = null;
sutProvider.GetDependency<ICurrentContext>()
.OrganizationId.Returns(organizationId);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor)
.Returns(true);
sutProvider.GetDependency<IVNextSavePolicyCommand>()
.SaveAsync(Arg.Any<SavePolicyModel>())
.Returns(policy);
// Act
await sutProvider.Sut.Put(policyType, model);
// Assert
await sutProvider.GetDependency<IVNextSavePolicyCommand>()
.Received(1)
.SaveAsync(Arg.Is<SavePolicyModel>(m =>
m.PolicyUpdate.OrganizationId == organizationId &&
m.PolicyUpdate.Type == policyType &&
m.PolicyUpdate.Enabled == model.Enabled.GetValueOrDefault() &&
m.PerformedBy is SystemUser));
}
[Theory]
[BitAutoData]
public async Task Put_WhenPolicyValidatorsRefactorDisabled_UsesLegacySavePolicyCommand(
Guid organizationId,
PolicyType policyType,
PolicyUpdateRequestModel model,
Policy policy,
SutProvider<PoliciesController> sutProvider)
{
// Arrange
policy.Data = null;
sutProvider.GetDependency<ICurrentContext>()
.OrganizationId.Returns(organizationId);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor)
.Returns(false);
sutProvider.GetDependency<ISavePolicyCommand>()
.SaveAsync(Arg.Any<PolicyUpdate>())
.Returns(policy);
// Act
await sutProvider.Sut.Put(policyType, model);
// Assert
await sutProvider.GetDependency<ISavePolicyCommand>()
.Received(1)
.SaveAsync(Arg.Is<PolicyUpdate>(p =>
p.OrganizationId == organizationId &&
p.Type == policyType &&
p.Enabled == model.Enabled));
}
}

View File

@@ -1,10 +1,15 @@
using System.Security.Claims;
using System.Text.Json;
using Bit.Api.AdminConsole.Controllers;
using Bit.Api.AdminConsole.Models.Request;
using Bit.Api.AdminConsole.Models.Response.Organizations;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.Models.Business.Tokenables;
using Bit.Core.Context;
@@ -455,4 +460,98 @@ public class PoliciesControllerTests
Assert.Equal(enabledPolicy.Type, expectedPolicy.Type);
Assert.Equal(enabledPolicy.Enabled, expectedPolicy.Enabled);
}
[Theory]
[BitAutoData]
public async Task PutVNext_WhenPolicyValidatorsRefactorEnabled_UsesVNextSavePolicyCommand(
SutProvider<PoliciesController> sutProvider, Guid orgId,
SavePolicyRequest model, Policy policy, Guid userId)
{
// Arrange
policy.Data = null;
sutProvider.GetDependency<ICurrentContext>()
.UserId
.Returns(userId);
sutProvider.GetDependency<ICurrentContext>()
.OrganizationOwner(orgId)
.Returns(true);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor)
.Returns(true);
sutProvider.GetDependency<IVNextSavePolicyCommand>()
.SaveAsync(Arg.Any<SavePolicyModel>())
.Returns(policy);
// Act
var result = await sutProvider.Sut.PutVNext(orgId, model);
// Assert
await sutProvider.GetDependency<IVNextSavePolicyCommand>()
.Received(1)
.SaveAsync(Arg.Is<SavePolicyModel>(
m => m.PolicyUpdate.OrganizationId == orgId &&
m.PolicyUpdate.Type == model.Policy.Type &&
m.PolicyUpdate.Enabled == model.Policy.Enabled &&
m.PerformedBy.UserId == userId &&
m.PerformedBy.IsOrganizationOwnerOrProvider == true));
await sutProvider.GetDependency<ISavePolicyCommand>()
.DidNotReceiveWithAnyArgs()
.VNextSaveAsync(default);
Assert.NotNull(result);
Assert.Equal(policy.Id, result.Id);
Assert.Equal(policy.Type, result.Type);
}
[Theory]
[BitAutoData]
public async Task PutVNext_WhenPolicyValidatorsRefactorDisabled_UsesSavePolicyCommand(
SutProvider<PoliciesController> sutProvider, Guid orgId,
SavePolicyRequest model, Policy policy, Guid userId)
{
// Arrange
policy.Data = null;
sutProvider.GetDependency<ICurrentContext>()
.UserId
.Returns(userId);
sutProvider.GetDependency<ICurrentContext>()
.OrganizationOwner(orgId)
.Returns(true);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor)
.Returns(false);
sutProvider.GetDependency<ISavePolicyCommand>()
.VNextSaveAsync(Arg.Any<SavePolicyModel>())
.Returns(policy);
// Act
var result = await sutProvider.Sut.PutVNext(orgId, model);
// Assert
await sutProvider.GetDependency<ISavePolicyCommand>()
.Received(1)
.VNextSaveAsync(Arg.Is<SavePolicyModel>(
m => m.PolicyUpdate.OrganizationId == orgId &&
m.PolicyUpdate.Type == model.Policy.Type &&
m.PolicyUpdate.Enabled == model.Policy.Enabled &&
m.PerformedBy.UserId == userId &&
m.PerformedBy.IsOrganizationOwnerOrProvider == true));
await sutProvider.GetDependency<IVNextSavePolicyCommand>()
.DidNotReceiveWithAnyArgs()
.SaveAsync(default);
Assert.NotNull(result);
Assert.Equal(policy.Id, result.Id);
Assert.Equal(policy.Type, result.Type);
}
}

View File

@@ -0,0 +1,947 @@
using Bit.Billing.Services;
using Bit.Billing.Services.Implementations;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models.StaticStore.Plans;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Payment.Queries;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Pricing.Premium;
using Bit.Core.Entities;
using Bit.Core.Models.Mail.UpdatedInvoiceIncoming;
using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces;
using Bit.Core.Platform.Mail.Mailer;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Microsoft.Extensions.Logging;
using NSubstitute;
using NSubstitute.ExceptionExtensions;
using Stripe;
using Xunit;
using static Bit.Core.Billing.Constants.StripeConstants;
using Address = Stripe.Address;
using Event = Stripe.Event;
using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan;
namespace Bit.Billing.Test.Services;
public class UpcomingInvoiceHandlerTests
{
private readonly IGetPaymentMethodQuery _getPaymentMethodQuery;
private readonly ILogger<StripeEventProcessor> _logger;
private readonly IMailService _mailService;
private readonly IOrganizationRepository _organizationRepository;
private readonly IPricingClient _pricingClient;
private readonly IProviderRepository _providerRepository;
private readonly IStripeFacade _stripeFacade;
private readonly IStripeEventService _stripeEventService;
private readonly IStripeEventUtilityService _stripeEventUtilityService;
private readonly IUserRepository _userRepository;
private readonly IValidateSponsorshipCommand _validateSponsorshipCommand;
private readonly IMailer _mailer;
private readonly IFeatureService _featureService;
private readonly UpcomingInvoiceHandler _sut;
private readonly Guid _userId = Guid.NewGuid();
private readonly Guid _organizationId = Guid.NewGuid();
private readonly Guid _providerId = Guid.NewGuid();
public UpcomingInvoiceHandlerTests()
{
_getPaymentMethodQuery = Substitute.For<IGetPaymentMethodQuery>();
_logger = Substitute.For<ILogger<StripeEventProcessor>>();
_mailService = Substitute.For<IMailService>();
_organizationRepository = Substitute.For<IOrganizationRepository>();
_pricingClient = Substitute.For<IPricingClient>();
_providerRepository = Substitute.For<IProviderRepository>();
_stripeFacade = Substitute.For<IStripeFacade>();
_stripeEventService = Substitute.For<IStripeEventService>();
_stripeEventUtilityService = Substitute.For<IStripeEventUtilityService>();
_userRepository = Substitute.For<IUserRepository>();
_validateSponsorshipCommand = Substitute.For<IValidateSponsorshipCommand>();
_mailer = Substitute.For<IMailer>();
_featureService = Substitute.For<IFeatureService>();
_sut = new UpcomingInvoiceHandler(
_getPaymentMethodQuery,
_logger,
_mailService,
_organizationRepository,
_pricingClient,
_providerRepository,
_stripeFacade,
_stripeEventService,
_stripeEventUtilityService,
_userRepository,
_validateSponsorshipCommand,
_mailer,
_featureService);
}
[Fact]
public async Task HandleAsync_WhenNullSubscription_DoesNothing()
{
// Arrange
var parsedEvent = new Event();
var invoice = new Invoice { CustomerId = "cus_123" };
var customer = new Customer { Id = "cus_123", Subscriptions = new StripeList<Subscription> { Data = [] } };
_stripeEventService.GetInvoice(parsedEvent).Returns(invoice);
_stripeFacade
.GetCustomer(invoice.CustomerId, Arg.Any<CustomerGetOptions>())
.Returns(customer);
// Act
await _sut.HandleAsync(parsedEvent);
// Assert
await _stripeFacade.DidNotReceive()
.UpdateCustomer(Arg.Any<string>(), Arg.Any<CustomerUpdateOptions>());
}
[Fact]
public async Task HandleAsync_WhenValidUser_SendsEmail()
{
// Arrange
var parsedEvent = new Event { Id = "evt_123" };
var customerId = "cus_123";
var invoice = new Invoice
{
CustomerId = customerId,
AmountDue = 10000,
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
}
};
var subscription = new Subscription
{
Id = "sub_123",
CustomerId = customerId,
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new() { Id = "si_123", Price = new Price { Id = Prices.PremiumAnnually } }
}
},
AutomaticTax = new SubscriptionAutomaticTax { Enabled = false },
Customer = new Customer { Id = customerId },
Metadata = new Dictionary<string, string>()
};
var user = new User { Id = _userId, Email = "user@example.com", Premium = true };
var plan = new PremiumPlan
{
Name = "Premium",
Available = true,
LegacyYear = null,
Seat = new Purchasable { Price = 10M, StripePriceId = Prices.PremiumAnnually },
Storage = new Purchasable { Price = 4M, StripePriceId = Prices.StoragePlanPersonal }
};
var customer = new Customer
{
Id = customerId,
Tax = new CustomerTax { AutomaticTax = AutomaticTaxStatus.Supported },
Subscriptions = new StripeList<Subscription> { Data = [subscription] }
};
_stripeEventService.GetInvoice(parsedEvent).Returns(invoice);
_stripeFacade
.GetCustomer(customerId, Arg.Any<CustomerGetOptions>())
.Returns(customer);
_stripeEventUtilityService
.GetIdsFromMetadata(subscription.Metadata)
.Returns(new Tuple<Guid?, Guid?, Guid?>(null, _userId, null));
_userRepository.GetByIdAsync(_userId).Returns(user);
_pricingClient.GetAvailablePremiumPlan().Returns(plan);
// If milestone 2 is disabled, the default email is sent
_featureService
.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2)
.Returns(false);
// Act
await _sut.HandleAsync(parsedEvent);
// Assert
await _userRepository.Received(1).GetByIdAsync(_userId);
await _mailService.Received(1).SendInvoiceUpcoming(
Arg.Is<IEnumerable<string>>(emails => emails.Contains("user@example.com")),
Arg.Is<decimal>(amount => amount == invoice.AmountDue / 100M),
Arg.Is<DateTime>(dueDate => dueDate == invoice.NextPaymentAttempt.Value),
Arg.Is<List<string>>(items => items.Count == invoice.Lines.Data.Count),
Arg.Is<bool>(b => b == true));
}
[Fact]
public async Task
HandleAsync_WhenUserValid_AndMilestone2Enabled_UpdatesPriceId_AndSendsUpdatedInvoiceUpcomingEmail()
{
// Arrange
var parsedEvent = new Event { Id = "evt_123" };
var customerId = "cus_123";
var priceSubscriptionId = "sub-1";
var priceId = "price-id-2";
var invoice = new Invoice
{
CustomerId = customerId,
AmountDue = 10000,
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
}
};
var subscription = new Subscription
{
Id = "sub_123",
CustomerId = customerId,
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new() { Id = priceSubscriptionId, Price = new Price { Id = Prices.PremiumAnnually } }
}
},
AutomaticTax = new SubscriptionAutomaticTax { Enabled = false },
Customer = new Customer
{
Id = customerId,
Tax = new CustomerTax { AutomaticTax = AutomaticTaxStatus.Supported }
},
Metadata = new Dictionary<string, string>()
};
var user = new User { Id = _userId, Email = "user@example.com", Premium = true };
var plan = new PremiumPlan
{
Name = "Premium",
Available = true,
LegacyYear = null,
Seat = new Purchasable { Price = 10M, StripePriceId = priceId },
Storage = new Purchasable { Price = 4M, StripePriceId = Prices.StoragePlanPersonal }
};
var customer = new Customer
{
Id = customerId,
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } }
};
_stripeEventService.GetInvoice(parsedEvent).Returns(invoice);
_stripeFacade
.GetCustomer(customerId, Arg.Any<CustomerGetOptions>())
.Returns(customer);
_stripeEventUtilityService
.GetIdsFromMetadata(subscription.Metadata)
.Returns(new Tuple<Guid?, Guid?, Guid?>(null, _userId, null));
_userRepository.GetByIdAsync(_userId).Returns(user);
_pricingClient.GetAvailablePremiumPlan().Returns(plan);
_stripeFacade.UpdateSubscription(
subscription.Id,
Arg.Any<SubscriptionUpdateOptions>())
.Returns(subscription);
// If milestone 2 is true, the updated invoice email is sent
_featureService
.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2)
.Returns(true);
// Act
await _sut.HandleAsync(parsedEvent);
// Assert
await _userRepository.Received(1).GetByIdAsync(_userId);
await _pricingClient.Received(1).GetAvailablePremiumPlan();
await _stripeFacade.Received(1).UpdateSubscription(
Arg.Is("sub_123"),
Arg.Is<SubscriptionUpdateOptions>(o =>
o.Items[0].Id == priceSubscriptionId &&
o.Items[0].Price == priceId));
// Verify the updated invoice email was sent
await _mailer.Received(1).SendEmail(
Arg.Is<UpdatedInvoiceUpcomingMail>(email =>
email.ToEmails.Contains("user@example.com") &&
email.Subject == "Your Subscription Will Renew Soon"));
}
[Fact]
public async Task HandleAsync_WhenOrganizationHasSponsorship_SendsEmail()
{
// Arrange
var parsedEvent = new Event { Id = "evt_123" };
var invoice = new Invoice
{
CustomerId = "cus_123",
AmountDue = 10000,
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
}
};
var subscription = new Subscription
{
Id = "sub_123",
CustomerId = "cus_123",
Items = new StripeList<SubscriptionItem>(),
AutomaticTax = new SubscriptionAutomaticTax { Enabled = false },
Customer = new Customer { Id = "cus_123" },
Metadata = new Dictionary<string, string>(),
LatestInvoiceId = "inv_latest"
};
var customer = new Customer
{
Id = "cus_123",
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } },
Address = new Address { Country = "US" }
};
var organization = new Organization
{
Id = _organizationId,
BillingEmail = "org@example.com",
PlanType = PlanType.EnterpriseAnnually
};
var plan = new FamiliesPlan();
_stripeEventService.GetInvoice(parsedEvent).Returns(invoice);
_stripeFacade
.GetCustomer(invoice.CustomerId, Arg.Any<CustomerGetOptions>())
.Returns(customer);
_stripeEventUtilityService
.GetIdsFromMetadata(subscription.Metadata)
.Returns(new Tuple<Guid?, Guid?, Guid?>(_organizationId, null, null));
_organizationRepository
.GetByIdAsync(_organizationId)
.Returns(organization);
_pricingClient
.GetPlanOrThrow(organization.PlanType)
.Returns(plan);
_stripeEventUtilityService
.IsSponsoredSubscription(subscription)
.Returns(true);
// Configure that this is a sponsored subscription
_stripeEventUtilityService
.IsSponsoredSubscription(subscription)
.Returns(true);
_validateSponsorshipCommand
.ValidateSponsorshipAsync(_organizationId)
.Returns(true);
// Act
await _sut.HandleAsync(parsedEvent);
// Assert
await _organizationRepository.Received(1).GetByIdAsync(_organizationId);
await _validateSponsorshipCommand.Received(1).ValidateSponsorshipAsync(_organizationId);
await _mailService.Received(1).SendInvoiceUpcoming(
Arg.Is<IEnumerable<string>>(emails => emails.Contains("org@example.com")),
Arg.Is<decimal>(amount => amount == invoice.AmountDue / 100M),
Arg.Is<DateTime>(dueDate => dueDate == invoice.NextPaymentAttempt.Value),
Arg.Is<List<string>>(items => items.Count == invoice.Lines.Data.Count),
Arg.Is<bool>(b => b == true));
}
[Fact]
public async Task
HandleAsync_WhenOrganizationHasSponsorship_ButInvalidSponsorship_RetrievesUpdatedInvoice_SendsEmail()
{
// Arrange
var parsedEvent = new Event { Id = "evt_123" };
var invoice = new Invoice
{
CustomerId = "cus_123",
AmountDue = 10000,
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
}
};
var subscription = new Subscription
{
Id = "sub_123",
CustomerId = "cus_123",
Items = new StripeList<SubscriptionItem>
{
Data =
[new SubscriptionItem { Price = new Price { Id = "2021-family-for-enterprise-annually" } }]
},
AutomaticTax = new SubscriptionAutomaticTax { Enabled = false },
Customer = new Customer { Id = "cus_123" },
Metadata = new Dictionary<string, string>(),
LatestInvoiceId = "inv_latest"
};
var customer = new Customer
{
Id = "cus_123",
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } },
Address = new Address { Country = "US" }
};
var organization = new Organization
{
Id = _organizationId,
BillingEmail = "org@example.com",
PlanType = PlanType.EnterpriseAnnually
};
var plan = new FamiliesPlan();
var paymentMethod = new Card { Last4 = "4242", Brand = "visa" };
_stripeEventService.GetInvoice(parsedEvent).Returns(invoice);
_stripeFacade
.GetCustomer(invoice.CustomerId, Arg.Any<CustomerGetOptions>())
.Returns(customer);
_stripeEventUtilityService
.GetIdsFromMetadata(subscription.Metadata)
.Returns(new Tuple<Guid?, Guid?, Guid?>(_organizationId, null, null));
_organizationRepository
.GetByIdAsync(_organizationId)
.Returns(organization);
_pricingClient
.GetPlanOrThrow(organization.PlanType)
.Returns(plan);
// Configure that this is not a sponsored subscription
_stripeEventUtilityService
.IsSponsoredSubscription(subscription)
.Returns(true);
// Validate sponsorship should return false
_validateSponsorshipCommand
.ValidateSponsorshipAsync(_organizationId)
.Returns(false);
_stripeFacade
.GetInvoice(subscription.LatestInvoiceId)
.Returns(invoice);
_getPaymentMethodQuery.Run(organization).Returns(MaskedPaymentMethod.From(paymentMethod));
// Act
await _sut.HandleAsync(parsedEvent);
// Assert
await _organizationRepository.Received(1).GetByIdAsync(_organizationId);
_stripeEventUtilityService.Received(1).IsSponsoredSubscription(subscription);
await _validateSponsorshipCommand.Received(1).ValidateSponsorshipAsync(_organizationId);
await _stripeFacade.Received(1).GetInvoice(Arg.Is("inv_latest"));
await _mailService.Received(1).SendInvoiceUpcoming(
Arg.Is<IEnumerable<string>>(emails => emails.Contains("org@example.com")),
Arg.Is<decimal>(amount => amount == invoice.AmountDue / 100M),
Arg.Is<DateTime>(dueDate => dueDate == invoice.NextPaymentAttempt.Value),
Arg.Is<List<string>>(items => items.Count == invoice.Lines.Data.Count),
Arg.Is<bool>(b => b == true));
}
[Fact]
public async Task HandleAsync_WhenValidOrganization_SendsEmail()
{
// Arrange
var parsedEvent = new Event { Id = "evt_123" };
var invoice = new Invoice
{
CustomerId = "cus_123",
AmountDue = 10000,
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
}
};
var subscription = new Subscription
{
Id = "sub_123",
CustomerId = "cus_123",
Items = new StripeList<SubscriptionItem>
{
Data =
[new SubscriptionItem { Price = new Price { Id = "enterprise-annually" } }]
},
AutomaticTax = new SubscriptionAutomaticTax { Enabled = false },
Customer = new Customer { Id = "cus_123" },
Metadata = new Dictionary<string, string>(),
LatestInvoiceId = "inv_latest"
};
var customer = new Customer
{
Id = "cus_123",
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } },
Address = new Address { Country = "US" }
};
var organization = new Organization
{
Id = _organizationId,
BillingEmail = "org@example.com",
PlanType = PlanType.EnterpriseAnnually
};
var plan = new FamiliesPlan();
var paymentMethod = new Card { Last4 = "4242", Brand = "visa" };
_stripeEventService.GetInvoice(parsedEvent).Returns(invoice);
_stripeFacade
.GetCustomer(invoice.CustomerId, Arg.Any<CustomerGetOptions>())
.Returns(customer);
_stripeEventUtilityService
.GetIdsFromMetadata(subscription.Metadata)
.Returns(new Tuple<Guid?, Guid?, Guid?>(_organizationId, null, null));
_organizationRepository
.GetByIdAsync(_organizationId)
.Returns(organization);
_pricingClient
.GetPlanOrThrow(organization.PlanType)
.Returns(plan);
_stripeEventUtilityService
.IsSponsoredSubscription(subscription)
.Returns(false);
_stripeFacade
.GetInvoice(subscription.LatestInvoiceId)
.Returns(invoice);
_getPaymentMethodQuery.Run(organization).Returns(MaskedPaymentMethod.From(paymentMethod));
// Act
await _sut.HandleAsync(parsedEvent);
// Assert
await _organizationRepository.Received(1).GetByIdAsync(_organizationId);
_stripeEventUtilityService.Received(1).IsSponsoredSubscription(subscription);
// Should not validate sponsorship for non-sponsored subscription
await _validateSponsorshipCommand.DidNotReceive().ValidateSponsorshipAsync(Arg.Any<Guid>());
await _mailService.Received(1).SendInvoiceUpcoming(
Arg.Is<IEnumerable<string>>(emails => emails.Contains("org@example.com")),
Arg.Is<decimal>(amount => amount == invoice.AmountDue / 100M),
Arg.Is<DateTime>(dueDate => dueDate == invoice.NextPaymentAttempt.Value),
Arg.Is<List<string>>(items => items.Count == invoice.Lines.Data.Count),
Arg.Is<bool>(b => b == true));
}
[Fact]
public async Task HandleAsync_WhenValidProviderSubscription_SendsEmail()
{
// Arrange
var parsedEvent = new Event { Id = "evt_123" };
var invoice = new Invoice
{
CustomerId = "cus_123",
AmountDue = 10000,
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
}
};
var subscription = new Subscription
{
Id = "sub_123",
CustomerId = "cus_123",
Items = new StripeList<SubscriptionItem>(),
AutomaticTax = new SubscriptionAutomaticTax { Enabled = false },
Customer = new Customer { Id = "cus_123" },
Metadata = new Dictionary<string, string>(),
CollectionMethod = "charge_automatically"
};
var customer = new Customer
{
Id = "cus_123",
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } },
Address = new Address { Country = "UK" },
TaxExempt = TaxExempt.None
};
var provider = new Provider { Id = _providerId, BillingEmail = "provider@example.com" };
var paymentMethod = new Card { Last4 = "4242", Brand = "visa" };
_stripeEventService.GetInvoice(parsedEvent).Returns(invoice);
_stripeFacade.GetCustomer(invoice.CustomerId, Arg.Any<CustomerGetOptions>()).Returns(customer);
_stripeEventUtilityService
.GetIdsFromMetadata(subscription.Metadata)
.Returns(new Tuple<Guid?, Guid?, Guid?>(null, null, _providerId));
_providerRepository.GetByIdAsync(_providerId).Returns(provider);
_getPaymentMethodQuery.Run(provider).Returns(MaskedPaymentMethod.From(paymentMethod));
// Act
await _sut.HandleAsync(parsedEvent);
// Assert
await _providerRepository.Received(2).GetByIdAsync(_providerId);
// Verify tax exempt was set to reverse for non-US providers
await _stripeFacade.Received(1).UpdateCustomer(
Arg.Is("cus_123"),
Arg.Is<CustomerUpdateOptions>(o => o.TaxExempt == TaxExempt.Reverse));
// Verify automatic tax was enabled
await _stripeFacade.Received(1).UpdateSubscription(
Arg.Is("sub_123"),
Arg.Is<SubscriptionUpdateOptions>(o => o.AutomaticTax.Enabled == true));
// Verify provider invoice email was sent
await _mailService.Received(1).SendProviderInvoiceUpcoming(
Arg.Is<IEnumerable<string>>(e => e.Contains("provider@example.com")),
Arg.Is<decimal>(amount => amount == invoice.AmountDue / 100M),
Arg.Is<DateTime>(dueDate => dueDate == invoice.NextPaymentAttempt.Value),
Arg.Is<List<string>>(items => items.Count == invoice.Lines.Data.Count),
Arg.Is<string>(s => s == subscription.CollectionMethod),
Arg.Is<bool>(b => b == true),
Arg.Is<string>(s => s == $"{paymentMethod.Brand} ending in {paymentMethod.Last4}"));
}
[Fact]
public async Task HandleAsync_WhenUpdateSubscriptionItemPriceIdFails_LogsErrorAndSendsEmail()
{
// Arrange
// Arrange
var parsedEvent = new Event { Id = "evt_123" };
var customerId = "cus_123";
var priceSubscriptionId = "sub-1";
var priceId = "price-id-2";
var invoice = new Invoice
{
CustomerId = customerId,
AmountDue = 10000,
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
}
};
var subscription = new Subscription
{
Id = "sub_123",
CustomerId = customerId,
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new() { Id = priceSubscriptionId, Price = new Price { Id = Prices.PremiumAnnually } }
}
},
AutomaticTax = new SubscriptionAutomaticTax { Enabled = true },
Customer = new Customer
{
Id = customerId,
Tax = new CustomerTax { AutomaticTax = AutomaticTaxStatus.Supported }
},
Metadata = new Dictionary<string, string>()
};
var user = new User { Id = _userId, Email = "user@example.com", Premium = true };
var plan = new PremiumPlan
{
Name = "Premium",
Available = true,
LegacyYear = null,
Seat = new Purchasable { Price = 10M, StripePriceId = priceId },
Storage = new Purchasable { Price = 4M, StripePriceId = Prices.StoragePlanPersonal }
};
var customer = new Customer
{
Id = customerId,
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } }
};
_stripeEventService.GetInvoice(parsedEvent).Returns(invoice);
_stripeFacade.GetCustomer(invoice.CustomerId, Arg.Any<CustomerGetOptions>()).Returns(customer);
_stripeEventUtilityService
.GetIdsFromMetadata(subscription.Metadata)
.Returns(new Tuple<Guid?, Guid?, Guid?>(null, _userId, null));
_userRepository.GetByIdAsync(_userId).Returns(user);
_featureService
.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2)
.Returns(true);
_pricingClient.GetAvailablePremiumPlan().Returns(plan);
// Setup exception when updating subscription
_stripeFacade
.UpdateSubscription(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>())
.ThrowsAsync(new Exception());
// Act
await _sut.HandleAsync(parsedEvent);
// Assert
_logger.Received(1).Log(
LogLevel.Error,
Arg.Any<EventId>(),
Arg.Is<object>(o =>
o.ToString()
.Contains(
$"Failed to update user's ({user.Id}) subscription price id while processing event with ID {parsedEvent.Id}")),
Arg.Any<Exception>(),
Arg.Any<Func<object, Exception, string>>());
// Verify that email was still sent despite the exception
await _mailer.Received(1).SendEmail(
Arg.Is<UpdatedInvoiceUpcomingMail>(email =>
email.ToEmails.Contains("user@example.com") &&
email.Subject == "Your Subscription Will Renew Soon"));
}
[Fact]
public async Task HandleAsync_WhenOrganizationNotFound_DoesNothing()
{
// Arrange
var parsedEvent = new Event { Id = "evt_123" };
var invoice = new Invoice
{
CustomerId = "cus_123",
AmountDue = 10000,
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
}
};
var subscription = new Subscription
{
Id = "sub_123",
CustomerId = "cus_123",
Items = new StripeList<SubscriptionItem>(),
AutomaticTax = new SubscriptionAutomaticTax { Enabled = false },
Customer = new Customer { Id = "cus_123" },
Metadata = new Dictionary<string, string>(),
};
var customer = new Customer
{
Id = "cus_123",
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } }
};
_stripeEventService.GetInvoice(parsedEvent).Returns(invoice);
_stripeFacade
.GetCustomer(invoice.CustomerId, Arg.Any<CustomerGetOptions>())
.Returns(customer);
_stripeEventUtilityService
.GetIdsFromMetadata(subscription.Metadata)
.Returns(new Tuple<Guid?, Guid?, Guid?>(_organizationId, null, null));
// Organization not found
_organizationRepository.GetByIdAsync(_organizationId).Returns((Organization)null);
// Act
await _sut.HandleAsync(parsedEvent);
// Assert
await _organizationRepository.Received(1).GetByIdAsync(_organizationId);
// Verify no emails were sent
await _mailService.DidNotReceive().SendInvoiceUpcoming(
Arg.Any<IEnumerable<string>>(),
Arg.Any<decimal>(),
Arg.Any<DateTime>(),
Arg.Any<List<string>>(),
Arg.Any<bool>());
}
[Fact]
public async Task HandleAsync_WhenZeroAmountInvoice_DoesNothing()
{
// Arrange
var parsedEvent = new Event { Id = "evt_123" };
var invoice = new Invoice
{
CustomerId = "cus_123",
AmountDue = 0, // Zero amount due
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Free Item" } }
}
};
var subscription = new Subscription
{
Id = "sub_123",
CustomerId = "cus_123",
Items = new StripeList<SubscriptionItem>(),
AutomaticTax = new SubscriptionAutomaticTax { Enabled = false },
Customer = new Customer { Id = "cus_123" },
Metadata = new Dictionary<string, string>()
};
var user = new User { Id = _userId, Email = "user@example.com", Premium = true };
var customer = new Customer
{
Id = "cus_123",
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } }
};
_stripeEventService.GetInvoice(parsedEvent).Returns(invoice);
_stripeFacade
.GetCustomer(invoice.CustomerId, Arg.Any<CustomerGetOptions>())
.Returns(customer);
_stripeEventUtilityService
.GetIdsFromMetadata(subscription.Metadata)
.Returns(new Tuple<Guid?, Guid?, Guid?>(null, _userId, null));
_userRepository.GetByIdAsync(_userId).Returns(user);
// Act
await _sut.HandleAsync(parsedEvent);
// Assert
await _userRepository.Received(1).GetByIdAsync(_userId);
// Should not
await _mailService.DidNotReceive().SendInvoiceUpcoming(
Arg.Any<IEnumerable<string>>(),
Arg.Any<decimal>(),
Arg.Any<DateTime>(),
Arg.Any<List<string>>(),
Arg.Any<bool>());
}
[Fact]
public async Task HandleAsync_WhenUserNotFound_DoesNothing()
{
// Arrange
var parsedEvent = new Event { Id = "evt_123" };
var invoice = new Invoice
{
CustomerId = "cus_123",
AmountDue = 10000,
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
}
};
var subscription = new Subscription
{
Id = "sub_123",
CustomerId = "cus_123",
Items = new StripeList<SubscriptionItem>(),
AutomaticTax = new SubscriptionAutomaticTax { Enabled = false },
Customer = new Customer { Id = "cus_123" },
Metadata = new Dictionary<string, string>()
};
var customer = new Customer
{
Id = "cus_123",
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } }
};
_stripeEventService.GetInvoice(parsedEvent).Returns(invoice);
_stripeFacade
.GetCustomer(invoice.CustomerId, Arg.Any<CustomerGetOptions>())
.Returns(customer);
_stripeEventUtilityService
.GetIdsFromMetadata(subscription.Metadata)
.Returns(new Tuple<Guid?, Guid?, Guid?>(null, _userId, null));
// User not found
_userRepository.GetByIdAsync(_userId).Returns((User)null);
// Act
await _sut.HandleAsync(parsedEvent);
// Assert
await _userRepository.Received(1).GetByIdAsync(_userId);
// Verify no emails were sent
await _mailService.DidNotReceive().SendInvoiceUpcoming(
Arg.Any<IEnumerable<string>>(),
Arg.Any<decimal>(),
Arg.Any<DateTime>(),
Arg.Any<List<string>>(),
Arg.Any<bool>());
await _mailer.DidNotReceive().SendEmail(Arg.Any<UpdatedInvoiceUpcomingMail>());
}
[Fact]
public async Task HandleAsync_WhenProviderNotFound_DoesNothing()
{
// Arrange
var parsedEvent = new Event { Id = "evt_123" };
var invoice = new Invoice
{
CustomerId = "cus_123",
AmountDue = 10000,
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
}
};
var subscription = new Subscription
{
Id = "sub_123",
CustomerId = "cus_123",
Items = new StripeList<SubscriptionItem>(),
AutomaticTax = new SubscriptionAutomaticTax { Enabled = false },
Customer = new Customer { Id = "cus_123" },
Metadata = new Dictionary<string, string>()
};
var customer = new Customer
{
Id = "cus_123",
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } }
};
_stripeEventService.GetInvoice(parsedEvent).Returns(invoice);
_stripeFacade
.GetCustomer(invoice.CustomerId, Arg.Any<CustomerGetOptions>())
.Returns(customer);
_stripeEventUtilityService
.GetIdsFromMetadata(subscription.Metadata)
.Returns(new Tuple<Guid?, Guid?, Guid?>(null, null, _providerId));
// Provider not found
_providerRepository.GetByIdAsync(_providerId).Returns((Provider)null);
// Act
await _sut.HandleAsync(parsedEvent);
// Assert
await _providerRepository.Received(1).GetByIdAsync(_providerId);
// Verify no provider emails were sent
await _mailService.DidNotReceive().SendProviderInvoiceUpcoming(
Arg.Any<IEnumerable<string>>(),
Arg.Any<decimal>(),
Arg.Any<DateTime>(),
Arg.Any<List<string>>(),
Arg.Any<string>(),
Arg.Any<bool>(),
Arg.Any<string>());
}
}

View File

@@ -4,6 +4,7 @@ using Bit.Core.AdminConsole.Models.Data;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
using Bit.Core.Context;
using Bit.Core.Entities;
using Bit.Core.Enums;
@@ -191,6 +192,37 @@ public class VerifyOrganizationDomainCommandTests
x.PerformedBy.UserId == userId));
}
[Theory, BitAutoData]
public async Task UserVerifyOrganizationDomainAsync_WhenPolicyValidatorsRefactorFlagEnabled_UsesVNextSavePolicyCommand(
OrganizationDomain domain, Guid userId, SutProvider<VerifyOrganizationDomainCommand> sutProvider)
{
sutProvider.GetDependency<IOrganizationDomainRepository>()
.GetClaimedDomainsByDomainNameAsync(domain.DomainName)
.Returns([]);
sutProvider.GetDependency<IDnsResolverService>()
.ResolveAsync(domain.DomainName, domain.Txt)
.Returns(true);
sutProvider.GetDependency<ICurrentContext>()
.UserId.Returns(userId);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor)
.Returns(true);
_ = await sutProvider.Sut.UserVerifyOrganizationDomainAsync(domain);
await sutProvider.GetDependency<IVNextSavePolicyCommand>()
.Received(1)
.SaveAsync(Arg.Is<SavePolicyModel>(m =>
m.PolicyUpdate.Type == PolicyType.SingleOrg &&
m.PolicyUpdate.OrganizationId == domain.OrganizationId &&
m.PolicyUpdate.Enabled &&
m.PerformedBy is StandardUser &&
m.PerformedBy.UserId == userId));
}
[Theory, BitAutoData]
public async Task UserVerifyOrganizationDomainAsync_WhenDomainIsNotVerified_ThenSingleOrgPolicyShouldNotBeEnabled(
OrganizationDomain domain, SutProvider<VerifyOrganizationDomainCommand> sutProvider)

View File

@@ -92,7 +92,7 @@ public class FreeFamiliesForEnterprisePolicyValidatorTests
.GetManyBySponsoringOrganizationAsync(policyUpdate.OrganizationId)
.Returns(organizationSponsorships);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
await sutProvider.Sut.ExecutePreUpsertSideEffectAsync(savePolicyModel, policy);
@@ -120,7 +120,7 @@ public class FreeFamiliesForEnterprisePolicyValidatorTests
.GetManyBySponsoringOrganizationAsync(policyUpdate.OrganizationId)
.Returns(organizationSponsorships);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
await sutProvider.Sut.ExecutePreUpsertSideEffectAsync(savePolicyModel, policy);

View File

@@ -32,7 +32,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
.IsEnabled(FeatureFlagKeys.CreateDefaultLocation)
.Returns(false);
var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
// Act
await sutProvider.Sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
@@ -58,7 +58,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
.IsEnabled(FeatureFlagKeys.CreateDefaultLocation)
.Returns(true);
var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
// Act
await sutProvider.Sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
@@ -84,7 +84,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
.IsEnabled(FeatureFlagKeys.CreateDefaultLocation)
.Returns(true);
var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
// Act
await sutProvider.Sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
@@ -110,7 +110,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
var collectionRepository = Substitute.For<ICollectionRepository>();
var sut = ArrangeSut(factory, policyRepository, collectionRepository);
var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
// Act
await sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
@@ -199,7 +199,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
var collectionRepository = Substitute.For<ICollectionRepository>();
var sut = ArrangeSut(factory, policyRepository, collectionRepository);
var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
// Act
await sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
@@ -238,7 +238,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
.IsEnabled(FeatureFlagKeys.CreateDefaultLocation)
.Returns(true);
var policyRequest = new SavePolicyModel(policyUpdate, null, metadata);
var policyRequest = new SavePolicyModel(policyUpdate, metadata);
// Act
await sutProvider.Sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
@@ -286,7 +286,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
.IsEnabled(FeatureFlagKeys.CreateDefaultLocation)
.Returns(false);
var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
// Act
await sutProvider.Sut.ExecutePostUpsertSideEffectAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
@@ -312,7 +312,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
.IsEnabled(FeatureFlagKeys.CreateDefaultLocation)
.Returns(true);
var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
// Act
await sutProvider.Sut.ExecutePostUpsertSideEffectAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
@@ -338,7 +338,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
.IsEnabled(FeatureFlagKeys.CreateDefaultLocation)
.Returns(true);
var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
// Act
await sutProvider.Sut.ExecutePostUpsertSideEffectAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
@@ -364,7 +364,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
var collectionRepository = Substitute.For<ICollectionRepository>();
var sut = ArrangeSut(factory, policyRepository, collectionRepository);
var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
// Act
await sut.ExecutePostUpsertSideEffectAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
@@ -404,7 +404,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
var collectionRepository = Substitute.For<ICollectionRepository>();
var sut = ArrangeSut(factory, policyRepository, collectionRepository);
var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
// Act
await sut.ExecutePostUpsertSideEffectAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
@@ -436,7 +436,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
.IsEnabled(FeatureFlagKeys.CreateDefaultLocation)
.Returns(true);
var policyRequest = new SavePolicyModel(policyUpdate, null, metadata);
var policyRequest = new SavePolicyModel(policyUpdate, metadata);
// Act
await sutProvider.Sut.ExecutePostUpsertSideEffectAsync(policyRequest, postUpdatedPolicy, previousPolicyState);

View File

@@ -88,7 +88,7 @@ public class RequireSsoPolicyValidatorTests
.GetByOrganizationIdAsync(policyUpdate.OrganizationId)
.Returns(ssoConfig);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, policy);
Assert.Contains("Key Connector is enabled", result, StringComparison.OrdinalIgnoreCase);
@@ -109,7 +109,7 @@ public class RequireSsoPolicyValidatorTests
.GetByOrganizationIdAsync(policyUpdate.OrganizationId)
.Returns(ssoConfig);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, policy);
Assert.Contains("Trusted device encryption is on", result, StringComparison.OrdinalIgnoreCase);
@@ -129,7 +129,7 @@ public class RequireSsoPolicyValidatorTests
.GetByOrganizationIdAsync(policyUpdate.OrganizationId)
.Returns(ssoConfig);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, policy);
Assert.True(string.IsNullOrEmpty(result));

View File

@@ -94,7 +94,7 @@ public class ResetPasswordPolicyValidatorTests
.GetByOrganizationIdAsync(policyUpdate.OrganizationId)
.Returns(ssoConfig);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, policy);
Assert.Contains("Trusted device encryption is on and requires this policy.", result, StringComparison.OrdinalIgnoreCase);
@@ -118,7 +118,7 @@ public class ResetPasswordPolicyValidatorTests
.GetByOrganizationIdAsync(policyUpdate.OrganizationId)
.Returns(ssoConfig);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, policy);
Assert.True(string.IsNullOrEmpty(result));

View File

@@ -162,7 +162,7 @@ public class SingleOrgPolicyValidatorTests
.GetByOrganizationIdAsync(policyUpdate.OrganizationId)
.Returns(ssoConfig);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, policy);
Assert.Contains("Key Connector is enabled", result, StringComparison.OrdinalIgnoreCase);
@@ -186,7 +186,7 @@ public class SingleOrgPolicyValidatorTests
.HasVerifiedDomainsAsync(policyUpdate.OrganizationId)
.Returns(false);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, policy);
Assert.True(string.IsNullOrEmpty(result));
@@ -256,7 +256,7 @@ public class SingleOrgPolicyValidatorTests
.RevokeNonCompliantOrganizationUsersAsync(Arg.Any<RevokeOrganizationUsersRequest>())
.Returns(new CommandResult());
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
await sutProvider.Sut.ExecutePreUpsertSideEffectAsync(savePolicyModel, policy);

View File

@@ -169,7 +169,7 @@ public class TwoFactorAuthenticationPolicyValidatorTests
(orgUserDetailUserWithout2Fa, false),
});
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var exception = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.ExecutePreUpsertSideEffectAsync(savePolicyModel, policy));
@@ -228,7 +228,7 @@ public class TwoFactorAuthenticationPolicyValidatorTests
.RevokeNonCompliantOrganizationUsersAsync(Arg.Any<RevokeOrganizationUsersRequest>())
.Returns(new CommandResult());
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
// Act
await sutProvider.Sut.ExecutePreUpsertSideEffectAsync(savePolicyModel, policy);

View File

@@ -288,7 +288,7 @@ public class SavePolicyCommandTests
{
// Arrange
var sutProvider = SutProviderFactory();
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
currentPolicy.OrganizationId = policyUpdate.OrganizationId;
sutProvider.GetDependency<IPolicyRepository>()
@@ -332,7 +332,7 @@ public class SavePolicyCommandTests
var sutProvider = SutProviderFactory();
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, policyUpdate.Type)

View File

@@ -33,7 +33,7 @@ public class VNextSavePolicyCommandTests
fakePolicyValidationEvent
]);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var newPolicy = new Policy
{
@@ -77,7 +77,7 @@ public class VNextSavePolicyCommandTests
fakePolicyValidationEvent
]);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
currentPolicy.OrganizationId = policyUpdate.OrganizationId;
sutProvider.GetDependency<IPolicyRepository>()
@@ -117,7 +117,7 @@ public class VNextSavePolicyCommandTests
{
// Arrange
var sutProvider = SutProviderFactory();
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
sutProvider.GetDependency<IApplicationCacheService>()
.GetOrganizationAbilityAsync(policyUpdate.OrganizationId)
@@ -137,7 +137,7 @@ public class VNextSavePolicyCommandTests
{
// Arrange
var sutProvider = SutProviderFactory();
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
sutProvider.GetDependency<IApplicationCacheService>()
.GetOrganizationAbilityAsync(policyUpdate.OrganizationId)
@@ -167,7 +167,7 @@ public class VNextSavePolicyCommandTests
new FakeSingleOrgDependencyEvent()
]);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var requireSsoPolicy = new Policy
{
@@ -202,7 +202,7 @@ public class VNextSavePolicyCommandTests
new FakeSingleOrgDependencyEvent()
]);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var requireSsoPolicy = new Policy
{
@@ -237,7 +237,7 @@ public class VNextSavePolicyCommandTests
new FakeSingleOrgDependencyEvent()
]);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var requireSsoPolicy = new Policy
{
@@ -271,7 +271,7 @@ public class VNextSavePolicyCommandTests
new FakeSingleOrgDependencyEvent()
]);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
ArrangeOrganization(sutProvider, policyUpdate);
sutProvider.GetDependency<IPolicyRepository>()
@@ -302,7 +302,7 @@ public class VNextSavePolicyCommandTests
new FakeVaultTimeoutDependencyEvent()
]);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
ArrangeOrganization(sutProvider, policyUpdate);
sutProvider.GetDependency<IPolicyRepository>()
@@ -331,7 +331,7 @@ public class VNextSavePolicyCommandTests
new FakeSingleOrgDependencyEvent()
]);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
ArrangeOrganization(sutProvider, policyUpdate);
sutProvider.GetDependency<IPolicyRepository>()
@@ -356,7 +356,7 @@ public class VNextSavePolicyCommandTests
fakePolicyValidationEvent
]);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var singleOrgPolicy = new Policy
{

View File

@@ -1,8 +1,10 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Enums;
@@ -12,6 +14,7 @@ using Bit.Core.Auth.Services;
using Bit.Core.Exceptions;
using Bit.Core.Models.Data.Organizations.OrganizationUsers;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
@@ -364,4 +367,54 @@ public class SsoConfigServiceTests
await sutProvider.GetDependency<ISsoConfigRepository>().ReceivedWithAnyArgs()
.UpsertAsync(default);
}
[Theory, BitAutoData]
public async Task SaveAsync_Tde_WhenPolicyValidatorsRefactorEnabled_UsesVNextSavePolicyCommand(
SutProvider<SsoConfigService> sutProvider, Organization organization)
{
var ssoConfig = new SsoConfig
{
Id = default,
Data = new SsoConfigurationData
{
MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption,
}.Serialize(),
Enabled = true,
OrganizationId = organization.Id,
};
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor)
.Returns(true);
await sutProvider.Sut.SaveAsync(ssoConfig, organization);
await sutProvider.GetDependency<IVNextSavePolicyCommand>()
.Received(1)
.SaveAsync(Arg.Is<SavePolicyModel>(m =>
m.PolicyUpdate.Type == PolicyType.SingleOrg &&
m.PolicyUpdate.OrganizationId == organization.Id &&
m.PolicyUpdate.Enabled &&
m.PerformedBy is SystemUser));
await sutProvider.GetDependency<IVNextSavePolicyCommand>()
.Received(1)
.SaveAsync(Arg.Is<SavePolicyModel>(m =>
m.PolicyUpdate.Type == PolicyType.ResetPassword &&
m.PolicyUpdate.GetDataModel<ResetPasswordDataModel>().AutoEnrollEnabled &&
m.PolicyUpdate.OrganizationId == organization.Id &&
m.PolicyUpdate.Enabled &&
m.PerformedBy is SystemUser));
await sutProvider.GetDependency<IVNextSavePolicyCommand>()
.Received(1)
.SaveAsync(Arg.Is<SavePolicyModel>(m =>
m.PolicyUpdate.Type == PolicyType.RequireSso &&
m.PolicyUpdate.OrganizationId == organization.Id &&
m.PolicyUpdate.Enabled &&
m.PerformedBy is SystemUser));
await sutProvider.GetDependency<ISsoConfigRepository>().ReceivedWithAnyArgs()
.UpsertAsync(default);
}
}

View File

@@ -2286,6 +2286,63 @@ public class CipherServiceTests
.PushSyncCiphersAsync(deletingUserId);
}
[Theory]
[BitAutoData]
public async Task SoftDeleteAsync_CallsMarkAsCompleteByCipherIds(
Guid deletingUserId, CipherDetails cipherDetails, SutProvider<CipherService> sutProvider)
{
cipherDetails.UserId = deletingUserId;
cipherDetails.OrganizationId = null;
cipherDetails.DeletedDate = null;
sutProvider.GetDependency<IUserService>()
.GetUserByIdAsync(deletingUserId)
.Returns(new User
{
Id = deletingUserId,
});
await sutProvider.Sut.SoftDeleteAsync(cipherDetails, deletingUserId);
await sutProvider.GetDependency<ISecurityTaskRepository>()
.Received(1)
.MarkAsCompleteByCipherIds(Arg.Is<IEnumerable<Guid>>(ids =>
ids.Count() == 1 && ids.First() == cipherDetails.Id));
}
[Theory]
[BitAutoData]
public async Task SoftDeleteManyAsync_CallsMarkAsCompleteByCipherIds(
Guid deletingUserId, List<CipherDetails> ciphers, SutProvider<CipherService> sutProvider)
{
var cipherIds = ciphers.Select(c => c.Id).ToArray();
foreach (var cipher in ciphers)
{
cipher.UserId = deletingUserId;
cipher.OrganizationId = null;
cipher.Edit = true;
cipher.DeletedDate = null;
}
sutProvider.GetDependency<IUserService>()
.GetUserByIdAsync(deletingUserId)
.Returns(new User
{
Id = deletingUserId,
});
sutProvider.GetDependency<ICipherRepository>()
.GetManyByUserIdAsync(deletingUserId)
.Returns(ciphers);
await sutProvider.Sut.SoftDeleteManyAsync(cipherIds, deletingUserId, null, false);
await sutProvider.GetDependency<ISecurityTaskRepository>()
.Received(1)
.MarkAsCompleteByCipherIds(Arg.Is<IEnumerable<Guid>>(ids =>
ids.Count() == cipherIds.Length && ids.All(id => cipherIds.Contains(id))));
}
private async Task AssertNoActionsAsync(SutProvider<CipherService> sutProvider)
{
await sutProvider.GetDependency<ICipherRepository>().DidNotReceiveWithAnyArgs().GetManyOrganizationDetailsByOrganizationIdAsync(default);

View File

@@ -345,4 +345,110 @@ public class SecurityTaskRepositoryTests
Assert.Equal(0, metrics.CompletedTasks);
Assert.Equal(0, metrics.TotalTasks);
}
[DatabaseTheory, DatabaseData]
public async Task MarkAsCompleteByCipherIds_MarksPendingTasksAsCompleted(
IOrganizationRepository organizationRepository,
ICipherRepository cipherRepository,
ISecurityTaskRepository securityTaskRepository)
{
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
PlanType = PlanType.EnterpriseAnnually,
Plan = "Test Plan",
BillingEmail = "billing@email.com"
});
var cipher1 = await cipherRepository.CreateAsync(new Cipher
{
Type = CipherType.Login,
OrganizationId = organization.Id,
Data = "",
});
var cipher2 = await cipherRepository.CreateAsync(new Cipher
{
Type = CipherType.Login,
OrganizationId = organization.Id,
Data = "",
});
var task1 = await securityTaskRepository.CreateAsync(new SecurityTask
{
OrganizationId = organization.Id,
CipherId = cipher1.Id,
Status = SecurityTaskStatus.Pending,
Type = SecurityTaskType.UpdateAtRiskCredential,
});
var task2 = await securityTaskRepository.CreateAsync(new SecurityTask
{
OrganizationId = organization.Id,
CipherId = cipher2.Id,
Status = SecurityTaskStatus.Pending,
Type = SecurityTaskType.UpdateAtRiskCredential,
});
await securityTaskRepository.MarkAsCompleteByCipherIds([cipher1.Id, cipher2.Id]);
var updatedTask1 = await securityTaskRepository.GetByIdAsync(task1.Id);
var updatedTask2 = await securityTaskRepository.GetByIdAsync(task2.Id);
Assert.Equal(SecurityTaskStatus.Completed, updatedTask1.Status);
Assert.Equal(SecurityTaskStatus.Completed, updatedTask2.Status);
}
[DatabaseTheory, DatabaseData]
public async Task MarkAsCompleteByCipherIds_OnlyUpdatesSpecifiedCiphers(
IOrganizationRepository organizationRepository,
ICipherRepository cipherRepository,
ISecurityTaskRepository securityTaskRepository)
{
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
PlanType = PlanType.EnterpriseAnnually,
Plan = "Test Plan",
BillingEmail = "billing@email.com"
});
var cipher1 = await cipherRepository.CreateAsync(new Cipher
{
Type = CipherType.Login,
OrganizationId = organization.Id,
Data = "",
});
var cipher2 = await cipherRepository.CreateAsync(new Cipher
{
Type = CipherType.Login,
OrganizationId = organization.Id,
Data = "",
});
var taskToUpdate = await securityTaskRepository.CreateAsync(new SecurityTask
{
OrganizationId = organization.Id,
CipherId = cipher1.Id,
Status = SecurityTaskStatus.Pending,
Type = SecurityTaskType.UpdateAtRiskCredential,
});
var taskToKeep = await securityTaskRepository.CreateAsync(new SecurityTask
{
OrganizationId = organization.Id,
CipherId = cipher2.Id,
Status = SecurityTaskStatus.Pending,
Type = SecurityTaskType.UpdateAtRiskCredential,
});
await securityTaskRepository.MarkAsCompleteByCipherIds([cipher1.Id]);
var updatedTask = await securityTaskRepository.GetByIdAsync(taskToUpdate.Id);
var unchangedTask = await securityTaskRepository.GetByIdAsync(taskToKeep.Id);
Assert.Equal(SecurityTaskStatus.Completed, updatedTask.Status);
Assert.Equal(SecurityTaskStatus.Pending, unchangedTask.Status);
}
}

View File

@@ -0,0 +1,15 @@
CREATE OR ALTER PROCEDURE [dbo].[SecurityTask_MarkCompleteByCipherIds]
@CipherIds AS [dbo].[GuidIdArray] READONLY
AS
BEGIN
SET NOCOUNT ON
UPDATE
[dbo].[SecurityTask]
SET
[Status] = 1, -- Completed
[RevisionDate] = SYSUTCDATETIME()
WHERE
[CipherId] IN (SELECT [Id] FROM @CipherIds)
AND [Status] <> 1 -- Not already completed
END