1
0
mirror of https://github.com/bitwarden/server synced 2025-12-20 18:23:44 +00:00

[PM-21411] Refactor interface for determining premium status and features (#6688)

* Removed 2FA user interface from premium method signatures

* Added some more comments for clarity and small touchups.

* Add PremiumAccessCacheCheck feature flag to Constants.cs

* Add IPremiumAccessQuery interface and PremiumAccessQuery implementation for checking user premium access status

* Add unit tests for PremiumAccessQuery to validate user premium access logic

* Add XML documentation to Premium in OrganizationUserUserDetails and User classes

* Add PremiumAccessQueries to UserServiceCollectionExtensions

* Refactor TwoFactorIsEnabledQuery to incorporate PremiumAccessQuery and feature flag for premium access checks. Enhanced user premium status retrieval logic and improved handling of user details based on feature flag state.

* Mark methods in IUserRepository and IUserService as obsolete, directing users to new methods in IPremiumAccessQuery for premium access checks.

* Rename CanAccessPremiumBulkAsync to CanAccessPremiumAsync in IPremiumAccessQuery

* Update TwoFactorIsEnabledQuery to use CanAccessPremiumAsync for premium status checks

* Refactor TwoFactorIsEnabledQuery to introduce VNextAsync methods for improved premium access checks and user detail handling. Removed obsolete feature service dependency and enhanced test coverage for new functionality.

* Refactor IPremiumAccessQuery and PremiumAccessQuery to remove the overloaded CanAccessPremiumAsync method. Update related methods to streamline premium access checks using the User object directly. Enhance test coverage by removing obsolete tests and ensuring proper functionality with the new method signatures.

* Add new sync static method to determine if TwoFactor is enabled

* Enhance XML documentation for Premium property in OrganizationUserUserDetails and User classes to clarify its usage and limitations regarding personal and organizational premium access.

* Refactor IPremiumAccessQuery and PremiumAccessQuery to replace User parameter with Guid for user ID in CanAccessPremiumAsync methods. Update related methods and tests to streamline premium access checks and improve clarity in method signatures.

* Update feature flag references in IUserRepository and IUserService to use 'PremiumAccessQuery' instead of 'PremiumAccessCacheCheck'. Adjust related XML documentation for clarity on premium access methods.

* Rename IPremiumAccessQuery to IHasPremiumAccessQuery and move to Billing owned folder

* Remove unnecessary whitespace from IHasPremiumAccessQuery interface.

* Refactor HasPremiumAccessQuery to throw NotFoundException for null users

* Add NotFoundException handling in HasPremiumAccessQuery for mismatched user counts

* Refactor TwoFactorIsEnabledQuery to optimize premium access checks and improve two-factor provider handling. Introduced bulk fetching of premium status for users with only premium providers and streamlined the logic for determining if two-factor authentication is enabled.

* Refactor TwoFactorIsEnabledQueryTests to enhance clarity and optimize test scenarios. Consolidated test cases for two-factor authentication, improved naming conventions, and ensured premium access checks are only performed when necessary.

* Add UserPremiumAccess model to represent user premium access status from personal subscriptions and memberships

* Add User_ReadPremiumAccessByIds stored procedure and UserPremiumAccessView view to enhance premium access retrieval. Updated Organization table index to include UsersGetPremium for optimized queries.

* Add SQL migration script

* Add premium access retrieval methods to IUserRepository and implementations in UserRepository classes. Introduced GetPremiumAccessByIdsAsync and GetPremiumAccessAsync methods to fetch premium status for multiple users and a single user, respectively. Updated using directives to include necessary models.

* Refactor HasPremiumAccessQuery and IHasPremiumAccessQuery to streamline premium access checks. Updated method names for clarity and improved documentation. Adjusted test cases to reflect changes in user premium access retrieval logic.

* Update IUserRepository to reflect new method names for premium access retrieval. Changed obsolete method messages to point to GetPremiumAccessByIdsAsync and GetPremiumAccessAsync. Added internal use notes for IHasPremiumAccessQuery. Improved documentation for clarity.

* Refactor TwoFactorIsEnabledQuery to utilize IFeatureService for premium access checks.

* Enhance EF UserRepository to improve premium access retrieval by including related organization data.

* Add unit tests for premium access retrieval in UserRepositoryTests.

* Optimize HasPremiumAccessQuery to eliminate duplicate user IDs before checking premium access. Updated logic to ensure accurate comparison of premium users against distinct user IDs.

* Refactor TwoFactorIsEnabledQuery to improve handling of users without two-factor providers. Added early exit for users lacking providers and streamlined premium status checks for enabled two-factor authentication.

* Update HasPremiumAccessQueryTests to use simplified exception handling and improve test clarity

* Replaced fully qualified exception references with simplified ones.
* Refactored test setup to use individual user variables for better readability.
* Ensured assertions reflect the updated user variable structure.

* Enhance TwoFactorIsEnabledQuery to throw NotFoundException for non-existent users

* Updated TwoFactorIsEnabledQuery to throw NotFoundException when a user is not found instead of returning false.
* Added a new unit test to verify that the NotFoundException is thrown when a user is not found while premium access query is enabled.

* Move premium access query to Billing owned ServiceCollectionExtensions

* Refactor IUserService to enhance premium access checks

* Updated CanAccessPremium and HasPremiumFromOrganization methods to clarify usage with the new premium access query.
* Integrated IHasPremiumAccessQuery into UserService for improved premium access handling based on feature flag.
* Adjusted method documentation to reflect changes in premium access logic.

* Update IUserRepository to clarify usage of premium access methods

* Modified Obsolete attribute messages for GetManyWithCalculatedPremiumAsync and GetCalculatedPremiumAsync to indicate that callers should use the new methods when the 'PremiumAccessQuery' feature flag is enabled.
* Enhanced documentation to improve clarity regarding premium access handling.

* Update IUserRepository and IUserService to clarify deprecation of premium access methods

* Modified Obsolete attribute messages for GetManyWithCalculatedPremiumAsync and GetCalculatedPremiumAsync in IUserRepository to indicate these methods will be removed in a future version.
* Updated Obsolete attribute message for HasPremiumFromOrganization in IUserService to reflect the same deprecation notice.

* Refactor TwoFactorIsEnabledQuery to streamline user ID retrieval

* Consolidated user ID retrieval logic to avoid redundancy.
* Ensured consistent handling of user ID checks for premium access queries.
* Improved code readability by reducing duplicate code blocks.

* Rename migration script to fix the date

* Update migration script to create the index with DROP_EXISTING = ON

* Refactor UserPremiumAccessView to use LEFT JOINs and GROUP BY for improved performance and clarity

* Update HasPremiumAccessQueryTests to return null for GetPremiumAccessAsync instead of throwing NotFoundException

* Add unit tests for premium access scenarios in UserRepositoryTests

- Implement tests for GetPremiumAccessAsync to cover various user and organization premium access combinations.
- Validate behavior when users belong to multiple organizations, including cases with and without premium access.
- Update email generation for user creation to ensure uniqueness without specific prefixes.
- Enhance assertions to verify expected premium access results across different test cases.

* Bump date on migration script

* Update OrganizationEntityTypeConfiguration to include UsersGetPremium in index properties

* Add migration scripts for OrganizationUsersGetPremiumIndex across MySQL, PostgreSQL, and SQLite

- Introduced new migration files to create the OrganizationUsersGetPremiumIndex.
- Updated the DatabaseContextModelSnapshot to include UsersGetPremium in index properties for all database types.
- Ensured consistency in index creation across different database implementations.

---------

Co-authored-by: Todd Martin <tmartin@bitwarden.com>
Co-authored-by: Patrick Pimentel <ppimentel@bitwarden.com>
This commit is contained in:
Rui Tomé
2025-12-16 10:31:56 +00:00
committed by GitHub
parent e646b91a50
commit f7c615cc01
30 changed files with 11701 additions and 21 deletions

View File

@@ -20,6 +20,12 @@ public class OrganizationUserUserDetails : IExternal, ITwoFactorProvidersUser, I
public string Email { get; set; }
public string AvatarColor { get; set; }
public string TwoFactorProviders { get; set; }
/// <summary>
/// Indicates whether the user has a personal premium subscription.
/// Does not include premium access from organizations -
/// do not use this to check whether the user can access premium features.
/// Null when the organization user is in Invited status (UserId is null).
/// </summary>
public bool? Premium { get; set; }
public OrganizationUserStatusType Status { get; set; }
public OrganizationUserType Type { get; set; }

View File

@@ -4,16 +4,37 @@
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models;
using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces;
using Bit.Core.Billing.Premium.Queries;
using Bit.Core.Entities;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Core.Services;
namespace Bit.Core.Auth.UserFeatures.TwoFactorAuth;
public class TwoFactorIsEnabledQuery(IUserRepository userRepository) : ITwoFactorIsEnabledQuery
public class TwoFactorIsEnabledQuery : ITwoFactorIsEnabledQuery
{
private readonly IUserRepository _userRepository = userRepository;
private readonly IUserRepository _userRepository;
private readonly IHasPremiumAccessQuery _hasPremiumAccessQuery;
private readonly IFeatureService _featureService;
public TwoFactorIsEnabledQuery(
IUserRepository userRepository,
IHasPremiumAccessQuery hasPremiumAccessQuery,
IFeatureService featureService)
{
_userRepository = userRepository;
_hasPremiumAccessQuery = hasPremiumAccessQuery;
_featureService = featureService;
}
public async Task<IEnumerable<(Guid userId, bool twoFactorIsEnabled)>> TwoFactorIsEnabledAsync(IEnumerable<Guid> userIds)
{
if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery))
{
return await TwoFactorIsEnabledVNextAsync(userIds);
}
var result = new List<(Guid userId, bool hasTwoFactor)>();
if (userIds == null || !userIds.Any())
{
@@ -36,6 +57,11 @@ public class TwoFactorIsEnabledQuery(IUserRepository userRepository) : ITwoFacto
public async Task<IEnumerable<(T user, bool twoFactorIsEnabled)>> TwoFactorIsEnabledAsync<T>(IEnumerable<T> users) where T : ITwoFactorProvidersUser
{
if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery))
{
return await TwoFactorIsEnabledVNextAsync(users);
}
var userIds = users
.Select(u => u.GetUserId())
.Where(u => u.HasValue)
@@ -71,13 +97,134 @@ public class TwoFactorIsEnabledQuery(IUserRepository userRepository) : ITwoFacto
return false;
}
if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery))
{
var userEntity = user as User ?? await _userRepository.GetByIdAsync(userId.Value);
if (userEntity == null)
{
throw new NotFoundException();
}
return await TwoFactorIsEnabledVNextAsync(userEntity);
}
return await TwoFactorEnabledAsync(
user.GetTwoFactorProviders(),
async () =>
{
var calcUser = await _userRepository.GetCalculatedPremiumAsync(userId.Value);
return calcUser?.HasPremiumAccess ?? false;
});
user.GetTwoFactorProviders(),
async () =>
{
var calcUser = await _userRepository.GetCalculatedPremiumAsync(userId.Value);
return calcUser?.HasPremiumAccess ?? false;
});
}
private async Task<IEnumerable<(Guid userId, bool twoFactorIsEnabled)>> TwoFactorIsEnabledVNextAsync(IEnumerable<Guid> userIds)
{
var result = new List<(Guid userId, bool hasTwoFactor)>();
if (userIds == null || !userIds.Any())
{
return result;
}
var users = await _userRepository.GetManyAsync([.. userIds]);
// Get enabled providers for each user
var usersTwoFactorProvidersMap = users.ToDictionary(u => u.Id, GetEnabledTwoFactorProviders);
// Bulk fetch premium status only for users who need it (those with only premium providers)
var userIdsNeedingPremium = usersTwoFactorProvidersMap
.Where(kvp => kvp.Value.Any() && kvp.Value.All(TwoFactorProvider.RequiresPremium))
.Select(kvp => kvp.Key)
.ToList();
var premiumStatusMap = userIdsNeedingPremium.Count > 0
? await _hasPremiumAccessQuery.HasPremiumAccessAsync(userIdsNeedingPremium)
: new Dictionary<Guid, bool>();
foreach (var user in users)
{
var userTwoFactorProviders = usersTwoFactorProvidersMap[user.Id];
if (!userTwoFactorProviders.Any())
{
result.Add((user.Id, false));
continue;
}
// User has providers. If they're in the premium check map, verify premium status
var twoFactorIsEnabled = !premiumStatusMap.TryGetValue(user.Id, out var hasPremium) || hasPremium;
result.Add((user.Id, twoFactorIsEnabled));
}
return result;
}
private async Task<IEnumerable<(T user, bool twoFactorIsEnabled)>> TwoFactorIsEnabledVNextAsync<T>(IEnumerable<T> users)
where T : ITwoFactorProvidersUser
{
var userIds = users
.Select(u => u.GetUserId())
.Where(u => u.HasValue)
.Select(u => u.Value)
.ToList();
var twoFactorResults = await TwoFactorIsEnabledVNextAsync(userIds);
var result = new List<(T user, bool twoFactorIsEnabled)>();
foreach (var user in users)
{
var userId = user.GetUserId();
if (userId.HasValue)
{
var hasTwoFactor = twoFactorResults.FirstOrDefault(res => res.userId == userId.Value).twoFactorIsEnabled;
result.Add((user, hasTwoFactor));
}
else
{
result.Add((user, false));
}
}
return result;
}
private async Task<bool> TwoFactorIsEnabledVNextAsync(User user)
{
var enabledProviders = GetEnabledTwoFactorProviders(user);
if (!enabledProviders.Any())
{
return false;
}
// If all providers require premium, check if user has premium access
if (enabledProviders.All(TwoFactorProvider.RequiresPremium))
{
return await _hasPremiumAccessQuery.HasPremiumAccessAsync(user.Id);
}
// User has at least one non-premium provider
return true;
}
/// <summary>
/// Gets all enabled two-factor provider types for a user.
/// </summary>
/// <param name="user">user with two factor providers</param>
/// <returns>list of enabled provider types</returns>
private static IList<TwoFactorProviderType> GetEnabledTwoFactorProviders(User user)
{
var providers = user.GetTwoFactorProviders();
if (providers == null || providers.Count == 0)
{
return Array.Empty<TwoFactorProviderType>();
}
// TODO: PM-21210: In practice we don't save disabled providers to the database, worth looking into.
return (from provider in providers
where provider.Value?.Enabled ?? false
select provider.Key).ToList();
}
/// <summary>

