1
0
mirror of https://github.com/bitwarden/server synced 2025-12-24 12:13:17 +00:00

Merge branch 'main' into km/move-models

This commit is contained in:
Bernd Schoolmann
2025-12-02 15:48:00 +01:00
committed by GitHub
106 changed files with 14956 additions and 394 deletions

View File

@@ -48,6 +48,7 @@ public class ProfileOrganizationResponseModelTests
UsersGetPremium = organization.UsersGetPremium,
UseCustomPermissions = organization.UseCustomPermissions,
UseRiskInsights = organization.UseRiskInsights,
UsePhishingBlocker = organization.UsePhishingBlocker,
UseOrganizationDomains = organization.UseOrganizationDomains,
UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies,
UseAutomaticUserConfirmation = organization.UseAutomaticUserConfirmation,

View File

@@ -45,6 +45,7 @@ public class ProfileProviderOrganizationResponseModelTests
UsersGetPremium = organization.UsersGetPremium,
UseCustomPermissions = organization.UseCustomPermissions,
UseRiskInsights = organization.UseRiskInsights,
UsePhishingBlocker = organization.UsePhishingBlocker,
UseOrganizationDomains = organization.UseOrganizationDomains,
UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies,
UseAutomaticUserConfirmation = organization.UseAutomaticUserConfirmation,

View File

@@ -0,0 +1,292 @@
using System.Net;
using System.Reflection;
using Bit.Api.Dirt.Controllers;
using Bit.Core.Entities;
using Bit.Core.Exceptions;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Mvc;
using NSubstitute;
using Xunit;
using GlobalSettings = Bit.Core.Settings.GlobalSettings;
namespace Bit.Api.Test.Dirt;
[ControllerCustomize(typeof(HibpController))]
[SutProviderCustomize]
public class HibpControllerTests : IDisposable
{
private readonly HttpClient _originalHttpClient;
private readonly FieldInfo _httpClientField;
public HibpControllerTests()
{
// Store original HttpClient for restoration
_httpClientField = typeof(HibpController).GetField("_httpClient", BindingFlags.Static | BindingFlags.NonPublic);
_originalHttpClient = (HttpClient)_httpClientField?.GetValue(null);
}
public void Dispose()
{
// Restore original HttpClient after tests
_httpClientField?.SetValue(null, _originalHttpClient);
}
[Theory, BitAutoData]
public async Task Get_WithMissingApiKey_ThrowsBadRequestException(
SutProvider<HibpController> sutProvider,
string username)
{
// Arrange
sutProvider.GetDependency<GlobalSettings>().HibpApiKey = null;
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(
async () => await sutProvider.Sut.Get(username));
Assert.Equal("HaveIBeenPwned API key not set.", exception.Message);
}
[Theory, BitAutoData]
public async Task Get_WithValidApiKeyAndNoBreaches_Returns200WithEmptyArray(
SutProvider<HibpController> sutProvider,
string username,
Guid userId)
{
// Arrange
sutProvider.GetDependency<GlobalSettings>().HibpApiKey = "test-api-key";
var user = new User { Id = userId };
sutProvider.GetDependency<IUserService>()
.GetProperUserId(Arg.Any<System.Security.Claims.ClaimsPrincipal>())
.Returns(userId);
// Mock HttpClient to return 404 (no breaches found)
var mockHttpClient = CreateMockHttpClient(HttpStatusCode.NotFound, "");
_httpClientField.SetValue(null, mockHttpClient);
// Act
var result = await sutProvider.Sut.Get(username);
// Assert
var contentResult = Assert.IsType<ContentResult>(result);
Assert.Equal("[]", contentResult.Content);
Assert.Equal("application/json", contentResult.ContentType);
}
[Theory, BitAutoData]
public async Task Get_WithValidApiKeyAndBreachesFound_Returns200WithBreachData(
SutProvider<HibpController> sutProvider,
string username,
Guid userId)
{
// Arrange
sutProvider.GetDependency<GlobalSettings>().HibpApiKey = "test-api-key";
sutProvider.GetDependency<IUserService>()
.GetProperUserId(Arg.Any<System.Security.Claims.ClaimsPrincipal>())
.Returns(userId);
var breachData = "[{\"Name\":\"Adobe\",\"Title\":\"Adobe\",\"Domain\":\"adobe.com\"}]";
var mockHttpClient = CreateMockHttpClient(HttpStatusCode.OK, breachData);
_httpClientField.SetValue(null, mockHttpClient);
// Act
var result = await sutProvider.Sut.Get(username);
// Assert
var contentResult = Assert.IsType<ContentResult>(result);
Assert.Equal(breachData, contentResult.Content);
Assert.Equal("application/json", contentResult.ContentType);
}
[Theory, BitAutoData]
public async Task Get_WithRateLimiting_RetriesWithDelay(
SutProvider<HibpController> sutProvider,
string username,
Guid userId)
{
// Arrange
sutProvider.GetDependency<GlobalSettings>().HibpApiKey = "test-api-key";
sutProvider.GetDependency<IUserService>()
.GetProperUserId(Arg.Any<System.Security.Claims.ClaimsPrincipal>())
.Returns(userId);
// First response is rate limited, second is success
var requestCount = 0;
var mockHandler = new MockHttpMessageHandler((request, cancellationToken) =>
{
requestCount++;
if (requestCount == 1)
{
var response = new HttpResponseMessage(HttpStatusCode.TooManyRequests);
response.Headers.Add("retry-after", "1");
return Task.FromResult(response);
}
else
{
return Task.FromResult(new HttpResponseMessage(HttpStatusCode.NotFound)
{
Content = new StringContent("")
});
}
});
var mockHttpClient = new HttpClient(mockHandler);
_httpClientField.SetValue(null, mockHttpClient);
// Act
var result = await sutProvider.Sut.Get(username);
// Assert
Assert.Equal(2, requestCount); // Verify retry happened
var contentResult = Assert.IsType<ContentResult>(result);
Assert.Equal("[]", contentResult.Content);
}
[Theory, BitAutoData]
public async Task Get_WithServerError_ThrowsBadRequestException(
SutProvider<HibpController> sutProvider,
string username,
Guid userId)
{
// Arrange
sutProvider.GetDependency<GlobalSettings>().HibpApiKey = "test-api-key";
sutProvider.GetDependency<IUserService>()
.GetProperUserId(Arg.Any<System.Security.Claims.ClaimsPrincipal>())
.Returns(userId);
var mockHttpClient = CreateMockHttpClient(HttpStatusCode.InternalServerError, "");
_httpClientField.SetValue(null, mockHttpClient);
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(
async () => await sutProvider.Sut.Get(username));
Assert.Contains("Request failed. Status code:", exception.Message);
}
[Theory, BitAutoData]
public async Task Get_WithBadRequest_ThrowsBadRequestException(
SutProvider<HibpController> sutProvider,
string username,
Guid userId)
{
// Arrange
sutProvider.GetDependency<GlobalSettings>().HibpApiKey = "test-api-key";
sutProvider.GetDependency<IUserService>()
.GetProperUserId(Arg.Any<System.Security.Claims.ClaimsPrincipal>())
.Returns(userId);
var mockHttpClient = CreateMockHttpClient(HttpStatusCode.BadRequest, "");
_httpClientField.SetValue(null, mockHttpClient);
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(
async () => await sutProvider.Sut.Get(username));
Assert.Contains("Request failed. Status code:", exception.Message);
}
[Theory, BitAutoData]
public async Task Get_EncodesUsernameCorrectly(
SutProvider<HibpController> sutProvider,
Guid userId)
{
// Arrange
var usernameWithSpecialChars = "test+user@example.com";
sutProvider.GetDependency<GlobalSettings>().HibpApiKey = "test-api-key";
sutProvider.GetDependency<IUserService>()
.GetProperUserId(Arg.Any<System.Security.Claims.ClaimsPrincipal>())
.Returns(userId);
string capturedUrl = null;
var mockHandler = new MockHttpMessageHandler((request, cancellationToken) =>
{
capturedUrl = request.RequestUri.ToString();
return Task.FromResult(new HttpResponseMessage(HttpStatusCode.NotFound)
{
Content = new StringContent("")
});
});
var mockHttpClient = new HttpClient(mockHandler);
_httpClientField.SetValue(null, mockHttpClient);
// Act
await sutProvider.Sut.Get(usernameWithSpecialChars);
// Assert
Assert.NotNull(capturedUrl);
// Username should be URL encoded (+ becomes %2B, @ becomes %40)
Assert.Contains("test%2Buser%40example.com", capturedUrl);
}
[Theory, BitAutoData]
public async Task SendAsync_IncludesRequiredHeaders(
SutProvider<HibpController> sutProvider,
string username,
Guid userId)
{
// Arrange
sutProvider.GetDependency<GlobalSettings>().HibpApiKey = "test-api-key";
sutProvider.GetDependency<GlobalSettings>().SelfHosted = false;
sutProvider.GetDependency<IUserService>()
.GetProperUserId(Arg.Any<System.Security.Claims.ClaimsPrincipal>())
.Returns(userId);
HttpRequestMessage capturedRequest = null;
var mockHandler = new MockHttpMessageHandler((request, cancellationToken) =>
{
capturedRequest = request;
return Task.FromResult(new HttpResponseMessage(HttpStatusCode.NotFound)
{
Content = new StringContent("")
});
});
var mockHttpClient = new HttpClient(mockHandler);
_httpClientField.SetValue(null, mockHttpClient);
// Act
await sutProvider.Sut.Get(username);
// Assert
Assert.NotNull(capturedRequest);
Assert.True(capturedRequest.Headers.Contains("hibp-api-key"));
Assert.True(capturedRequest.Headers.Contains("hibp-client-id"));
Assert.True(capturedRequest.Headers.Contains("User-Agent"));
Assert.Equal("Bitwarden", capturedRequest.Headers.GetValues("User-Agent").First());
}
/// <summary>
/// Helper to create a mock HttpClient that returns a specific status code and content
/// </summary>
private HttpClient CreateMockHttpClient(HttpStatusCode statusCode, string content)
{
var mockHandler = new MockHttpMessageHandler((request, cancellationToken) =>
{
return Task.FromResult(new HttpResponseMessage(statusCode)
{
Content = new StringContent(content)
});
});
return new HttpClient(mockHandler);
}
}
/// <summary>
/// Mock HttpMessageHandler for testing HttpClient behavior
/// </summary>
public class MockHttpMessageHandler : HttpMessageHandler
{
private readonly Func<HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>> _sendAsync;
public MockHttpMessageHandler(Func<HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>> sendAsync)
{
_sendAsync = sendAsync;
}
protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
return _sendAsync(request, cancellationToken);
}
}

View File

@@ -1,4 +1,5 @@
using Bit.Billing.Services;
using System.Globalization;
using Bit.Billing.Services;
using Bit.Billing.Services.Implementations;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
@@ -10,7 +11,8 @@ using Bit.Core.Billing.Payment.Queries;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Pricing.Premium;
using Bit.Core.Entities;
using Bit.Core.Models.Mail.UpdatedInvoiceIncoming;
using Bit.Core.Models.Mail.Billing.Renewal.Families2019Renewal;
using Bit.Core.Models.Mail.Billing.Renewal.Families2020Renewal;
using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces;
using Bit.Core.Platform.Mail.Mailer;
using Bit.Core.Repositories;
@@ -117,7 +119,7 @@ public class UpcomingInvoiceHandlerTests
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
Data = [new() { Description = "Test Item" }]
}
};
var subscription = new Subscription
@@ -126,10 +128,7 @@ public class UpcomingInvoiceHandlerTests
CustomerId = customerId,
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new() { Id = "si_123", Price = new Price { Id = Prices.PremiumAnnually } }
}
Data = [new() { Id = "si_123", Price = new Price { Id = Prices.PremiumAnnually } }]
},
AutomaticTax = new SubscriptionAutomaticTax { Enabled = false },
Customer = new Customer { Id = customerId },
@@ -199,7 +198,7 @@ public class UpcomingInvoiceHandlerTests
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
Data = [new() { Description = "Test Item" }]
}
};
var subscription = new Subscription
@@ -208,10 +207,7 @@ public class UpcomingInvoiceHandlerTests
CustomerId = customerId,
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new() { Id = priceSubscriptionId, Price = new Price { Id = Prices.PremiumAnnually } }
}
Data = [new() { Id = priceSubscriptionId, Price = new Price { Id = Prices.PremiumAnnually } }]
},
AutomaticTax = new SubscriptionAutomaticTax { Enabled = false },
Customer = new Customer
@@ -233,7 +229,7 @@ public class UpcomingInvoiceHandlerTests
var customer = new Customer
{
Id = customerId,
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } }
Subscriptions = new StripeList<Subscription> { Data = [subscription] }
};
_stripeEventService.GetInvoice(parsedEvent).Returns(invoice);
@@ -272,11 +268,12 @@ public class UpcomingInvoiceHandlerTests
o.Discounts[0].Coupon == CouponIDs.Milestone2SubscriptionDiscount &&
o.ProrationBehavior == "none"));
// Verify the updated invoice email was sent
// Verify the updated invoice email was sent with correct price
await _mailer.Received(1).SendEmail(
Arg.Is<UpdatedInvoiceUpcomingMail>(email =>
Arg.Is<Families2020RenewalMail>(email =>
email.ToEmails.Contains("user@example.com") &&
email.Subject == "Your Subscription Will Renew Soon"));
email.Subject == "Your Bitwarden Families renewal is updating" &&
email.View.MonthlyRenewalPrice == (plan.Seat.Price / 12).ToString("C", new CultureInfo("en-US"))));
}
[Fact]
@@ -291,7 +288,7 @@ public class UpcomingInvoiceHandlerTests
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
Data = [new() { Description = "Test Item" }]
}
};
var subscription = new Subscription
@@ -307,7 +304,7 @@ public class UpcomingInvoiceHandlerTests
var customer = new Customer
{
Id = "cus_123",
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } },
Subscriptions = new StripeList<Subscription> { Data = [subscription] },
Address = new Address { Country = "US" }
};
var organization = new Organization
@@ -375,7 +372,7 @@ public class UpcomingInvoiceHandlerTests
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
Data = [new() { Description = "Test Item" }]
}
};
var subscription = new Subscription
@@ -395,7 +392,7 @@ public class UpcomingInvoiceHandlerTests
var customer = new Customer
{
Id = "cus_123",
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } },
Subscriptions = new StripeList<Subscription> { Data = [subscription] },
Address = new Address { Country = "US" }
};
var organization = new Organization
@@ -469,7 +466,7 @@ public class UpcomingInvoiceHandlerTests
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
Data = [new() { Description = "Test Item" }]
}
};
var subscription = new Subscription
@@ -489,7 +486,7 @@ public class UpcomingInvoiceHandlerTests
var customer = new Customer
{
Id = "cus_123",
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } },
Subscriptions = new StripeList<Subscription> { Data = [subscription] },
Address = new Address { Country = "US" }
};
var organization = new Organization
@@ -560,7 +557,7 @@ public class UpcomingInvoiceHandlerTests
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
Data = [new() { Description = "Test Item" }]
}
};
var subscription = new Subscription
@@ -576,7 +573,7 @@ public class UpcomingInvoiceHandlerTests
var customer = new Customer
{
Id = "cus_123",
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } },
Subscriptions = new StripeList<Subscription> { Data = [subscription] },
Address = new Address { Country = "UK" },
TaxExempt = TaxExempt.None
};
@@ -622,9 +619,8 @@ public class UpcomingInvoiceHandlerTests
}
[Fact]
public async Task HandleAsync_WhenUpdateSubscriptionItemPriceIdFails_LogsErrorAndSendsEmail()
public async Task HandleAsync_WhenUpdateSubscriptionItemPriceIdFails_LogsErrorAndSendsTraditionalEmail()
{
// Arrange
// Arrange
var parsedEvent = new Event { Id = "evt_123" };
var customerId = "cus_123";
@@ -637,7 +633,7 @@ public class UpcomingInvoiceHandlerTests
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
Data = [new() { Description = "Test Item" }]
}
};
var subscription = new Subscription
@@ -646,10 +642,7 @@ public class UpcomingInvoiceHandlerTests
CustomerId = customerId,
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new() { Id = priceSubscriptionId, Price = new Price { Id = Prices.PremiumAnnually } }
}
Data = [new() { Id = priceSubscriptionId, Price = new Price { Id = Prices.PremiumAnnually } }]
},
AutomaticTax = new SubscriptionAutomaticTax { Enabled = true },
Customer = new Customer
@@ -671,7 +664,7 @@ public class UpcomingInvoiceHandlerTests
var customer = new Customer
{
Id = customerId,
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } }
Subscriptions = new StripeList<Subscription> { Data = [subscription] }
};
_stripeEventService.GetInvoice(parsedEvent).Returns(invoice);
@@ -708,11 +701,16 @@ public class UpcomingInvoiceHandlerTests
Arg.Any<Exception>(),
Arg.Any<Func<object, Exception, string>>());
// Verify that email was still sent despite the exception
await _mailer.Received(1).SendEmail(
Arg.Is<UpdatedInvoiceUpcomingMail>(email =>
email.ToEmails.Contains("user@example.com") &&
email.Subject == "Your Subscription Will Renew Soon"));
// Verify that traditional email was sent when update fails
await _mailService.Received(1).SendInvoiceUpcoming(
Arg.Is<IEnumerable<string>>(emails => emails.Contains("user@example.com")),
Arg.Is<decimal>(amount => amount == invoice.AmountDue / 100M),
Arg.Is<DateTime>(dueDate => dueDate == invoice.NextPaymentAttempt.Value),
Arg.Is<List<string>>(items => items.Count == invoice.Lines.Data.Count),
Arg.Is<bool>(b => b == true));
// Verify renewal email was NOT sent
await _mailer.DidNotReceive().SendEmail(Arg.Any<Families2020RenewalMail>());
}
[Fact]
@@ -727,7 +725,7 @@ public class UpcomingInvoiceHandlerTests
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
Data = [new() { Description = "Test Item" }]
}
};
var subscription = new Subscription
@@ -737,12 +735,12 @@ public class UpcomingInvoiceHandlerTests
Items = new StripeList<SubscriptionItem>(),
AutomaticTax = new SubscriptionAutomaticTax { Enabled = false },
Customer = new Customer { Id = "cus_123" },
Metadata = new Dictionary<string, string>(),
Metadata = new Dictionary<string, string>()
};
var customer = new Customer
{
Id = "cus_123",
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } }
Subscriptions = new StripeList<Subscription> { Data = [subscription] }
};
_stripeEventService.GetInvoice(parsedEvent).Returns(invoice);
@@ -784,7 +782,7 @@ public class UpcomingInvoiceHandlerTests
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Free Item" } }
Data = [new() { Description = "Free Item" }]
}
};
var subscription = new Subscription
@@ -800,7 +798,7 @@ public class UpcomingInvoiceHandlerTests
var customer = new Customer
{
Id = "cus_123",
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } }
Subscriptions = new StripeList<Subscription> { Data = [subscription] }
};
_stripeEventService.GetInvoice(parsedEvent).Returns(invoice);
@@ -841,7 +839,7 @@ public class UpcomingInvoiceHandlerTests
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
Data = [new() { Description = "Test Item" }]
}
};
var subscription = new Subscription
@@ -856,7 +854,7 @@ public class UpcomingInvoiceHandlerTests
var customer = new Customer
{
Id = "cus_123",
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } }
Subscriptions = new StripeList<Subscription> { Data = [subscription] }
};
_stripeEventService.GetInvoice(parsedEvent).Returns(invoice);
@@ -885,7 +883,7 @@ public class UpcomingInvoiceHandlerTests
Arg.Any<List<string>>(),
Arg.Any<bool>());
await _mailer.DidNotReceive().SendEmail(Arg.Any<UpdatedInvoiceUpcomingMail>());
await _mailer.DidNotReceive().SendEmail(Arg.Any<Families2020RenewalMail>());
}
[Fact]
@@ -900,7 +898,7 @@ public class UpcomingInvoiceHandlerTests
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
Data = [new() { Description = "Test Item" }]
}
};
var subscription = new Subscription
@@ -915,7 +913,7 @@ public class UpcomingInvoiceHandlerTests
var customer = new Customer
{
Id = "cus_123",
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } }
Subscriptions = new StripeList<Subscription> { Data = [subscription] }
};
_stripeEventService.GetInvoice(parsedEvent).Returns(invoice);
@@ -964,7 +962,7 @@ public class UpcomingInvoiceHandlerTests
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
Data = [new() { Description = "Test Item" }]
}
};
@@ -977,8 +975,8 @@ public class UpcomingInvoiceHandlerTests
CustomerId = customerId,
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
Data =
[
new()
{
Id = passwordManagerItemId,
@@ -989,7 +987,7 @@ public class UpcomingInvoiceHandlerTests
Id = premiumAccessItemId,
Price = new Price { Id = families2019Plan.PasswordManager.StripePremiumAccessPlanId }
}
}
]
},
AutomaticTax = new SubscriptionAutomaticTax { Enabled = true },
Metadata = new Dictionary<string, string>()
@@ -998,7 +996,7 @@ public class UpcomingInvoiceHandlerTests
var customer = new Customer
{
Id = customerId,
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } },
Subscriptions = new StripeList<Subscription> { Data = [subscription] },
Address = new Address { Country = "US" }
};
@@ -1009,8 +1007,11 @@ public class UpcomingInvoiceHandlerTests
PlanType = PlanType.FamiliesAnnually2019
};
var coupon = new Coupon { PercentOff = 25, Id = CouponIDs.Milestone3SubscriptionDiscount };
_stripeEventService.GetInvoice(parsedEvent).Returns(invoice);
_stripeFacade.GetCustomer(customerId, Arg.Any<CustomerGetOptions>()).Returns(customer);
_stripeFacade.GetCoupon(CouponIDs.Milestone3SubscriptionDiscount).Returns(coupon);
_stripeEventUtilityService
.GetIdsFromMetadata(subscription.Metadata)
.Returns(new Tuple<Guid?, Guid?, Guid?>(_organizationId, null, null));
@@ -1036,6 +1037,8 @@ public class UpcomingInvoiceHandlerTests
o.Discounts[0].Coupon == CouponIDs.Milestone3SubscriptionDiscount &&
o.ProrationBehavior == ProrationBehavior.None));
await _stripeFacade.Received(1).GetCoupon(CouponIDs.Milestone3SubscriptionDiscount);
await _organizationRepository.Received(1).ReplaceAsync(
Arg.Is<Organization>(org =>
org.Id == _organizationId &&
@@ -1045,9 +1048,13 @@ public class UpcomingInvoiceHandlerTests
org.Seats == familiesPlan.PasswordManager.BaseSeats));
await _mailer.Received(1).SendEmail(
Arg.Is<UpdatedInvoiceUpcomingMail>(email =>
Arg.Is<Families2019RenewalMail>(email =>
email.ToEmails.Contains("org@example.com") &&
email.Subject == "Your Subscription Will Renew Soon"));
email.Subject == "Your Bitwarden Families renewal is updating" &&
email.View.BaseMonthlyRenewalPrice == (familiesPlan.PasswordManager.BasePrice / 12).ToString("C", new CultureInfo("en-US")) &&
email.View.BaseAnnualRenewalPrice == familiesPlan.PasswordManager.BasePrice.ToString("C", new CultureInfo("en-US")) &&
email.View.DiscountAmount == $"{coupon.PercentOff}%"
));
}
[Fact]
@@ -1066,7 +1073,7 @@ public class UpcomingInvoiceHandlerTests
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
Data = [new() { Description = "Test Item" }]
}
};
@@ -1079,14 +1086,14 @@ public class UpcomingInvoiceHandlerTests
CustomerId = customerId,
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
Data =
[
new()
{
Id = passwordManagerItemId,
Price = new Price { Id = families2019Plan.PasswordManager.StripePlanId }
}
}
]
},
AutomaticTax = new SubscriptionAutomaticTax { Enabled = true },
Metadata = new Dictionary<string, string>()
@@ -1095,7 +1102,7 @@ public class UpcomingInvoiceHandlerTests
var customer = new Customer
{
Id = customerId,
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } },
Subscriptions = new StripeList<Subscription> { Data = [subscription] },
Address = new Address { Country = "US" }
};
@@ -1156,7 +1163,7 @@ public class UpcomingInvoiceHandlerTests
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
Data = [new() { Description = "Test Item" }]
}
};
@@ -1168,14 +1175,14 @@ public class UpcomingInvoiceHandlerTests
CustomerId = customerId,
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
Data =
[
new()
{
Id = passwordManagerItemId,
Price = new Price { Id = families2019Plan.PasswordManager.StripePlanId }
}
}
]
},
AutomaticTax = new SubscriptionAutomaticTax { Enabled = true },
Metadata = new Dictionary<string, string>()
@@ -1184,7 +1191,7 @@ public class UpcomingInvoiceHandlerTests
var customer = new Customer
{
Id = customerId,
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } },
Subscriptions = new StripeList<Subscription> { Data = [subscription] },
Address = new Address { Country = "US" }
};
@@ -1232,7 +1239,7 @@ public class UpcomingInvoiceHandlerTests
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
Data = [new() { Description = "Test Item" }]
}
};
@@ -1244,14 +1251,10 @@ public class UpcomingInvoiceHandlerTests
CustomerId = customerId,
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new()
{
Id = "si_pm_123",
Price = new Price { Id = familiesPlan.PasswordManager.StripePlanId }
}
}
Data =
[
new() { Id = "si_pm_123", Price = new Price { Id = familiesPlan.PasswordManager.StripePlanId } }
]
},
AutomaticTax = new SubscriptionAutomaticTax { Enabled = true },
Metadata = new Dictionary<string, string>()
@@ -1260,7 +1263,7 @@ public class UpcomingInvoiceHandlerTests
var customer = new Customer
{
Id = customerId,
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } },
Subscriptions = new StripeList<Subscription> { Data = [subscription] },
Address = new Address { Country = "US" }
};
@@ -1307,7 +1310,7 @@ public class UpcomingInvoiceHandlerTests
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
Data = [new() { Description = "Test Item" }]
}
};
@@ -1319,14 +1322,10 @@ public class UpcomingInvoiceHandlerTests
CustomerId = customerId,
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new()
{
Id = "si_different_item",
Price = new Price { Id = "different-price-id" }
}
}
Data =
[
new() { Id = "si_different_item", Price = new Price { Id = "different-price-id" } }
]
},
AutomaticTax = new SubscriptionAutomaticTax { Enabled = true },
Metadata = new Dictionary<string, string>()
@@ -1335,7 +1334,7 @@ public class UpcomingInvoiceHandlerTests
var customer = new Customer
{
Id = customerId,
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } },
Subscriptions = new StripeList<Subscription> { Data = [subscription] },
Address = new Address { Country = "US" }
};
@@ -1378,7 +1377,7 @@ public class UpcomingInvoiceHandlerTests
}
[Fact]
public async Task HandleAsync_WhenMilestone3Enabled_AndUpdateFails_LogsError()
public async Task HandleAsync_WhenMilestone3Enabled_AndUpdateFails_LogsErrorAndSendsTraditionalEmail()
{
// Arrange
var parsedEvent = new Event { Id = "evt_123", Type = "invoice.upcoming" };
@@ -1393,7 +1392,7 @@ public class UpcomingInvoiceHandlerTests
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
Data = [new() { Description = "Test Item" }]
}
};
@@ -1406,14 +1405,14 @@ public class UpcomingInvoiceHandlerTests
CustomerId = customerId,
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
Data =
[
new()
{
Id = passwordManagerItemId,
Price = new Price { Id = families2019Plan.PasswordManager.StripePlanId }
}
}
]
},
AutomaticTax = new SubscriptionAutomaticTax { Enabled = true },
Metadata = new Dictionary<string, string>()
@@ -1422,7 +1421,7 @@ public class UpcomingInvoiceHandlerTests
var customer = new Customer
{
Id = customerId,
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } },
Subscriptions = new StripeList<Subscription> { Data = [subscription] },
Address = new Address { Country = "US" }
};
@@ -1463,11 +1462,16 @@ public class UpcomingInvoiceHandlerTests
Arg.Any<Exception>(),
Arg.Any<Func<object, Exception, string>>());
// Should still attempt to send email despite the failure
await _mailer.Received(1).SendEmail(
Arg.Is<UpdatedInvoiceUpcomingMail>(email =>
email.ToEmails.Contains("org@example.com") &&
email.Subject == "Your Subscription Will Renew Soon"));
// Should send traditional email when update fails
await _mailService.Received(1).SendInvoiceUpcoming(
Arg.Is<IEnumerable<string>>(emails => emails.Contains("org@example.com")),
Arg.Is<decimal>(amount => amount == invoice.AmountDue / 100M),
Arg.Is<DateTime>(dueDate => dueDate == invoice.NextPaymentAttempt.Value),
Arg.Is<List<string>>(items => items.Count == invoice.Lines.Data.Count),
Arg.Is<bool>(b => b == true));
// Verify renewal email was NOT sent
await _mailer.DidNotReceive().SendEmail(Arg.Any<Families2020RenewalMail>());
}
[Fact]
@@ -1487,7 +1491,7 @@ public class UpcomingInvoiceHandlerTests
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
Data = [new() { Description = "Test Item" }]
}
};
@@ -1500,20 +1504,21 @@ public class UpcomingInvoiceHandlerTests
CustomerId = customerId,
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
Data =
[
new()
{
Id = passwordManagerItemId,
Price = new Price { Id = families2019Plan.PasswordManager.StripePlanId }
},
new()
{
Id = seatAddOnItemId,
Price = new Price { Id = "personal-org-seat-annually" },
Quantity = 3
}
}
]
},
AutomaticTax = new SubscriptionAutomaticTax { Enabled = true },
Metadata = new Dictionary<string, string>()
@@ -1522,7 +1527,7 @@ public class UpcomingInvoiceHandlerTests
var customer = new Customer
{
Id = customerId,
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } },
Subscriptions = new StripeList<Subscription> { Data = [subscription] },
Address = new Address { Country = "US" }
};
@@ -1533,8 +1538,11 @@ public class UpcomingInvoiceHandlerTests
PlanType = PlanType.FamiliesAnnually2019
};
var coupon = new Coupon { PercentOff = 25, Id = CouponIDs.Milestone3SubscriptionDiscount };
_stripeEventService.GetInvoice(parsedEvent).Returns(invoice);
_stripeFacade.GetCustomer(customerId, Arg.Any<CustomerGetOptions>()).Returns(customer);
_stripeFacade.GetCoupon(CouponIDs.Milestone3SubscriptionDiscount).Returns(coupon);
_stripeEventUtilityService
.GetIdsFromMetadata(subscription.Metadata)
.Returns(new Tuple<Guid?, Guid?, Guid?>(_organizationId, null, null));
@@ -1560,6 +1568,8 @@ public class UpcomingInvoiceHandlerTests
o.Discounts[0].Coupon == CouponIDs.Milestone3SubscriptionDiscount &&
o.ProrationBehavior == ProrationBehavior.None));
await _stripeFacade.Received(1).GetCoupon(CouponIDs.Milestone3SubscriptionDiscount);
await _organizationRepository.Received(1).ReplaceAsync(
Arg.Is<Organization>(org =>
org.Id == _organizationId &&
@@ -1569,9 +1579,13 @@ public class UpcomingInvoiceHandlerTests
org.Seats == familiesPlan.PasswordManager.BaseSeats));
await _mailer.Received(1).SendEmail(
Arg.Is<UpdatedInvoiceUpcomingMail>(email =>
Arg.Is<Families2019RenewalMail>(email =>
email.ToEmails.Contains("org@example.com") &&
email.Subject == "Your Subscription Will Renew Soon"));
email.Subject == "Your Bitwarden Families renewal is updating" &&
email.View.BaseMonthlyRenewalPrice == (familiesPlan.PasswordManager.BasePrice / 12).ToString("C", new CultureInfo("en-US")) &&
email.View.BaseAnnualRenewalPrice == familiesPlan.PasswordManager.BasePrice.ToString("C", new CultureInfo("en-US")) &&
email.View.DiscountAmount == $"{coupon.PercentOff}%"
));
}
[Fact]
@@ -1591,7 +1605,7 @@ public class UpcomingInvoiceHandlerTests
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
Data = [new() { Description = "Test Item" }]
}
};
@@ -1604,20 +1618,21 @@ public class UpcomingInvoiceHandlerTests
CustomerId = customerId,
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
Data =
[
new()
{
Id = passwordManagerItemId,
Price = new Price { Id = families2019Plan.PasswordManager.StripePlanId }
},
new()
{
Id = seatAddOnItemId,
Price = new Price { Id = "personal-org-seat-annually" },
Quantity = 1
}
}
]
},
AutomaticTax = new SubscriptionAutomaticTax { Enabled = true },
Metadata = new Dictionary<string, string>()
@@ -1626,7 +1641,7 @@ public class UpcomingInvoiceHandlerTests
var customer = new Customer
{
Id = customerId,
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } },
Subscriptions = new StripeList<Subscription> { Data = [subscription] },
Address = new Address { Country = "US" }
};
@@ -1637,8 +1652,11 @@ public class UpcomingInvoiceHandlerTests
PlanType = PlanType.FamiliesAnnually2019
};
var coupon = new Coupon { PercentOff = 25, Id = CouponIDs.Milestone3SubscriptionDiscount };
_stripeEventService.GetInvoice(parsedEvent).Returns(invoice);
_stripeFacade.GetCustomer(customerId, Arg.Any<CustomerGetOptions>()).Returns(customer);
_stripeFacade.GetCoupon(CouponIDs.Milestone3SubscriptionDiscount).Returns(coupon);
_stripeEventUtilityService
.GetIdsFromMetadata(subscription.Metadata)
.Returns(new Tuple<Guid?, Guid?, Guid?>(_organizationId, null, null));
@@ -1664,6 +1682,8 @@ public class UpcomingInvoiceHandlerTests
o.Discounts[0].Coupon == CouponIDs.Milestone3SubscriptionDiscount &&
o.ProrationBehavior == ProrationBehavior.None));
await _stripeFacade.Received(1).GetCoupon(CouponIDs.Milestone3SubscriptionDiscount);
await _organizationRepository.Received(1).ReplaceAsync(
Arg.Is<Organization>(org =>
org.Id == _organizationId &&
@@ -1673,9 +1693,13 @@ public class UpcomingInvoiceHandlerTests
org.Seats == familiesPlan.PasswordManager.BaseSeats));
await _mailer.Received(1).SendEmail(
Arg.Is<UpdatedInvoiceUpcomingMail>(email =>
Arg.Is<Families2019RenewalMail>(email =>
email.ToEmails.Contains("org@example.com") &&
email.Subject == "Your Subscription Will Renew Soon"));
email.Subject == "Your Bitwarden Families renewal is updating" &&
email.View.BaseMonthlyRenewalPrice == (familiesPlan.PasswordManager.BasePrice / 12).ToString("C", new CultureInfo("en-US")) &&
email.View.BaseAnnualRenewalPrice == familiesPlan.PasswordManager.BasePrice.ToString("C", new CultureInfo("en-US")) &&
email.View.DiscountAmount == $"{coupon.PercentOff}%"
));
}
[Fact]
@@ -1696,7 +1720,7 @@ public class UpcomingInvoiceHandlerTests
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
Data = [new() { Description = "Test Item" }]
}
};
@@ -1709,25 +1733,27 @@ public class UpcomingInvoiceHandlerTests
CustomerId = customerId,
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
Data =
[
new()
{
Id = passwordManagerItemId,
Price = new Price { Id = families2019Plan.PasswordManager.StripePlanId }
},
new()
{
Id = premiumAccessItemId,
Price = new Price { Id = families2019Plan.PasswordManager.StripePremiumAccessPlanId }
},
new()
{
Id = seatAddOnItemId,
Price = new Price { Id = "personal-org-seat-annually" },
Quantity = 2
}
}
]
},
AutomaticTax = new SubscriptionAutomaticTax { Enabled = true },
Metadata = new Dictionary<string, string>()
@@ -1736,7 +1762,7 @@ public class UpcomingInvoiceHandlerTests
var customer = new Customer
{
Id = customerId,
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } },
Subscriptions = new StripeList<Subscription> { Data = [subscription] },
Address = new Address { Country = "US" }
};
@@ -1747,8 +1773,11 @@ public class UpcomingInvoiceHandlerTests
PlanType = PlanType.FamiliesAnnually2019
};
var coupon = new Coupon { PercentOff = 25, Id = CouponIDs.Milestone3SubscriptionDiscount };
_stripeEventService.GetInvoice(parsedEvent).Returns(invoice);
_stripeFacade.GetCustomer(customerId, Arg.Any<CustomerGetOptions>()).Returns(customer);
_stripeFacade.GetCoupon(CouponIDs.Milestone3SubscriptionDiscount).Returns(coupon);
_stripeEventUtilityService
.GetIdsFromMetadata(subscription.Metadata)
.Returns(new Tuple<Guid?, Guid?, Guid?>(_organizationId, null, null));
@@ -1776,6 +1805,8 @@ public class UpcomingInvoiceHandlerTests
o.Discounts[0].Coupon == CouponIDs.Milestone3SubscriptionDiscount &&
o.ProrationBehavior == ProrationBehavior.None));
await _stripeFacade.Received(1).GetCoupon(CouponIDs.Milestone3SubscriptionDiscount);
await _organizationRepository.Received(1).ReplaceAsync(
Arg.Is<Organization>(org =>
org.Id == _organizationId &&
@@ -1785,9 +1816,13 @@ public class UpcomingInvoiceHandlerTests
org.Seats == familiesPlan.PasswordManager.BaseSeats));
await _mailer.Received(1).SendEmail(
Arg.Is<UpdatedInvoiceUpcomingMail>(email =>
Arg.Is<Families2019RenewalMail>(email =>
email.ToEmails.Contains("org@example.com") &&
email.Subject == "Your Subscription Will Renew Soon"));
email.Subject == "Your Bitwarden Families renewal is updating" &&
email.View.BaseMonthlyRenewalPrice == (familiesPlan.PasswordManager.BasePrice / 12).ToString("C", new CultureInfo("en-US")) &&
email.View.BaseAnnualRenewalPrice == familiesPlan.PasswordManager.BasePrice.ToString("C", new CultureInfo("en-US")) &&
email.View.DiscountAmount == $"{coupon.PercentOff}%"
));
}
[Fact]
@@ -1806,7 +1841,7 @@ public class UpcomingInvoiceHandlerTests
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
Data = [new() { Description = "Test Item" }]
}
};
@@ -1819,14 +1854,14 @@ public class UpcomingInvoiceHandlerTests
CustomerId = customerId,
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
Data =
[
new()
{
Id = passwordManagerItemId,
Price = new Price { Id = families2025Plan.PasswordManager.StripePlanId }
}
}
]
},
AutomaticTax = new SubscriptionAutomaticTax { Enabled = true },
Metadata = new Dictionary<string, string>()
@@ -1835,7 +1870,7 @@ public class UpcomingInvoiceHandlerTests
var customer = new Customer
{
Id = customerId,
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } },
Subscriptions = new StripeList<Subscription> { Data = [subscription] },
Address = new Address { Country = "US" }
};
@@ -1877,6 +1912,12 @@ public class UpcomingInvoiceHandlerTests
org.Plan == familiesPlan.Name &&
org.UsersGetPremium == familiesPlan.UsersGetPremium &&
org.Seats == familiesPlan.PasswordManager.BaseSeats));
await _mailer.Received(1).SendEmail(
Arg.Is<Families2020RenewalMail>(email =>
email.ToEmails.Contains("org@example.com") &&
email.Subject == "Your Bitwarden Families renewal is updating" &&
email.View.MonthlyRenewalPrice == (familiesPlan.PasswordManager.BasePrice / 12).ToString("C", new CultureInfo("en-US"))));
}
[Fact]
@@ -1895,7 +1936,7 @@ public class UpcomingInvoiceHandlerTests
NextPaymentAttempt = DateTime.UtcNow.AddDays(7),
Lines = new StripeList<InvoiceLineItem>
{
Data = new List<InvoiceLineItem> { new() { Description = "Test Item" } }
Data = [new() { Description = "Test Item" }]
}
};
@@ -1907,14 +1948,14 @@ public class UpcomingInvoiceHandlerTests
CustomerId = customerId,
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
Data =
[
new()
{
Id = passwordManagerItemId,
Price = new Price { Id = families2025Plan.PasswordManager.StripePlanId }
}
}
]
},
AutomaticTax = new SubscriptionAutomaticTax { Enabled = true },
Metadata = new Dictionary<string, string>()
@@ -1923,7 +1964,7 @@ public class UpcomingInvoiceHandlerTests
var customer = new Customer
{
Id = customerId,
Subscriptions = new StripeList<Subscription> { Data = new List<Subscription> { subscription } },
Subscriptions = new StripeList<Subscription> { Data = [subscription] },
Address = new Address { Country = "US" }
};