View File

@@ -6,6 +6,7 @@ using Bit.Core.Billing.Organizations.Queries;
using Bit.Core.Billing.Organizations.Services;
using Bit.Core.Billing.Payment;
using Bit.Core.Billing.Premium.Commands;
using Bit.Core.Billing.Premium.Queries;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Implementations;
@@ -31,6 +32,7 @@ public static class ServiceCollectionExtensions
services.AddPaymentOperations();
services.AddOrganizationLicenseCommandsQueries();
services.AddPremiumCommands();
services.AddPremiumQueries();
services.AddTransient<IGetOrganizationMetadataQuery, GetOrganizationMetadataQuery>();
services.AddTransient<IGetOrganizationWarningsQuery, GetOrganizationWarningsQuery>();
services.AddTransient<IRestartSubscriptionCommand, RestartSubscriptionCommand>();
@@ -50,4 +52,9 @@ public static class ServiceCollectionExtensions
services.AddScoped<ICreatePremiumSelfHostedSubscriptionCommand, CreatePremiumSelfHostedSubscriptionCommand>();
services.AddTransient<IPreviewPremiumTaxCommand, PreviewPremiumTaxCommand>();
}
private static void AddPremiumQueries(this IServiceCollection services)
{
services.AddScoped<IHasPremiumAccessQuery, HasPremiumAccessQuery>();
}
}