View File

@@ -11,11 +11,11 @@
<ItemGroup>
<PackageReference Include="coverlet.collector" Version="6.0.4" />
<PackageReference Include="MartinCostello.Logging.XUnit" Version="0.5.1" />
<PackageReference Include="MartinCostello.Logging.XUnit" Version="0.7.0" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.8.0" />
<PackageReference Include="Rnwood.SmtpServer" Version="3.1.0-ci0868" />
<PackageReference Include="xunit" Version="2.9.3" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.8.2" />
<PackageReference Include="xunit.runner.visualstudio" Version="3.1.5" />
</ItemGroup>
<ItemGroup>

View File

@@ -6,8 +6,11 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Models;
using Bit.Core.Models.Data.Organizations;
using Bit.Core.Platform.Push;
using Bit.Core.Services;
using Bit.Core.Test.AdminConsole.AutoFixture;
using Bit.Test.Common.AutoFixture;
@@ -95,7 +98,8 @@ public class SavePolicyCommandTests
Substitute.For<IPolicyRepository>(),
[new FakeSingleOrgPolicyValidator(), new FakeSingleOrgPolicyValidator()],
Substitute.For<TimeProvider>(),
Substitute.For<IPostSavePolicySideEffect>()));
Substitute.For<IPostSavePolicySideEffect>(),
Substitute.For<IPushNotificationService>()));
Assert.Contains("Duplicate PolicyValidator for SingleOrg policy", exception.Message);
}
@@ -360,6 +364,103 @@ public class SavePolicyCommandTests
.ExecuteSideEffectsAsync(default!, default!, default!);
}
[Theory, BitAutoData]
public async Task VNextSaveAsync_SendsPushNotification(
[PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate,
[Policy(PolicyType.SingleOrg, false)] Policy currentPolicy)
{
// Arrange
var fakePolicyValidator = new FakeSingleOrgPolicyValidator();
fakePolicyValidator.ValidateAsyncMock(policyUpdate, null).Returns("");
var sutProvider = SutProviderFactory([fakePolicyValidator]);
var savePolicyModel = new SavePolicyModel(policyUpdate);
currentPolicy.OrganizationId = policyUpdate.OrganizationId;
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, policyUpdate.Type)
.Returns(currentPolicy);
ArrangeOrganization(sutProvider, policyUpdate);
sutProvider.GetDependency<IPolicyRepository>()
.GetManyByOrganizationIdAsync(policyUpdate.OrganizationId)
.Returns([currentPolicy]);
// Act
var result = await sutProvider.Sut.VNextSaveAsync(savePolicyModel);
// Assert
await sutProvider.GetDependency<IPushNotificationService>().Received(1)
.PushAsync(Arg.Is<PushNotification<SyncPolicyPushNotification>>(p =>
p.Type == PushType.PolicyChanged &&
p.Target == NotificationTarget.Organization &&
p.TargetId == policyUpdate.OrganizationId &&
p.ExcludeCurrentContext == false &&
p.Payload.OrganizationId == policyUpdate.OrganizationId &&
p.Payload.Policy.Id == result.Id &&
p.Payload.Policy.Type == policyUpdate.Type &&
p.Payload.Policy.Enabled == policyUpdate.Enabled &&
p.Payload.Policy.Data == policyUpdate.Data));
}
[Theory, BitAutoData]
public async Task SaveAsync_SendsPushNotification([PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate)
{
var fakePolicyValidator = new FakeSingleOrgPolicyValidator();
fakePolicyValidator.ValidateAsyncMock(policyUpdate, null).Returns("");
var sutProvider = SutProviderFactory([fakePolicyValidator]);
ArrangeOrganization(sutProvider, policyUpdate);
sutProvider.GetDependency<IPolicyRepository>().GetManyByOrganizationIdAsync(policyUpdate.OrganizationId).Returns([]);
var result = await sutProvider.Sut.SaveAsync(policyUpdate);
await sutProvider.GetDependency<IPushNotificationService>().Received(1)
.PushAsync(Arg.Is<PushNotification<SyncPolicyPushNotification>>(p =>
p.Type == PushType.PolicyChanged &&
p.Target == NotificationTarget.Organization &&
p.TargetId == policyUpdate.OrganizationId &&
p.ExcludeCurrentContext == false &&
p.Payload.OrganizationId == policyUpdate.OrganizationId &&
p.Payload.Policy.Id == result.Id &&
p.Payload.Policy.Type == policyUpdate.Type &&
p.Payload.Policy.Enabled == policyUpdate.Enabled &&
p.Payload.Policy.Data == policyUpdate.Data));
}
[Theory, BitAutoData]
public async Task SaveAsync_ExistingPolicy_SendsPushNotificationWithUpdatedPolicy(
[PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate,
[Policy(PolicyType.SingleOrg, false)] Policy currentPolicy)
{
var fakePolicyValidator = new FakeSingleOrgPolicyValidator();
fakePolicyValidator.ValidateAsyncMock(policyUpdate, null).Returns("");
var sutProvider = SutProviderFactory([fakePolicyValidator]);
currentPolicy.OrganizationId = policyUpdate.OrganizationId;
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, policyUpdate.Type)
.Returns(currentPolicy);
ArrangeOrganization(sutProvider, policyUpdate);
sutProvider.GetDependency<IPolicyRepository>()
.GetManyByOrganizationIdAsync(policyUpdate.OrganizationId)
.Returns([currentPolicy]);
var result = await sutProvider.Sut.SaveAsync(policyUpdate);
await sutProvider.GetDependency<IPushNotificationService>().Received(1)
.PushAsync(Arg.Is<PushNotification<SyncPolicyPushNotification>>(p =>
p.Type == PushType.PolicyChanged &&
p.Target == NotificationTarget.Organization &&
p.TargetId == policyUpdate.OrganizationId &&
p.ExcludeCurrentContext == false &&
p.Payload.OrganizationId == policyUpdate.OrganizationId &&
p.Payload.Policy.Id == result.Id &&
p.Payload.Policy.Type == policyUpdate.Type &&
p.Payload.Policy.Enabled == policyUpdate.Enabled &&
p.Payload.Policy.Data == policyUpdate.Data));
}
/// <summary>
/// Returns a new SutProvider with the PolicyValidators registered in the Sut.
/// </summary>

View File

@@ -0,0 +1,275 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Auth.Sso;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Auth.UserFeatures.Sso;
[SutProviderCustomize]
public class UserSsoOrganizationIdentifierQueryTests
{
[Theory, BitAutoData]
public async Task GetSsoOrganizationIdentifierAsync_UserHasSingleConfirmedOrganization_ReturnsIdentifier(
SutProvider<UserSsoOrganizationIdentifierQuery> sutProvider,
Guid userId,
Organization organization,
OrganizationUser organizationUser)
{
// Arrange
organizationUser.UserId = userId;
organizationUser.OrganizationId = organization.Id;
organizationUser.Status = OrganizationUserStatusType.Confirmed;
organization.Identifier = "test-org-identifier";
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyByUserAsync(userId)
.Returns([organizationUser]);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
// Act
var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId);
// Assert
Assert.Equal("test-org-identifier", result);
await sutProvider.GetDependency<IOrganizationUserRepository>()
.Received(1)
.GetManyByUserAsync(userId);
await sutProvider.GetDependency<IOrganizationRepository>()
.Received(1)
.GetByIdAsync(organization.Id);
}
[Theory, BitAutoData]
public async Task GetSsoOrganizationIdentifierAsync_UserHasNoOrganizations_ReturnsNull(
SutProvider<UserSsoOrganizationIdentifierQuery> sutProvider,
Guid userId)
{
// Arrange
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyByUserAsync(userId)
.Returns(Array.Empty<OrganizationUser>());
// Act
var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId);
// Assert
Assert.Null(result);
await sutProvider.GetDependency<IOrganizationUserRepository>()
.Received(1)
.GetManyByUserAsync(userId);
await sutProvider.GetDependency<IOrganizationRepository>()
.DidNotReceive()
.GetByIdAsync(Arg.Any<Guid>());
}
[Theory, BitAutoData]
public async Task GetSsoOrganizationIdentifierAsync_UserHasMultipleConfirmedOrganizations_ReturnsNull(
SutProvider<UserSsoOrganizationIdentifierQuery> sutProvider,
Guid userId,
OrganizationUser organizationUser1,
OrganizationUser organizationUser2)
{
// Arrange
organizationUser1.UserId = userId;
organizationUser1.Status = OrganizationUserStatusType.Confirmed;
organizationUser2.UserId = userId;
organizationUser2.Status = OrganizationUserStatusType.Confirmed;
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyByUserAsync(userId)
.Returns([organizationUser1, organizationUser2]);
// Act
var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId);
// Assert
Assert.Null(result);
await sutProvider.GetDependency<IOrganizationUserRepository>()
.Received(1)
.GetManyByUserAsync(userId);
await sutProvider.GetDependency<IOrganizationRepository>()
.DidNotReceive()
.GetByIdAsync(Arg.Any<Guid>());
}
[Theory]
[BitAutoData(OrganizationUserStatusType.Invited)]
[BitAutoData(OrganizationUserStatusType.Accepted)]
[BitAutoData(OrganizationUserStatusType.Revoked)]
public async Task GetSsoOrganizationIdentifierAsync_UserHasOnlyInvitedOrganization_ReturnsNull(
OrganizationUserStatusType status,
SutProvider<UserSsoOrganizationIdentifierQuery> sutProvider,
Guid userId,
OrganizationUser organizationUser)
{
// Arrange
organizationUser.UserId = userId;
organizationUser.Status = status;
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyByUserAsync(userId)
.Returns([organizationUser]);
// Act
var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId);
// Assert
Assert.Null(result);
await sutProvider.GetDependency<IOrganizationUserRepository>()
.Received(1)
.GetManyByUserAsync(userId);
await sutProvider.GetDependency<IOrganizationRepository>()
.DidNotReceive()
.GetByIdAsync(Arg.Any<Guid>());
}
[Theory, BitAutoData]
public async Task GetSsoOrganizationIdentifierAsync_UserHasMixedStatusOrganizations_OnlyOneConfirmed_ReturnsIdentifier(
SutProvider<UserSsoOrganizationIdentifierQuery> sutProvider,
Guid userId,
Organization organization,
OrganizationUser confirmedOrgUser,
OrganizationUser invitedOrgUser,
OrganizationUser revokedOrgUser)
{
// Arrange
confirmedOrgUser.UserId = userId;
confirmedOrgUser.OrganizationId = organization.Id;
confirmedOrgUser.Status = OrganizationUserStatusType.Confirmed;
invitedOrgUser.UserId = userId;
invitedOrgUser.Status = OrganizationUserStatusType.Invited;
revokedOrgUser.UserId = userId;
revokedOrgUser.Status = OrganizationUserStatusType.Revoked;
organization.Identifier = "mixed-status-org";
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyByUserAsync(userId)
.Returns(new[] { confirmedOrgUser, invitedOrgUser, revokedOrgUser });
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
// Act
var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId);
// Assert
Assert.Equal("mixed-status-org", result);
await sutProvider.GetDependency<IOrganizationUserRepository>()
.Received(1)
.GetManyByUserAsync(userId);
await sutProvider.GetDependency<IOrganizationRepository>()
.Received(1)
.GetByIdAsync(organization.Id);
}
[Theory, BitAutoData]
public async Task GetSsoOrganizationIdentifierAsync_OrganizationNotFound_ReturnsNull(
SutProvider<UserSsoOrganizationIdentifierQuery> sutProvider,
Guid userId,
OrganizationUser organizationUser)
{
// Arrange
organizationUser.UserId = userId;
organizationUser.Status = OrganizationUserStatusType.Confirmed;
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyByUserAsync(userId)
.Returns([organizationUser]);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organizationUser.OrganizationId)
.Returns((Organization)null);
// Act
var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId);
// Assert
Assert.Null(result);
await sutProvider.GetDependency<IOrganizationUserRepository>()
.Received(1)
.GetManyByUserAsync(userId);
await sutProvider.GetDependency<IOrganizationRepository>()
.Received(1)
.GetByIdAsync(organizationUser.OrganizationId);
}
[Theory, BitAutoData]
public async Task GetSsoOrganizationIdentifierAsync_OrganizationIdentifierIsNull_ReturnsNull(
SutProvider<UserSsoOrganizationIdentifierQuery> sutProvider,
Guid userId,
Organization organization,
OrganizationUser organizationUser)
{
// Arrange
organizationUser.UserId = userId;
organizationUser.OrganizationId = organization.Id;
organizationUser.Status = OrganizationUserStatusType.Confirmed;
organization.Identifier = null;
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyByUserAsync(userId)
.Returns(new[] { organizationUser });
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
// Act
var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId);
// Assert
Assert.Null(result);
await sutProvider.GetDependency<IOrganizationUserRepository>()
.Received(1)
.GetManyByUserAsync(userId);
await sutProvider.GetDependency<IOrganizationRepository>()
.Received(1)
.GetByIdAsync(organization.Id);
}
[Theory, BitAutoData]
public async Task GetSsoOrganizationIdentifierAsync_OrganizationIdentifierIsEmpty_ReturnsEmpty(
SutProvider<UserSsoOrganizationIdentifierQuery> sutProvider,
Guid userId,
Organization organization,
OrganizationUser organizationUser)
{
// Arrange
organizationUser.UserId = userId;
organizationUser.OrganizationId = organization.Id;
organizationUser.Status = OrganizationUserStatusType.Confirmed;
organization.Identifier = string.Empty;
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyByUserAsync(userId)
.Returns(new[] { organizationUser });
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
// Act
var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId);
// Assert
Assert.Equal(string.Empty, result);
await sutProvider.GetDependency<IOrganizationUserRepository>()
.Received(1)
.GetManyByUserAsync(userId);
await sutProvider.GetDependency<IOrganizationRepository>()
.Received(1)
.GetByIdAsync(organization.Id);
}
}