View File

@@ -0,0 +1,29 @@
namespace Bit.Core.Billing.Premium.Models;
/// <summary>
/// Represents user premium access status from personal subscriptions and organization memberships.
/// </summary>
public class UserPremiumAccess
{
/// <summary>
/// The unique identifier for the user.
/// </summary>
public Guid Id { get; set; }
/// <summary>
/// Indicates whether the user has a personal premium subscription.
/// This does NOT include premium access from organizations.
/// </summary>
public bool PersonalPremium { get; set; }
/// <summary>
/// Indicates whether the user has premium access through any organization membership.
/// This is true if the user is a member of at least one enabled organization that grants premium access to users.
/// </summary>
public bool OrganizationPremium { get; set; }
/// <summary>
/// Indicates whether the user has premium access from any source (personal subscription or organization).
/// </summary>
public bool HasPremiumAccess => PersonalPremium || OrganizationPremium;
}

View File

@@ -0,0 +1,49 @@
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
namespace Bit.Core.Billing.Premium.Queries;
public class HasPremiumAccessQuery : IHasPremiumAccessQuery
{
private readonly IUserRepository _userRepository;
public HasPremiumAccessQuery(IUserRepository userRepository)
{
_userRepository = userRepository;
}
public async Task<bool> HasPremiumAccessAsync(Guid userId)
{
var user = await _userRepository.GetPremiumAccessAsync(userId);
if (user == null)
{
throw new NotFoundException();
}
return user.HasPremiumAccess;
}
public async Task<Dictionary<Guid, bool>> HasPremiumAccessAsync(IEnumerable<Guid> userIds)
{
var distinctUserIds = userIds.Distinct().ToList();
var usersWithPremium = await _userRepository.GetPremiumAccessByIdsAsync(distinctUserIds);
if (usersWithPremium.Count() != distinctUserIds.Count)
{
throw new NotFoundException();
}
return usersWithPremium.ToDictionary(u => u.Id, u => u.HasPremiumAccess);
}
public async Task<bool> HasPremiumFromOrganizationAsync(Guid userId)
{
var user = await _userRepository.GetPremiumAccessAsync(userId);
if (user == null)
{
throw new NotFoundException();
}
return user.OrganizationPremium;
}
}

View File

@@ -0,0 +1,30 @@
namespace Bit.Core.Billing.Premium.Queries;
/// <summary>
/// Centralized query for checking if users have premium access through personal subscriptions or organizations.
/// Note: Different from User.Premium which only checks personal subscriptions.
/// </summary>
public interface IHasPremiumAccessQuery
{
/// <summary>
/// Checks if a user has premium access (personal or organization).
/// </summary>
/// <param name="userId">The user ID to check</param>
/// <returns>True if user can access premium features</returns>
Task<bool> HasPremiumAccessAsync(Guid userId);
/// <summary>
/// Checks premium access for multiple users.
/// </summary>
/// <param name="userIds">The user IDs to check</param>
/// <returns>Dictionary mapping user IDs to their premium access status</returns>
Task<Dictionary<Guid, bool>> HasPremiumAccessAsync(IEnumerable<Guid> userIds);
/// <summary>
/// Checks if a user belongs to any organization that grants premium (enabled org with UsersGetPremium).
/// Returns true regardless of personal subscription. Useful for UI decisions like showing subscription options.
/// </summary>
/// <param name="userId">The user ID to check</param>
/// <returns>True if user is in any organization that grants premium</returns>
Task<bool> HasPremiumFromOrganizationAsync(Guid userId);
}

View File

@@ -143,6 +143,7 @@ public static class FeatureFlagKeys
public const string BlockClaimedDomainAccountCreation = "pm-28297-block-uninvited-claimed-domain-registration";
public const string IncreaseBulkReinviteLimitForCloud = "pm-28251-increase-bulk-reinvite-limit-for-cloud";
public const string BulkRevokeUsersV2 = "pm-28456-bulk-revoke-users-v2";
public const string PremiumAccessQuery = "pm-21411-premium-access-query";
/* Architecture */
public const string DesktopMigrationMilestone1 = "desktop-ui-migration-milestone-1";

View File

@@ -69,6 +69,11 @@ public class User : ITableObject<Guid>, IStorableSubscriber, IRevisable, ITwoFac
/// The security state is a signed object attesting to the version of the user's account.
/// </summary>
public string? SecurityState { get; set; }
/// <summary>
/// Indicates whether the user has a personal premium subscription.
/// Does not include premium access from organizations -
/// do not use this to check whether the user can access premium features.
/// </summary>
public bool Premium { get; set; }
public DateTime? PremiumExpirationDate { get; set; }
public DateTime? RenewalReminderDate { get; set; }

View File

@@ -1,4 +1,5 @@
using Bit.Core.Entities;
using Bit.Core.Billing.Premium.Models;
using Bit.Core.Entities;
using Bit.Core.KeyManagement.Models.Data;
using Bit.Core.KeyManagement.UserKey;
using Bit.Core.Models.Data;
@@ -24,6 +25,7 @@ public interface IUserRepository : IRepository<User, Guid>
/// Retrieves the data for the requested user IDs and includes an additional property indicating
/// whether the user has premium access directly or through an organization.
/// </summary>
[Obsolete("Use GetPremiumAccessByIdsAsync instead. This method will be removed in a future version.")]
Task<IEnumerable<UserWithCalculatedPremium>> GetManyWithCalculatedPremiumAsync(IEnumerable<Guid> ids);
/// <summary>
/// Retrieves the data for the requested user ID and includes additional property indicating
@@ -34,8 +36,23 @@ public interface IUserRepository : IRepository<User, Guid>
/// </summary>
/// <param name="userId">The user ID to retrieve data for.</param>
/// <returns>User data with calculated premium access; null if nothing is found</returns>
[Obsolete("Use GetPremiumAccessAsync instead. This method will be removed in a future version.")]
Task<UserWithCalculatedPremium?> GetCalculatedPremiumAsync(Guid userId);
/// <summary>
/// Retrieves premium access status for multiple users.
/// For internal use - consumers should use IHasPremiumAccessQuery instead.
/// </summary>
/// <param name="ids">The user IDs to check</param>
/// <returns>Collection of UserPremiumAccess objects containing premium status information</returns>
Task<IEnumerable<UserPremiumAccess>> GetPremiumAccessByIdsAsync(IEnumerable<Guid> ids);
/// <summary>
/// Retrieves premium access status for a single user.
/// For internal use - consumers should use IHasPremiumAccessQuery instead.
/// </summary>
/// <param name="userId">The user ID to check</param>
/// <returns>UserPremiumAccess object containing premium status information, or null if user not found</returns>
Task<UserPremiumAccess?> GetPremiumAccessAsync(Guid userId);
/// <summary>
/// Sets a new user key and updates all encrypted data.
/// <para>Warning: Any user key encrypted data not included will be lost.</para>
/// </summary>