View File

@@ -213,7 +213,8 @@ If you believe you need to change the version for a valid reason, please discuss
LimitCollectionDeletion = true,
AllowAdminAccessToAllCollectionItems = true,
UseOrganizationDomains = true,
UseAdminSponsoredFamilies = false
UseAdminSponsoredFamilies = false,
UsePhishingBlocker = false,
};
}

View File

@@ -88,7 +88,7 @@ public class UpdateOrganizationLicenseCommandTests
"Hash", "Signature", "SignatureBytes", "InstallationId", "Expires",
"ExpirationWithoutGracePeriod", "Token", "LimitCollectionCreationDeletion",
"LimitCollectionCreation", "LimitCollectionDeletion", "AllowAdminAccessToAllCollectionItems",
"UseOrganizationDomains", "UseAdminSponsoredFamilies", "UseAutomaticUserConfirmation") &&
"UseOrganizationDomains", "UseAdminSponsoredFamilies", "UseAutomaticUserConfirmation", "UsePhishingBlocker") &&
// Same property but different name, use explicit mapping
org.ExpirationDate == license.Expires));
}

View File

@@ -107,30 +107,6 @@ public class CurrentContextTests
Assert.Equal(deviceType, sutProvider.Sut.DeviceType);
}
[Theory, BitAutoData]
public async Task BuildAsync_HttpContext_SetsCloudflareFlags(
SutProvider<CurrentContext> sutProvider)
{
var httpContext = new DefaultHttpContext();
var globalSettings = new Core.Settings.GlobalSettings();
sutProvider.Sut.BotScore = null;
// Arrange
var botScore = 85;
httpContext.Request.Headers["X-Cf-Bot-Score"] = botScore.ToString();
httpContext.Request.Headers["X-Cf-Worked-Proxied"] = "1";
httpContext.Request.Headers["X-Cf-Is-Bot"] = "1";
httpContext.Request.Headers["X-Cf-Maybe-Bot"] = "1";
// Act
await sutProvider.Sut.BuildAsync(httpContext, globalSettings);
// Assert
Assert.True(sutProvider.Sut.CloudflareWorkerProxied);
Assert.True(sutProvider.Sut.IsBot);
Assert.True(sutProvider.Sut.MaybeBot);
Assert.Equal(botScore, sutProvider.Sut.BotScore);
}
[Theory, BitAutoData]
public async Task BuildAsync_HttpContext_SetsClientVersion(
SutProvider<CurrentContext> sutProvider)

View File

@@ -74,7 +74,7 @@ public class SendGridMailDeliveryServiceTests : IDisposable
Assert.Equal(mailMessage.HtmlContent, msg.HtmlContent);
Assert.Equal(mailMessage.TextContent, msg.PlainTextContent);
Assert.Contains("type:Cateogry", msg.Categories);
Assert.Contains("type:Category", msg.Categories);
Assert.Contains(msg.Categories, x => x.StartsWith("env:"));
Assert.Contains(msg.Categories, x => x.StartsWith("sender:"));

View File

@@ -44,14 +44,17 @@ internal class CustomValidatorRequestContextCustomization : ICustomization
/// <see cref="CustomValidatorRequestContext.TwoFactorRecoveryRequested"/>, and
/// <see cref="CustomValidatorRequestContext.SsoRequired" /> should initialize false,
/// and are made truthy in context upon evaluation of a request. Do not allow AutoFixture to eagerly make these
/// truthy; that is the responsibility of the <see cref="Bit.Identity.IdentityServer.RequestValidators.BaseRequestValidator{T}" />
/// truthy; that is the responsibility of the <see cref="Bit.Identity.IdentityServer.RequestValidators.BaseRequestValidator{T}" />.
/// ValidationErrorResult and CustomResponse should also be null initially; they are hydrated during the validation process.
/// </summary>
public void Customize(IFixture fixture)
{
fixture.Customize<CustomValidatorRequestContext>(composer => composer
.With(o => o.RememberMeRequested, false)
.With(o => o.TwoFactorRecoveryRequested, false)
.With(o => o.SsoRequired, false));
.With(o => o.SsoRequired, false)
.With(o => o.ValidationErrorResult, () => null)
.With(o => o.CustomResponse, () => null));
}
}

View File