View File

@@ -60,7 +60,7 @@ public interface IUserService
/// <summary>
/// Checks if the user has access to premium features, either through a personal subscription or through an organization.
///
/// This is the preferred way to definitively know if a user has access to premium features.
/// This is the preferred way to definitively know if a user has access to premium features when you already have the User object.
/// </summary>
/// <param name="user">user being acted on</param>
/// <returns>true if they can access premium; false otherwise.</returns>
@@ -74,6 +74,7 @@ public interface IUserService
/// </summary>
/// <param name="user">user being acted on</param>
/// <returns>true if they can access premium because of organization membership; false otherwise.</returns>
[Obsolete("Use IHasPremiumAccessQuery.HasPremiumFromOrganizationAsync instead. This method will be removed in a future version.")]
Task<bool> HasPremiumFromOrganization(User user);
Task<string> GenerateSignInTokenAsync(User user, string purpose);

View File

@@ -17,6 +17,7 @@ using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Business;
using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Premium.Queries;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Tax.Models;
@@ -73,6 +74,7 @@ public class UserService : UserManager<User>, IUserService
private readonly IDistributedCache _distributedCache;
private readonly IPolicyRequirementQuery _policyRequirementQuery;
private readonly IPricingClient _pricingClient;
private readonly IHasPremiumAccessQuery _hasPremiumAccessQuery;
public UserService(
IUserRepository userRepository,
@@ -108,7 +110,8 @@ public class UserService : UserManager<User>, IUserService
ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery,
IDistributedCache distributedCache,
IPolicyRequirementQuery policyRequirementQuery,
IPricingClient pricingClient)
IPricingClient pricingClient,
IHasPremiumAccessQuery hasPremiumAccessQuery)
: base(
store,
optionsAccessor,
@@ -149,6 +152,7 @@ public class UserService : UserManager<User>, IUserService
_distributedCache = distributedCache;
_policyRequirementQuery = policyRequirementQuery;
_pricingClient = pricingClient;
_hasPremiumAccessQuery = hasPremiumAccessQuery;
}
public Guid? GetProperUserId(ClaimsPrincipal principal)
@@ -1112,6 +1116,11 @@ public class UserService : UserManager<User>, IUserService
return false;
}
if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery))
{
return user.Premium || await _hasPremiumAccessQuery.HasPremiumFromOrganizationAsync(userId.Value);
}
return user.Premium || await HasPremiumFromOrganization(user);
}
@@ -1123,6 +1132,11 @@ public class UserService : UserManager<User>, IUserService
return false;
}
if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery))
{
return await _hasPremiumAccessQuery.HasPremiumFromOrganizationAsync(userId.Value);
}
// orgUsers in the Invited status are not associated with a userId yet, so this will get
// orgUsers in Accepted and Confirmed states only
var orgUsers = await _organizationUserRepository.GetManyByUserAsync(userId.Value);

View File

@@ -1,6 +1,7 @@
using System.Data;
using System.Text.Json;
using Bit.Core;
using Bit.Core.Billing.Premium.Models;
using Bit.Core.Entities;
using Bit.Core.KeyManagement.Models.Data;
using Bit.Core.KeyManagement.UserKey;
@@ -381,6 +382,25 @@ public class UserRepository : Repository<User, Guid>, IUserRepository
return result.SingleOrDefault();
}
public async Task<IEnumerable<UserPremiumAccess>> GetPremiumAccessByIdsAsync(IEnumerable<Guid> ids)
{
using (var connection = new SqlConnection(ReadOnlyConnectionString))
{
var results = await connection.QueryAsync<UserPremiumAccess>(
$"[{Schema}].[{Table}_ReadPremiumAccessByIds]",
new { Ids = ids.ToGuidIdArrayTVP() },
commandType: CommandType.StoredProcedure);
return results.ToList();
}
}
public async Task<UserPremiumAccess?> GetPremiumAccessAsync(Guid userId)
{
var result = await GetPremiumAccessByIdsAsync([userId]);
return result.SingleOrDefault();
}
private async Task ProtectDataAndSaveAsync(User user, Func<Task> saveTask)
{
if (user == null)

View File

@@ -18,7 +18,7 @@ public class OrganizationEntityTypeConfiguration : IEntityTypeConfiguration<Orga
NpgsqlIndexBuilderExtensions.IncludeProperties(
builder.HasIndex(o => new { o.Id, o.Enabled }),
o => o.UseTotp);
o => new { o.UseTotp, o.UsersGetPremium });
builder.ToTable(nameof(Organization));
}

View File