@@ -21,6 +21,7 @@ using Bit.Identity.IdentityServer;
using Bit.Identity.IdentityServer.RequestValidators;
using Bit.Identity.Test.Wrappers;
using Bit.Test.Common.AutoFixture.Attributes;
using Duende.IdentityModel;
using Duende.IdentityServer.Validation;
using Microsoft.AspNetCore.Identity;
using Microsoft.Extensions.Logging;
@@ -42,6 +43,7 @@ public class BaseRequestValidatorTests
private readonly IEventService _eventService;
private readonly IDeviceValidator _deviceValidator;
private readonly ITwoFactorAuthenticationValidator _twoFactorAuthenticationValidator;
private readonly ISsoRequestValidator _ssoRequestValidator;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly FakeLogger<BaseRequestValidatorTests> _logger;
private readonly ICurrentContext _currentContext;
@@ -65,6 +67,7 @@ public class BaseRequestValidatorTests
_eventService = Substitute.For<IEventService>();
_deviceValidator = Substitute.For<IDeviceValidator>();
_twoFactorAuthenticationValidator = Substitute.For<ITwoFactorAuthenticationValidator>();
_ssoRequestValidator = Substitute.For<ISsoRequestValidator>();
_organizationUserRepository = Substitute.For<IOrganizationUserRepository>();
_logger = new FakeLogger<BaseRequestValidatorTests>();
_currentContext = Substitute.For<ICurrentContext>();
@@ -85,6 +88,7 @@ public class BaseRequestValidatorTests
_eventService,
_deviceValidator,
_twoFactorAuthenticationValidator,
_ssoRequestValidator,
_organizationUserRepository,
_logger,
_currentContext,
@@ -151,6 +155,7 @@ public class BaseRequestValidatorTests
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
// 1 -> to pass
_sut.isValid = true;
@@ -162,9 +167,9 @@ public class BaseRequestValidatorTests
// 4 -> set up device validator to fail
requestContext.KnownDevice = false;
tokenRequest.GrantType = "password";
tokenRequest.GrantType = OidcConstants.GrantTypes.Password;
_deviceValidator
.ValidateRequestDeviceAsync(Arg.Any<ValidatedTokenRequest>(), Arg.Any<CustomValidatorRequestContext>())
.ValidateRequestDeviceAsync(tokenRequest, requestContext)
.Returns(Task.FromResult(false));
// 5 -> not legacy user
@@ -192,6 +197,7 @@ public class BaseRequestValidatorTests
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
// 1 -> to pass
_sut.isValid = true;
@@ -203,12 +209,13 @@ public class BaseRequestValidatorTests
// 4 -> set up device validator to pass
_deviceValidator
.ValidateRequestDeviceAsync(Arg.Any<ValidatedTokenRequest>(), Arg.Any<CustomValidatorRequestContext>())
.ValidateRequestDeviceAsync(tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// 5 -> not legacy user
_userService.IsLegacyUser(Arg.Any<string>())
.Returns(false);
_userAccountKeysQuery.Run(Arg.Any<User>()).Returns(new UserAccountKeysData
{
PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData(
@@ -236,6 +243,7 @@ public class BaseRequestValidatorTests
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
// 1 -> to pass
_sut.isValid = true;
@@ -262,12 +270,13 @@ public class BaseRequestValidatorTests
// 4 -> set up device validator to pass
_deviceValidator
.ValidateRequestDeviceAsync(Arg.Any<ValidatedTokenRequest>(), Arg.Any<CustomValidatorRequestContext>())
.ValidateRequestDeviceAsync(tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// 5 -> not legacy user
_userService.IsLegacyUser(Arg.Any<string>())
.Returns(false);
_userAccountKeysQuery.Run(Arg.Any<User>()).Returns(new UserAccountKeysData
{
PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData(
@@ -299,6 +308,7 @@ public class BaseRequestValidatorTests
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
// 1 -> to pass
_sut.isValid = true;
@@ -319,10 +329,19 @@ public class BaseRequestValidatorTests
// 2 -> will result to false with no extra configuration
// 3 -> set two factor to be required
requestContext.User.TwoFactorProviders = "{\"1\":{\"Enabled\":true,\"MetaData\":{\"Email\":\"user@test.dev\"}}}";
_twoFactorAuthenticationValidator
.RequiresTwoFactorAsync(Arg.Any<User>(), tokenRequest)
.RequiresTwoFactorAsync(requestContext.User, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(true, null)));
_twoFactorAuthenticationValidator
.BuildTwoFactorResultAsync(requestContext.User, null)
.Returns(Task.FromResult(new Dictionary<string, object>
{
{ "TwoFactorProviders", new[] { "0", "1" } },
{ "TwoFactorProviders2", new Dictionary<string, object>{{"Email", null}} }
}));
// Act
await _sut.ValidateAsync(context);
@@ -330,7 +349,10 @@ public class BaseRequestValidatorTests
Assert.True(context.GrantResult.IsError);
// Assert that the auth request was NOT consumed
await _authRequestRepository.DidNotReceive().ReplaceAsync(Arg.Any<AuthRequest>());
await _authRequestRepository.DidNotReceive().ReplaceAsync(authRequest);
// Assert that the error is for 2fa
Assert.Equal("Two-factor authentication required.", context.GrantResult.ErrorDescription);
}
[Theory]
@@ -420,6 +442,7 @@ public class BaseRequestValidatorTests
{ "TwoFactorProviders", new[] { "0", "1" } },
{ "TwoFactorProviders2", new Dictionary<string, object>() }
};
_twoFactorAuthenticationValidator
.BuildTwoFactorResultAsync(user, null)
.Returns(Task.FromResult(twoFactorResultDict));
@@ -428,6 +451,8 @@ public class BaseRequestValidatorTests
await _sut.ValidateAsync(context);
// Assert
Assert.Equal("Two-factor authentication required.", context.GrantResult.ErrorDescription);
// Verify that the failed 2FA email was NOT sent for remember token expiration
await _mailService.DidNotReceive()
.SendFailedTwoFactorAttemptEmailAsync(Arg.Any<string>(), Arg.Any<TwoFactorProviderType>(),
@@ -1243,6 +1268,343 @@ public class BaseRequestValidatorTests
}
}
/// <summary>
/// Tests that when RedirectOnSsoRequired is DISABLED, the legacy SSO validation path is used.
/// This validates the deprecated RequireSsoLoginAsync method is called and SSO requirement
/// is checked using the old PolicyService.AnyPoliciesApplicableToUserAsync approach.
/// </summary>
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_RedirectOnSsoRequired_Disabled_UsesLegacySsoValidation(
bool recoveryCodeFeatureEnabled,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(recoveryCodeFeatureEnabled);
_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(false);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
tokenRequest.GrantType = OidcConstants.GrantTypes.Password;
// SSO is required via legacy path
_policyService.AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.True(context.GrantResult.IsError);
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
Assert.Equal("SSO authentication is required.", errorResponse.Message);
// Verify legacy path was used
await _policyService.Received(1).AnyPoliciesApplicableToUserAsync(
requestContext.User.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed);
// Verify new SsoRequestValidator was NOT called
await _ssoRequestValidator.DidNotReceive().ValidateAsync(
Arg.Any<User>(), Arg.Any<ValidatedTokenRequest>(), Arg.Any<CustomValidatorRequestContext>());
}
/// <summary>
/// Tests that when RedirectOnSsoRequired is ENABLED, the new ISsoRequestValidator is used
/// instead of the legacy RequireSsoLoginAsync method.
/// </summary>
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_UsesNewSsoRequestValidator(
bool recoveryCodeFeatureEnabled,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(recoveryCodeFeatureEnabled);
_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
tokenRequest.GrantType = OidcConstants.GrantTypes.Password;
// Configure SsoRequestValidator to indicate SSO is required
_ssoRequestValidator.ValidateAsync(
Arg.Any<User>(),
Arg.Any<ValidatedTokenRequest>(),
Arg.Any<CustomValidatorRequestContext>())
.Returns(Task.FromResult(false)); // false = SSO required
// Set up the ValidationErrorResult that SsoRequestValidator would set
requestContext.ValidationErrorResult = new ValidationResult
{
IsError = true,
Error = "sso_required",
ErrorDescription = "SSO authentication is required."
};
requestContext.CustomResponse = new Dictionary<string, object>
{
{ "ErrorModel", new ErrorResponseModel("SSO authentication is required.") }
};
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.True(context.GrantResult.IsError);
// Verify new SsoRequestValidator was called
await _ssoRequestValidator.Received(1).ValidateAsync(
requestContext.User,
tokenRequest,
requestContext);
// Verify legacy path was NOT used
await _policyService.DidNotReceive().AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), Arg.Any<PolicyType>(), Arg.Any<OrganizationUserStatusType>());
}
/// <summary>
/// Tests that when RedirectOnSsoRequired is ENABLED and SSO is NOT required,
/// authentication continues successfully through the new validation path.
/// </summary>
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_SsoNotRequired_SuccessfulLogin(
bool recoveryCodeFeatureEnabled,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(recoveryCodeFeatureEnabled);
_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
tokenRequest.GrantType = OidcConstants.GrantTypes.Password;
tokenRequest.ClientId = "web";
// SsoRequestValidator returns true (SSO not required)
_ssoRequestValidator.ValidateAsync(
Arg.Any<User>(),
Arg.Any<ValidatedTokenRequest>(),
Arg.Any<CustomValidatorRequestContext>())
.Returns(Task.FromResult(true));
// No 2FA required
_twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
// Device validation passes
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// User is not legacy
_userService.IsLegacyUser(Arg.Any<string>()).Returns(false);
_userAccountKeysQuery.Run(Arg.Any<User>()).Returns(new UserAccountKeysData
{
PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData(
"test-private-key",
"test-public-key"
)
});
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.False(context.GrantResult.IsError);
await _eventService.Received(1).LogUserEventAsync(requestContext.User.Id, EventType.User_LoggedIn);
// Verify new validator was used
await _ssoRequestValidator.Received(1).ValidateAsync(
requestContext.User,
tokenRequest,
requestContext);
}
/// <summary>
/// Tests that when RedirectOnSsoRequired is ENABLED and SSO validation returns a custom response
/// (e.g., with organization identifier), that custom response is properly propagated to the result.
/// </summary>
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_PropagatesCustomResponse(
bool recoveryCodeFeatureEnabled,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(recoveryCodeFeatureEnabled);
_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true);
_sut.isValid = true;
tokenRequest.GrantType = OidcConstants.GrantTypes.Password;
// SsoRequestValidator sets custom response with organization identifier
requestContext.ValidationErrorResult = new ValidationResult
{
IsError = true,
Error = "sso_required",
ErrorDescription = "SSO authentication is required."
};
requestContext.CustomResponse = new Dictionary<string, object>
{
{ "ErrorModel", new ErrorResponseModel("SSO authentication is required.") },
{ "SsoOrganizationIdentifier", "test-org-identifier" }
};
var context = CreateContext(tokenRequest, requestContext, grantResult);
_ssoRequestValidator.ValidateAsync(
Arg.Any<User>(),
Arg.Any<ValidatedTokenRequest>(),
Arg.Any<CustomValidatorRequestContext>())
.Returns(Task.FromResult(false));
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.True(context.GrantResult.IsError);
Assert.NotNull(context.GrantResult.CustomResponse);
Assert.Contains("SsoOrganizationIdentifier", context.CustomValidatorRequestContext.CustomResponse);
Assert.Equal("test-org-identifier", context.CustomValidatorRequestContext.CustomResponse["SsoOrganizationIdentifier"]);
}
/// <summary>
/// Tests that when RedirectOnSsoRequired is DISABLED and a user with 2FA recovery completes recovery,
/// but SSO is required, the legacy error message is returned (without the recovery-specific message).
/// </summary>
[Theory]
[BitAutoData]
public async Task ValidateAsync_RedirectOnSsoRequired_Disabled_RecoveryWithSso_LegacyMessage(
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(true);
_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(false);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
// Recovery code scenario
tokenRequest.Raw["TwoFactorProvider"] = ((int)TwoFactorProviderType.RecoveryCode).ToString();
tokenRequest.Raw["TwoFactorToken"] = "valid-recovery-code";
// 2FA with recovery
_twoFactorAuthenticationValidator
.RequiresTwoFactorAsync(requestContext.User, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(true, null)));
_twoFactorAuthenticationValidator
.VerifyTwoFactorAsync(requestContext.User, null, TwoFactorProviderType.RecoveryCode, "valid-recovery-code")
.Returns(Task.FromResult(true));
// SSO is required (legacy check)
_policyService.AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.True(context.GrantResult.IsError);
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
// Legacy behavior: recovery-specific message IS shown even without RedirectOnSsoRequired
Assert.Equal("Two-factor recovery has been performed. SSO authentication is required.", errorResponse.Message);
// But legacy validation path was used
await _policyService.Received(1).AnyPoliciesApplicableToUserAsync(
requestContext.User.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed);
}
/// <summary>
/// Tests that when RedirectOnSsoRequired is ENABLED and recovery code is used for SSO-required user,
/// the SsoRequestValidator provides the recovery-specific error message.
/// </summary>
[Theory]
[BitAutoData]
public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_RecoveryWithSso_NewValidatorMessage(
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(true);
_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
// Recovery code scenario
tokenRequest.Raw["TwoFactorProvider"] = ((int)TwoFactorProviderType.RecoveryCode).ToString();
tokenRequest.Raw["TwoFactorToken"] = "valid-recovery-code";
// 2FA with recovery
_twoFactorAuthenticationValidator
.RequiresTwoFactorAsync(requestContext.User, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(true, null)));
_twoFactorAuthenticationValidator
.VerifyTwoFactorAsync(requestContext.User, null, TwoFactorProviderType.RecoveryCode, "valid-recovery-code")
.Returns(Task.FromResult(true));
// SsoRequestValidator handles the recovery + SSO scenario
requestContext.TwoFactorRecoveryRequested = true;
requestContext.ValidationErrorResult = new ValidationResult
{
IsError = true,
Error = "sso_required",
ErrorDescription = "Two-factor recovery has been performed. SSO authentication is required."
};
requestContext.CustomResponse = new Dictionary<string, object>
{
{ "ErrorModel", new ErrorResponseModel("Two-factor recovery has been performed. SSO authentication is required.") }
};
_ssoRequestValidator.ValidateAsync(
Arg.Any<User>(),
Arg.Any<ValidatedTokenRequest>(),
Arg.Any<CustomValidatorRequestContext>())
.Returns(Task.FromResult(false));
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.True(context.GrantResult.IsError);
var errorResponse = (ErrorResponseModel)context.CustomValidatorRequestContext.CustomResponse["ErrorModel"];
Assert.Equal("Two-factor recovery has been performed. SSO authentication is required.", errorResponse.Message);
// Verify new validator was used
await _ssoRequestValidator.Received(1).ValidateAsync(
requestContext.User,
tokenRequest,
Arg.Is<CustomValidatorRequestContext>(ctx => ctx.TwoFactorRecoveryRequested));
// Verify legacy path was NOT used
await _policyService.DidNotReceive().AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), Arg.Any<PolicyType>(), Arg.Any<OrganizationUserStatusType>());
}
private BaseRequestValidationContextFake CreateContext(
ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,

View File

@@ -0,0 +1,469 @@
using Bit.Core;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Auth.Sso;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Services;
using Bit.Identity.IdentityServer;
using Bit.Identity.IdentityServer.Enums;
using Bit.Identity.IdentityServer.RequestValidationConstants;
using Bit.Identity.IdentityServer.RequestValidators;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Duende.IdentityModel;
using Duende.IdentityServer.Validation;
using NSubstitute;
using Xunit;
using AuthFixtures = Bit.Identity.Test.AutoFixture;
namespace Bit.Identity.Test.IdentityServer;
[SutProviderCustomize]
public class SsoRequestValidatorTests
{
[Theory]
[BitAutoData(OidcConstants.GrantTypes.AuthorizationCode)]
[BitAutoData(OidcConstants.GrantTypes.ClientCredentials)]
public async void ValidateAsync_GrantTypeIgnoresSsoRequirement_ReturnsTrue(
string grantType,
User user,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request,
SutProvider<SsoRequestValidator> sutProvider)
{
// Arrange
request.GrantType = grantType;
// Act
var result = await sutProvider.Sut.ValidateAsync(user, request, context);
// Assert
Assert.True(result);
Assert.False(context.SsoRequired);
Assert.Null(context.ValidationErrorResult);
Assert.Null(context.CustomResponse);
// Should not check policies since grant type allows bypass
await sutProvider.GetDependency<IPolicyService>().DidNotReceive()
.AnyPoliciesApplicableToUserAsync(Arg.Any<Guid>(), Arg.Any<PolicyType>(), Arg.Any<OrganizationUserStatusType>());
await sutProvider.GetDependency<IPolicyRequirementQuery>().DidNotReceive()
.GetAsync<RequireSsoPolicyRequirement>(Arg.Any<Guid>());
}
[Theory, BitAutoData]
public async void ValidateAsync_SsoNotRequired_RequirementPolicyFeatureFlagEnabled_ReturnsTrue(
User user,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request,
SutProvider<SsoRequestValidator> sutProvider)
{
// Arrange
request.GrantType = OidcConstants.GrantTypes.Password;
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true);
var requirement = new RequireSsoPolicyRequirement { SsoRequired = false };
sutProvider.GetDependency<IPolicyRequirementQuery>().GetAsync<RequireSsoPolicyRequirement>(user.Id)
.Returns(requirement);
// Act
var result = await sutProvider.Sut.ValidateAsync(user, request, context);
// Assert
Assert.True(result);
Assert.False(context.SsoRequired);
Assert.Null(context.ValidationErrorResult);
Assert.Null(context.CustomResponse);
// Should use the new policy requirement query when feature flag is enabled
await sutProvider.GetDependency<IPolicyRequirementQuery>().Received(1).GetAsync<RequireSsoPolicyRequirement>(user.Id);
await sutProvider.GetDependency<IPolicyService>().DidNotReceive()
.AnyPoliciesApplicableToUserAsync(Arg.Any<Guid>(), Arg.Any<PolicyType>(), Arg.Any<OrganizationUserStatusType>());
}
[Theory, BitAutoData]
public async void ValidateAsync_SsoNotRequired_RequirementPolicyFeatureFlagDisabled_ReturnsTrue(
User user,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request,
SutProvider<SsoRequestValidator> sutProvider)
{
// Arrange
request.GrantType = OidcConstants.GrantTypes.Password;
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(false);
sutProvider.GetDependency<IPolicyService>().AnyPoliciesApplicableToUserAsync(
user.Id,
PolicyType.RequireSso,
OrganizationUserStatusType.Confirmed)
.Returns(false);
// Act
var result = await sutProvider.Sut.ValidateAsync(user, request, context);
// Assert
Assert.True(result);
Assert.False(context.SsoRequired);
Assert.Null(context.ValidationErrorResult);
Assert.Null(context.CustomResponse);
// Should use the legacy policy service when feature flag is disabled
await sutProvider.GetDependency<IPolicyService>().Received(1).AnyPoliciesApplicableToUserAsync(
user.Id,
PolicyType.RequireSso,
OrganizationUserStatusType.Confirmed);
await sutProvider.GetDependency<IPolicyRequirementQuery>().DidNotReceive()
.GetAsync<RequireSsoPolicyRequirement>(Arg.Any<Guid>());
}
[Theory, BitAutoData]
public async void ValidateAsync_SsoRequired_RequirementPolicyFeatureFlagEnabled_ReturnsFalse(
User user,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request,
SutProvider<SsoRequestValidator> sutProvider)
{
// Arrange
request.GrantType = OidcConstants.GrantTypes.Password;
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true);
var requirement = new RequireSsoPolicyRequirement { SsoRequired = true };
sutProvider.GetDependency<IPolicyRequirementQuery>().GetAsync<RequireSsoPolicyRequirement>(user.Id)
.Returns(requirement);
sutProvider.GetDependency<IUserSsoOrganizationIdentifierQuery>()
.GetSsoOrganizationIdentifierAsync(user.Id)
.Returns((string)null);
// Act
var result = await sutProvider.Sut.ValidateAsync(user, request, context);
// Assert
Assert.False(result);
Assert.True(context.SsoRequired);
Assert.NotNull(context.ValidationErrorResult);
Assert.True(context.ValidationErrorResult.IsError);
Assert.Equal(OidcConstants.TokenErrors.InvalidGrant, context.ValidationErrorResult.Error);
Assert.Equal(SsoConstants.RequestErrors.SsoRequiredDescription, context.ValidationErrorResult.ErrorDescription);
Assert.NotNull(context.CustomResponse);
Assert.True(context.CustomResponse.ContainsKey(CustomResponseConstants.ResponseKeys.ErrorModel));
Assert.False(context.CustomResponse.ContainsKey(CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier));
}
[Theory, BitAutoData]
public async void ValidateAsync_SsoRequired_RequirementPolicyFeatureFlagDisabled_ReturnsFalse(
User user,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request,
SutProvider<SsoRequestValidator> sutProvider)
{
// Arrange
request.GrantType = OidcConstants.GrantTypes.Password;
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(false);
sutProvider.GetDependency<IPolicyService>().AnyPoliciesApplicableToUserAsync(
user.Id,
PolicyType.RequireSso,
OrganizationUserStatusType.Confirmed)
.Returns(true);
sutProvider.GetDependency<IUserSsoOrganizationIdentifierQuery>()
.GetSsoOrganizationIdentifierAsync(user.Id)
.Returns((string)null);
// Act
var result = await sutProvider.Sut.ValidateAsync(user, request, context);
// Assert
Assert.False(result);
Assert.True(context.SsoRequired);
Assert.NotNull(context.ValidationErrorResult);
Assert.True(context.ValidationErrorResult.IsError);
Assert.Equal(OidcConstants.TokenErrors.InvalidGrant, context.ValidationErrorResult.Error);
Assert.Equal(SsoConstants.RequestErrors.SsoRequiredDescription, context.ValidationErrorResult.ErrorDescription);
Assert.NotNull(context.CustomResponse);
Assert.True(context.CustomResponse.ContainsKey("ErrorModel"));
Assert.False(context.CustomResponse.ContainsKey("SsoOrganizationIdentifier"));
}
[Theory, BitAutoData]
public async void ValidateAsync_SsoRequired_TwoFactorRecoveryRequested_ReturnsFalse_WithSpecialMessage(
User user,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request,
SutProvider<SsoRequestValidator> sutProvider)
{
// Arrange
request.GrantType = OidcConstants.GrantTypes.Password;
context.TwoFactorRecoveryRequested = true;
context.TwoFactorRequired = true;
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true);
var requirement = new RequireSsoPolicyRequirement { SsoRequired = true };
sutProvider.GetDependency<IPolicyRequirementQuery>().GetAsync<RequireSsoPolicyRequirement>(user.Id)
.Returns(requirement);
sutProvider.GetDependency<IUserSsoOrganizationIdentifierQuery>()
.GetSsoOrganizationIdentifierAsync(user.Id)
.Returns((string)null);
// Act
var result = await sutProvider.Sut.ValidateAsync(user, request, context);
// Assert
Assert.False(result);
Assert.True(context.SsoRequired);
Assert.NotNull(context.ValidationErrorResult);
Assert.True(context.ValidationErrorResult.IsError);
Assert.Equal(OidcConstants.TokenErrors.InvalidGrant, context.ValidationErrorResult.Error);
Assert.Equal("Two-factor recovery has been performed. SSO authentication is required.",
context.ValidationErrorResult.ErrorDescription);
Assert.NotNull(context.CustomResponse);
Assert.True(context.CustomResponse.ContainsKey("ErrorModel"));
Assert.False(context.CustomResponse.ContainsKey("SsoOrganizationIdentifier"));
}
[Theory, BitAutoData]
public async void ValidateAsync_SsoRequired_TwoFactorRequiredButNotRecovery_ReturnsFalse_WithStandardMessage(
User user,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request,
SutProvider<SsoRequestValidator> sutProvider)
{
// Arrange
request.GrantType = OidcConstants.GrantTypes.Password;
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true);
var requirement = new RequireSsoPolicyRequirement { SsoRequired = true };
sutProvider.GetDependency<IPolicyRequirementQuery>().GetAsync<RequireSsoPolicyRequirement>(user.Id)
.Returns(requirement);
sutProvider.GetDependency<IUserSsoOrganizationIdentifierQuery>()
.GetSsoOrganizationIdentifierAsync(user.Id)
.Returns((string)null);
// Act
var result = await sutProvider.Sut.ValidateAsync(user, request, context);
// Assert
Assert.False(result);
Assert.True(context.SsoRequired);
Assert.NotNull(context.ValidationErrorResult);
Assert.True(context.ValidationErrorResult.IsError);
Assert.Equal(OidcConstants.TokenErrors.InvalidGrant, context.ValidationErrorResult.Error);
Assert.Equal(SsoConstants.RequestErrors.SsoRequiredDescription, context.ValidationErrorResult.ErrorDescription);
Assert.NotNull(context.CustomResponse);
Assert.True(context.CustomResponse.ContainsKey("ErrorModel"));
Assert.False(context.CustomResponse.ContainsKey("SsoOrganizationIdentifier"));
}
[Theory]
[BitAutoData(OidcConstants.GrantTypes.Password)]
[BitAutoData(OidcConstants.GrantTypes.RefreshToken)]
[BitAutoData(CustomGrantTypes.WebAuthn)]
public async void ValidateAsync_VariousGrantTypes_SsoRequired_ReturnsFalse(
string grantType,
User user,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request,
SutProvider<SsoRequestValidator> sutProvider)
{
// Arrange
request.GrantType = grantType;
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true);
var requirement = new RequireSsoPolicyRequirement { SsoRequired = true };
sutProvider.GetDependency<IPolicyRequirementQuery>().GetAsync<RequireSsoPolicyRequirement>(user.Id)
.Returns(requirement);
sutProvider.GetDependency<IUserSsoOrganizationIdentifierQuery>()
.GetSsoOrganizationIdentifierAsync(user.Id)
.Returns((string)null);
// Act
var result = await sutProvider.Sut.ValidateAsync(user, request, context);
// Assert
Assert.False(result);
Assert.True(context.SsoRequired);
Assert.NotNull(context.ValidationErrorResult);
Assert.True(context.ValidationErrorResult.IsError);
Assert.Equal(OidcConstants.TokenErrors.InvalidGrant, context.ValidationErrorResult.Error);
Assert.Equal(SsoConstants.RequestErrors.SsoRequiredDescription, context.ValidationErrorResult.ErrorDescription);
Assert.NotNull(context.CustomResponse);
}
[Theory, BitAutoData]
public async void ValidateAsync_ContextSsoRequiredUpdated_RegardlessOfInitialValue(
User user,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request,
SutProvider<SsoRequestValidator> sutProvider)
{
// Arrange
request.GrantType = OidcConstants.GrantTypes.Password;
context.SsoRequired = true; // Start with true to ensure it gets updated
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true);
var requirement = new RequireSsoPolicyRequirement { SsoRequired = false };
sutProvider.GetDependency<IPolicyRequirementQuery>().GetAsync<RequireSsoPolicyRequirement>(user.Id)
.Returns(requirement);
// Act
var result = await sutProvider.Sut.ValidateAsync(user, request, context);
// Assert
Assert.True(result);
Assert.False(context.SsoRequired); // Should be updated to false
Assert.Null(context.ValidationErrorResult);
Assert.Null(context.CustomResponse);
}
[Theory, BitAutoData]
public async void ValidateAsync_SsoRequired_WithOrganizationIdentifier_IncludesIdentifierInResponse(
User user,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request,
SutProvider<SsoRequestValidator> sutProvider)
{
// Arrange
const string orgIdentifier = "test-organization";
request.GrantType = OidcConstants.GrantTypes.Password;
context.User = user;
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true);
var requirement = new RequireSsoPolicyRequirement { SsoRequired = true };
sutProvider.GetDependency<IPolicyRequirementQuery>().GetAsync<RequireSsoPolicyRequirement>(user.Id)
.Returns(requirement);
sutProvider.GetDependency<IUserSsoOrganizationIdentifierQuery>()
.GetSsoOrganizationIdentifierAsync(user.Id)
.Returns(orgIdentifier);
// Act
var result = await sutProvider.Sut.ValidateAsync(user, request, context);
// Assert
Assert.False(result);
Assert.True(context.SsoRequired);
Assert.NotNull(context.CustomResponse);
Assert.True(context.CustomResponse.ContainsKey(CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier));
Assert.Equal(orgIdentifier, context.CustomResponse[CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier]);
await sutProvider.GetDependency<IUserSsoOrganizationIdentifierQuery>()
.Received(1)
.GetSsoOrganizationIdentifierAsync(user.Id);
}
[Theory, BitAutoData]
public async void ValidateAsync_SsoRequired_NoOrganizationIdentifier_DoesNotIncludeIdentifierInResponse(
User user,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request,
SutProvider<SsoRequestValidator> sutProvider)
{
// Arrange
request.GrantType = OidcConstants.GrantTypes.Password;
context.User = user;
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true);
var requirement = new RequireSsoPolicyRequirement { SsoRequired = true };
sutProvider.GetDependency<IPolicyRequirementQuery>().GetAsync<RequireSsoPolicyRequirement>(user.Id)
.Returns(requirement);
sutProvider.GetDependency<IUserSsoOrganizationIdentifierQuery>()
.GetSsoOrganizationIdentifierAsync(user.Id)
.Returns((string)null);
// Act
var result = await sutProvider.Sut.ValidateAsync(user, request, context);
// Assert
Assert.False(result);
Assert.True(context.SsoRequired);
Assert.NotNull(context.CustomResponse);
Assert.False(context.CustomResponse.ContainsKey(CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier));
await sutProvider.GetDependency<IUserSsoOrganizationIdentifierQuery>()
.Received(1)
.GetSsoOrganizationIdentifierAsync(user.Id);
}
[Theory, BitAutoData]
public async void ValidateAsync_SsoRequired_EmptyOrganizationIdentifier_DoesNotIncludeIdentifierInResponse(
User user,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request,
SutProvider<SsoRequestValidator> sutProvider)
{
// Arrange
request.GrantType = OidcConstants.GrantTypes.Password;
context.User = user;
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true);
var requirement = new RequireSsoPolicyRequirement { SsoRequired = true };
sutProvider.GetDependency<IPolicyRequirementQuery>().GetAsync<RequireSsoPolicyRequirement>(user.Id)
.Returns(requirement);
sutProvider.GetDependency<IUserSsoOrganizationIdentifierQuery>()
.GetSsoOrganizationIdentifierAsync(user.Id)
.Returns(string.Empty);
// Act
var result = await sutProvider.Sut.ValidateAsync(user, request, context);
// Assert
Assert.False(result);
Assert.True(context.SsoRequired);
Assert.NotNull(context.CustomResponse);
Assert.False(context.CustomResponse.ContainsKey(CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier));
await sutProvider.GetDependency<IUserSsoOrganizationIdentifierQuery>()
.Received(1)
.GetSsoOrganizationIdentifierAsync(user.Id);
}
[Theory, BitAutoData]
public async void ValidateAsync_SsoNotRequired_DoesNotCallOrganizationIdentifierQuery(
User user,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request,
SutProvider<SsoRequestValidator> sutProvider)
{
// Arrange
request.GrantType = OidcConstants.GrantTypes.Password;
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true);
var requirement = new RequireSsoPolicyRequirement { SsoRequired = false };
sutProvider.GetDependency<IPolicyRequirementQuery>().GetAsync<RequireSsoPolicyRequirement>(user.Id)
.Returns(requirement);
// Act
var result = await sutProvider.Sut.ValidateAsync(user, request, context);
// Assert
Assert.True(result);
Assert.False(context.SsoRequired);
await sutProvider.GetDependency<IUserSsoOrganizationIdentifierQuery>()
.DidNotReceive()
.GetSsoOrganizationIdentifierAsync(Arg.Any<Guid>());
}
}