@@ -1,4 +1,5 @@
using AutoMapper;
using Bit.Core.Billing.Premium.Models;
using Bit.Core.KeyManagement.Models.Data;
using Bit.Core.KeyManagement.UserKey;
using Bit.Core.Models.Data;
@@ -350,6 +351,36 @@ public class UserRepository : Repository<Core.Entities.User, User, Guid>, IUserR
return result.FirstOrDefault();
}
public async Task<IEnumerable<UserPremiumAccess>> GetPremiumAccessByIdsAsync(IEnumerable<Guid> ids)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var users = await dbContext.Users
.Where(x => ids.Contains(x.Id))
.Include(u => u.OrganizationUsers)
.ThenInclude(ou => ou.Organization)
.ToListAsync();
return users.Select(user => new UserPremiumAccess
{
Id = user.Id,
PersonalPremium = user.Premium,
OrganizationPremium = user.OrganizationUsers
.Any(ou => ou.Organization != null &&
ou.Organization.Enabled == true &&
ou.Organization.UsersGetPremium == true)
}).ToList();
}
}
public async Task<UserPremiumAccess?> GetPremiumAccessAsync(Guid userId)
{
var result = await GetPremiumAccessByIdsAsync([userId]);
return result.FirstOrDefault();
}
public override async Task DeleteAsync(Core.Entities.User user)
{
using (var scope = ServiceScopeFactory.CreateScope())

View File

@@ -0,0 +1,15 @@
CREATE PROCEDURE [dbo].[User_ReadPremiumAccessByIds]
@Ids [dbo].[GuidIdArray] READONLY
AS
BEGIN
SET NOCOUNT ON
SELECT
UPA.[Id],
UPA.[PersonalPremium],
UPA.[OrganizationPremium]
FROM
[dbo].[UserPremiumAccessView] UPA
WHERE
UPA.[Id] IN (SELECT [Id] FROM @Ids)
END

View File

@@ -69,7 +69,7 @@ CREATE TABLE [dbo].[Organization] (
GO
CREATE NONCLUSTERED INDEX [IX_Organization_Enabled]
ON [dbo].[Organization]([Id] ASC, [Enabled] ASC)
INCLUDE ([UseTotp]);
INCLUDE ([UseTotp], [UsersGetPremium]);
GO
CREATE UNIQUE NONCLUSTERED INDEX [IX_Organization_Identifier]

View File

@@ -0,0 +1,21 @@
CREATE VIEW [dbo].[UserPremiumAccessView]
AS
SELECT
U.[Id],
U.[Premium] AS [PersonalPremium],
CAST(
MAX(CASE
WHEN O.[Id] IS NOT NULL THEN 1
ELSE 0
END) AS BIT
) AS [OrganizationPremium]
FROM
[dbo].[User] U
LEFT JOIN
[dbo].[OrganizationUser] OU ON OU.[UserId] = U.[Id]
LEFT JOIN
[dbo].[Organization] O ON O.[Id] = OU.[OrganizationId]
AND O.[UsersGetPremium] = 1
AND O.[Enabled] = 1
GROUP BY
U.[Id], U.[Premium];

View File

@@ -1,10 +1,13 @@
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models;
using Bit.Core.Auth.UserFeatures.TwoFactorAuth;
using Bit.Core.Billing.Premium.Queries;
using Bit.Core.Entities;
using Bit.Core.Exceptions;
using Bit.Core.Models.Data;
using Bit.Core.Models.Data.Organizations.OrganizationUsers;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Utilities;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
@@ -404,6 +407,277 @@ public class TwoFactorIsEnabledQueryTests
.GetCalculatedPremiumAsync(default);
}
[Theory]
[BitAutoData((IEnumerable<Guid>)null)]
[BitAutoData([])]
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithNoUserIds_ReturnsEmpty(
IEnumerable<Guid> userIds,
SutProvider<TwoFactorIsEnabledQuery> sutProvider)
{
// Arrange
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
.Returns(true);
// Act
var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(userIds);
// Assert
Assert.Empty(result);
}
[Theory]
[BitAutoData(TwoFactorProviderType.Duo)]
[BitAutoData(TwoFactorProviderType.YubiKey)]
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithMixedScenarios_ReturnsCorrectResults(
TwoFactorProviderType premiumProviderType,
SutProvider<TwoFactorIsEnabledQuery> sutProvider,
User user1,
User user2,
User user3)
{
// Arrange
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
.Returns(true);
var users = new List<User> { user1, user2, user3 };
var userIds = users.Select(u => u.Id).ToList();
// User 1: Non-premium provider → 2FA enabled
user1.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
{
{ TwoFactorProviderType.Authenticator, new TwoFactorProvider { Enabled = true } }
});
// User 2: Premium provider + has premium → 2FA enabled
user2.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
{
{ premiumProviderType, new TwoFactorProvider { Enabled = true } }
});
// User 3: Premium provider + no premium → 2FA disabled
user3.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
{
{ premiumProviderType, new TwoFactorProvider { Enabled = true } }
});
var premiumStatus = new Dictionary<Guid, bool>
{
{ user2.Id, true },
{ user3.Id, false }
};
sutProvider.GetDependency<IUserRepository>()
.GetManyAsync(Arg.Is<IEnumerable<Guid>>(ids => ids.SequenceEqual(userIds)))
.Returns(users);
sutProvider.GetDependency<IHasPremiumAccessQuery>()
.HasPremiumAccessAsync(Arg.Is<IEnumerable<Guid>>(ids =>
ids.Count() == 2 && ids.Contains(user2.Id) && ids.Contains(user3.Id)))
.Returns(premiumStatus);
// Act
var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(userIds);
// Assert
Assert.Contains(result, res => res.userId == user1.Id && res.twoFactorIsEnabled == true); // Non-premium provider
Assert.Contains(result, res => res.userId == user2.Id && res.twoFactorIsEnabled == true); // Premium + has premium
Assert.Contains(result, res => res.userId == user3.Id && res.twoFactorIsEnabled == false); // Premium + no premium
}
[Theory]
[BitAutoData(TwoFactorProviderType.Duo)]
[BitAutoData(TwoFactorProviderType.YubiKey)]
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_OnlyChecksPremiumAccessForUsersWhoNeedIt(
TwoFactorProviderType premiumProviderType,
SutProvider<TwoFactorIsEnabledQuery> sutProvider,
User user1,
User user2,
User user3)
{
// Arrange
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
.Returns(true);
var users = new List<User> { user1, user2, user3 };
var userIds = users.Select(u => u.Id).ToList();
// User 1: Has non-premium provider - should NOT trigger premium check
user1.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
{
{ TwoFactorProviderType.Authenticator, new TwoFactorProvider { Enabled = true } }
});
// User 2 & 3: Have only premium providers - SHOULD trigger premium check
user2.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
{
{ premiumProviderType, new TwoFactorProvider { Enabled = true } }
});
user3.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
{
{ premiumProviderType, new TwoFactorProvider { Enabled = true } }
});
var premiumStatus = new Dictionary<Guid, bool>
{
{ user2.Id, true },
{ user3.Id, false }
};
sutProvider.GetDependency<IUserRepository>()
.GetManyAsync(Arg.Is<IEnumerable<Guid>>(ids => ids.SequenceEqual(userIds)))
.Returns(users);
sutProvider.GetDependency<IHasPremiumAccessQuery>()
.HasPremiumAccessAsync(Arg.Is<IEnumerable<Guid>>(ids =>
ids.Count() == 2 && ids.Contains(user2.Id) && ids.Contains(user3.Id)))
.Returns(premiumStatus);
// Act
var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(userIds);
// Assert - Verify optimization: premium checked ONLY for users 2 and 3 (not user 1)
await sutProvider.GetDependency<IHasPremiumAccessQuery>()
.Received(1)
.HasPremiumAccessAsync(Arg.Is<IEnumerable<Guid>>(ids =>
ids.Count() == 2 && ids.Contains(user2.Id) && ids.Contains(user3.Id)));
}
[Theory]
[BitAutoData]
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithNoUserIds_ReturnsAllTwoFactorDisabled(
SutProvider<TwoFactorIsEnabledQuery> sutProvider,
List<OrganizationUserUserDetails> users)
{
// Arrange
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
.Returns(true);
foreach (var user in users)
{
user.UserId = null;
}
// Act
var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(users);
// Assert
foreach (var user in users)
{
Assert.Contains(result, res => res.user.Equals(user) && res.twoFactorIsEnabled == false);
}
// No UserIds were supplied so no calls to the UserRepository should have been made
await sutProvider.GetDependency<IUserRepository>()
.DidNotReceiveWithAnyArgs()
.GetManyAsync(default);
}
[Theory]
[BitAutoData(TwoFactorProviderType.Authenticator, true)] // Non-premium provider
[BitAutoData(TwoFactorProviderType.Duo, true)] // Premium provider with premium access
[BitAutoData(TwoFactorProviderType.YubiKey, false)] // Premium provider without premium access
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_SingleUser_VariousScenarios(
TwoFactorProviderType providerType,
bool hasPremiumAccess,
SutProvider<TwoFactorIsEnabledQuery> sutProvider,
User user)
{
// Arrange
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
.Returns(true);
user.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
{
{ providerType, new TwoFactorProvider { Enabled = true } }
});
sutProvider.GetDependency<IHasPremiumAccessQuery>()
.HasPremiumAccessAsync(user.Id)
.Returns(hasPremiumAccess);
// Act
var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(user);
// Assert
var requiresPremium = TwoFactorProvider.RequiresPremium(providerType);
var expectedResult = !requiresPremium || hasPremiumAccess;
Assert.Equal(expectedResult, result);
}
[Theory]
[BitAutoData]
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithNoEnabledProviders_ReturnsFalse(
SutProvider<TwoFactorIsEnabledQuery> sutProvider,
User user)
{
// Arrange
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
.Returns(true);
user.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>
{
{ TwoFactorProviderType.Email, new TwoFactorProvider { Enabled = false } }
});
// Act
var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(user);
// Assert
Assert.False(result);
}
[Theory]
[BitAutoData]
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithNullProviders_ReturnsFalse(
SutProvider<TwoFactorIsEnabledQuery> sutProvider,
User user)
{
// Arrange
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
.Returns(true);
user.TwoFactorProviders = null;
// Act
var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(user);
// Assert
Assert.False(result);
}
[Theory]
[BitAutoData]
public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_UserNotFound_ThrowsNotFoundException(
SutProvider<TwoFactorIsEnabledQuery> sutProvider,
Guid userId)
{
// Arrange
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)
.Returns(true);
var testUser = new TestTwoFactorProviderUser
{
Id = userId,
TwoFactorProviders = null
};
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(userId)
.Returns((User)null);
// Act & Assert
await Assert.ThrowsAsync<NotFoundException>(
async () => await sutProvider.Sut.TwoFactorIsEnabledAsync(testUser));
}
private class TestTwoFactorProviderUser : ITwoFactorProvidersUser
{
public Guid? Id { get; set; }
@@ -418,10 +692,5 @@ public class TwoFactorIsEnabledQueryTests
{
return Id;
}
public bool GetPremium()
{
return Premium;
}
}
}