View File

@@ -32,7 +32,7 @@ public class TwoFactorAuthenticationValidatorTests
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IOrganizationRepository _organizationRepository;
private readonly IDataProtectorTokenFactory<SsoEmail2faSessionTokenable> _ssoEmail2faSessionTokenable;
private readonly ITwoFactorIsEnabledQuery _twoFactorenabledQuery;
private readonly ITwoFactorIsEnabledQuery _twoFactorEnabledQuery;
private readonly ICurrentContext _currentContext;
private readonly TwoFactorAuthenticationValidator _sut;
@@ -45,7 +45,7 @@ public class TwoFactorAuthenticationValidatorTests
_organizationUserRepository = Substitute.For<IOrganizationUserRepository>();
_organizationRepository = Substitute.For<IOrganizationRepository>();
_ssoEmail2faSessionTokenable = Substitute.For<IDataProtectorTokenFactory<SsoEmail2faSessionTokenable>>();
_twoFactorenabledQuery = Substitute.For<ITwoFactorIsEnabledQuery>();
_twoFactorEnabledQuery = Substitute.For<ITwoFactorIsEnabledQuery>();
_currentContext = Substitute.For<ICurrentContext>();
_sut = new TwoFactorAuthenticationValidator(
@@ -56,7 +56,7 @@ public class TwoFactorAuthenticationValidatorTests
_organizationUserRepository,
_organizationRepository,
_ssoEmail2faSessionTokenable,
_twoFactorenabledQuery,
_twoFactorEnabledQuery,
_currentContext);
}