View File

@@ -0,0 +1,234 @@
using Bit.Core.Billing.Premium.Models;
using Bit.Core.Billing.Premium.Queries;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Billing.Premium.Queries;
[SutProviderCustomize]
public class HasPremiumAccessQueryTests
{
[Theory, BitAutoData]
public async Task HasPremiumAccessAsync_WhenUserHasPersonalPremium_ReturnsTrue(
UserPremiumAccess user,
SutProvider<HasPremiumAccessQuery> sutProvider)
{
// Arrange
user.PersonalPremium = true;
user.OrganizationPremium = false;
sutProvider.GetDependency<IUserRepository>()
.GetPremiumAccessAsync(user.Id)
.Returns(user);
// Act
var result = await sutProvider.Sut.HasPremiumAccessAsync(user.Id);
// Assert
Assert.True(result);
}
[Theory, BitAutoData]
public async Task HasPremiumAccessAsync_WhenUserHasNoPersonalPremiumButHasOrgPremium_ReturnsTrue(
UserPremiumAccess user,
SutProvider<HasPremiumAccessQuery> sutProvider)
{
// Arrange
user.PersonalPremium = false;
user.OrganizationPremium = true; // Has org premium
sutProvider.GetDependency<IUserRepository>()
.GetPremiumAccessAsync(user.Id)
.Returns(user);
// Act
var result = await sutProvider.Sut.HasPremiumAccessAsync(user.Id);
// Assert
Assert.True(result);
}
[Theory, BitAutoData]
public async Task HasPremiumAccessAsync_WhenUserHasNoPersonalPremiumAndNoOrgPremium_ReturnsFalse(
UserPremiumAccess user,
SutProvider<HasPremiumAccessQuery> sutProvider)
{
// Arrange
user.PersonalPremium = false;
user.OrganizationPremium = false;
sutProvider.GetDependency<IUserRepository>()
.GetPremiumAccessAsync(user.Id)
.Returns(user);
// Act
var result = await sutProvider.Sut.HasPremiumAccessAsync(user.Id);
// Assert
Assert.False(result);
}
[Theory, BitAutoData]
public async Task HasPremiumAccessAsync_WhenUserNotFound_ThrowsNotFoundException(
Guid userId,
SutProvider<HasPremiumAccessQuery> sutProvider)
{
// Arrange
sutProvider.GetDependency<IUserRepository>()
.GetPremiumAccessAsync(userId)
.Returns((UserPremiumAccess?)null);
// Act & Assert
await Assert.ThrowsAsync<NotFoundException>(
() => sutProvider.Sut.HasPremiumAccessAsync(userId));
}
[Theory, BitAutoData]
public async Task HasPremiumFromOrganizationAsync_WhenUserHasNoOrganizations_ReturnsFalse(
UserPremiumAccess user,
SutProvider<HasPremiumAccessQuery> sutProvider)
{
// Arrange
user.PersonalPremium = false;
user.OrganizationPremium = false; // No premium from anywhere
sutProvider.GetDependency<IUserRepository>()
.GetPremiumAccessAsync(user.Id)
.Returns(user);
// Act
var result = await sutProvider.Sut.HasPremiumFromOrganizationAsync(user.Id);
// Assert
Assert.False(result);
}
[Theory, BitAutoData]
public async Task HasPremiumFromOrganizationAsync_WhenUserHasPremiumFromOrg_ReturnsTrue(
UserPremiumAccess user,
SutProvider<HasPremiumAccessQuery> sutProvider)
{
// Arrange
user.PersonalPremium = false; // No personal premium
user.OrganizationPremium = true; // But has premium from org
sutProvider.GetDependency<IUserRepository>()
.GetPremiumAccessAsync(user.Id)
.Returns(user);
// Act
var result = await sutProvider.Sut.HasPremiumFromOrganizationAsync(user.Id);
// Assert
Assert.True(result);
}
[Theory, BitAutoData]
public async Task HasPremiumFromOrganizationAsync_WhenUserHasOnlyPersonalPremium_ReturnsFalse(
UserPremiumAccess user,
SutProvider<HasPremiumAccessQuery> sutProvider)
{
// Arrange
user.PersonalPremium = true; // Has personal premium
user.OrganizationPremium = false; // Not in any org that grants premium
sutProvider.GetDependency<IUserRepository>()
.GetPremiumAccessAsync(user.Id)
.Returns(user);
// Act
var result = await sutProvider.Sut.HasPremiumFromOrganizationAsync(user.Id);
// Assert
Assert.False(result); // Should return false because user is not in an org that grants premium
}
[Theory, BitAutoData]
public async Task HasPremiumFromOrganizationAsync_WhenUserHasBothPersonalAndOrgPremium_ReturnsTrue(
UserPremiumAccess user,
SutProvider<HasPremiumAccessQuery> sutProvider)
{
// Arrange
user.PersonalPremium = true; // Has personal premium
user.OrganizationPremium = true; // Also in an org that grants premium
sutProvider.GetDependency<IUserRepository>()
.GetPremiumAccessAsync(user.Id)
.Returns(user);
// Act
var result = await sutProvider.Sut.HasPremiumFromOrganizationAsync(user.Id);
// Assert
Assert.True(result); // Should return true because user IS in an org that grants premium (regardless of personal premium)
}
[Theory, BitAutoData]
public async Task HasPremiumFromOrganizationAsync_WhenUserNotFound_ThrowsNotFoundException(
Guid userId,
SutProvider<HasPremiumAccessQuery> sutProvider)
{
// Arrange
sutProvider.GetDependency<IUserRepository>()
.GetPremiumAccessAsync(userId)
.Returns((UserPremiumAccess?)null);
// Act & Assert
await Assert.ThrowsAsync<NotFoundException>(
() => sutProvider.Sut.HasPremiumFromOrganizationAsync(userId));
}
[Theory, BitAutoData]
public async Task HasPremiumAccessAsync_Bulk_WhenEmptyList_ReturnsEmptyDictionary(
SutProvider<HasPremiumAccessQuery> sutProvider)
{
// Arrange
var userIds = new List<Guid>();
sutProvider.GetDependency<IUserRepository>()
.GetPremiumAccessByIdsAsync(userIds)
.Returns(new List<UserPremiumAccess>());
// Act
var result = await sutProvider.Sut.HasPremiumAccessAsync(userIds);
// Assert
Assert.Empty(result);
}
[Theory, BitAutoData]
public async Task HasPremiumAccessAsync_Bulk_ReturnsCorrectStatus(
UserPremiumAccess user1,
UserPremiumAccess user2,
UserPremiumAccess user3,
SutProvider<HasPremiumAccessQuery> sutProvider)
{
// Arrange
user1.PersonalPremium = true;
user1.OrganizationPremium = false;
user2.PersonalPremium = false;
user2.OrganizationPremium = false;
user3.PersonalPremium = false;
user3.OrganizationPremium = true;
var users = new List<UserPremiumAccess> { user1, user2, user3 };
var userIds = users.Select(u => u.Id).ToList();
sutProvider.GetDependency<IUserRepository>()
.GetPremiumAccessByIdsAsync(Arg.Is<IEnumerable<Guid>>(ids => ids.SequenceEqual(userIds)))
.Returns(users);
// Act
var result = await sutProvider.Sut.HasPremiumAccessAsync(userIds);
// Assert
Assert.Equal(3, result.Count);
Assert.True(result[user1.Id]); // Personal premium
Assert.False(result[user2.Id]); // No premium
Assert.True(result[user3.Id]); // Organization premium
}
}