View File

@@ -54,6 +54,7 @@ IBaseRequestValidatorTestWrapper
IEventService eventService,
IDeviceValidator deviceValidator,
ITwoFactorAuthenticationValidator twoFactorAuthenticationValidator,
ISsoRequestValidator ssoRequestValidator,
IOrganizationUserRepository organizationUserRepository,
ILogger logger,
ICurrentContext currentContext,
@@ -73,6 +74,7 @@ IBaseRequestValidatorTestWrapper
eventService,
deviceValidator,
twoFactorAuthenticationValidator,
ssoRequestValidator,
organizationUserRepository,
logger,
currentContext,
@@ -132,12 +134,17 @@ IBaseRequestValidatorTestWrapper
protected override void SetTwoFactorResult(
BaseRequestValidationContextFake context,
Dictionary<string, object> customResponse)
{ }
{
context.GrantResult = new GrantValidationResult(
TokenRequestErrors.InvalidGrant, "Two-factor authentication required.", customResponse);
}
protected override void SetValidationErrorResult(
BaseRequestValidationContextFake context,
CustomValidatorRequestContext requestContext)
{ }
{
context.GrantResult.IsError = true;
}
protected override Task<bool> ValidateContextAsync(
BaseRequestValidationContextFake context,

View File

@@ -93,7 +93,8 @@ public static class OrganizationTestHelpers
UseOrganizationDomains = true,
UseAdminSponsoredFamilies = true,
SyncSeats = false,
UseAutomaticUserConfirmation = true
UseAutomaticUserConfirmation = true,
UsePhishingBlocker = true,
});
}

View File

@@ -673,7 +673,8 @@ public class OrganizationUserRepositoryTests
LimitItemDeletion = false,
AllowAdminAccessToAllCollectionItems = false,
UseRiskInsights = false,
UseAdminSponsoredFamilies = false
UseAdminSponsoredFamilies = false,
UsePhishingBlocker = false,
});
var organizationDomain = new OrganizationDomain

View File

@@ -225,6 +225,30 @@ public class HubHelpersTest
.Group(Arg.Any<string>());
}
[Theory]
[BitAutoData]
public async Task SendNotificationToHubAsync_PolicyChanged_SentToOrganizationGroup(
SutProvider<HubHelpers> sutProvider,
SyncPolicyPushNotification notification,
string contextId,
CancellationToken cancellationToken)
{
var json = ToNotificationJson(notification, PushType.PolicyChanged, contextId);
await sutProvider.Sut.SendNotificationToHubAsync(json, cancellationToken);
sutProvider.GetDependency<IHubContext<NotificationsHub>>().Clients.Received(0).User(Arg.Any<string>());
await sutProvider.GetDependency<IHubContext<NotificationsHub>>().Clients.Received(1)
.Group($"Organization_{notification.OrganizationId}")
.Received(1)
.SendCoreAsync("ReceiveMessage", Arg.Is<object?[]>(objects =>
objects.Length == 1 && AssertSyncPolicyPushNotification(notification, objects[0],
PushType.PolicyChanged, contextId)),
cancellationToken);
sutProvider.GetDependency<IHubContext<AnonymousNotificationsHub>>().Clients.Received(0).User(Arg.Any<string>());
sutProvider.GetDependency<IHubContext<AnonymousNotificationsHub>>().Clients.Received(0)
.Group(Arg.Any<string>());
}
private static string ToNotificationJson(object payload, PushType type, string contextId)
{
var notification = new PushNotificationData<object>(type, payload, contextId);
@@ -247,4 +271,20 @@ public class HubHelpersTest
expected.ClientType == pushNotificationData.Payload.ClientType &&
expected.RevisionDate == pushNotificationData.Payload.RevisionDate;
}
private static bool AssertSyncPolicyPushNotification(SyncPolicyPushNotification expected, object? actual,
PushType type, string contextId)
{
if (actual is not PushNotificationData<SyncPolicyPushNotification> pushNotificationData)
{
return false;
}
return pushNotificationData.Type == type &&
pushNotificationData.ContextId == contextId &&
expected.OrganizationId == pushNotificationData.Payload.OrganizationId &&
expected.Policy.Id == pushNotificationData.Payload.Policy.Id &&
expected.Policy.Type == pushNotificationData.Payload.Policy.Type &&
expected.Policy.Enabled == pushNotificationData.Payload.Policy.Enabled;
}
}