View File

@@ -179,4 +179,325 @@ public class UserRepositoryTests
Assert.Equal(CollectionType.SharedCollection, updatedCollection2.Type);
Assert.Equal(user2.Email, updatedCollection2.DefaultUserCollectionEmail);
}
[Theory, DatabaseData]
public async Task GetPremiumAccessAsync_WithPersonalPremium_ReturnsCorrectAccess(
IUserRepository userRepository)
{
// Arrange
var user = await userRepository.CreateAsync(new User
{
Name = "Premium User",
Email = $"premium+{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = true
});
// Act
var result = await userRepository.GetPremiumAccessAsync(user.Id);
// Assert
Assert.NotNull(result);
Assert.True(result.PersonalPremium);
Assert.False(result.OrganizationPremium);
Assert.True(result.HasPremiumAccess);
}
[Theory, DatabaseData]
public async Task GetPremiumAccessAsync_WithOrganizationPremium_ReturnsCorrectAccess(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository)
{
// Arrange
var user = await userRepository.CreateAsync(new User
{
Name = "Org User",
Email = $"org+{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = false
});
var organization = await organizationRepository.CreateTestOrganizationAsync();
await organizationUserRepository.CreateTestOrganizationUserAsync(organization, user);
// Act
var result = await userRepository.GetPremiumAccessAsync(user.Id);
// Assert
Assert.NotNull(result);
Assert.False(result.PersonalPremium);
Assert.True(result.OrganizationPremium);
Assert.True(result.HasPremiumAccess);
}
[Theory, DatabaseData]
public async Task GetPremiumAccessAsync_WithDisabledOrganization_ReturnsNoOrganizationPremium(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository)
{
// Arrange
var user = await userRepository.CreateAsync(new User
{
Name = "User",
Email = $"user+{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = false
});
var organization = await organizationRepository.CreateTestOrganizationAsync();
organization.Enabled = false;
await organizationRepository.ReplaceAsync(organization);
await organizationUserRepository.CreateTestOrganizationUserAsync(organization, user);
// Act
var result = await userRepository.GetPremiumAccessAsync(user.Id);
// Assert
Assert.NotNull(result);
Assert.False(result.OrganizationPremium);
Assert.False(result.HasPremiumAccess);
}
[Theory, DatabaseData]
public async Task GetPremiumAccessAsync_WithOrganizationUsersGetPremiumFalse_ReturnsNoOrganizationPremium(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository)
{
// Arrange
var user = await userRepository.CreateAsync(new User
{
Name = "User",
Email = $"{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = false
});
var organization = await organizationRepository.CreateTestOrganizationAsync();
organization.UsersGetPremium = false;
await organizationRepository.ReplaceAsync(organization);
await organizationUserRepository.CreateTestOrganizationUserAsync(organization, user);
// Act
var result = await userRepository.GetPremiumAccessAsync(user.Id);
// Assert
Assert.NotNull(result);
Assert.False(result.OrganizationPremium);
Assert.False(result.HasPremiumAccess);
}
[Theory, DatabaseData]
public async Task GetPremiumAccessAsync_WithMultipleOrganizations_OneProvidesPremium_ReturnsOrganizationPremium(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository)
{
// Arrange
var user = await userRepository.CreateAsync(new User
{
Name = "User With Premium Org",
Email = $"{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = false
});
var orgWithPremium = await organizationRepository.CreateTestOrganizationAsync();
await organizationUserRepository.CreateTestOrganizationUserAsync(orgWithPremium, user);
var orgNoPremium = await organizationRepository.CreateTestOrganizationAsync();
orgNoPremium.UsersGetPremium = false;
await organizationRepository.ReplaceAsync(orgNoPremium);
await organizationUserRepository.CreateTestOrganizationUserAsync(orgNoPremium, user);
// Act
var result = await userRepository.GetPremiumAccessAsync(user.Id);
// Assert
Assert.NotNull(result);
Assert.False(result.PersonalPremium);
Assert.True(result.OrganizationPremium);
Assert.True(result.HasPremiumAccess);
}
[Theory, DatabaseData]
public async Task GetPremiumAccessAsync_WithMultipleOrganizations_NoneProvidePremium_ReturnsNoOrganizationPremium(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository)
{
// Arrange
var user = await userRepository.CreateAsync(new User
{
Name = "User With No Premium Orgs",
Email = $"{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = false
});
var disabledOrg = await organizationRepository.CreateTestOrganizationAsync();
disabledOrg.Enabled = false;
await organizationRepository.ReplaceAsync(disabledOrg);
await organizationUserRepository.CreateTestOrganizationUserAsync(disabledOrg, user);
var orgNoPremium = await organizationRepository.CreateTestOrganizationAsync();
orgNoPremium.UsersGetPremium = false;
await organizationRepository.ReplaceAsync(orgNoPremium);
await organizationUserRepository.CreateTestOrganizationUserAsync(orgNoPremium, user);
// Act
var result = await userRepository.GetPremiumAccessAsync(user.Id);
// Assert
Assert.NotNull(result);
Assert.False(result.PersonalPremium);
Assert.False(result.OrganizationPremium);
Assert.False(result.HasPremiumAccess);
}
[Theory, DatabaseData]
public async Task GetPremiumAccessAsync_NonExistentUser_ReturnsNull(
IUserRepository userRepository)
{
// Act
var result = await userRepository.GetPremiumAccessAsync(Guid.NewGuid());
// Assert
Assert.Null(result);
}
[Theory, DatabaseData]
public async Task GetPremiumAccessByIdsAsync_MultipleUsers_ReturnsCorrectAccessForEach(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository)
{
// Arrange
var personalPremiumUser = await userRepository.CreateAsync(new User
{
Name = "Personal Premium",
Email = $"{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = true
});
var orgPremiumUser = await userRepository.CreateAsync(new User
{
Name = "Org Premium",
Email = $"{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = false
});
var bothPremiumUser = await userRepository.CreateAsync(new User
{
Name = "Both Premium",
Email = $"{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = true
});
var noPremiumUser = await userRepository.CreateAsync(new User
{
Name = "No Premium",
Email = $"{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = false
});
var multiOrgUser = await userRepository.CreateAsync(new User
{
Name = "Multi Org User",
Email = $"{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = false
});
var personalPremiumWithDisabledOrg = await userRepository.CreateAsync(new User
{
Name = "Personal Premium With Disabled Org",
Email = $"{Guid.NewGuid()}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
Premium = true
});
var organization = await organizationRepository.CreateTestOrganizationAsync();
await organizationUserRepository.CreateTestOrganizationUserAsync(organization, orgPremiumUser);
await organizationUserRepository.CreateTestOrganizationUserAsync(organization, bothPremiumUser);
await organizationUserRepository.CreateTestOrganizationUserAsync(organization, multiOrgUser);
var orgWithoutPremium = await organizationRepository.CreateTestOrganizationAsync();
orgWithoutPremium.UsersGetPremium = false;
await organizationRepository.ReplaceAsync(orgWithoutPremium);
await organizationUserRepository.CreateTestOrganizationUserAsync(orgWithoutPremium, multiOrgUser);
var disabledOrg = await organizationRepository.CreateTestOrganizationAsync();
disabledOrg.Enabled = false;
await organizationRepository.ReplaceAsync(disabledOrg);
await organizationUserRepository.CreateTestOrganizationUserAsync(disabledOrg, personalPremiumWithDisabledOrg);
// Act
var results = await userRepository.GetPremiumAccessByIdsAsync([
personalPremiumUser.Id,
orgPremiumUser.Id,
bothPremiumUser.Id,
noPremiumUser.Id,
multiOrgUser.Id,
personalPremiumWithDisabledOrg.Id
]);
var resultsList = results.ToList();
// Assert
Assert.Equal(6, resultsList.Count);
var personalResult = resultsList.First(r => r.Id == personalPremiumUser.Id);
Assert.True(personalResult.PersonalPremium);
Assert.False(personalResult.OrganizationPremium);
var orgResult = resultsList.First(r => r.Id == orgPremiumUser.Id);
Assert.False(orgResult.PersonalPremium);
Assert.True(orgResult.OrganizationPremium);
var bothResult = resultsList.First(r => r.Id == bothPremiumUser.Id);
Assert.True(bothResult.PersonalPremium);
Assert.True(bothResult.OrganizationPremium);
var noneResult = resultsList.First(r => r.Id == noPremiumUser.Id);
Assert.False(noneResult.PersonalPremium);
Assert.False(noneResult.OrganizationPremium);
var multiResult = resultsList.First(r => r.Id == multiOrgUser.Id);
Assert.False(multiResult.PersonalPremium);
Assert.True(multiResult.OrganizationPremium);
var personalWithDisabledOrgResult = resultsList.First(r => r.Id == personalPremiumWithDisabledOrg.Id);
Assert.True(personalWithDisabledOrgResult.PersonalPremium);
Assert.False(personalWithDisabledOrgResult.OrganizationPremium);
}
[Theory, DatabaseData]
public async Task GetPremiumAccessByIdsAsync_EmptyList_ReturnsEmptyResult(
IUserRepository userRepository)
{
// Act
var results = await userRepository.GetPremiumAccessByIdsAsync([]);
// Assert
Assert.Empty(results);
}
}

View File

@@ -0,0 +1,60 @@
-- Add UsersGetPremium to IX_Organization_Enabled index to support premium access queries
IF EXISTS (
SELECT * FROM sys.indexes
WHERE name = 'IX_Organization_Enabled'
AND object_id = OBJECT_ID('[dbo].[Organization]')
)
BEGIN
CREATE NONCLUSTERED INDEX [IX_Organization_Enabled]
ON [dbo].[Organization]([Id] ASC, [Enabled] ASC)
INCLUDE ([UseTotp], [UsersGetPremium])
WITH (DROP_EXISTING = ON);
END
ELSE
BEGIN
CREATE NONCLUSTERED INDEX [IX_Organization_Enabled]
ON [dbo].[Organization]([Id] ASC, [Enabled] ASC)
INCLUDE ([UseTotp], [UsersGetPremium]);
END
GO
CREATE OR ALTER VIEW [dbo].[UserPremiumAccessView]
AS
SELECT
U.[Id],
U.[Premium] AS [PersonalPremium],
CAST(
MAX(CASE
WHEN O.[Id] IS NOT NULL THEN 1
ELSE 0
END) AS BIT
) AS [OrganizationPremium]
FROM
[dbo].[User] U
LEFT JOIN
[dbo].[OrganizationUser] OU ON OU.[UserId] = U.[Id]
LEFT JOIN
[dbo].[Organization] O ON O.[Id] = OU.[OrganizationId]
AND O.[UsersGetPremium] = 1
AND O.[Enabled] = 1
GROUP BY
U.[Id], U.[Premium];
GO
CREATE OR ALTER PROCEDURE [dbo].[User_ReadPremiumAccessByIds]
@Ids [dbo].[GuidIdArray] READONLY
AS
BEGIN
SET NOCOUNT ON
SELECT
UPA.[Id],
UPA.[PersonalPremium],
UPA.[OrganizationPremium]
FROM
[dbo].[UserPremiumAccessView] UPA
WHERE
UPA.[Id] IN (SELECT [Id] FROM @Ids)
END
GO

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,21 @@
using Microsoft.EntityFrameworkCore.Migrations;
#nullable disable
namespace Bit.MySqlMigrations.Migrations;
/// <inheritdoc />
public partial class OrganizationUsersGetPremiumIndex : Migration
{
/// <inheritdoc />
protected override void Up(MigrationBuilder migrationBuilder)
{
}
/// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder)
{
}
}

View File

@@ -274,7 +274,7 @@ namespace Bit.MySqlMigrations.Migrations
b.HasKey("Id");
b.HasIndex("Id", "Enabled")
.HasAnnotation("Npgsql:IndexInclude", new[] { "UseTotp" });
.HasAnnotation("Npgsql:IndexInclude", new[] { "UseTotp", "UsersGetPremium" });
b.ToTable("Organization", (string)null);
});

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,37 @@
using Microsoft.EntityFrameworkCore.Migrations;
#nullable disable
namespace Bit.PostgresMigrations.Migrations;
/// <inheritdoc />
public partial class OrganizationUsersGetPremiumIndex : Migration
{
/// <inheritdoc />
protected override void Up(MigrationBuilder migrationBuilder)
{
migrationBuilder.DropIndex(
name: "IX_Organization_Id_Enabled",
table: "Organization");
migrationBuilder.CreateIndex(
name: "IX_Organization_Id_Enabled",
table: "Organization",
columns: new[] { "Id", "Enabled" })
.Annotation("Npgsql:IndexInclude", new[] { "UseTotp", "UsersGetPremium" });
}
/// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder)
{
migrationBuilder.DropIndex(
name: "IX_Organization_Id_Enabled",
table: "Organization");
migrationBuilder.CreateIndex(
name: "IX_Organization_Id_Enabled",
table: "Organization",
columns: new[] { "Id", "Enabled" })
.Annotation("Npgsql:IndexInclude", new[] { "UseTotp" });
}
}

View File

@@ -277,7 +277,7 @@ namespace Bit.PostgresMigrations.Migrations
b.HasIndex("Id", "Enabled");
NpgsqlIndexBuilderExtensions.IncludeProperties(b.HasIndex("Id", "Enabled"), new[] { "UseTotp" });
NpgsqlIndexBuilderExtensions.IncludeProperties(b.HasIndex("Id", "Enabled"), new[] { "UseTotp", "UsersGetPremium" });
b.ToTable("Organization", (string)null);
});

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,21 @@
using Microsoft.EntityFrameworkCore.Migrations;
#nullable disable
namespace Bit.SqliteMigrations.Migrations;
/// <inheritdoc />
public partial class OrganizationUsersGetPremiumIndex : Migration
{
/// <inheritdoc />
protected override void Up(MigrationBuilder migrationBuilder)
{
}
/// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder)
{
}
}

View File

@@ -269,7 +269,7 @@ namespace Bit.SqliteMigrations.Migrations
b.HasKey("Id");
b.HasIndex("Id", "Enabled")
.HasAnnotation("Npgsql:IndexInclude", new[] { "UseTotp" });
.HasAnnotation("Npgsql:IndexInclude", new[] { "UseTotp", "UsersGetPremium" });
b.ToTable("Organization", (string)null);
});