mirror of
https://github.com/bitwarden/server
synced 2026-01-04 17:43:53 +00:00
Merge branch 'main' into SM-1571-DisableSMAdsForUsers
This commit is contained in:
@@ -129,11 +129,15 @@ public class Organization : ITableObject<Guid>, IStorableSubscriber, IRevisable
|
||||
/// </summary>
|
||||
public bool SyncSeats { get; set; }
|
||||
|
||||
/// If set to true, user accounts created within the organization are automatically confirmed without requiring additional verification steps.
|
||||
/// </summary>
|
||||
public bool UseAutomaticUserConfirmation { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// If set to true, disables Secrets Manager ads for users in the organization
|
||||
/// </summary>
|
||||
public bool UseDisableSMAdsForUsers { get; set; }
|
||||
|
||||
|
||||
public void SetNewId()
|
||||
{
|
||||
if (Id == default(Guid))
|
||||
|
||||
@@ -28,6 +28,7 @@ public class OrganizationAbility
|
||||
UseRiskInsights = organization.UseRiskInsights;
|
||||
UseOrganizationDomains = organization.UseOrganizationDomains;
|
||||
UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies;
|
||||
UseAutomaticUserConfirmation = organization.UseAutomaticUserConfirmation;
|
||||
UseDisableSMAdsForUsers = organization.UseDisableSMAdsForUsers;
|
||||
}
|
||||
|
||||
@@ -50,5 +51,6 @@ public class OrganizationAbility
|
||||
public bool UseRiskInsights { get; set; }
|
||||
public bool UseOrganizationDomains { get; set; }
|
||||
public bool UseAdminSponsoredFamilies { get; set; }
|
||||
public bool UseAutomaticUserConfirmation { get; set; }
|
||||
public bool UseDisableSMAdsForUsers { get; set; }
|
||||
}
|
||||
|
||||
@@ -66,4 +66,5 @@ public class OrganizationUserOrganizationDetails
|
||||
public bool UseOrganizationDomains { get; set; }
|
||||
public bool UseAdminSponsoredFamilies { get; set; }
|
||||
public bool? IsAdminInitiated { get; set; }
|
||||
public bool UseAutomaticUserConfirmation { get; set; }
|
||||
}
|
||||
|
||||
@@ -51,4 +51,5 @@ public class ProviderUserOrganizationDetails
|
||||
public bool UseOrganizationDomains { get; set; }
|
||||
public bool UseAdminSponsoredFamilies { get; set; }
|
||||
public ProviderType ProviderType { get; set; }
|
||||
public bool UseAutomaticUserConfirmation { get; set; }
|
||||
}
|
||||
|
||||
@@ -89,7 +89,7 @@ public class UpdateOrganizationAuthRequestCommand : IUpdateOrganizationAuthReque
|
||||
AuthRequestExpiresAfter = _globalSettings.PasswordlessAuth.AdminRequestExpiration
|
||||
}
|
||||
);
|
||||
processor.Process((Exception e) => _logger.LogError(e.Message));
|
||||
processor.Process((Exception e) => _logger.LogError("Error processing organization auth request: {Message}", e.Message));
|
||||
await processor.Save((IEnumerable<OrganizationAdminAuthRequest> authRequests) => _authRequestRepository.UpdateManyAsync(authRequests));
|
||||
await processor.SendPushNotifications((ar) => _pushNotificationService.PushAuthRequestResponseAsync(ar));
|
||||
await processor.SendApprovalEmailsForProcessedRequests(SendApprovalEmail);
|
||||
@@ -114,7 +114,7 @@ public class UpdateOrganizationAuthRequestCommand : IUpdateOrganizationAuthReque
|
||||
// This should be impossible
|
||||
if (user == null)
|
||||
{
|
||||
_logger.LogError($"User {authRequest.UserId} not found. Trusted device admin approval email not sent.");
|
||||
_logger.LogError("User {UserId} not found. Trusted device admin approval email not sent.", authRequest.UserId);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
#nullable enable
|
||||
|
||||
using Bit.Core.AdminConsole.Enums;
|
||||
using Bit.Core.AdminConsole.Enums;
|
||||
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
|
||||
using Bit.Core.AdminConsole.Repositories;
|
||||
@@ -20,7 +18,7 @@ public class PolicyRequirementQuery(
|
||||
throw new NotImplementedException("No Requirement Factory found for " + typeof(T));
|
||||
}
|
||||
|
||||
var policyDetails = await GetPolicyDetails(userId);
|
||||
var policyDetails = await GetPolicyDetails(userId, factory.PolicyType);
|
||||
var filteredPolicies = policyDetails
|
||||
.Where(p => p.PolicyType == factory.PolicyType)
|
||||
.Where(factory.Enforce);
|
||||
@@ -48,8 +46,8 @@ public class PolicyRequirementQuery(
|
||||
return eligibleOrganizationUserIds;
|
||||
}
|
||||
|
||||
private Task<IEnumerable<PolicyDetails>> GetPolicyDetails(Guid userId)
|
||||
=> policyRepository.GetPolicyDetailsByUserId(userId);
|
||||
private async Task<IEnumerable<OrganizationPolicyDetails>> GetPolicyDetails(Guid userId, PolicyType policyType)
|
||||
=> await policyRepository.GetPolicyDetailsByUserIdsAndPolicyType([userId], policyType);
|
||||
|
||||
private async Task<IEnumerable<OrganizationPolicyDetails>> GetOrganizationPolicyDetails(Guid organizationId, PolicyType policyType)
|
||||
=> await policyRepository.GetPolicyDetailsByOrganizationIdAsync(organizationId, policyType);
|
||||
|
||||
@@ -13,25 +13,11 @@ public class VNextSavePolicyCommand(
|
||||
IApplicationCacheService applicationCacheService,
|
||||
IEventService eventService,
|
||||
IPolicyRepository policyRepository,
|
||||
IEnumerable<IEnforceDependentPoliciesEvent> policyValidationEventHandlers,
|
||||
IEnumerable<IPolicyUpdateEvent> policyUpdateEventHandlers,
|
||||
TimeProvider timeProvider,
|
||||
IPolicyEventHandlerFactory policyEventHandlerFactory)
|
||||
: IVNextSavePolicyCommand
|
||||
{
|
||||
private readonly IReadOnlyDictionary<PolicyType, IEnforceDependentPoliciesEvent> _policyValidationEvents = MapToDictionary(policyValidationEventHandlers);
|
||||
|
||||
private static Dictionary<PolicyType, IEnforceDependentPoliciesEvent> MapToDictionary(IEnumerable<IEnforceDependentPoliciesEvent> policyValidationEventHandlers)
|
||||
{
|
||||
var policyValidationEventsDict = new Dictionary<PolicyType, IEnforceDependentPoliciesEvent>();
|
||||
foreach (var policyValidationEvent in policyValidationEventHandlers)
|
||||
{
|
||||
if (!policyValidationEventsDict.TryAdd(policyValidationEvent.Type, policyValidationEvent))
|
||||
{
|
||||
throw new Exception($"Duplicate PolicyValidationEvent for {policyValidationEvent.Type} policy.");
|
||||
}
|
||||
}
|
||||
return policyValidationEventsDict;
|
||||
}
|
||||
|
||||
public async Task<Policy> SaveAsync(SavePolicyModel policyRequest)
|
||||
{
|
||||
@@ -112,32 +98,26 @@ public class VNextSavePolicyCommand(
|
||||
Policy? currentPolicy,
|
||||
Dictionary<PolicyType, Policy> savedPoliciesDict)
|
||||
{
|
||||
var result = policyEventHandlerFactory.GetHandler<IEnforceDependentPoliciesEvent>(policyUpdateRequest.Type);
|
||||
var isCurrentlyEnabled = currentPolicy?.Enabled == true;
|
||||
var isBeingEnabled = policyUpdateRequest.Enabled && !isCurrentlyEnabled;
|
||||
var isBeingDisabled = !policyUpdateRequest.Enabled && isCurrentlyEnabled;
|
||||
|
||||
result.Switch(
|
||||
validator =>
|
||||
{
|
||||
var isCurrentlyEnabled = currentPolicy?.Enabled == true;
|
||||
|
||||
switch (policyUpdateRequest.Enabled)
|
||||
{
|
||||
case true when !isCurrentlyEnabled:
|
||||
ValidateEnablingRequirements(validator, savedPoliciesDict);
|
||||
return;
|
||||
case false when isCurrentlyEnabled:
|
||||
ValidateDisablingRequirements(validator, policyUpdateRequest.Type, savedPoliciesDict);
|
||||
break;
|
||||
}
|
||||
},
|
||||
_ => { });
|
||||
if (isBeingEnabled)
|
||||
{
|
||||
ValidateEnablingRequirements(policyUpdateRequest.Type, savedPoliciesDict);
|
||||
}
|
||||
else if (isBeingDisabled)
|
||||
{
|
||||
ValidateDisablingRequirements(policyUpdateRequest.Type, savedPoliciesDict);
|
||||
}
|
||||
}
|
||||
|
||||
private void ValidateDisablingRequirements(
|
||||
IEnforceDependentPoliciesEvent validator,
|
||||
PolicyType policyType,
|
||||
Dictionary<PolicyType, Policy> savedPoliciesDict)
|
||||
{
|
||||
var dependentPolicyTypes = _policyValidationEvents.Values
|
||||
var dependentPolicyTypes = policyUpdateEventHandlers
|
||||
.OfType<IEnforceDependentPoliciesEvent>()
|
||||
.Where(otherValidator => otherValidator.RequiredPolicies.Contains(policyType))
|
||||
.Select(otherValidator => otherValidator.Type)
|
||||
.Where(otherPolicyType => savedPoliciesDict.TryGetValue(otherPolicyType, out var savedPolicy) &&
|
||||
@@ -147,24 +127,31 @@ public class VNextSavePolicyCommand(
|
||||
switch (dependentPolicyTypes)
|
||||
{
|
||||
case { Count: 1 }:
|
||||
throw new BadRequestException($"Turn off the {dependentPolicyTypes.First().GetName()} policy because it requires the {validator.Type.GetName()} policy.");
|
||||
throw new BadRequestException($"Turn off the {dependentPolicyTypes.First().GetName()} policy because it requires the {policyType.GetName()} policy.");
|
||||
case { Count: > 1 }:
|
||||
throw new BadRequestException($"Turn off all of the policies that require the {validator.Type.GetName()} policy.");
|
||||
throw new BadRequestException($"Turn off all of the policies that require the {policyType.GetName()} policy.");
|
||||
}
|
||||
}
|
||||
|
||||
private static void ValidateEnablingRequirements(
|
||||
IEnforceDependentPoliciesEvent validator,
|
||||
private void ValidateEnablingRequirements(
|
||||
PolicyType policyType,
|
||||
Dictionary<PolicyType, Policy> savedPoliciesDict)
|
||||
{
|
||||
var missingRequiredPolicyTypes = validator.RequiredPolicies
|
||||
.Where(requiredPolicyType => savedPoliciesDict.GetValueOrDefault(requiredPolicyType) is not { Enabled: true })
|
||||
.ToList();
|
||||
var result = policyEventHandlerFactory.GetHandler<IEnforceDependentPoliciesEvent>(policyType);
|
||||
|
||||
if (missingRequiredPolicyTypes.Count != 0)
|
||||
{
|
||||
throw new BadRequestException($"Turn on the {missingRequiredPolicyTypes.First().GetName()} policy because it is required for the {validator.Type.GetName()} policy.");
|
||||
}
|
||||
result.Switch(
|
||||
validator =>
|
||||
{
|
||||
var missingRequiredPolicyTypes = validator.RequiredPolicies
|
||||
.Where(requiredPolicyType => savedPoliciesDict.GetValueOrDefault(requiredPolicyType) is not { Enabled: true })
|
||||
.ToList();
|
||||
|
||||
if (missingRequiredPolicyTypes.Count != 0)
|
||||
{
|
||||
throw new BadRequestException($"Turn on the {missingRequiredPolicyTypes.First().GetName()} policy because it is required for the {policyType.GetName()} policy.");
|
||||
}
|
||||
},
|
||||
_ => { /* Policy has no required dependencies */ });
|
||||
}
|
||||
|
||||
private async Task ExecutePreUpsertSideEffectAsync(
|
||||
|
||||
@@ -22,8 +22,10 @@ public static class PolicyServiceCollectionExtensions
|
||||
services.AddPolicyValidators();
|
||||
services.AddPolicyRequirements();
|
||||
services.AddPolicySideEffects();
|
||||
services.AddPolicyUpdateEvents();
|
||||
}
|
||||
|
||||
[Obsolete("Use AddPolicyUpdateEvents instead.")]
|
||||
private static void AddPolicyValidators(this IServiceCollection services)
|
||||
{
|
||||
services.AddScoped<IPolicyValidator, TwoFactorAuthenticationPolicyValidator>();
|
||||
@@ -34,11 +36,23 @@ public static class PolicyServiceCollectionExtensions
|
||||
services.AddScoped<IPolicyValidator, FreeFamiliesForEnterprisePolicyValidator>();
|
||||
}
|
||||
|
||||
[Obsolete("Use AddPolicyUpdateEvents instead.")]
|
||||
private static void AddPolicySideEffects(this IServiceCollection services)
|
||||
{
|
||||
services.AddScoped<IPostSavePolicySideEffect, OrganizationDataOwnershipPolicyValidator>();
|
||||
}
|
||||
|
||||
private static void AddPolicyUpdateEvents(this IServiceCollection services)
|
||||
{
|
||||
services.AddScoped<IPolicyUpdateEvent, RequireSsoPolicyValidator>();
|
||||
services.AddScoped<IPolicyUpdateEvent, TwoFactorAuthenticationPolicyValidator>();
|
||||
services.AddScoped<IPolicyUpdateEvent, SingleOrgPolicyValidator>();
|
||||
services.AddScoped<IPolicyUpdateEvent, ResetPasswordPolicyValidator>();
|
||||
services.AddScoped<IPolicyUpdateEvent, MaximumVaultTimeoutPolicyValidator>();
|
||||
services.AddScoped<IPolicyUpdateEvent, FreeFamiliesForEnterprisePolicyValidator>();
|
||||
services.AddScoped<IPolicyUpdateEvent, OrganizationDataOwnershipPolicyValidator>();
|
||||
}
|
||||
|
||||
private static void AddPolicyRequirements(this IServiceCollection services)
|
||||
{
|
||||
services.AddScoped<IPolicyRequirementFactory<IPolicyRequirement>, DisableSendPolicyRequirementFactory>();
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
using Bit.Core.AdminConsole.Entities;
|
||||
using Bit.Core.AdminConsole.Enums;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
|
||||
using Bit.Core.Repositories;
|
||||
using Bit.Core.Services;
|
||||
|
||||
@@ -12,11 +13,16 @@ public class FreeFamiliesForEnterprisePolicyValidator(
|
||||
IOrganizationSponsorshipRepository organizationSponsorshipRepository,
|
||||
IMailService mailService,
|
||||
IOrganizationRepository organizationRepository)
|
||||
: IPolicyValidator
|
||||
: IPolicyValidator, IOnPolicyPreUpdateEvent
|
||||
{
|
||||
public PolicyType Type => PolicyType.FreeFamiliesSponsorshipPolicy;
|
||||
public IEnumerable<PolicyType> RequiredPolicies => [];
|
||||
|
||||
public async Task ExecutePreUpsertSideEffectAsync(SavePolicyModel policyRequest, Policy? currentPolicy)
|
||||
{
|
||||
await OnSaveSideEffectsAsync(policyRequest.PolicyUpdate, currentPolicy);
|
||||
}
|
||||
|
||||
public async Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy)
|
||||
{
|
||||
if (currentPolicy is not { Enabled: true } && policyUpdate is { Enabled: true })
|
||||
|
||||
@@ -3,10 +3,11 @@
|
||||
using Bit.Core.AdminConsole.Entities;
|
||||
using Bit.Core.AdminConsole.Enums;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
|
||||
|
||||
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
|
||||
|
||||
public class MaximumVaultTimeoutPolicyValidator : IPolicyValidator
|
||||
public class MaximumVaultTimeoutPolicyValidator : IPolicyValidator, IEnforceDependentPoliciesEvent
|
||||
{
|
||||
public PolicyType Type => PolicyType.MaximumVaultTimeout;
|
||||
public IEnumerable<PolicyType> RequiredPolicies => [PolicyType.SingleOrg];
|
||||
|
||||
@@ -1,24 +1,32 @@
|
||||
|
||||
using Bit.Core.AdminConsole.Entities;
|
||||
using Bit.Core.AdminConsole.Enums;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
|
||||
using Bit.Core.AdminConsole.Repositories;
|
||||
using Bit.Core.Repositories;
|
||||
using Bit.Core.Services;
|
||||
|
||||
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
|
||||
|
||||
/// <summary>
|
||||
/// Please do not extend or expand this validator. We're currently in the process of refactoring our policy validator pattern.
|
||||
/// This is a stop-gap solution for post-policy-save side effects, but it is not the long-term solution.
|
||||
/// </summary>
|
||||
public class OrganizationDataOwnershipPolicyValidator(
|
||||
IPolicyRepository policyRepository,
|
||||
ICollectionRepository collectionRepository,
|
||||
IEnumerable<IPolicyRequirementFactory<IPolicyRequirement>> factories,
|
||||
IFeatureService featureService)
|
||||
: OrganizationPolicyValidator(policyRepository, factories), IPostSavePolicySideEffect
|
||||
: OrganizationPolicyValidator(policyRepository, factories), IPostSavePolicySideEffect, IOnPolicyPostUpdateEvent
|
||||
{
|
||||
public PolicyType Type => PolicyType.OrganizationDataOwnership;
|
||||
|
||||
public async Task ExecutePostUpsertSideEffectAsync(
|
||||
SavePolicyModel policyRequest,
|
||||
Policy postUpsertedPolicyState,
|
||||
Policy? previousPolicyState)
|
||||
{
|
||||
await ExecuteSideEffectsAsync(policyRequest, postUpsertedPolicyState, previousPolicyState);
|
||||
}
|
||||
|
||||
public async Task ExecuteSideEffectsAsync(
|
||||
SavePolicyModel policyRequest,
|
||||
Policy postUpdatedPolicy,
|
||||
@@ -68,5 +76,4 @@ public class OrganizationDataOwnershipPolicyValidator(
|
||||
userOrgIds,
|
||||
defaultCollectionName);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -3,12 +3,13 @@
|
||||
using Bit.Core.AdminConsole.Entities;
|
||||
using Bit.Core.AdminConsole.Enums;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
|
||||
using Bit.Core.Auth.Enums;
|
||||
using Bit.Core.Auth.Repositories;
|
||||
|
||||
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
|
||||
|
||||
public class RequireSsoPolicyValidator : IPolicyValidator
|
||||
public class RequireSsoPolicyValidator : IPolicyValidator, IPolicyValidationEvent, IEnforceDependentPoliciesEvent
|
||||
{
|
||||
private readonly ISsoConfigRepository _ssoConfigRepository;
|
||||
|
||||
@@ -20,6 +21,11 @@ public class RequireSsoPolicyValidator : IPolicyValidator
|
||||
public PolicyType Type => PolicyType.RequireSso;
|
||||
public IEnumerable<PolicyType> RequiredPolicies => [PolicyType.SingleOrg];
|
||||
|
||||
public async Task<string> ValidateAsync(SavePolicyModel policyRequest, Policy? currentPolicy)
|
||||
{
|
||||
return await ValidateAsync(policyRequest.PolicyUpdate, currentPolicy);
|
||||
}
|
||||
|
||||
public async Task<string> ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy)
|
||||
{
|
||||
if (policyUpdate is not { Enabled: true })
|
||||
|
||||
@@ -4,12 +4,13 @@ using Bit.Core.AdminConsole.Entities;
|
||||
using Bit.Core.AdminConsole.Enums;
|
||||
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
|
||||
using Bit.Core.Auth.Enums;
|
||||
using Bit.Core.Auth.Repositories;
|
||||
|
||||
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
|
||||
|
||||
public class ResetPasswordPolicyValidator : IPolicyValidator
|
||||
public class ResetPasswordPolicyValidator : IPolicyValidator, IPolicyValidationEvent, IEnforceDependentPoliciesEvent
|
||||
{
|
||||
private readonly ISsoConfigRepository _ssoConfigRepository;
|
||||
public PolicyType Type => PolicyType.ResetPassword;
|
||||
@@ -20,6 +21,11 @@ public class ResetPasswordPolicyValidator : IPolicyValidator
|
||||
_ssoConfigRepository = ssoConfigRepository;
|
||||
}
|
||||
|
||||
public async Task<string> ValidateAsync(SavePolicyModel policyRequest, Policy? currentPolicy)
|
||||
{
|
||||
return await ValidateAsync(policyRequest.PolicyUpdate, currentPolicy);
|
||||
}
|
||||
|
||||
public async Task<string> ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy)
|
||||
{
|
||||
if (policyUpdate is not { Enabled: true } ||
|
||||
|
||||
@@ -7,6 +7,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
|
||||
using Bit.Core.Auth.Enums;
|
||||
using Bit.Core.Auth.Repositories;
|
||||
using Bit.Core.Context;
|
||||
@@ -17,7 +18,7 @@ using Bit.Core.Services;
|
||||
|
||||
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
|
||||
|
||||
public class SingleOrgPolicyValidator : IPolicyValidator
|
||||
public class SingleOrgPolicyValidator : IPolicyValidator, IPolicyValidationEvent, IOnPolicyPreUpdateEvent
|
||||
{
|
||||
public PolicyType Type => PolicyType.SingleOrg;
|
||||
private const string OrganizationNotFoundErrorMessage = "Organization not found.";
|
||||
@@ -57,6 +58,16 @@ public class SingleOrgPolicyValidator : IPolicyValidator
|
||||
|
||||
public IEnumerable<PolicyType> RequiredPolicies => [];
|
||||
|
||||
public async Task<string> ValidateAsync(SavePolicyModel policyRequest, Policy? currentPolicy)
|
||||
{
|
||||
return await ValidateAsync(policyRequest.PolicyUpdate, currentPolicy);
|
||||
}
|
||||
|
||||
public async Task ExecutePreUpsertSideEffectAsync(SavePolicyModel policyRequest, Policy? currentPolicy)
|
||||
{
|
||||
await OnSaveSideEffectsAsync(policyRequest.PolicyUpdate, currentPolicy);
|
||||
}
|
||||
|
||||
public async Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy)
|
||||
{
|
||||
if (currentPolicy is not { Enabled: true } && policyUpdate is { Enabled: true })
|
||||
|
||||
@@ -6,6 +6,7 @@ using Bit.Core.AdminConsole.Models.Data;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
|
||||
using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces;
|
||||
using Bit.Core.Context;
|
||||
using Bit.Core.Enums;
|
||||
@@ -16,7 +17,7 @@ using Bit.Core.Services;
|
||||
|
||||
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
|
||||
|
||||
public class TwoFactorAuthenticationPolicyValidator : IPolicyValidator
|
||||
public class TwoFactorAuthenticationPolicyValidator : IPolicyValidator, IOnPolicyPreUpdateEvent
|
||||
{
|
||||
private readonly IOrganizationUserRepository _organizationUserRepository;
|
||||
private readonly IMailService _mailService;
|
||||
@@ -46,6 +47,11 @@ public class TwoFactorAuthenticationPolicyValidator : IPolicyValidator
|
||||
_revokeNonCompliantOrganizationUserCommand = revokeNonCompliantOrganizationUserCommand;
|
||||
}
|
||||
|
||||
public async Task ExecutePreUpsertSideEffectAsync(SavePolicyModel policyRequest, Policy? currentPolicy)
|
||||
{
|
||||
await OnSaveSideEffectsAsync(policyRequest.PolicyUpdate, currentPolicy);
|
||||
}
|
||||
|
||||
public async Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy)
|
||||
{
|
||||
if (currentPolicy is not { Enabled: true } && policyUpdate is { Enabled: true })
|
||||
|
||||
@@ -87,4 +87,13 @@ public interface IOrganizationUserRepository : IRepository<OrganizationUser, Gui
|
||||
Task<IEnumerable<OrganizationUserUserDetails>> GetManyDetailsByRoleAsync(Guid organizationId, OrganizationUserType role);
|
||||
|
||||
Task CreateManyAsync(IEnumerable<CreateOrganizationUser> organizationUserCollection);
|
||||
|
||||
/// <summary>
|
||||
/// It will only confirm if the user is in the `Accepted` state.
|
||||
///
|
||||
/// This is an idempotent operation.
|
||||
/// </summary>
|
||||
/// <param name="organizationUser">Accepted OrganizationUser to confirm</param>
|
||||
/// <returns>True, if the user was updated. False, if not performed.</returns>
|
||||
Task<bool> ConfirmOrganizationUserAsync(OrganizationUser organizationUser);
|
||||
}
|
||||
|
||||
@@ -20,17 +20,6 @@ public interface IPolicyRepository : IRepository<Policy, Guid>
|
||||
Task<Policy?> GetByOrganizationIdTypeAsync(Guid organizationId, PolicyType type);
|
||||
Task<ICollection<Policy>> GetManyByOrganizationIdAsync(Guid organizationId);
|
||||
Task<ICollection<Policy>> GetManyByUserIdAsync(Guid userId);
|
||||
/// <summary>
|
||||
/// Gets all PolicyDetails for a user for all policy types.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// Each PolicyDetail represents an OrganizationUser and a Policy which *may* be enforced
|
||||
/// against them. It only returns PolicyDetails for policies that are enabled and where the organization's plan
|
||||
/// supports policies. It also excludes "revoked invited" users who are not subject to policy enforcement.
|
||||
/// This is consumed by <see cref="IPolicyRequirementQuery"/> to create requirements for specific policy types.
|
||||
/// You probably do not want to call it directly.
|
||||
/// </remarks>
|
||||
Task<IEnumerable<PolicyDetails>> GetPolicyDetailsByUserId(Guid userId);
|
||||
|
||||
/// <summary>
|
||||
/// Retrieves <see cref="OrganizationPolicyDetails"/> of the specified <paramref name="policyType"/>
|
||||
|
||||
@@ -61,8 +61,9 @@ public static class OrganizationFactory
|
||||
claimsPrincipal.GetValue<bool>(OrganizationLicenseConstants.UseOrganizationDomains),
|
||||
UseAdminSponsoredFamilies =
|
||||
claimsPrincipal.GetValue<bool>(OrganizationLicenseConstants.UseAdminSponsoredFamilies),
|
||||
UseAutomaticUserConfirmation = claimsPrincipal.GetValue<bool>(OrganizationLicenseConstants.UseAutomaticUserConfirmation),
|
||||
UseDisableSMAdsForUsers =
|
||||
claimsPrincipal.GetValue<bool>(OrganizationLicenseConstants.UseDisableSMAdsForUsers),
|
||||
claimsPrincipal.GetValue<bool>(OrganizationLicenseConstants.UseDisableSMAdsForUsers),
|
||||
};
|
||||
|
||||
public static Organization Create(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
using System.Text.Json.Serialization;
|
||||
using Bit.Core.KeyManagement.Models.Response;
|
||||
using Bit.Core.KeyManagement.Models.Api.Response;
|
||||
using Bit.Core.Models.Api;
|
||||
|
||||
namespace Bit.Core.Auth.Models.Api.Response;
|
||||
|
||||
@@ -64,10 +64,12 @@ public static class InvoiceExtensions
|
||||
}
|
||||
}
|
||||
|
||||
var tax = invoice.TotalTaxes?.Sum(invoiceTotalTax => invoiceTotalTax.Amount) ?? 0;
|
||||
|
||||
// Add fallback tax from invoice-level tax if present and not already included
|
||||
if (invoice.Tax.HasValue && invoice.Tax.Value > 0)
|
||||
if (tax > 0)
|
||||
{
|
||||
var taxAmount = invoice.Tax.Value / 100m;
|
||||
var taxAmount = tax / 100m;
|
||||
items.Add($"1 × Tax (at ${taxAmount:F2} / month)");
|
||||
}
|
||||
|
||||
|
||||
25
src/Core/Billing/Extensions/SubscriptionExtensions.cs
Normal file
25
src/Core/Billing/Extensions/SubscriptionExtensions.cs
Normal file
@@ -0,0 +1,25 @@
|
||||
using Stripe;
|
||||
|
||||
namespace Bit.Core.Billing.Extensions;
|
||||
|
||||
public static class SubscriptionExtensions
|
||||
{
|
||||
/*
|
||||
* For the time being, this is the simplest migration approach from v45 to v48 as
|
||||
* we do not support multi-cadence subscriptions. Each subscription item should be on the
|
||||
* same billing cycle. If this changes, we'll need a significantly more robust approach.
|
||||
*
|
||||
* Because we can't guarantee a subscription will have items, this has to be nullable.
|
||||
*/
|
||||
public static (DateTime? Start, DateTime? End)? GetCurrentPeriod(this Subscription subscription)
|
||||
{
|
||||
var item = subscription.Items?.FirstOrDefault();
|
||||
return item is null ? null : (item.CurrentPeriodStart, item.CurrentPeriodEnd);
|
||||
}
|
||||
|
||||
public static DateTime? GetCurrentPeriodStart(this Subscription subscription) =>
|
||||
subscription.Items?.FirstOrDefault()?.CurrentPeriodStart;
|
||||
|
||||
public static DateTime? GetCurrentPeriodEnd(this Subscription subscription) =>
|
||||
subscription.Items?.FirstOrDefault()?.CurrentPeriodEnd;
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
using Stripe;
|
||||
|
||||
namespace Bit.Core.Billing.Extensions;
|
||||
|
||||
public static class UpcomingInvoiceOptionsExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Attempts to enable automatic tax for given upcoming invoice options.
|
||||
/// </summary>
|
||||
/// <param name="options"></param>
|
||||
/// <param name="customer">The existing customer to which the upcoming invoice belongs.</param>
|
||||
/// <param name="subscription">The existing subscription to which the upcoming invoice belongs.</param>
|
||||
/// <returns>Returns true when successful, false when conditions are not met.</returns>
|
||||
public static bool EnableAutomaticTax(
|
||||
this UpcomingInvoiceOptions options,
|
||||
Customer customer,
|
||||
Subscription subscription)
|
||||
{
|
||||
if (subscription != null && subscription.AutomaticTax.Enabled)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// We might only need to check the automatic tax status.
|
||||
if (!customer.HasRecognizedTaxLocation() && string.IsNullOrWhiteSpace(customer.Address?.Country))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true };
|
||||
options.SubscriptionDefaultTaxRates = [];
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@@ -43,6 +43,7 @@ public static class OrganizationLicenseConstants
|
||||
public const string Trial = nameof(Trial);
|
||||
public const string UseAdminSponsoredFamilies = nameof(UseAdminSponsoredFamilies);
|
||||
public const string UseOrganizationDomains = nameof(UseOrganizationDomains);
|
||||
public const string UseAutomaticUserConfirmation = nameof(UseAutomaticUserConfirmation);
|
||||
public const string UseDisableSMAdsForUsers = nameof(UseDisableSMAdsForUsers);
|
||||
}
|
||||
|
||||
|
||||
@@ -56,6 +56,7 @@ public class OrganizationLicenseClaimsFactory : ILicenseClaimsFactory<Organizati
|
||||
new(nameof(OrganizationLicenseConstants.Trial), trial.ToString()),
|
||||
new(nameof(OrganizationLicenseConstants.UseAdminSponsoredFamilies), entity.UseAdminSponsoredFamilies.ToString()),
|
||||
new(nameof(OrganizationLicenseConstants.UseOrganizationDomains), entity.UseOrganizationDomains.ToString()),
|
||||
new(nameof(OrganizationLicenseConstants.UseAutomaticUserConfirmation), entity.UseAutomaticUserConfirmation.ToString()),
|
||||
new(nameof(OrganizationLicenseConstants.UseDisableSMAdsForUsers), entity.UseDisableSMAdsForUsers.ToString()),
|
||||
};
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// FIXME: Update this file to be null safe and then delete the line below
|
||||
#nullable disable
|
||||
|
||||
using Bit.Core.Billing.Constants;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Enums;
|
||||
using Stripe;
|
||||
@@ -46,7 +47,7 @@ public class BillingHistoryInfo
|
||||
Url = inv.HostedInvoiceUrl;
|
||||
PdfUrl = inv.InvoicePdf;
|
||||
Number = inv.Number;
|
||||
Paid = inv.Paid;
|
||||
Paid = inv.Status == StripeConstants.InvoiceStatus.Paid;
|
||||
Amount = inv.Total / 100M;
|
||||
}
|
||||
|
||||
|
||||
@@ -43,6 +43,8 @@ public abstract record Plan
|
||||
public SecretsManagerPlanFeatures SecretsManager { get; protected init; }
|
||||
public bool SupportsSecretsManager => SecretsManager != null;
|
||||
|
||||
public bool AutomaticUserConfirmation { get; init; }
|
||||
|
||||
public bool HasNonSeatBasedPasswordManagerPlan() =>
|
||||
PasswordManager is { StripePlanId: not null and not "", StripeSeatPlanId: null or "" };
|
||||
|
||||
|
||||
@@ -75,7 +75,13 @@ public class PreviewOrganizationTaxCommand(
|
||||
Quantity = purchase.SecretsManager.Seats
|
||||
}
|
||||
]);
|
||||
options.Coupon = CouponIDs.SecretsManagerStandalone;
|
||||
options.Discounts =
|
||||
[
|
||||
new InvoiceDiscountOptions
|
||||
{
|
||||
Coupon = CouponIDs.SecretsManagerStandalone
|
||||
}
|
||||
];
|
||||
break;
|
||||
|
||||
default:
|
||||
@@ -180,7 +186,10 @@ public class PreviewOrganizationTaxCommand(
|
||||
|
||||
if (subscription.Customer.Discount != null)
|
||||
{
|
||||
options.Coupon = subscription.Customer.Discount.Coupon.Id;
|
||||
options.Discounts =
|
||||
[
|
||||
new InvoiceDiscountOptions { Coupon = subscription.Customer.Discount.Coupon.Id }
|
||||
];
|
||||
}
|
||||
|
||||
var currentPlan = await pricingClient.GetPlanOrThrow(organization.PlanType);
|
||||
@@ -277,7 +286,10 @@ public class PreviewOrganizationTaxCommand(
|
||||
|
||||
if (subscription.Customer.Discount != null)
|
||||
{
|
||||
options.Coupon = subscription.Customer.Discount.Coupon.Id;
|
||||
options.Discounts =
|
||||
[
|
||||
new InvoiceDiscountOptions { Coupon = subscription.Customer.Discount.Coupon.Id }
|
||||
];
|
||||
}
|
||||
|
||||
var currentPlan = await pricingClient.GetPlanOrThrow(organization.PlanType);
|
||||
@@ -329,7 +341,7 @@ public class PreviewOrganizationTaxCommand(
|
||||
});
|
||||
|
||||
private static (decimal, decimal) GetAmounts(Invoice invoice) => (
|
||||
Convert.ToDecimal(invoice.Tax) / 100,
|
||||
Convert.ToDecimal(invoice.TotalTaxes.Sum(invoiceTotalTax => invoiceTotalTax.Amount)) / 100,
|
||||
Convert.ToDecimal(invoice.Total) / 100);
|
||||
|
||||
private static InvoiceCreatePreviewOptions GetBaseOptions(
|
||||
|
||||
@@ -153,6 +153,7 @@ public class OrganizationLicense : ILicense
|
||||
public LicenseType? LicenseType { get; set; }
|
||||
public bool UseOrganizationDomains { get; set; }
|
||||
public bool UseAdminSponsoredFamilies { get; set; }
|
||||
public bool UseAutomaticUserConfirmation { get; set; }
|
||||
public bool UseDisableSMAdsForUsers { get; set; }
|
||||
public string Hash { get; set; }
|
||||
public string Signature { get; set; }
|
||||
@@ -228,6 +229,7 @@ public class OrganizationLicense : ILicense
|
||||
!p.Name.Equals(nameof(UseRiskInsights)) &&
|
||||
!p.Name.Equals(nameof(UseAdminSponsoredFamilies)) &&
|
||||
!p.Name.Equals(nameof(UseOrganizationDomains)) &&
|
||||
!p.Name.Equals(nameof(UseAutomaticUserConfirmation))) &&
|
||||
!p.Name.Equals(nameof(UseDisableSMAdsForUsers)))
|
||||
.OrderBy(p => p.Name)
|
||||
.Select(p => $"{p.Name}:{Core.Utilities.CoreHelpers.FormatLicenseSignatureValue(p.GetValue(this, null))}")
|
||||
@@ -423,6 +425,7 @@ public class OrganizationLicense : ILicense
|
||||
var smServiceAccounts = claimsPrincipal.GetValue<int?>(nameof(SmServiceAccounts));
|
||||
var useAdminSponsoredFamilies = claimsPrincipal.GetValue<bool>(nameof(UseAdminSponsoredFamilies));
|
||||
var useOrganizationDomains = claimsPrincipal.GetValue<bool>(nameof(UseOrganizationDomains));
|
||||
var useAutomaticUserConfirmation = claimsPrincipal.GetValue<bool>(nameof(UseAutomaticUserConfirmation));
|
||||
var UseDisableSMAdsForUsers = claimsPrincipal.GetValue<bool>(nameof(UseDisableSMAdsForUsers));
|
||||
|
||||
return issued <= DateTime.UtcNow &&
|
||||
@@ -454,6 +457,7 @@ public class OrganizationLicense : ILicense
|
||||
smServiceAccounts == organization.SmServiceAccounts &&
|
||||
useAdminSponsoredFamilies == organization.UseAdminSponsoredFamilies &&
|
||||
useOrganizationDomains == organization.UseOrganizationDomains &&
|
||||
useAutomaticUserConfirmation == organization.UseAutomaticUserConfirmation;
|
||||
UseDisableSMAdsForUsers == organization.UseDisableSMAdsForUsers;
|
||||
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ namespace Bit.Core.Billing.Organizations.Models;
|
||||
|
||||
public class OrganizationSale
|
||||
{
|
||||
private OrganizationSale() { }
|
||||
internal OrganizationSale() { }
|
||||
|
||||
public void Deconstruct(
|
||||
out Organization organization,
|
||||
|
||||
@@ -162,17 +162,23 @@ public class GetOrganizationWarningsQuery(
|
||||
if (subscription is
|
||||
{
|
||||
Status: SubscriptionStatus.Trialing or SubscriptionStatus.Active,
|
||||
LatestInvoice: null or { Status: InvoiceStatus.Paid }
|
||||
} && (subscription.CurrentPeriodEnd - now).TotalDays <= 14)
|
||||
LatestInvoice: null or { Status: InvoiceStatus.Paid },
|
||||
Items.Data.Count: > 0
|
||||
})
|
||||
{
|
||||
return new ResellerRenewalWarning
|
||||
var currentPeriodEnd = subscription.GetCurrentPeriodEnd();
|
||||
|
||||
if (currentPeriodEnd != null && (currentPeriodEnd.Value - now).TotalDays <= 14)
|
||||
{
|
||||
Type = "upcoming",
|
||||
Upcoming = new ResellerRenewalWarning.UpcomingRenewal
|
||||
return new ResellerRenewalWarning
|
||||
{
|
||||
RenewalDate = subscription.CurrentPeriodEnd
|
||||
}
|
||||
};
|
||||
Type = "upcoming",
|
||||
Upcoming = new ResellerRenewalWarning.UpcomingRenewal
|
||||
{
|
||||
RenewalDate = currentPeriodEnd.Value
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (subscription is
|
||||
|
||||
@@ -45,12 +45,12 @@ public class OrganizationBillingService(
|
||||
? await CreateCustomerAsync(organization, customerSetup, subscriptionSetup.PlanType)
|
||||
: await GetCustomerWhileEnsuringCorrectTaxExemptionAsync(organization, subscriptionSetup);
|
||||
|
||||
var subscription = await CreateSubscriptionAsync(organization, customer, subscriptionSetup);
|
||||
var subscription = await CreateSubscriptionAsync(organization, customer, subscriptionSetup, customerSetup?.Coupon);
|
||||
|
||||
if (subscription.Status is StripeConstants.SubscriptionStatus.Trialing or StripeConstants.SubscriptionStatus.Active)
|
||||
{
|
||||
organization.Enabled = true;
|
||||
organization.ExpirationDate = subscription.CurrentPeriodEnd;
|
||||
organization.ExpirationDate = subscription.GetCurrentPeriodEnd();
|
||||
await organizationRepository.ReplaceAsync(organization);
|
||||
}
|
||||
}
|
||||
@@ -187,7 +187,6 @@ public class OrganizationBillingService(
|
||||
|
||||
var customerCreateOptions = new CustomerCreateOptions
|
||||
{
|
||||
Coupon = customerSetup.Coupon,
|
||||
Description = organization.DisplayBusinessName(),
|
||||
Email = organization.BillingEmail,
|
||||
Expand = ["tax", "tax_ids"],
|
||||
@@ -273,7 +272,7 @@ public class OrganizationBillingService(
|
||||
|
||||
customerCreateOptions.TaxIdData =
|
||||
[
|
||||
new() { Type = taxIdType, Value = customerSetup.TaxInformation.TaxId }
|
||||
new CustomerTaxIdDataOptions { Type = taxIdType, Value = customerSetup.TaxInformation.TaxId }
|
||||
];
|
||||
|
||||
if (taxIdType == StripeConstants.TaxIdType.SpanishNIF)
|
||||
@@ -381,7 +380,8 @@ public class OrganizationBillingService(
|
||||
private async Task<Subscription> CreateSubscriptionAsync(
|
||||
Organization organization,
|
||||
Customer customer,
|
||||
SubscriptionSetup subscriptionSetup)
|
||||
SubscriptionSetup subscriptionSetup,
|
||||
string? coupon)
|
||||
{
|
||||
var plan = await pricingClient.GetPlanOrThrow(subscriptionSetup.PlanType);
|
||||
|
||||
@@ -444,6 +444,7 @@ public class OrganizationBillingService(
|
||||
{
|
||||
CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically,
|
||||
Customer = customer.Id,
|
||||
Discounts = !string.IsNullOrEmpty(coupon) ? [new SubscriptionDiscountOptions { Coupon = coupon }] : null,
|
||||
Items = subscriptionItemOptionsList,
|
||||
Metadata = new Dictionary<string, string>
|
||||
{
|
||||
@@ -459,8 +460,9 @@ public class OrganizationBillingService(
|
||||
|
||||
var hasPaymentMethod = await hasPaymentMethodQuery.Run(organization);
|
||||
|
||||
// Only set trial_settings.end_behavior.missing_payment_method to "cancel" if there is no payment method
|
||||
if (!hasPaymentMethod)
|
||||
// Only set trial_settings.end_behavior.missing_payment_method to "cancel"
|
||||
// if there is no payment method AND there's an actual trial period
|
||||
if (!hasPaymentMethod && subscriptionCreateOptions.TrialPeriodDays > 0)
|
||||
{
|
||||
subscriptionCreateOptions.TrialSettings = new SubscriptionTrialSettingsOptions
|
||||
{
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
using Bit.Core.Billing.Caches;
|
||||
using Bit.Core.Billing.Commands;
|
||||
using Bit.Core.Billing.Constants;
|
||||
using Bit.Core.Billing.Extensions;
|
||||
using Bit.Core.Billing.Payment.Models;
|
||||
using Bit.Core.Billing.Services;
|
||||
using Bit.Core.Entities;
|
||||
@@ -87,7 +88,7 @@ public class CreatePremiumCloudHostedSubscriptionCommand(
|
||||
when subscription.Status == StripeConstants.SubscriptionStatus.Active:
|
||||
{
|
||||
user.Premium = true;
|
||||
user.PremiumExpirationDate = subscription.CurrentPeriodEnd;
|
||||
user.PremiumExpirationDate = subscription.GetCurrentPeriodEnd();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,6 +60,6 @@ public class PreviewPremiumTaxCommand(
|
||||
});
|
||||
|
||||
private static (decimal, decimal) GetAmounts(Invoice invoice) => (
|
||||
Convert.ToDecimal(invoice.Tax) / 100,
|
||||
Convert.ToDecimal(invoice.TotalTaxes.Sum(invoiceTotalTax => invoiceTotalTax.Amount)) / 100,
|
||||
Convert.ToDecimal(invoice.Total) / 100);
|
||||
}
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
// FIXME: Update this file to be null safe and then delete the line below
|
||||
#nullable disable
|
||||
|
||||
namespace Bit.Core.Billing.Providers.Migration.Models;
|
||||
|
||||
public enum ClientMigrationProgress
|
||||
{
|
||||
Started = 1,
|
||||
MigrationRecordCreated = 2,
|
||||
SubscriptionEnded = 3,
|
||||
Completed = 4,
|
||||
|
||||
Reversing = 5,
|
||||
ResetOrganization = 6,
|
||||
RecreatedSubscription = 7,
|
||||
RemovedMigrationRecord = 8,
|
||||
Reversed = 9
|
||||
}
|
||||
|
||||
public class ClientMigrationTracker
|
||||
{
|
||||
public Guid ProviderId { get; set; }
|
||||
public Guid OrganizationId { get; set; }
|
||||
public string OrganizationName { get; set; }
|
||||
public ClientMigrationProgress Progress { get; set; } = ClientMigrationProgress.Started;
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
// FIXME: Update this file to be null safe and then delete the line below
|
||||
#nullable disable
|
||||
|
||||
using Bit.Core.Billing.Providers.Entities;
|
||||
|
||||
namespace Bit.Core.Billing.Providers.Migration.Models;
|
||||
|
||||
public class ProviderMigrationResult
|
||||
{
|
||||
public Guid ProviderId { get; set; }
|
||||
public string ProviderName { get; set; }
|
||||
public string Result { get; set; }
|
||||
public List<ClientMigrationResult> Clients { get; set; }
|
||||
}
|
||||
|
||||
public class ClientMigrationResult
|
||||
{
|
||||
public Guid OrganizationId { get; set; }
|
||||
public string OrganizationName { get; set; }
|
||||
public string Result { get; set; }
|
||||
public ClientPreviousState PreviousState { get; set; }
|
||||
}
|
||||
|
||||
public class ClientPreviousState
|
||||
{
|
||||
public ClientPreviousState() { }
|
||||
|
||||
public ClientPreviousState(ClientOrganizationMigrationRecord migrationRecord)
|
||||
{
|
||||
PlanType = migrationRecord.PlanType.ToString();
|
||||
Seats = migrationRecord.Seats;
|
||||
MaxStorageGb = migrationRecord.MaxStorageGb;
|
||||
GatewayCustomerId = migrationRecord.GatewayCustomerId;
|
||||
GatewaySubscriptionId = migrationRecord.GatewaySubscriptionId;
|
||||
ExpirationDate = migrationRecord.ExpirationDate;
|
||||
MaxAutoscaleSeats = migrationRecord.MaxAutoscaleSeats;
|
||||
Status = migrationRecord.Status.ToString();
|
||||
}
|
||||
|
||||
public string PlanType { get; set; }
|
||||
public int Seats { get; set; }
|
||||
public short? MaxStorageGb { get; set; }
|
||||
public string GatewayCustomerId { get; set; } = null!;
|
||||
public string GatewaySubscriptionId { get; set; } = null!;
|
||||
public DateTime? ExpirationDate { get; set; }
|
||||
public int? MaxAutoscaleSeats { get; set; }
|
||||
public string Status { get; set; }
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
// FIXME: Update this file to be null safe and then delete the line below
|
||||
#nullable disable
|
||||
|
||||
namespace Bit.Core.Billing.Providers.Migration.Models;
|
||||
|
||||
public enum ProviderMigrationProgress
|
||||
{
|
||||
Started = 1,
|
||||
NoClients = 2,
|
||||
ClientsMigrated = 3,
|
||||
TeamsPlanConfigured = 4,
|
||||
EnterprisePlanConfigured = 5,
|
||||
CustomerSetup = 6,
|
||||
SubscriptionSetup = 7,
|
||||
CreditApplied = 8,
|
||||
Completed = 9,
|
||||
}
|
||||
|
||||
public class ProviderMigrationTracker
|
||||
{
|
||||
public Guid ProviderId { get; set; }
|
||||
public string ProviderName { get; set; }
|
||||
public List<Guid> OrganizationIds { get; set; }
|
||||
public ProviderMigrationProgress Progress { get; set; } = ProviderMigrationProgress.Started;
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
using Bit.Core.Billing.Providers.Migration.Services;
|
||||
using Bit.Core.Billing.Providers.Migration.Services.Implementations;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
|
||||
namespace Bit.Core.Billing.Providers.Migration;
|
||||
|
||||
public static class ServiceCollectionExtensions
|
||||
{
|
||||
public static void AddProviderMigration(this IServiceCollection services)
|
||||
{
|
||||
services.AddTransient<IMigrationTrackerCache, MigrationTrackerDistributedCache>();
|
||||
services.AddTransient<IOrganizationMigrator, OrganizationMigrator>();
|
||||
services.AddTransient<IProviderMigrator, ProviderMigrator>();
|
||||
}
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
using Bit.Core.AdminConsole.Entities;
|
||||
using Bit.Core.AdminConsole.Entities.Provider;
|
||||
using Bit.Core.Billing.Providers.Migration.Models;
|
||||
|
||||
namespace Bit.Core.Billing.Providers.Migration.Services;
|
||||
|
||||
public interface IMigrationTrackerCache
|
||||
{
|
||||
Task StartTracker(Provider provider);
|
||||
Task SetOrganizationIds(Guid providerId, IEnumerable<Guid> organizationIds);
|
||||
Task<ProviderMigrationTracker> GetTracker(Guid providerId);
|
||||
Task UpdateTrackingStatus(Guid providerId, ProviderMigrationProgress status);
|
||||
|
||||
Task StartTracker(Guid providerId, Organization organization);
|
||||
Task<ClientMigrationTracker> GetTracker(Guid providerId, Guid organizationId);
|
||||
Task UpdateTrackingStatus(Guid providerId, Guid organizationId, ClientMigrationProgress status);
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
using Bit.Core.AdminConsole.Entities;
|
||||
|
||||
namespace Bit.Core.Billing.Providers.Migration.Services;
|
||||
|
||||
public interface IOrganizationMigrator
|
||||
{
|
||||
Task Migrate(Guid providerId, Organization organization);
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
using Bit.Core.Billing.Providers.Migration.Models;
|
||||
|
||||
namespace Bit.Core.Billing.Providers.Migration.Services;
|
||||
|
||||
public interface IProviderMigrator
|
||||
{
|
||||
Task Migrate(Guid providerId);
|
||||
|
||||
Task<ProviderMigrationResult> GetResult(Guid providerId);
|
||||
}
|
||||
@@ -1,110 +0,0 @@
|
||||
// FIXME: Update this file to be null safe and then delete the line below
|
||||
#nullable disable
|
||||
|
||||
using System.Text.Json;
|
||||
using Bit.Core.AdminConsole.Entities;
|
||||
using Bit.Core.AdminConsole.Entities.Provider;
|
||||
using Bit.Core.Billing.Providers.Migration.Models;
|
||||
using Microsoft.Extensions.Caching.Distributed;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
|
||||
namespace Bit.Core.Billing.Providers.Migration.Services.Implementations;
|
||||
|
||||
public class MigrationTrackerDistributedCache(
|
||||
[FromKeyedServices("persistent")]
|
||||
IDistributedCache distributedCache) : IMigrationTrackerCache
|
||||
{
|
||||
public async Task StartTracker(Provider provider) =>
|
||||
await SetAsync(new ProviderMigrationTracker
|
||||
{
|
||||
ProviderId = provider.Id,
|
||||
ProviderName = provider.Name
|
||||
});
|
||||
|
||||
public async Task SetOrganizationIds(Guid providerId, IEnumerable<Guid> organizationIds)
|
||||
{
|
||||
var tracker = await GetAsync(providerId);
|
||||
|
||||
tracker.OrganizationIds = organizationIds.ToList();
|
||||
|
||||
await SetAsync(tracker);
|
||||
}
|
||||
|
||||
public Task<ProviderMigrationTracker> GetTracker(Guid providerId) => GetAsync(providerId);
|
||||
|
||||
public async Task UpdateTrackingStatus(Guid providerId, ProviderMigrationProgress status)
|
||||
{
|
||||
var tracker = await GetAsync(providerId);
|
||||
|
||||
tracker.Progress = status;
|
||||
|
||||
await SetAsync(tracker);
|
||||
}
|
||||
|
||||
public async Task StartTracker(Guid providerId, Organization organization) =>
|
||||
await SetAsync(new ClientMigrationTracker
|
||||
{
|
||||
ProviderId = providerId,
|
||||
OrganizationId = organization.Id,
|
||||
OrganizationName = organization.Name
|
||||
});
|
||||
|
||||
public Task<ClientMigrationTracker> GetTracker(Guid providerId, Guid organizationId) =>
|
||||
GetAsync(providerId, organizationId);
|
||||
|
||||
public async Task UpdateTrackingStatus(Guid providerId, Guid organizationId, ClientMigrationProgress status)
|
||||
{
|
||||
var tracker = await GetAsync(providerId, organizationId);
|
||||
|
||||
tracker.Progress = status;
|
||||
|
||||
await SetAsync(tracker);
|
||||
}
|
||||
|
||||
private static string GetProviderCacheKey(Guid providerId) => $"provider_{providerId}_migration";
|
||||
|
||||
private static string GetClientCacheKey(Guid providerId, Guid clientId) =>
|
||||
$"provider_{providerId}_client_{clientId}_migration";
|
||||
|
||||
private async Task<ProviderMigrationTracker> GetAsync(Guid providerId)
|
||||
{
|
||||
var cacheKey = GetProviderCacheKey(providerId);
|
||||
|
||||
var json = await distributedCache.GetStringAsync(cacheKey);
|
||||
|
||||
return string.IsNullOrEmpty(json) ? null : JsonSerializer.Deserialize<ProviderMigrationTracker>(json);
|
||||
}
|
||||
|
||||
private async Task<ClientMigrationTracker> GetAsync(Guid providerId, Guid organizationId)
|
||||
{
|
||||
var cacheKey = GetClientCacheKey(providerId, organizationId);
|
||||
|
||||
var json = await distributedCache.GetStringAsync(cacheKey);
|
||||
|
||||
return string.IsNullOrEmpty(json) ? null : JsonSerializer.Deserialize<ClientMigrationTracker>(json);
|
||||
}
|
||||
|
||||
private async Task SetAsync(ProviderMigrationTracker tracker)
|
||||
{
|
||||
var cacheKey = GetProviderCacheKey(tracker.ProviderId);
|
||||
|
||||
var json = JsonSerializer.Serialize(tracker);
|
||||
|
||||
await distributedCache.SetStringAsync(cacheKey, json, new DistributedCacheEntryOptions
|
||||
{
|
||||
SlidingExpiration = TimeSpan.FromMinutes(30)
|
||||
});
|
||||
}
|
||||
|
||||
private async Task SetAsync(ClientMigrationTracker tracker)
|
||||
{
|
||||
var cacheKey = GetClientCacheKey(tracker.ProviderId, tracker.OrganizationId);
|
||||
|
||||
var json = JsonSerializer.Serialize(tracker);
|
||||
|
||||
await distributedCache.SetStringAsync(cacheKey, json, new DistributedCacheEntryOptions
|
||||
{
|
||||
SlidingExpiration = TimeSpan.FromMinutes(30)
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,331 +0,0 @@
|
||||
// FIXME: Update this file to be null safe and then delete the line below
|
||||
#nullable disable
|
||||
|
||||
using Bit.Core.AdminConsole.Entities;
|
||||
using Bit.Core.Billing.Constants;
|
||||
using Bit.Core.Billing.Enums;
|
||||
using Bit.Core.Billing.Pricing;
|
||||
using Bit.Core.Billing.Providers.Entities;
|
||||
using Bit.Core.Billing.Providers.Migration.Models;
|
||||
using Bit.Core.Billing.Providers.Repositories;
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Repositories;
|
||||
using Bit.Core.Services;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Stripe;
|
||||
using Plan = Bit.Core.Models.StaticStore.Plan;
|
||||
|
||||
namespace Bit.Core.Billing.Providers.Migration.Services.Implementations;
|
||||
|
||||
public class OrganizationMigrator(
|
||||
IClientOrganizationMigrationRecordRepository clientOrganizationMigrationRecordRepository,
|
||||
ILogger<OrganizationMigrator> logger,
|
||||
IMigrationTrackerCache migrationTrackerCache,
|
||||
IOrganizationRepository organizationRepository,
|
||||
IPricingClient pricingClient,
|
||||
IStripeAdapter stripeAdapter) : IOrganizationMigrator
|
||||
{
|
||||
private const string _cancellationComment = "Cancelled as part of provider migration to Consolidated Billing";
|
||||
|
||||
public async Task Migrate(Guid providerId, Organization organization)
|
||||
{
|
||||
logger.LogInformation("CB: Starting migration for organization ({OrganizationID})", organization.Id);
|
||||
|
||||
await migrationTrackerCache.StartTracker(providerId, organization);
|
||||
|
||||
await CreateMigrationRecordAsync(providerId, organization);
|
||||
|
||||
await CancelSubscriptionAsync(providerId, organization);
|
||||
|
||||
await UpdateOrganizationAsync(providerId, organization);
|
||||
}
|
||||
|
||||
#region Steps
|
||||
|
||||
private async Task CreateMigrationRecordAsync(Guid providerId, Organization organization)
|
||||
{
|
||||
logger.LogInformation("CB: Creating ClientOrganizationMigrationRecord for organization ({OrganizationID})", organization.Id);
|
||||
|
||||
var migrationRecord = await clientOrganizationMigrationRecordRepository.GetByOrganizationId(organization.Id);
|
||||
|
||||
if (migrationRecord != null)
|
||||
{
|
||||
logger.LogInformation(
|
||||
"CB: ClientOrganizationMigrationRecord already exists for organization ({OrganizationID}), deleting record",
|
||||
organization.Id);
|
||||
|
||||
await clientOrganizationMigrationRecordRepository.DeleteAsync(migrationRecord);
|
||||
}
|
||||
|
||||
await clientOrganizationMigrationRecordRepository.CreateAsync(new ClientOrganizationMigrationRecord
|
||||
{
|
||||
OrganizationId = organization.Id,
|
||||
ProviderId = providerId,
|
||||
PlanType = organization.PlanType,
|
||||
Seats = organization.Seats ?? 0,
|
||||
MaxStorageGb = organization.MaxStorageGb,
|
||||
GatewayCustomerId = organization.GatewayCustomerId!,
|
||||
GatewaySubscriptionId = organization.GatewaySubscriptionId!,
|
||||
ExpirationDate = organization.ExpirationDate,
|
||||
MaxAutoscaleSeats = organization.MaxAutoscaleSeats,
|
||||
Status = organization.Status
|
||||
});
|
||||
|
||||
logger.LogInformation("CB: Created migration record for organization ({OrganizationID})", organization.Id);
|
||||
|
||||
await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id,
|
||||
ClientMigrationProgress.MigrationRecordCreated);
|
||||
}
|
||||
|
||||
private async Task CancelSubscriptionAsync(Guid providerId, Organization organization)
|
||||
{
|
||||
logger.LogInformation("CB: Cancelling subscription for organization ({OrganizationID})", organization.Id);
|
||||
|
||||
var subscription = await stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId);
|
||||
|
||||
if (subscription is
|
||||
{
|
||||
Status:
|
||||
StripeConstants.SubscriptionStatus.Active or
|
||||
StripeConstants.SubscriptionStatus.PastDue or
|
||||
StripeConstants.SubscriptionStatus.Trialing
|
||||
})
|
||||
{
|
||||
await stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId,
|
||||
new SubscriptionUpdateOptions { CancelAtPeriodEnd = false });
|
||||
|
||||
subscription = await stripeAdapter.SubscriptionCancelAsync(organization.GatewaySubscriptionId,
|
||||
new SubscriptionCancelOptions
|
||||
{
|
||||
CancellationDetails = new SubscriptionCancellationDetailsOptions
|
||||
{
|
||||
Comment = _cancellationComment
|
||||
},
|
||||
InvoiceNow = true,
|
||||
Prorate = true,
|
||||
Expand = ["latest_invoice", "test_clock"]
|
||||
});
|
||||
|
||||
logger.LogInformation("CB: Cancelled subscription for organization ({OrganizationID})", organization.Id);
|
||||
|
||||
var now = subscription.TestClock?.FrozenTime ?? DateTime.UtcNow;
|
||||
|
||||
var trialing = subscription.TrialEnd.HasValue && subscription.TrialEnd.Value > now;
|
||||
|
||||
if (!trialing && subscription is { Status: StripeConstants.SubscriptionStatus.Canceled, CancellationDetails.Comment: _cancellationComment })
|
||||
{
|
||||
var latestInvoice = subscription.LatestInvoice;
|
||||
|
||||
if (latestInvoice.Status == "draft")
|
||||
{
|
||||
await stripeAdapter.InvoiceFinalizeInvoiceAsync(latestInvoice.Id,
|
||||
new InvoiceFinalizeOptions { AutoAdvance = true });
|
||||
|
||||
logger.LogInformation("CB: Finalized prorated invoice for organization ({OrganizationID})", organization.Id);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
logger.LogInformation(
|
||||
"CB: Did not need to cancel subscription for organization ({OrganizationID}) as it was inactive",
|
||||
organization.Id);
|
||||
}
|
||||
|
||||
await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id,
|
||||
ClientMigrationProgress.SubscriptionEnded);
|
||||
}
|
||||
|
||||
private async Task UpdateOrganizationAsync(Guid providerId, Organization organization)
|
||||
{
|
||||
logger.LogInformation("CB: Bringing organization ({OrganizationID}) under provider management",
|
||||
organization.Id);
|
||||
|
||||
var plan = await pricingClient.GetPlanOrThrow(organization.Plan.Contains("Teams") ? PlanType.TeamsMonthly : PlanType.EnterpriseMonthly);
|
||||
|
||||
ResetOrganizationPlan(organization, plan);
|
||||
organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb;
|
||||
organization.GatewaySubscriptionId = null;
|
||||
organization.ExpirationDate = null;
|
||||
organization.MaxAutoscaleSeats = null;
|
||||
organization.Status = OrganizationStatusType.Managed;
|
||||
|
||||
await organizationRepository.ReplaceAsync(organization);
|
||||
|
||||
logger.LogInformation("CB: Brought organization ({OrganizationID}) under provider management",
|
||||
organization.Id);
|
||||
|
||||
await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id,
|
||||
ClientMigrationProgress.Completed);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Reverse
|
||||
|
||||
private async Task RemoveMigrationRecordAsync(Guid providerId, Organization organization)
|
||||
{
|
||||
logger.LogInformation("CB: Removing migration record for organization ({OrganizationID})", organization.Id);
|
||||
|
||||
var migrationRecord = await clientOrganizationMigrationRecordRepository.GetByOrganizationId(organization.Id);
|
||||
|
||||
if (migrationRecord != null)
|
||||
{
|
||||
await clientOrganizationMigrationRecordRepository.DeleteAsync(migrationRecord);
|
||||
|
||||
logger.LogInformation(
|
||||
"CB: Removed migration record for organization ({OrganizationID})",
|
||||
organization.Id);
|
||||
}
|
||||
else
|
||||
{
|
||||
logger.LogInformation("CB: Did not remove migration record for organization ({OrganizationID}) as it does not exist", organization.Id);
|
||||
}
|
||||
|
||||
await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id, ClientMigrationProgress.Reversed);
|
||||
}
|
||||
|
||||
private async Task RecreateSubscriptionAsync(Guid providerId, Organization organization)
|
||||
{
|
||||
logger.LogInformation("CB: Recreating subscription for organization ({OrganizationID})", organization.Id);
|
||||
|
||||
if (!string.IsNullOrEmpty(organization.GatewaySubscriptionId))
|
||||
{
|
||||
if (string.IsNullOrEmpty(organization.GatewayCustomerId))
|
||||
{
|
||||
logger.LogError(
|
||||
"CB: Cannot recreate subscription for organization ({OrganizationID}) as it does not have a Stripe customer",
|
||||
organization.Id);
|
||||
|
||||
throw new Exception();
|
||||
}
|
||||
|
||||
var customer = await stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId,
|
||||
new CustomerGetOptions { Expand = ["default_source", "invoice_settings.default_payment_method"] });
|
||||
|
||||
var collectionMethod =
|
||||
customer.DefaultSource != null ||
|
||||
customer.InvoiceSettings?.DefaultPaymentMethod != null ||
|
||||
customer.Metadata.ContainsKey(Utilities.BraintreeCustomerIdKey)
|
||||
? StripeConstants.CollectionMethod.ChargeAutomatically
|
||||
: StripeConstants.CollectionMethod.SendInvoice;
|
||||
|
||||
var plan = await pricingClient.GetPlanOrThrow(organization.PlanType);
|
||||
|
||||
var items = new List<SubscriptionItemOptions>
|
||||
{
|
||||
new ()
|
||||
{
|
||||
Price = plan.PasswordManager.StripeSeatPlanId,
|
||||
Quantity = organization.Seats
|
||||
}
|
||||
};
|
||||
|
||||
if (organization.MaxStorageGb.HasValue && plan.PasswordManager.BaseStorageGb.HasValue && organization.MaxStorageGb.Value > plan.PasswordManager.BaseStorageGb.Value)
|
||||
{
|
||||
var additionalStorage = organization.MaxStorageGb.Value - plan.PasswordManager.BaseStorageGb.Value;
|
||||
|
||||
items.Add(new SubscriptionItemOptions
|
||||
{
|
||||
Price = plan.PasswordManager.StripeStoragePlanId,
|
||||
Quantity = additionalStorage
|
||||
});
|
||||
}
|
||||
|
||||
var subscriptionCreateOptions = new SubscriptionCreateOptions
|
||||
{
|
||||
AutomaticTax = new SubscriptionAutomaticTaxOptions
|
||||
{
|
||||
Enabled = true
|
||||
},
|
||||
Customer = customer.Id,
|
||||
CollectionMethod = collectionMethod,
|
||||
DaysUntilDue = collectionMethod == StripeConstants.CollectionMethod.SendInvoice ? 30 : null,
|
||||
Items = items,
|
||||
Metadata = new Dictionary<string, string>
|
||||
{
|
||||
[organization.GatewayIdField()] = organization.Id.ToString()
|
||||
},
|
||||
OffSession = true,
|
||||
ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations,
|
||||
TrialPeriodDays = plan.TrialPeriodDays
|
||||
};
|
||||
|
||||
var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
|
||||
|
||||
organization.GatewaySubscriptionId = subscription.Id;
|
||||
|
||||
await organizationRepository.ReplaceAsync(organization);
|
||||
|
||||
logger.LogInformation("CB: Recreated subscription for organization ({OrganizationID})", organization.Id);
|
||||
}
|
||||
else
|
||||
{
|
||||
logger.LogInformation(
|
||||
"CB: Did not recreate subscription for organization ({OrganizationID}) as it already exists",
|
||||
organization.Id);
|
||||
}
|
||||
|
||||
await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id,
|
||||
ClientMigrationProgress.RecreatedSubscription);
|
||||
}
|
||||
|
||||
private async Task ReverseOrganizationUpdateAsync(Guid providerId, Organization organization)
|
||||
{
|
||||
var migrationRecord = await clientOrganizationMigrationRecordRepository.GetByOrganizationId(organization.Id);
|
||||
|
||||
if (migrationRecord == null)
|
||||
{
|
||||
logger.LogError(
|
||||
"CB: Cannot reverse migration for organization ({OrganizationID}) as it does not have a migration record",
|
||||
organization.Id);
|
||||
|
||||
throw new Exception();
|
||||
}
|
||||
|
||||
var plan = await pricingClient.GetPlanOrThrow(migrationRecord.PlanType);
|
||||
|
||||
ResetOrganizationPlan(organization, plan);
|
||||
organization.MaxStorageGb = migrationRecord.MaxStorageGb;
|
||||
organization.ExpirationDate = migrationRecord.ExpirationDate;
|
||||
organization.MaxAutoscaleSeats = migrationRecord.MaxAutoscaleSeats;
|
||||
organization.Status = migrationRecord.Status;
|
||||
|
||||
await organizationRepository.ReplaceAsync(organization);
|
||||
|
||||
logger.LogInformation("CB: Reversed organization ({OrganizationID}) updates",
|
||||
organization.Id);
|
||||
|
||||
await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id,
|
||||
ClientMigrationProgress.ResetOrganization);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Shared
|
||||
|
||||
private static void ResetOrganizationPlan(Organization organization, Plan plan)
|
||||
{
|
||||
organization.Plan = plan.Name;
|
||||
organization.PlanType = plan.Type;
|
||||
organization.MaxCollections = plan.PasswordManager.MaxCollections;
|
||||
organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb;
|
||||
organization.UsePolicies = plan.HasPolicies;
|
||||
organization.UseSso = plan.HasSso;
|
||||
organization.UseOrganizationDomains = plan.HasOrganizationDomains;
|
||||
organization.UseGroups = plan.HasGroups;
|
||||
organization.UseEvents = plan.HasEvents;
|
||||
organization.UseDirectory = plan.HasDirectory;
|
||||
organization.UseTotp = plan.HasTotp;
|
||||
organization.Use2fa = plan.Has2fa;
|
||||
organization.UseApi = plan.HasApi;
|
||||
organization.UseResetPassword = plan.HasResetPassword;
|
||||
organization.SelfHost = plan.HasSelfHost;
|
||||
organization.UsersGetPremium = plan.UsersGetPremium;
|
||||
organization.UseCustomPermissions = plan.HasCustomPermissions;
|
||||
organization.UseScim = plan.HasScim;
|
||||
organization.UseKeyConnector = plan.HasKeyConnector;
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
||||
@@ -1,436 +0,0 @@
|
||||
// FIXME: Update this file to be null safe and then delete the line below
|
||||
#nullable disable
|
||||
|
||||
using Bit.Core.AdminConsole.Entities;
|
||||
using Bit.Core.AdminConsole.Entities.Provider;
|
||||
using Bit.Core.AdminConsole.Enums.Provider;
|
||||
using Bit.Core.AdminConsole.Repositories;
|
||||
using Bit.Core.Billing.Constants;
|
||||
using Bit.Core.Billing.Enums;
|
||||
using Bit.Core.Billing.Models;
|
||||
using Bit.Core.Billing.Providers.Entities;
|
||||
using Bit.Core.Billing.Providers.Migration.Models;
|
||||
using Bit.Core.Billing.Providers.Models;
|
||||
using Bit.Core.Billing.Providers.Repositories;
|
||||
using Bit.Core.Billing.Providers.Services;
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Repositories;
|
||||
using Bit.Core.Services;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Stripe;
|
||||
|
||||
namespace Bit.Core.Billing.Providers.Migration.Services.Implementations;
|
||||
|
||||
public class ProviderMigrator(
|
||||
IClientOrganizationMigrationRecordRepository clientOrganizationMigrationRecordRepository,
|
||||
IOrganizationMigrator organizationMigrator,
|
||||
ILogger<ProviderMigrator> logger,
|
||||
IMigrationTrackerCache migrationTrackerCache,
|
||||
IOrganizationRepository organizationRepository,
|
||||
IPaymentService paymentService,
|
||||
IProviderBillingService providerBillingService,
|
||||
IProviderOrganizationRepository providerOrganizationRepository,
|
||||
IProviderRepository providerRepository,
|
||||
IProviderPlanRepository providerPlanRepository,
|
||||
IStripeAdapter stripeAdapter) : IProviderMigrator
|
||||
{
|
||||
public async Task Migrate(Guid providerId)
|
||||
{
|
||||
var provider = await GetProviderAsync(providerId);
|
||||
|
||||
if (provider == null)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
logger.LogInformation("CB: Starting migration for provider ({ProviderID})", providerId);
|
||||
|
||||
await migrationTrackerCache.StartTracker(provider);
|
||||
|
||||
var organizations = await GetClientsAsync(provider.Id);
|
||||
|
||||
if (organizations.Count == 0)
|
||||
{
|
||||
logger.LogInformation("CB: Skipping migration for provider ({ProviderID}) with no clients", providerId);
|
||||
|
||||
await migrationTrackerCache.UpdateTrackingStatus(providerId, ProviderMigrationProgress.NoClients);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
await MigrateClientsAsync(providerId, organizations);
|
||||
|
||||
await ConfigureTeamsPlanAsync(providerId);
|
||||
|
||||
await ConfigureEnterprisePlanAsync(providerId);
|
||||
|
||||
await SetupCustomerAsync(provider);
|
||||
|
||||
await SetupSubscriptionAsync(provider);
|
||||
|
||||
await ApplyCreditAsync(provider);
|
||||
|
||||
await UpdateProviderAsync(provider);
|
||||
}
|
||||
|
||||
public async Task<ProviderMigrationResult> GetResult(Guid providerId)
|
||||
{
|
||||
var providerTracker = await migrationTrackerCache.GetTracker(providerId);
|
||||
|
||||
if (providerTracker == null)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
if (providerTracker.Progress == ProviderMigrationProgress.NoClients)
|
||||
{
|
||||
return new ProviderMigrationResult
|
||||
{
|
||||
ProviderId = providerTracker.ProviderId,
|
||||
ProviderName = providerTracker.ProviderName,
|
||||
Result = providerTracker.Progress.ToString()
|
||||
};
|
||||
}
|
||||
|
||||
var clientTrackers = await Task.WhenAll(providerTracker.OrganizationIds.Select(organizationId =>
|
||||
migrationTrackerCache.GetTracker(providerId, organizationId)));
|
||||
|
||||
var migrationRecordLookup = new Dictionary<Guid, ClientOrganizationMigrationRecord>();
|
||||
|
||||
foreach (var clientTracker in clientTrackers)
|
||||
{
|
||||
var migrationRecord =
|
||||
await clientOrganizationMigrationRecordRepository.GetByOrganizationId(clientTracker.OrganizationId);
|
||||
|
||||
migrationRecordLookup.Add(clientTracker.OrganizationId, migrationRecord);
|
||||
}
|
||||
|
||||
return new ProviderMigrationResult
|
||||
{
|
||||
ProviderId = providerTracker.ProviderId,
|
||||
ProviderName = providerTracker.ProviderName,
|
||||
Result = providerTracker.Progress.ToString(),
|
||||
Clients = clientTrackers.Select(tracker =>
|
||||
{
|
||||
var foundMigrationRecord = migrationRecordLookup.TryGetValue(tracker.OrganizationId, out var migrationRecord);
|
||||
return new ClientMigrationResult
|
||||
{
|
||||
OrganizationId = tracker.OrganizationId,
|
||||
OrganizationName = tracker.OrganizationName,
|
||||
Result = tracker.Progress.ToString(),
|
||||
PreviousState = foundMigrationRecord ? new ClientPreviousState(migrationRecord) : null
|
||||
};
|
||||
}).ToList(),
|
||||
};
|
||||
}
|
||||
|
||||
#region Steps
|
||||
|
||||
private async Task MigrateClientsAsync(Guid providerId, List<Organization> organizations)
|
||||
{
|
||||
logger.LogInformation("CB: Migrating clients for provider ({ProviderID})", providerId);
|
||||
|
||||
var organizationIds = organizations.Select(organization => organization.Id);
|
||||
|
||||
await migrationTrackerCache.SetOrganizationIds(providerId, organizationIds);
|
||||
|
||||
foreach (var organization in organizations)
|
||||
{
|
||||
var tracker = await migrationTrackerCache.GetTracker(providerId, organization.Id);
|
||||
|
||||
if (tracker is not { Progress: ClientMigrationProgress.Completed })
|
||||
{
|
||||
await organizationMigrator.Migrate(providerId, organization);
|
||||
}
|
||||
}
|
||||
|
||||
logger.LogInformation("CB: Migrated clients for provider ({ProviderID})", providerId);
|
||||
|
||||
await migrationTrackerCache.UpdateTrackingStatus(providerId,
|
||||
ProviderMigrationProgress.ClientsMigrated);
|
||||
}
|
||||
|
||||
private async Task ConfigureTeamsPlanAsync(Guid providerId)
|
||||
{
|
||||
logger.LogInformation("CB: Configuring Teams plan for provider ({ProviderID})", providerId);
|
||||
|
||||
var organizations = await GetClientsAsync(providerId);
|
||||
|
||||
var teamsSeats = organizations
|
||||
.Where(IsTeams)
|
||||
.Sum(client => client.Seats) ?? 0;
|
||||
|
||||
var teamsProviderPlan = (await providerPlanRepository.GetByProviderId(providerId))
|
||||
.FirstOrDefault(providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly);
|
||||
|
||||
if (teamsProviderPlan == null)
|
||||
{
|
||||
await providerPlanRepository.CreateAsync(new ProviderPlan
|
||||
{
|
||||
ProviderId = providerId,
|
||||
PlanType = PlanType.TeamsMonthly,
|
||||
SeatMinimum = teamsSeats,
|
||||
PurchasedSeats = 0,
|
||||
AllocatedSeats = teamsSeats
|
||||
});
|
||||
|
||||
logger.LogInformation("CB: Created Teams plan for provider ({ProviderID}) with a seat minimum of {Seats}",
|
||||
providerId, teamsSeats);
|
||||
}
|
||||
else
|
||||
{
|
||||
logger.LogInformation("CB: Teams plan already exists for provider ({ProviderID}), updating seat minimum", providerId);
|
||||
|
||||
teamsProviderPlan.SeatMinimum = teamsSeats;
|
||||
teamsProviderPlan.AllocatedSeats = teamsSeats;
|
||||
|
||||
await providerPlanRepository.ReplaceAsync(teamsProviderPlan);
|
||||
|
||||
logger.LogInformation("CB: Updated Teams plan for provider ({ProviderID}) to seat minimum of {Seats}",
|
||||
providerId, teamsProviderPlan.SeatMinimum);
|
||||
}
|
||||
|
||||
await migrationTrackerCache.UpdateTrackingStatus(providerId, ProviderMigrationProgress.TeamsPlanConfigured);
|
||||
}
|
||||
|
||||
private async Task ConfigureEnterprisePlanAsync(Guid providerId)
|
||||
{
|
||||
logger.LogInformation("CB: Configuring Enterprise plan for provider ({ProviderID})", providerId);
|
||||
|
||||
var organizations = await GetClientsAsync(providerId);
|
||||
|
||||
var enterpriseSeats = organizations
|
||||
.Where(IsEnterprise)
|
||||
.Sum(client => client.Seats) ?? 0;
|
||||
|
||||
var enterpriseProviderPlan = (await providerPlanRepository.GetByProviderId(providerId))
|
||||
.FirstOrDefault(providerPlan => providerPlan.PlanType == PlanType.EnterpriseMonthly);
|
||||
|
||||
if (enterpriseProviderPlan == null)
|
||||
{
|
||||
await providerPlanRepository.CreateAsync(new ProviderPlan
|
||||
{
|
||||
ProviderId = providerId,
|
||||
PlanType = PlanType.EnterpriseMonthly,
|
||||
SeatMinimum = enterpriseSeats,
|
||||
PurchasedSeats = 0,
|
||||
AllocatedSeats = enterpriseSeats
|
||||
});
|
||||
|
||||
logger.LogInformation("CB: Created Enterprise plan for provider ({ProviderID}) with a seat minimum of {Seats}",
|
||||
providerId, enterpriseSeats);
|
||||
}
|
||||
else
|
||||
{
|
||||
logger.LogInformation("CB: Enterprise plan already exists for provider ({ProviderID}), updating seat minimum", providerId);
|
||||
|
||||
enterpriseProviderPlan.SeatMinimum = enterpriseSeats;
|
||||
enterpriseProviderPlan.AllocatedSeats = enterpriseSeats;
|
||||
|
||||
await providerPlanRepository.ReplaceAsync(enterpriseProviderPlan);
|
||||
|
||||
logger.LogInformation("CB: Updated Enterprise plan for provider ({ProviderID}) to seat minimum of {Seats}",
|
||||
providerId, enterpriseProviderPlan.SeatMinimum);
|
||||
}
|
||||
|
||||
await migrationTrackerCache.UpdateTrackingStatus(providerId, ProviderMigrationProgress.EnterprisePlanConfigured);
|
||||
}
|
||||
|
||||
private async Task SetupCustomerAsync(Provider provider)
|
||||
{
|
||||
if (string.IsNullOrEmpty(provider.GatewayCustomerId))
|
||||
{
|
||||
var organizations = await GetClientsAsync(provider.Id);
|
||||
|
||||
var sampleOrganization = organizations.FirstOrDefault(organization => !string.IsNullOrEmpty(organization.GatewayCustomerId));
|
||||
|
||||
if (sampleOrganization == null)
|
||||
{
|
||||
logger.LogInformation(
|
||||
"CB: Could not find sample organization for provider ({ProviderID}) that has a Stripe customer",
|
||||
provider.Id);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
var taxInfo = await paymentService.GetTaxInfoAsync(sampleOrganization);
|
||||
|
||||
// Create dummy payment source for legacy migration - this migrator is deprecated and will be removed
|
||||
var dummyPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "migration_dummy_token");
|
||||
|
||||
var customer = await providerBillingService.SetupCustomer(provider, null, null);
|
||||
|
||||
await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions
|
||||
{
|
||||
Coupon = StripeConstants.CouponIDs.LegacyMSPDiscount
|
||||
});
|
||||
|
||||
provider.GatewayCustomerId = customer.Id;
|
||||
|
||||
await providerRepository.ReplaceAsync(provider);
|
||||
|
||||
logger.LogInformation("CB: Setup Stripe customer for provider ({ProviderID})", provider.Id);
|
||||
}
|
||||
else
|
||||
{
|
||||
logger.LogInformation("CB: Stripe customer already exists for provider ({ProviderID})", provider.Id);
|
||||
}
|
||||
|
||||
await migrationTrackerCache.UpdateTrackingStatus(provider.Id, ProviderMigrationProgress.CustomerSetup);
|
||||
}
|
||||
|
||||
private async Task SetupSubscriptionAsync(Provider provider)
|
||||
{
|
||||
if (string.IsNullOrEmpty(provider.GatewaySubscriptionId))
|
||||
{
|
||||
if (!string.IsNullOrEmpty(provider.GatewayCustomerId))
|
||||
{
|
||||
var subscription = await providerBillingService.SetupSubscription(provider);
|
||||
|
||||
provider.GatewaySubscriptionId = subscription.Id;
|
||||
|
||||
await providerRepository.ReplaceAsync(provider);
|
||||
|
||||
logger.LogInformation("CB: Setup Stripe subscription for provider ({ProviderID})", provider.Id);
|
||||
}
|
||||
else
|
||||
{
|
||||
logger.LogInformation(
|
||||
"CB: Could not set up Stripe subscription for provider ({ProviderID}) with no Stripe customer",
|
||||
provider.Id);
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
logger.LogInformation("CB: Stripe subscription already exists for provider ({ProviderID})", provider.Id);
|
||||
|
||||
var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id);
|
||||
|
||||
var enterpriseSeatMinimum = providerPlans
|
||||
.FirstOrDefault(providerPlan => providerPlan.PlanType == PlanType.EnterpriseMonthly)?
|
||||
.SeatMinimum ?? 0;
|
||||
|
||||
var teamsSeatMinimum = providerPlans
|
||||
.FirstOrDefault(providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly)?
|
||||
.SeatMinimum ?? 0;
|
||||
|
||||
var updateSeatMinimumsCommand = new UpdateProviderSeatMinimumsCommand(
|
||||
provider,
|
||||
[
|
||||
(Plan: PlanType.EnterpriseMonthly, SeatsMinimum: enterpriseSeatMinimum),
|
||||
(Plan: PlanType.TeamsMonthly, SeatsMinimum: teamsSeatMinimum)
|
||||
]);
|
||||
await providerBillingService.UpdateSeatMinimums(updateSeatMinimumsCommand);
|
||||
|
||||
logger.LogInformation(
|
||||
"CB: Updated Stripe subscription for provider ({ProviderID}) with current seat minimums", provider.Id);
|
||||
}
|
||||
|
||||
await migrationTrackerCache.UpdateTrackingStatus(provider.Id, ProviderMigrationProgress.SubscriptionSetup);
|
||||
}
|
||||
|
||||
private async Task ApplyCreditAsync(Provider provider)
|
||||
{
|
||||
var organizations = await GetClientsAsync(provider.Id);
|
||||
|
||||
var organizationCustomers =
|
||||
await Task.WhenAll(organizations.Select(organization => stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId)));
|
||||
|
||||
var organizationCancellationCredit = organizationCustomers.Sum(customer => customer.Balance);
|
||||
|
||||
if (organizationCancellationCredit != 0)
|
||||
{
|
||||
await stripeAdapter.CustomerBalanceTransactionCreate(provider.GatewayCustomerId,
|
||||
new CustomerBalanceTransactionCreateOptions
|
||||
{
|
||||
Amount = organizationCancellationCredit,
|
||||
Currency = "USD",
|
||||
Description = "Unused, prorated time for client organization subscriptions."
|
||||
});
|
||||
}
|
||||
|
||||
var migrationRecords = await Task.WhenAll(organizations.Select(organization =>
|
||||
clientOrganizationMigrationRecordRepository.GetByOrganizationId(organization.Id)));
|
||||
|
||||
var legacyOrganizationMigrationRecords = migrationRecords.Where(migrationRecord =>
|
||||
migrationRecord.PlanType is
|
||||
PlanType.EnterpriseAnnually2020 or
|
||||
PlanType.TeamsAnnually2020);
|
||||
|
||||
var legacyOrganizationCredit = legacyOrganizationMigrationRecords.Sum(migrationRecord => migrationRecord.Seats) * 12 * -100;
|
||||
|
||||
if (legacyOrganizationCredit < 0)
|
||||
{
|
||||
await stripeAdapter.CustomerBalanceTransactionCreate(provider.GatewayCustomerId,
|
||||
new CustomerBalanceTransactionCreateOptions
|
||||
{
|
||||
Amount = legacyOrganizationCredit,
|
||||
Currency = "USD",
|
||||
Description = "1 year rebate for legacy client organizations."
|
||||
});
|
||||
}
|
||||
|
||||
logger.LogInformation("CB: Applied {Credit} credit to provider ({ProviderID})", organizationCancellationCredit + legacyOrganizationCredit, provider.Id);
|
||||
|
||||
await migrationTrackerCache.UpdateTrackingStatus(provider.Id, ProviderMigrationProgress.CreditApplied);
|
||||
}
|
||||
|
||||
private async Task UpdateProviderAsync(Provider provider)
|
||||
{
|
||||
provider.Status = ProviderStatusType.Billable;
|
||||
|
||||
await providerRepository.ReplaceAsync(provider);
|
||||
|
||||
logger.LogInformation("CB: Completed migration for provider ({ProviderID})", provider.Id);
|
||||
|
||||
await migrationTrackerCache.UpdateTrackingStatus(provider.Id, ProviderMigrationProgress.Completed);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Utilities
|
||||
|
||||
private async Task<List<Organization>> GetClientsAsync(Guid providerId)
|
||||
{
|
||||
var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId);
|
||||
|
||||
return (await Task.WhenAll(providerOrganizations.Select(providerOrganization =>
|
||||
organizationRepository.GetByIdAsync(providerOrganization.OrganizationId))))
|
||||
.ToList();
|
||||
}
|
||||
|
||||
private async Task<Provider> GetProviderAsync(Guid providerId)
|
||||
{
|
||||
var provider = await providerRepository.GetByIdAsync(providerId);
|
||||
|
||||
if (provider == null)
|
||||
{
|
||||
logger.LogWarning("CB: Cannot migrate provider ({ProviderID}) as it does not exist", providerId);
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
if (provider.Type != ProviderType.Msp)
|
||||
{
|
||||
logger.LogWarning("CB: Cannot migrate provider ({ProviderID}) as it is not an MSP", providerId);
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
if (provider.Status == ProviderStatusType.Created)
|
||||
{
|
||||
return provider;
|
||||
}
|
||||
|
||||
logger.LogWarning("CB: Cannot migrate provider ({ProviderID}) as it is not in the 'Created' state", providerId);
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private static bool IsEnterprise(Organization organization) => organization.Plan.Contains("Enterprise");
|
||||
private static bool IsTeams(Organization organization) => organization.Plan.Contains("Teams");
|
||||
|
||||
#endregion
|
||||
}
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
using Bit.Core.Billing.Caches;
|
||||
using Bit.Core.Billing.Constants;
|
||||
using Bit.Core.Billing.Extensions;
|
||||
using Bit.Core.Billing.Models;
|
||||
using Bit.Core.Billing.Models.Sales;
|
||||
using Bit.Core.Billing.Tax.Models;
|
||||
@@ -108,7 +109,7 @@ public class PremiumUserBillingService(
|
||||
when subscription.Status == StripeConstants.SubscriptionStatus.Active:
|
||||
{
|
||||
user.Premium = true;
|
||||
user.PremiumExpirationDate = subscription.CurrentPeriodEnd;
|
||||
user.PremiumExpirationDate = subscription.GetCurrentPeriodEnd();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ using Bit.Core.AdminConsole.Entities.Provider;
|
||||
using Bit.Core.AdminConsole.Repositories;
|
||||
using Bit.Core.Billing.Commands;
|
||||
using Bit.Core.Billing.Constants;
|
||||
using Bit.Core.Billing.Extensions;
|
||||
using Bit.Core.Billing.Services;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Repositories;
|
||||
@@ -65,7 +66,7 @@ public class RestartSubscriptionCommand(
|
||||
{
|
||||
organization.GatewaySubscriptionId = subscription.Id;
|
||||
organization.Enabled = true;
|
||||
organization.ExpirationDate = subscription.CurrentPeriodEnd;
|
||||
organization.ExpirationDate = subscription.GetCurrentPeriodEnd();
|
||||
organization.RevisionDate = DateTime.UtcNow;
|
||||
await organizationRepository.ReplaceAsync(organization);
|
||||
break;
|
||||
@@ -82,7 +83,7 @@ public class RestartSubscriptionCommand(
|
||||
{
|
||||
user.GatewaySubscriptionId = subscription.Id;
|
||||
user.Premium = true;
|
||||
user.PremiumExpirationDate = subscription.CurrentPeriodEnd;
|
||||
user.PremiumExpirationDate = subscription.GetCurrentPeriodEnd();
|
||||
user.RevisionDate = DateTime.UtcNow;
|
||||
await userRepository.ReplaceAsync(user);
|
||||
break;
|
||||
|
||||
@@ -140,6 +140,7 @@ public static class FeatureFlagKeys
|
||||
public const string EventBasedOrganizationIntegrations = "event-based-organization-integrations";
|
||||
public const string SeparateCustomRolePermissions = "pm-19917-separate-custom-role-permissions";
|
||||
public const string CreateDefaultLocation = "pm-19467-create-default-location";
|
||||
public const string AutomaticConfirmUsers = "pm-19934-auto-confirm-organization-users";
|
||||
public const string PM23845_VNextApplicationCache = "pm-24957-refactor-memory-application-cache";
|
||||
|
||||
/* Auth Team */
|
||||
@@ -160,6 +161,7 @@ public static class FeatureFlagKeys
|
||||
public const string InlineMenuFieldQualification = "inline-menu-field-qualification";
|
||||
public const string InlineMenuPositioningImprovements = "inline-menu-positioning-improvements";
|
||||
public const string SSHAgent = "ssh-agent";
|
||||
public const string SSHAgentV2 = "ssh-agent-v2";
|
||||
public const string SSHVersionCheckQAOverride = "ssh-version-check-qa-override";
|
||||
public const string GenerateIdentityFillScriptRefactor = "generate-identity-fill-script-refactor";
|
||||
public const string DelayFido2PageScriptInitWithinMv2 = "delay-fido2-page-script-init-within-mv2";
|
||||
@@ -192,6 +194,7 @@ public static class FeatureFlagKeys
|
||||
public const string UserkeyRotationV2 = "userkey-rotation-v2";
|
||||
public const string SSHKeyItemVaultItem = "ssh-key-vault-item";
|
||||
public const string UserSdkForDecryption = "use-sdk-for-decryption";
|
||||
public const string EnrollAeadOnKeyRotation = "enroll-aead-on-key-rotation";
|
||||
public const string PM17987_BlockType0 = "pm-17987-block-type-0";
|
||||
public const string ForceUpdateKDFSettings = "pm-18021-force-update-kdf-settings";
|
||||
public const string UnlockWithMasterPasswordUnlockData = "pm-23246-unlock-with-master-password-unlock-data";
|
||||
@@ -199,24 +202,18 @@ public static class FeatureFlagKeys
|
||||
public const string LinuxBiometricsV2 = "pm-26340-linux-biometrics-v2";
|
||||
public const string NoLogoutOnKdfChange = "pm-23995-no-logout-on-kdf-change";
|
||||
public const string DisableType0Decryption = "pm-25174-disable-type-0-decryption";
|
||||
public const string ConsolidatedSessionTimeoutComponent = "pm-26056-consolidated-session-timeout-component";
|
||||
|
||||
/* Mobile Team */
|
||||
public const string NativeCarouselFlow = "native-carousel-flow";
|
||||
public const string NativeCreateAccountFlow = "native-create-account-flow";
|
||||
public const string AndroidImportLoginsFlow = "import-logins-flow";
|
||||
public const string AppReviewPrompt = "app-review-prompt";
|
||||
public const string AndroidMutualTls = "mutual-tls";
|
||||
public const string SingleTapPasskeyCreation = "single-tap-passkey-creation";
|
||||
public const string SingleTapPasskeyAuthentication = "single-tap-passkey-authentication";
|
||||
public const string EnablePMAuthenticatorSync = "enable-pm-bwa-sync";
|
||||
public const string PM3503_MobileAnonAddySelfHostAlias = "anon-addy-self-host-alias";
|
||||
public const string PM3553_MobileSimpleLoginSelfHostAlias = "simple-login-self-host-alias";
|
||||
public const string EnablePMFlightRecorder = "enable-pm-flight-recorder";
|
||||
public const string MobileErrorReporting = "mobile-error-reporting";
|
||||
public const string AndroidChromeAutofill = "android-chrome-autofill";
|
||||
public const string UserManagedPrivilegedApps = "pm-18970-user-managed-privileged-apps";
|
||||
public const string EnablePMPreloginSettings = "enable-pm-prelogin-settings";
|
||||
public const string AppIntents = "app-intents";
|
||||
public const string SendAccess = "pm-19394-send-access-control";
|
||||
public const string CxpImportMobile = "cxp-import-mobile";
|
||||
public const string CxpExportMobile = "cxp-export-mobile";
|
||||
@@ -229,6 +226,7 @@ public static class FeatureFlagKeys
|
||||
/* Tools Team */
|
||||
public const string DesktopSendUIRefresh = "desktop-send-ui-refresh";
|
||||
public const string UseSdkPasswordGenerators = "pm-19976-use-sdk-password-generators";
|
||||
public const string ChromiumImporterWithABE = "pm-25855-chromium-importer-abe";
|
||||
|
||||
/* Vault Team */
|
||||
public const string PM8851_BrowserOnboardingNudge = "pm-8851-browser-onboarding-nudge";
|
||||
@@ -247,6 +245,7 @@ public static class FeatureFlagKeys
|
||||
|
||||
/* DIRT Team */
|
||||
public const string PM22887_RiskInsightsActivityTab = "pm-22887-risk-insights-activity-tab";
|
||||
public const string EventManagementForDataDogAndCrowdStrike = "event-management-for-datadog-and-crowdstrike";
|
||||
|
||||
public static List<string> GetAllKeys()
|
||||
{
|
||||
|
||||
@@ -57,7 +57,7 @@
|
||||
<PackageReference Include="Serilog.Sinks.SyslogMessages" Version="4.0.0" />
|
||||
<PackageReference Include="AspNetCoreRateLimit" Version="5.0.0" />
|
||||
<PackageReference Include="Braintree" Version="5.28.0" />
|
||||
<PackageReference Include="Stripe.net" Version="45.14.0" />
|
||||
<PackageReference Include="Stripe.net" Version="48.5.0" />
|
||||
<PackageReference Include="Otp.NET" Version="1.4.0" />
|
||||
<PackageReference Include="YubicoDotNetClient" Version="1.2.0" />
|
||||
<PackageReference Include="Microsoft.Extensions.Caching.StackExchangeRedis" Version="8.0.10" />
|
||||
|
||||
@@ -3,6 +3,7 @@ using System.Text.Json;
|
||||
using Bit.Core.Auth.Enums;
|
||||
using Bit.Core.Auth.Models;
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.KeyManagement.Models.Data;
|
||||
using Bit.Core.Utilities;
|
||||
using Microsoft.AspNetCore.Identity;
|
||||
|
||||
@@ -21,6 +22,9 @@ public class User : ITableObject<Guid>, IStorableSubscriber, IRevisable, ITwoFac
|
||||
[MaxLength(256)]
|
||||
public string Email { get; set; } = null!;
|
||||
public bool EmailVerified { get; set; }
|
||||
/// <summary>
|
||||
/// The server-side master-password hash
|
||||
/// </summary>
|
||||
[MaxLength(300)]
|
||||
public string? MasterPassword { get; set; }
|
||||
[MaxLength(50)]
|
||||
@@ -41,9 +45,30 @@ public class User : ITableObject<Guid>, IStorableSubscriber, IRevisable, ITwoFac
|
||||
/// organization membership.
|
||||
/// </summary>
|
||||
public DateTime AccountRevisionDate { get; set; } = DateTime.UtcNow;
|
||||
/// <summary>
|
||||
/// The master-password-sealed user key.
|
||||
/// </summary>
|
||||
public string? Key { get; set; }
|
||||
/// <summary>
|
||||
/// The raw public key, without a signature from the user's signature key.
|
||||
/// </summary>
|
||||
public string? PublicKey { get; set; }
|
||||
/// <summary>
|
||||
/// User key wrapped private key.
|
||||
/// </summary>
|
||||
public string? PrivateKey { get; set; }
|
||||
/// <summary>
|
||||
/// The public key, signed by the user's signature key.
|
||||
/// </summary>
|
||||
public string? SignedPublicKey { get; set; }
|
||||
/// <summary>
|
||||
/// The security version is included in the security state, but needs COSE parsing
|
||||
/// </summary>
|
||||
public int? SecurityVersion { get; set; }
|
||||
/// <summary>
|
||||
/// The security state is a signed object attesting to the version of the user's account.
|
||||
/// </summary>
|
||||
public string? SecurityState { get; set; }
|
||||
public bool Premium { get; set; }
|
||||
public DateTime? PremiumExpirationDate { get; set; }
|
||||
public DateTime? RenewalReminderDate { get; set; }
|
||||
@@ -180,6 +205,12 @@ public class User : ITableObject<Guid>, IStorableSubscriber, IRevisable, ITwoFac
|
||||
return Premium;
|
||||
}
|
||||
|
||||
public int GetSecurityVersion()
|
||||
{
|
||||
// If no security version is set, it is version 1. The minimum initialized version is 2.
|
||||
return SecurityVersion ?? 1;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Serializes the C# object to the User.TwoFactorProviders property in JSON format.
|
||||
/// </summary>
|
||||
@@ -243,4 +274,14 @@ public class User : ITableObject<Guid>, IStorableSubscriber, IRevisable, ITwoFac
|
||||
{
|
||||
return MasterPassword != null;
|
||||
}
|
||||
|
||||
public PublicKeyEncryptionKeyPairData GetPublicKeyEncryptionKeyPair()
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(PrivateKey) || string.IsNullOrWhiteSpace(PublicKey))
|
||||
{
|
||||
throw new InvalidOperationException("User public key encryption key pair is not fully initialized.");
|
||||
}
|
||||
|
||||
return new PublicKeyEncryptionKeyPairData(PrivateKey, PublicKey, SignedPublicKey);
|
||||
}
|
||||
}
|
||||
|
||||
6
src/Core/Enums/PushNotificationLogOutReason.cs
Normal file
6
src/Core/Enums/PushNotificationLogOutReason.cs
Normal file
@@ -0,0 +1,6 @@
|
||||
namespace Bit.Core.Enums;
|
||||
|
||||
public enum PushNotificationLogOutReason : byte
|
||||
{
|
||||
KdfChange = 0
|
||||
}
|
||||
@@ -107,7 +107,7 @@ public abstract class BaseJobsHostedService : IHostedService, IDisposable
|
||||
throw new Exception("Job failed to start after 10 retries.");
|
||||
}
|
||||
|
||||
_logger.LogWarning($"Exception while trying to schedule job: {job.FullName}, {e}");
|
||||
_logger.LogWarning(e, "Exception while trying to schedule job: {JobName}", job.FullName);
|
||||
var random = new Random();
|
||||
await Task.Delay(random.Next(50, 250));
|
||||
}
|
||||
@@ -125,7 +125,7 @@ public abstract class BaseJobsHostedService : IHostedService, IDisposable
|
||||
continue;
|
||||
}
|
||||
|
||||
_logger.LogInformation($"Deleting old job with key {key}");
|
||||
_logger.LogInformation("Deleting old job with key {Key}", key);
|
||||
await _scheduler.DeleteJob(key);
|
||||
}
|
||||
|
||||
@@ -138,7 +138,7 @@ public abstract class BaseJobsHostedService : IHostedService, IDisposable
|
||||
continue;
|
||||
}
|
||||
|
||||
_logger.LogInformation($"Unscheduling old trigger with key {key}");
|
||||
_logger.LogInformation("Unscheduling old trigger with key {Key}", key);
|
||||
await _scheduler.UnscheduleJob(key);
|
||||
}
|
||||
}
|
||||
|
||||
30
src/Core/KeyManagement/Entities/UserSignatureKeyPair.cs
Normal file
30
src/Core/KeyManagement/Entities/UserSignatureKeyPair.cs
Normal file
@@ -0,0 +1,30 @@
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.KeyManagement.Enums;
|
||||
using Bit.Core.KeyManagement.Models.Data;
|
||||
using Bit.Core.Utilities;
|
||||
|
||||
|
||||
namespace Bit.Core.KeyManagement.Entities;
|
||||
|
||||
public class UserSignatureKeyPair : ITableObject<Guid>, IRevisable
|
||||
{
|
||||
public Guid Id { get; set; }
|
||||
public Guid UserId { get; set; }
|
||||
public SignatureAlgorithm SignatureAlgorithm { get; set; }
|
||||
|
||||
public required string VerifyingKey { get; set; }
|
||||
public required string SigningKey { get; set; }
|
||||
|
||||
public DateTime CreationDate { get; set; } = DateTime.UtcNow;
|
||||
public DateTime RevisionDate { get; set; } = DateTime.UtcNow;
|
||||
|
||||
public void SetNewId()
|
||||
{
|
||||
Id = CoreHelpers.GenerateComb();
|
||||
}
|
||||
|
||||
public SignatureKeyPairData ToSignatureKeyPairData()
|
||||
{
|
||||
return new SignatureKeyPairData(SignatureAlgorithm, SigningKey, VerifyingKey);
|
||||
}
|
||||
}
|
||||
9
src/Core/KeyManagement/Enums/SignatureAlgorithm.cs
Normal file
9
src/Core/KeyManagement/Enums/SignatureAlgorithm.cs
Normal file
@@ -0,0 +1,9 @@
|
||||
namespace Bit.Core.KeyManagement.Enums;
|
||||
|
||||
// <summary>
|
||||
// Represents the algorithm / digital signature scheme used for a signature key pair.
|
||||
// </summary>
|
||||
public enum SignatureAlgorithm : byte
|
||||
{
|
||||
Ed25519 = 0
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Exceptions;
|
||||
using Bit.Core.KeyManagement.Models.Data;
|
||||
using Bit.Core.Platform.Push;
|
||||
@@ -18,17 +19,22 @@ public class ChangeKdfCommand : IChangeKdfCommand
|
||||
private readonly IUserRepository _userRepository;
|
||||
private readonly IdentityErrorDescriber _identityErrorDescriber;
|
||||
private readonly ILogger<ChangeKdfCommand> _logger;
|
||||
private readonly IFeatureService _featureService;
|
||||
|
||||
public ChangeKdfCommand(IUserService userService, IPushNotificationService pushService, IUserRepository userRepository, IdentityErrorDescriber describer, ILogger<ChangeKdfCommand> logger)
|
||||
public ChangeKdfCommand(IUserService userService, IPushNotificationService pushService,
|
||||
IUserRepository userRepository, IdentityErrorDescriber describer, ILogger<ChangeKdfCommand> logger,
|
||||
IFeatureService featureService)
|
||||
{
|
||||
_userService = userService;
|
||||
_pushService = pushService;
|
||||
_userRepository = userRepository;
|
||||
_identityErrorDescriber = describer;
|
||||
_logger = logger;
|
||||
_featureService = featureService;
|
||||
}
|
||||
|
||||
public async Task<IdentityResult> ChangeKdfAsync(User user, string masterPasswordAuthenticationHash, MasterPasswordAuthenticationData authenticationData, MasterPasswordUnlockData unlockData)
|
||||
public async Task<IdentityResult> ChangeKdfAsync(User user, string masterPasswordAuthenticationHash,
|
||||
MasterPasswordAuthenticationData authenticationData, MasterPasswordUnlockData unlockData)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(user);
|
||||
if (!await _userService.CheckPasswordAsync(user, masterPasswordAuthenticationHash))
|
||||
@@ -37,8 +43,8 @@ public class ChangeKdfCommand : IChangeKdfCommand
|
||||
}
|
||||
|
||||
// Validate to prevent user account from becoming un-decryptable from invalid parameters
|
||||
//
|
||||
// Prevent a de-synced salt value from creating an un-decryptable unlock method
|
||||
//
|
||||
// Prevent a de-synced salt value from creating an un-decryptable unlock method
|
||||
authenticationData.ValidateSaltUnchangedForUser(user);
|
||||
unlockData.ValidateSaltUnchangedForUser(user);
|
||||
|
||||
@@ -47,12 +53,15 @@ public class ChangeKdfCommand : IChangeKdfCommand
|
||||
{
|
||||
throw new BadRequestException("KDF settings must be equal for authentication and unlock.");
|
||||
}
|
||||
|
||||
var validationErrors = KdfSettingsValidator.Validate(unlockData.Kdf);
|
||||
if (validationErrors.Any())
|
||||
{
|
||||
throw new BadRequestException("KDF settings are invalid.");
|
||||
}
|
||||
|
||||
var logoutOnKdfChange = !_featureService.IsEnabled(FeatureFlagKeys.NoLogoutOnKdfChange);
|
||||
|
||||
// Update the user with the new KDF settings
|
||||
// This updates the authentication data and unlock data for the user separately. Currently these still
|
||||
// use shared values for KDF settings and salt.
|
||||
@@ -68,7 +77,8 @@ public class ChangeKdfCommand : IChangeKdfCommand
|
||||
// This entire operation MUST be atomic to prevent a user from being locked out of their account.
|
||||
// Salt is ensured to be the same as unlock data, and the value stored in the account and not updated.
|
||||
// KDF is ensured to be the same as unlock data above and updated below.
|
||||
var result = await _userService.UpdatePasswordHash(user, authenticationData.MasterPasswordAuthenticationHash);
|
||||
var result = await _userService.UpdatePasswordHash(user, authenticationData.MasterPasswordAuthenticationHash,
|
||||
refreshStamp: logoutOnKdfChange);
|
||||
if (!result.Succeeded)
|
||||
{
|
||||
_logger.LogWarning("Change KDF failed for user {userId}.", user.Id);
|
||||
@@ -88,7 +98,17 @@ public class ChangeKdfCommand : IChangeKdfCommand
|
||||
user.LastKdfChangeDate = now;
|
||||
|
||||
await _userRepository.ReplaceAsync(user);
|
||||
await _pushService.PushLogOutAsync(user.Id);
|
||||
if (logoutOnKdfChange)
|
||||
{
|
||||
await _pushService.PushLogOutAsync(user.Id);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Clients that support the new feature flag will ignore the logout when it matches the reason and the feature flag is enabled.
|
||||
await _pushService.PushLogOutAsync(user.Id, reason: PushNotificationLogOutReason.KdfChange);
|
||||
await _pushService.PushSyncSettingsAsync(user.Id);
|
||||
}
|
||||
|
||||
return IdentityResult.Success;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
using Bit.Core.KeyManagement.Commands.Interfaces;
|
||||
using Bit.Core.KeyManagement.Kdf;
|
||||
using Bit.Core.KeyManagement.Kdf.Implementations;
|
||||
using Bit.Core.KeyManagement.Queries;
|
||||
using Bit.Core.KeyManagement.Queries.Interfaces;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
|
||||
namespace Bit.Core.KeyManagement;
|
||||
@@ -11,6 +13,7 @@ public static class KeyManagementServiceCollectionExtensions
|
||||
public static void AddKeyManagementServices(this IServiceCollection services)
|
||||
{
|
||||
services.AddKeyManagementCommands();
|
||||
services.AddKeyManagementQueries();
|
||||
services.AddSendPasswordServices();
|
||||
}
|
||||
|
||||
@@ -19,4 +22,9 @@ public static class KeyManagementServiceCollectionExtensions
|
||||
services.AddScoped<IRegenerateUserAsymmetricKeysCommand, RegenerateUserAsymmetricKeysCommand>();
|
||||
services.AddScoped<IChangeKdfCommand, ChangeKdfCommand>();
|
||||
}
|
||||
|
||||
private static void AddKeyManagementQueries(this IServiceCollection services)
|
||||
{
|
||||
services.AddScoped<IUserAccountKeysQuery, UserAccountKeysQuery>();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
using System.ComponentModel.DataAnnotations;
|
||||
using System.Text.Json.Serialization;
|
||||
using Bit.Core.KeyManagement.Models.Data;
|
||||
|
||||
namespace Bit.Core.KeyManagement.Models.Api.Request;
|
||||
|
||||
public class SecurityStateModel
|
||||
{
|
||||
[StringLength(1000)]
|
||||
[JsonPropertyName("securityState")]
|
||||
public required string SecurityState { get; set; }
|
||||
[JsonPropertyName("securityVersion")]
|
||||
public required int SecurityVersion { get; set; }
|
||||
|
||||
public SecurityStateData ToSecurityState()
|
||||
{
|
||||
return new SecurityStateData
|
||||
{
|
||||
SecurityState = SecurityState,
|
||||
SecurityVersion = SecurityVersion
|
||||
};
|
||||
}
|
||||
|
||||
public static SecurityStateModel FromSecurityStateData(SecurityStateData data)
|
||||
{
|
||||
return new SecurityStateModel
|
||||
{
|
||||
SecurityState = data.SecurityState,
|
||||
SecurityVersion = data.SecurityVersion
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Utilities;
|
||||
|
||||
namespace Bit.Core.KeyManagement.Models.Response;
|
||||
namespace Bit.Core.KeyManagement.Models.Api.Response;
|
||||
|
||||
public class MasterPasswordUnlockResponseModel
|
||||
{
|
||||
@@ -0,0 +1,48 @@
|
||||
using System.Text.Json.Serialization;
|
||||
using Bit.Core.KeyManagement.Models.Api.Request;
|
||||
using Bit.Core.KeyManagement.Models.Data;
|
||||
using Bit.Core.Models.Api;
|
||||
|
||||
namespace Bit.Core.KeyManagement.Models.Api.Response;
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// This response model is used to return the asymmetric encryption keys,
|
||||
/// and signature keys of an entity. This includes the private keys of the key pairs,
|
||||
/// (private key, signing key), and the public keys of the key pairs (unsigned public key,
|
||||
/// signed public key, verification key).
|
||||
/// </summary>
|
||||
public class PrivateKeysResponseModel : ResponseModel
|
||||
{
|
||||
// Not all accounts have signature keys, but all accounts have public encryption keys.
|
||||
[JsonPropertyName("signatureKeyPair")]
|
||||
public SignatureKeyPairResponseModel? SignatureKeyPair { get; set; }
|
||||
|
||||
[JsonPropertyName("publicKeyEncryptionKeyPair")]
|
||||
public required PublicKeyEncryptionKeyPairResponseModel PublicKeyEncryptionKeyPair { get; set; }
|
||||
|
||||
[JsonPropertyName("securityState")]
|
||||
public SecurityStateModel? SecurityState { get; set; }
|
||||
|
||||
[System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute]
|
||||
public PrivateKeysResponseModel(UserAccountKeysData accountKeys) : base("privateKeys")
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(accountKeys);
|
||||
PublicKeyEncryptionKeyPair = new PublicKeyEncryptionKeyPairResponseModel(accountKeys.PublicKeyEncryptionKeyPairData);
|
||||
|
||||
if (accountKeys.SignatureKeyPairData != null && accountKeys.SecurityStateData != null)
|
||||
{
|
||||
SignatureKeyPair = new SignatureKeyPairResponseModel(accountKeys.SignatureKeyPairData);
|
||||
SecurityState = SecurityStateModel.FromSecurityStateData(accountKeys.SecurityStateData!);
|
||||
}
|
||||
}
|
||||
|
||||
[JsonConstructor]
|
||||
public PrivateKeysResponseModel(SignatureKeyPairResponseModel? signatureKeyPair, PublicKeyEncryptionKeyPairResponseModel publicKeyEncryptionKeyPair, SecurityStateModel? securityState)
|
||||
: base("privateKeys")
|
||||
{
|
||||
SignatureKeyPair = signatureKeyPair;
|
||||
PublicKeyEncryptionKeyPair = publicKeyEncryptionKeyPair ?? throw new ArgumentNullException(nameof(publicKeyEncryptionKeyPair));
|
||||
SecurityState = securityState;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
using System.Text.Json.Serialization;
|
||||
using Bit.Core.KeyManagement.Models.Data;
|
||||
using Bit.Core.Models.Api;
|
||||
|
||||
namespace Bit.Core.KeyManagement.Models.Api.Response;
|
||||
|
||||
|
||||
public class PublicKeyEncryptionKeyPairResponseModel : ResponseModel
|
||||
{
|
||||
[JsonPropertyName("wrappedPrivateKey")]
|
||||
public required string WrappedPrivateKey { get; set; }
|
||||
[JsonPropertyName("publicKey")]
|
||||
public required string PublicKey { get; set; }
|
||||
[JsonPropertyName("signedPublicKey")]
|
||||
public string? SignedPublicKey { get; set; }
|
||||
|
||||
[System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute]
|
||||
public PublicKeyEncryptionKeyPairResponseModel(PublicKeyEncryptionKeyPairData keyPair)
|
||||
: base("publicKeyEncryptionKeyPair")
|
||||
{
|
||||
WrappedPrivateKey = keyPair.WrappedPrivateKey;
|
||||
PublicKey = keyPair.PublicKey;
|
||||
SignedPublicKey = keyPair.SignedPublicKey;
|
||||
}
|
||||
|
||||
[JsonConstructor]
|
||||
public PublicKeyEncryptionKeyPairResponseModel(string wrappedPrivateKey, string publicKey, string? signedPublicKey)
|
||||
: base("publicKeyEncryptionKeyPair")
|
||||
{
|
||||
WrappedPrivateKey = wrappedPrivateKey ?? throw new ArgumentNullException(nameof(wrappedPrivateKey));
|
||||
PublicKey = publicKey ?? throw new ArgumentNullException(nameof(publicKey));
|
||||
SignedPublicKey = signedPublicKey;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
using Bit.Core.KeyManagement.Models.Data;
|
||||
using Bit.Core.Models.Api;
|
||||
|
||||
namespace Bit.Core.KeyManagement.Models.Api.Response;
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// This response model is used to return the public keys of a user, to any other registered user or entity on the server.
|
||||
/// It can contain public keys (signature/encryption), and proofs between the two. It does not contain (encrypted) private keys.
|
||||
/// </summary>
|
||||
public class PublicKeysResponseModel : ResponseModel
|
||||
{
|
||||
[System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute]
|
||||
public PublicKeysResponseModel(UserAccountKeysData accountKeys)
|
||||
: base("publicKeys")
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(accountKeys);
|
||||
PublicKey = accountKeys.PublicKeyEncryptionKeyPairData.PublicKey;
|
||||
|
||||
if (accountKeys.SignatureKeyPairData != null)
|
||||
{
|
||||
SignedPublicKey = accountKeys.PublicKeyEncryptionKeyPairData.SignedPublicKey;
|
||||
VerifyingKey = accountKeys.SignatureKeyPairData.VerifyingKey;
|
||||
}
|
||||
}
|
||||
|
||||
public string? VerifyingKey { get; set; }
|
||||
public string? SignedPublicKey { get; set; }
|
||||
public required string PublicKey { get; set; }
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
using System.Text.Json.Serialization;
|
||||
using Bit.Core.KeyManagement.Models.Data;
|
||||
using Bit.Core.Models.Api;
|
||||
|
||||
namespace Bit.Core.KeyManagement.Models.Api.Response;
|
||||
|
||||
|
||||
public class SignatureKeyPairResponseModel : ResponseModel
|
||||
{
|
||||
[JsonPropertyName("wrappedSigningKey")]
|
||||
public required string WrappedSigningKey { get; set; }
|
||||
[JsonPropertyName("verifyingKey")]
|
||||
public required string VerifyingKey { get; set; }
|
||||
|
||||
[System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute]
|
||||
public SignatureKeyPairResponseModel(SignatureKeyPairData signatureKeyPair)
|
||||
: base("signatureKeyPair")
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(signatureKeyPair);
|
||||
WrappedSigningKey = signatureKeyPair.WrappedSigningKey;
|
||||
VerifyingKey = signatureKeyPair.VerifyingKey;
|
||||
}
|
||||
|
||||
|
||||
[JsonConstructor]
|
||||
public SignatureKeyPairResponseModel(string wrappedSigningKey, string verifyingKey)
|
||||
: base("signatureKeyPair")
|
||||
{
|
||||
WrappedSigningKey = wrappedSigningKey ?? throw new ArgumentNullException(nameof(wrappedSigningKey));
|
||||
VerifyingKey = verifyingKey ?? throw new ArgumentNullException(nameof(verifyingKey));
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
namespace Bit.Core.KeyManagement.Models.Response;
|
||||
namespace Bit.Core.KeyManagement.Models.Api.Response;
|
||||
|
||||
public class UserDecryptionResponseModel
|
||||
{
|
||||
@@ -1,4 +1,5 @@
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Exceptions;
|
||||
|
||||
namespace Bit.Core.KeyManagement.Models.Data;
|
||||
|
||||
@@ -12,7 +13,7 @@ public class MasterPasswordAuthenticationData
|
||||
{
|
||||
if (user.GetMasterPasswordSalt() != Salt)
|
||||
{
|
||||
throw new ArgumentException("Invalid master password salt.");
|
||||
throw new BadRequestException("Invalid master password salt.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#nullable enable
|
||||
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Exceptions;
|
||||
|
||||
namespace Bit.Core.KeyManagement.Models.Data;
|
||||
|
||||
@@ -14,7 +13,7 @@ public class MasterPasswordUnlockData
|
||||
{
|
||||
if (user.GetMasterPasswordSalt() != Salt)
|
||||
{
|
||||
throw new ArgumentException("Invalid master password salt.");
|
||||
throw new BadRequestException("Invalid master password salt.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
using System.Text.Json.Serialization;
|
||||
|
||||
namespace Bit.Core.KeyManagement.Models.Data;
|
||||
|
||||
|
||||
public class PublicKeyEncryptionKeyPairData
|
||||
{
|
||||
public required string WrappedPrivateKey { get; set; }
|
||||
public string? SignedPublicKey { get; set; }
|
||||
public required string PublicKey { get; set; }
|
||||
|
||||
[JsonConstructor]
|
||||
[System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute]
|
||||
public PublicKeyEncryptionKeyPairData(string wrappedPrivateKey, string publicKey, string? signedPublicKey = null)
|
||||
{
|
||||
WrappedPrivateKey = wrappedPrivateKey ?? throw new ArgumentNullException(nameof(wrappedPrivateKey));
|
||||
PublicKey = publicKey ?? throw new ArgumentNullException(nameof(publicKey));
|
||||
SignedPublicKey = signedPublicKey;
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,4 @@
|
||||
// FIXME: Update this file to be null safe and then delete the line below
|
||||
#nullable disable
|
||||
|
||||
|
||||
using Bit.Core.Auth.Entities;
|
||||
using Bit.Core.Auth.Models.Data;
|
||||
using Bit.Core.Entities;
|
||||
@@ -12,21 +10,19 @@ namespace Bit.Core.KeyManagement.Models.Data;
|
||||
public class RotateUserAccountKeysData
|
||||
{
|
||||
// Authentication for this requests
|
||||
public string OldMasterKeyAuthenticationHash { get; set; }
|
||||
public required string OldMasterKeyAuthenticationHash { get; set; }
|
||||
|
||||
// Other keys encrypted by the userkey
|
||||
public string UserKeyEncryptedAccountPrivateKey { get; set; }
|
||||
public string AccountPublicKey { get; set; }
|
||||
public required UserAccountKeysData AccountKeys { get; set; }
|
||||
|
||||
// All methods to get to the userkey
|
||||
public MasterPasswordUnlockAndAuthenticationData MasterPasswordUnlockData { get; set; }
|
||||
public IEnumerable<EmergencyAccess> EmergencyAccesses { get; set; }
|
||||
public IReadOnlyList<OrganizationUser> OrganizationUsers { get; set; }
|
||||
public IEnumerable<WebAuthnLoginRotateKeyData> WebAuthnKeys { get; set; }
|
||||
public IEnumerable<Device> DeviceKeys { get; set; }
|
||||
public required MasterPasswordUnlockAndAuthenticationData MasterPasswordUnlockData { get; set; }
|
||||
public required IEnumerable<EmergencyAccess> EmergencyAccesses { get; set; }
|
||||
public required IReadOnlyList<OrganizationUser> OrganizationUsers { get; set; }
|
||||
public required IEnumerable<WebAuthnLoginRotateKeyData> WebAuthnKeys { get; set; }
|
||||
public required IEnumerable<Device> DeviceKeys { get; set; }
|
||||
|
||||
// User vault data encrypted by the userkey
|
||||
public IEnumerable<Cipher> Ciphers { get; set; }
|
||||
public IEnumerable<Folder> Folders { get; set; }
|
||||
public IReadOnlyList<Send> Sends { get; set; }
|
||||
public required IEnumerable<Cipher> Ciphers { get; set; }
|
||||
public required IEnumerable<Folder> Folders { get; set; }
|
||||
public required IReadOnlyList<Send> Sends { get; set; }
|
||||
}
|
||||
|
||||
10
src/Core/KeyManagement/Models/Data/SecurityStateData.cs
Normal file
10
src/Core/KeyManagement/Models/Data/SecurityStateData.cs
Normal file
@@ -0,0 +1,10 @@
|
||||
|
||||
namespace Bit.Core.KeyManagement.Models.Data;
|
||||
|
||||
public class SecurityStateData
|
||||
{
|
||||
public required string SecurityState { get; set; }
|
||||
// The security version is included in the security state, but needs COSE parsing,
|
||||
// so this is a separate copy that can be used directly.
|
||||
public required int SecurityVersion { get; set; }
|
||||
}
|
||||
21
src/Core/KeyManagement/Models/Data/SignatureKeyPairData.cs
Normal file
21
src/Core/KeyManagement/Models/Data/SignatureKeyPairData.cs
Normal file
@@ -0,0 +1,21 @@
|
||||
|
||||
using System.Text.Json.Serialization;
|
||||
using Bit.Core.KeyManagement.Enums;
|
||||
|
||||
namespace Bit.Core.KeyManagement.Models.Data;
|
||||
|
||||
public class SignatureKeyPairData
|
||||
{
|
||||
public required SignatureAlgorithm SignatureAlgorithm { get; set; }
|
||||
public required string WrappedSigningKey { get; set; }
|
||||
public required string VerifyingKey { get; set; }
|
||||
|
||||
[JsonConstructor]
|
||||
[System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute]
|
||||
public SignatureKeyPairData(SignatureAlgorithm signatureAlgorithm, string wrappedSigningKey, string verifyingKey)
|
||||
{
|
||||
SignatureAlgorithm = signatureAlgorithm;
|
||||
WrappedSigningKey = wrappedSigningKey ?? throw new ArgumentNullException(nameof(wrappedSigningKey));
|
||||
VerifyingKey = verifyingKey ?? throw new ArgumentNullException(nameof(verifyingKey));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
namespace Bit.Core.KeyManagement.Models.Data;
|
||||
|
||||
|
||||
public class UserAccountKeysData
|
||||
{
|
||||
public required PublicKeyEncryptionKeyPairData PublicKeyEncryptionKeyPairData { get; set; }
|
||||
public SignatureKeyPairData? SignatureKeyPairData { get; set; }
|
||||
public SecurityStateData? SecurityStateData { get; set; }
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.KeyManagement.Models.Data;
|
||||
|
||||
namespace Bit.Core.KeyManagement.Queries.Interfaces;
|
||||
|
||||
public interface IUserAccountKeysQuery
|
||||
{
|
||||
Task<UserAccountKeysData> Run(User user);
|
||||
}
|
||||
35
src/Core/KeyManagement/Queries/UserAccountKeysQuery.cs
Normal file
35
src/Core/KeyManagement/Queries/UserAccountKeysQuery.cs
Normal file
@@ -0,0 +1,35 @@
|
||||
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.KeyManagement.Models.Data;
|
||||
using Bit.Core.KeyManagement.Queries.Interfaces;
|
||||
using Bit.Core.KeyManagement.Repositories;
|
||||
|
||||
namespace Bit.Core.KeyManagement.Queries;
|
||||
|
||||
|
||||
public class UserAccountKeysQuery(IUserSignatureKeyPairRepository signatureKeyPairRepository) : IUserAccountKeysQuery
|
||||
{
|
||||
public async Task<UserAccountKeysData> Run(User user)
|
||||
{
|
||||
if (user.GetSecurityVersion() < 2)
|
||||
{
|
||||
return new UserAccountKeysData
|
||||
{
|
||||
PublicKeyEncryptionKeyPairData = user.GetPublicKeyEncryptionKeyPair(),
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
return new UserAccountKeysData
|
||||
{
|
||||
PublicKeyEncryptionKeyPairData = user.GetPublicKeyEncryptionKeyPair(),
|
||||
SignatureKeyPairData = await signatureKeyPairRepository.GetByUserIdAsync(user.Id),
|
||||
SecurityStateData = new SecurityStateData
|
||||
{
|
||||
SecurityState = user.SecurityState!,
|
||||
SecurityVersion = user.GetSecurityVersion(),
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
using Bit.Core.KeyManagement.Entities;
|
||||
using Bit.Core.KeyManagement.Models.Data;
|
||||
using Bit.Core.KeyManagement.UserKey;
|
||||
using Bit.Core.Repositories;
|
||||
|
||||
namespace Bit.Core.KeyManagement.Repositories;
|
||||
|
||||
public interface IUserSignatureKeyPairRepository : IRepository<UserSignatureKeyPair, Guid>
|
||||
{
|
||||
public Task<SignatureKeyPairData?> GetByUserIdAsync(Guid userId);
|
||||
public UpdateEncryptedDataForKeyRotation UpdateForKeyRotation(Guid grantorId, SignatureKeyPairData signatureKeyPair);
|
||||
public UpdateEncryptedDataForKeyRotation SetUserSignatureKeyPair(Guid userId, SignatureKeyPairData signatureKeyPair);
|
||||
}
|
||||
@@ -1,6 +1,11 @@
|
||||
using Bit.Core.Auth.Repositories;
|
||||
// FIXME: Update this file to be null safe and then delete the line below
|
||||
#nullable disable
|
||||
|
||||
using Bit.Core.Auth.Repositories;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.KeyManagement.Models.Data;
|
||||
using Bit.Core.KeyManagement.Repositories;
|
||||
using Bit.Core.Platform.Push;
|
||||
using Bit.Core.Repositories;
|
||||
using Bit.Core.Services;
|
||||
@@ -25,6 +30,8 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand
|
||||
private readonly IdentityErrorDescriber _identityErrorDescriber;
|
||||
private readonly IWebAuthnCredentialRepository _credentialRepository;
|
||||
private readonly IPasswordHasher<User> _passwordHasher;
|
||||
private readonly IUserSignatureKeyPairRepository _userSignatureKeyPairRepository;
|
||||
private readonly IFeatureService _featureService;
|
||||
|
||||
/// <summary>
|
||||
/// Instantiates a new <see cref="RotateUserAccountKeysCommand"/>
|
||||
@@ -36,16 +43,19 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand
|
||||
/// <param name="sendRepository">Provides a method to update re-encrypted send data</param>
|
||||
/// <param name="emergencyAccessRepository">Provides a method to update re-encrypted emergency access data</param>
|
||||
/// <param name="organizationUserRepository">Provides a method to update re-encrypted organization user data</param>
|
||||
/// <param name="deviceRepository">Provides a method to update re-encrypted device keys</param>
|
||||
/// <param name="passwordHasher">Hashes the new master password</param>
|
||||
/// <param name="pushService">Logs out user from other devices after successful rotation</param>
|
||||
/// <param name="errors">Provides a password mismatch error if master password hash validation fails</param>
|
||||
/// <param name="credentialRepository">Provides a method to update re-encrypted WebAuthn keys</param>
|
||||
/// <param name="userSignatureKeyPairRepository">Provides a method to update re-encrypted signature keys</param>
|
||||
public RotateUserAccountKeysCommand(IUserService userService, IUserRepository userRepository,
|
||||
ICipherRepository cipherRepository, IFolderRepository folderRepository, ISendRepository sendRepository,
|
||||
IEmergencyAccessRepository emergencyAccessRepository, IOrganizationUserRepository organizationUserRepository,
|
||||
IDeviceRepository deviceRepository,
|
||||
IPasswordHasher<User> passwordHasher,
|
||||
IPushNotificationService pushService, IdentityErrorDescriber errors, IWebAuthnCredentialRepository credentialRepository,
|
||||
IUserSignatureKeyPairRepository userSignatureKeyPairRepository,
|
||||
IFeatureService featureService)
|
||||
{
|
||||
_userService = userService;
|
||||
@@ -60,6 +70,8 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand
|
||||
_identityErrorDescriber = errors;
|
||||
_credentialRepository = credentialRepository;
|
||||
_passwordHasher = passwordHasher;
|
||||
_userSignatureKeyPairRepository = userSignatureKeyPairRepository;
|
||||
_featureService = featureService;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
@@ -80,50 +92,106 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand
|
||||
user.LastKeyRotationDate = now;
|
||||
user.SecurityStamp = Guid.NewGuid().ToString();
|
||||
|
||||
if (
|
||||
!model.MasterPasswordUnlockData.ValidateForUser(user)
|
||||
)
|
||||
List<UpdateEncryptedDataForKeyRotation> saveEncryptedDataActions = [];
|
||||
|
||||
await UpdateAccountKeysAsync(model, user, saveEncryptedDataActions);
|
||||
UpdateUnlockMethods(model, user, saveEncryptedDataActions);
|
||||
UpdateUserData(model, user, saveEncryptedDataActions);
|
||||
|
||||
await _userRepository.UpdateUserKeyAndEncryptedDataV2Async(user, saveEncryptedDataActions);
|
||||
await _pushService.PushLogOutAsync(user.Id);
|
||||
return IdentityResult.Success;
|
||||
}
|
||||
|
||||
public async Task RotateV2AccountKeysAsync(RotateUserAccountKeysData model, User user, List<UpdateEncryptedDataForKeyRotation> saveEncryptedDataActions)
|
||||
{
|
||||
ValidateV2Encryption(model);
|
||||
await ValidateVerifyingKeyUnchangedAsync(model, user);
|
||||
|
||||
saveEncryptedDataActions.Add(_userSignatureKeyPairRepository.UpdateForKeyRotation(user.Id, model.AccountKeys.SignatureKeyPairData));
|
||||
user.SignedPublicKey = model.AccountKeys.PublicKeyEncryptionKeyPairData.SignedPublicKey;
|
||||
user.SecurityState = model.AccountKeys.SecurityStateData!.SecurityState;
|
||||
user.SecurityVersion = model.AccountKeys.SecurityStateData.SecurityVersion;
|
||||
}
|
||||
|
||||
public void UpgradeV1ToV2Keys(RotateUserAccountKeysData model, User user, List<UpdateEncryptedDataForKeyRotation> saveEncryptedDataActions)
|
||||
{
|
||||
ValidateV2Encryption(model);
|
||||
saveEncryptedDataActions.Add(_userSignatureKeyPairRepository.SetUserSignatureKeyPair(user.Id, model.AccountKeys.SignatureKeyPairData));
|
||||
user.SignedPublicKey = model.AccountKeys.PublicKeyEncryptionKeyPairData.SignedPublicKey;
|
||||
user.SecurityState = model.AccountKeys.SecurityStateData!.SecurityState;
|
||||
user.SecurityVersion = model.AccountKeys.SecurityStateData.SecurityVersion;
|
||||
}
|
||||
|
||||
public async Task UpdateAccountKeysAsync(RotateUserAccountKeysData model, User user, List<UpdateEncryptedDataForKeyRotation> saveEncryptedDataActions)
|
||||
{
|
||||
ValidatePublicKeyEncryptionKeyPairUnchanged(model, user);
|
||||
|
||||
if (IsV2EncryptionUserAsync(user))
|
||||
{
|
||||
throw new InvalidOperationException("The provided master password unlock data is not valid for this user.");
|
||||
await RotateV2AccountKeysAsync(model, user, saveEncryptedDataActions);
|
||||
}
|
||||
if (
|
||||
model.AccountPublicKey != user.PublicKey
|
||||
)
|
||||
else if (model.AccountKeys.SignatureKeyPairData != null)
|
||||
{
|
||||
throw new InvalidOperationException("The provided account public key does not match the user's current public key, and changing the account asymmetric keypair is currently not supported during key rotation.");
|
||||
UpgradeV1ToV2Keys(model, user, saveEncryptedDataActions);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (GetEncryptionType(model.AccountKeys.PublicKeyEncryptionKeyPairData.WrappedPrivateKey) != EncryptionType.AesCbc256_HmacSha256_B64)
|
||||
{
|
||||
throw new InvalidOperationException("The provided account private key was not wrapped with AES-256-CBC-HMAC");
|
||||
}
|
||||
// V1 user to V1 user rotation needs to further changes, the private key was re-encrypted.
|
||||
}
|
||||
|
||||
user.Key = model.MasterPasswordUnlockData.MasterKeyEncryptedUserKey;
|
||||
user.PrivateKey = model.UserKeyEncryptedAccountPrivateKey;
|
||||
user.MasterPassword = _passwordHasher.HashPassword(user, model.MasterPasswordUnlockData.MasterKeyAuthenticationHash);
|
||||
user.MasterPasswordHint = model.MasterPasswordUnlockData.MasterPasswordHint;
|
||||
// Private key is re-wrapped with new user key by client
|
||||
user.PrivateKey = model.AccountKeys.PublicKeyEncryptionKeyPairData.WrappedPrivateKey;
|
||||
}
|
||||
|
||||
public void UpdateUserData(RotateUserAccountKeysData model, User user, List<UpdateEncryptedDataForKeyRotation> saveEncryptedDataActions)
|
||||
{
|
||||
// The revision date has to be updated so that de-synced clients don't accidentally post over the re-encrypted data
|
||||
// with an old-user key-encrypted copy
|
||||
var now = DateTime.UtcNow;
|
||||
|
||||
List<UpdateEncryptedDataForKeyRotation> saveEncryptedDataActions = new();
|
||||
if (model.Ciphers.Any())
|
||||
{
|
||||
saveEncryptedDataActions.Add(_cipherRepository.UpdateForKeyRotation(user.Id, model.Ciphers));
|
||||
var ciphersWithUpdatedDate = model.Ciphers.ToList().Select(c => { c.RevisionDate = now; return c; });
|
||||
saveEncryptedDataActions.Add(_cipherRepository.UpdateForKeyRotation(user.Id, ciphersWithUpdatedDate));
|
||||
}
|
||||
|
||||
if (model.Folders.Any())
|
||||
{
|
||||
saveEncryptedDataActions.Add(_folderRepository.UpdateForKeyRotation(user.Id, model.Folders));
|
||||
var foldersWithUpdatedDate = model.Folders.ToList().Select(f => { f.RevisionDate = now; return f; });
|
||||
saveEncryptedDataActions.Add(_folderRepository.UpdateForKeyRotation(user.Id, foldersWithUpdatedDate));
|
||||
}
|
||||
|
||||
if (model.Sends.Any())
|
||||
{
|
||||
saveEncryptedDataActions.Add(_sendRepository.UpdateForKeyRotation(user.Id, model.Sends));
|
||||
var sendsWithUpdatedDate = model.Sends.ToList().Select(s => { s.RevisionDate = now; return s; });
|
||||
saveEncryptedDataActions.Add(_sendRepository.UpdateForKeyRotation(user.Id, sendsWithUpdatedDate));
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateUnlockMethods(RotateUserAccountKeysData model, User user, List<UpdateEncryptedDataForKeyRotation> saveEncryptedDataActions)
|
||||
{
|
||||
if (!model.MasterPasswordUnlockData.ValidateForUser(user))
|
||||
{
|
||||
throw new InvalidOperationException("The provided master password unlock data is not valid for this user.");
|
||||
}
|
||||
// Update master password authentication & unlock
|
||||
user.Key = model.MasterPasswordUnlockData.MasterKeyEncryptedUserKey;
|
||||
user.MasterPassword = _passwordHasher.HashPassword(user, model.MasterPasswordUnlockData.MasterKeyAuthenticationHash);
|
||||
user.MasterPasswordHint = model.MasterPasswordUnlockData.MasterPasswordHint;
|
||||
|
||||
if (model.EmergencyAccesses.Any())
|
||||
{
|
||||
saveEncryptedDataActions.Add(
|
||||
_emergencyAccessRepository.UpdateForKeyRotation(user.Id, model.EmergencyAccesses));
|
||||
saveEncryptedDataActions.Add(_emergencyAccessRepository.UpdateForKeyRotation(user.Id, model.EmergencyAccesses));
|
||||
}
|
||||
|
||||
if (model.OrganizationUsers.Any())
|
||||
{
|
||||
saveEncryptedDataActions.Add(
|
||||
_organizationUserRepository.UpdateForKeyRotation(user.Id, model.OrganizationUsers));
|
||||
saveEncryptedDataActions.Add(_organizationUserRepository.UpdateForKeyRotation(user.Id, model.OrganizationUsers));
|
||||
}
|
||||
|
||||
if (model.WebAuthnKeys.Any())
|
||||
@@ -135,9 +203,80 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand
|
||||
{
|
||||
saveEncryptedDataActions.Add(_deviceRepository.UpdateKeysForRotationAsync(user.Id, model.DeviceKeys));
|
||||
}
|
||||
}
|
||||
|
||||
await _userRepository.UpdateUserKeyAndEncryptedDataV2Async(user, saveEncryptedDataActions);
|
||||
await _pushService.PushLogOutAsync(user.Id);
|
||||
return IdentityResult.Success;
|
||||
private bool IsV2EncryptionUserAsync(User user)
|
||||
{
|
||||
// Returns whether the user is a V2 user based on the private key's encryption type.
|
||||
ArgumentNullException.ThrowIfNull(user);
|
||||
var isPrivateKeyEncryptionV2 = GetEncryptionType(user.PrivateKey) == EncryptionType.XChaCha20Poly1305_B64;
|
||||
return isPrivateKeyEncryptionV2;
|
||||
}
|
||||
|
||||
private async Task ValidateVerifyingKeyUnchangedAsync(RotateUserAccountKeysData model, User user)
|
||||
{
|
||||
var currentSignatureKeyPair = await _userSignatureKeyPairRepository.GetByUserIdAsync(user.Id) ?? throw new InvalidOperationException("User does not have a signature key pair.");
|
||||
if (model.AccountKeys.SignatureKeyPairData.VerifyingKey != currentSignatureKeyPair!.VerifyingKey)
|
||||
{
|
||||
throw new InvalidOperationException("The provided verifying key does not match the user's current verifying key.");
|
||||
}
|
||||
}
|
||||
|
||||
private static void ValidatePublicKeyEncryptionKeyPairUnchanged(RotateUserAccountKeysData model, User user)
|
||||
{
|
||||
var publicKey = model.AccountKeys.PublicKeyEncryptionKeyPairData.PublicKey;
|
||||
if (publicKey != user.PublicKey)
|
||||
{
|
||||
throw new InvalidOperationException("The provided account public key does not match the user's current public key, and changing the account asymmetric key pair is currently not supported during key rotation.");
|
||||
}
|
||||
}
|
||||
|
||||
private static void ValidateV2Encryption(RotateUserAccountKeysData model)
|
||||
{
|
||||
if (model.AccountKeys.SignatureKeyPairData == null)
|
||||
{
|
||||
throw new InvalidOperationException("Signature key pair data is required for V2 encryption.");
|
||||
}
|
||||
if (GetEncryptionType(model.AccountKeys.SignatureKeyPairData.WrappedSigningKey) != EncryptionType.XChaCha20Poly1305_B64)
|
||||
{
|
||||
throw new InvalidOperationException("The provided signing key data is not wrapped with XChaCha20-Poly1305.");
|
||||
}
|
||||
if (string.IsNullOrEmpty(model.AccountKeys.SignatureKeyPairData.VerifyingKey))
|
||||
{
|
||||
throw new InvalidOperationException("The provided signature key pair data does not contain a valid verifying key.");
|
||||
}
|
||||
|
||||
if (GetEncryptionType(model.AccountKeys.PublicKeyEncryptionKeyPairData.WrappedPrivateKey) != EncryptionType.XChaCha20Poly1305_B64)
|
||||
{
|
||||
throw new InvalidOperationException("The provided private key encryption key is not wrapped with XChaCha20-Poly1305.");
|
||||
}
|
||||
if (string.IsNullOrEmpty(model.AccountKeys.PublicKeyEncryptionKeyPairData.SignedPublicKey))
|
||||
{
|
||||
throw new InvalidOperationException("No signed public key provided, but the user already has a signature key pair.");
|
||||
}
|
||||
if (model.AccountKeys.SecurityStateData == null || string.IsNullOrEmpty(model.AccountKeys.SecurityStateData.SecurityState))
|
||||
{
|
||||
throw new InvalidOperationException("No signed security state provider for V2 user");
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Helper method to convert an encryption type string to an enum value.
|
||||
/// </summary>
|
||||
private static EncryptionType GetEncryptionType(string encString)
|
||||
{
|
||||
var parts = encString.Split('.');
|
||||
if (parts.Length == 1)
|
||||
{
|
||||
throw new ArgumentException("Invalid encryption type string.");
|
||||
}
|
||||
if (byte.TryParse(parts[0], out var encryptionTypeNumber))
|
||||
{
|
||||
if (Enum.IsDefined(typeof(EncryptionType), encryptionTypeNumber))
|
||||
{
|
||||
return (EncryptionType)encryptionTypeNumber;
|
||||
}
|
||||
}
|
||||
throw new ArgumentException("Invalid encryption type string.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// FIXME: Update this file to be null safe and then delete the line below
|
||||
#nullable disable
|
||||
|
||||
using Bit.Core.Billing.Extensions;
|
||||
using Stripe;
|
||||
|
||||
namespace Bit.Core.Models.Business;
|
||||
@@ -36,8 +37,13 @@ public class SubscriptionInfo
|
||||
Status = sub.Status;
|
||||
TrialStartDate = sub.TrialStart;
|
||||
TrialEndDate = sub.TrialEnd;
|
||||
PeriodStartDate = sub.CurrentPeriodStart;
|
||||
PeriodEndDate = sub.CurrentPeriodEnd;
|
||||
var currentPeriod = sub.GetCurrentPeriod();
|
||||
if (currentPeriod != null)
|
||||
{
|
||||
var (start, end) = currentPeriod.Value;
|
||||
PeriodStartDate = start;
|
||||
PeriodEndDate = end;
|
||||
}
|
||||
CancelledDate = sub.CanceledAt;
|
||||
CancelAtEndDate = sub.CancelAtPeriodEnd;
|
||||
Cancelled = sub.Status == "canceled" || sub.Status == "unpaid" || sub.Status == "incomplete_expired";
|
||||
|
||||
@@ -97,3 +97,9 @@ public class ProviderBankAccountVerifiedPushNotification
|
||||
public Guid ProviderId { get; set; }
|
||||
public Guid AdminId { get; set; }
|
||||
}
|
||||
|
||||
public class LogOutPushNotification
|
||||
{
|
||||
public Guid UserId { get; set; }
|
||||
public PushNotificationLogOutReason? Reason { get; set; }
|
||||
}
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
// FIXME: Update this file to be null safe and then delete the line below
|
||||
#nullable disable
|
||||
|
||||
namespace Bit.Core.Models.BitStripe;
|
||||
|
||||
// Stripe's SubscriptionListOptions model has a complex input for date filters.
|
||||
// It expects a dictionary, and has lots of validation rules around what can have a value and what can't.
|
||||
// To simplify this a bit we are extending Stripe's model and using our own date inputs, and building the dictionary they expect JiT.
|
||||
// ___
|
||||
// Our model also facilitates selecting all elements in a list, which is unsupported by Stripe's model.
|
||||
public class StripeSubscriptionListOptions : Stripe.SubscriptionListOptions
|
||||
{
|
||||
public DateTime? CurrentPeriodEndDate { get; set; }
|
||||
public string CurrentPeriodEndRange { get; set; } = "lt";
|
||||
public bool SelectAll { get; set; }
|
||||
public new Stripe.DateRangeOptions CurrentPeriodEnd
|
||||
{
|
||||
get
|
||||
{
|
||||
return CurrentPeriodEndDate.HasValue ?
|
||||
new Stripe.DateRangeOptions()
|
||||
{
|
||||
LessThan = CurrentPeriodEndRange == "lt" ? CurrentPeriodEndDate : null,
|
||||
GreaterThan = CurrentPeriodEndRange == "gt" ? CurrentPeriodEndDate : null
|
||||
} :
|
||||
null;
|
||||
}
|
||||
}
|
||||
|
||||
public Stripe.SubscriptionListOptions ToStripeApiOptions()
|
||||
{
|
||||
var stripeApiOptions = (Stripe.SubscriptionListOptions)this;
|
||||
|
||||
if (SelectAll)
|
||||
{
|
||||
stripeApiOptions.EndingBefore = null;
|
||||
stripeApiOptions.StartingAfter = null;
|
||||
}
|
||||
|
||||
if (CurrentPeriodEndDate.HasValue)
|
||||
{
|
||||
stripeApiOptions.CurrentPeriodEnd = new Stripe.DateRangeOptions()
|
||||
{
|
||||
LessThan = CurrentPeriodEndRange == "lt" ? CurrentPeriodEndDate : null,
|
||||
GreaterThan = CurrentPeriodEndRange == "gt" ? CurrentPeriodEndDate : null
|
||||
};
|
||||
}
|
||||
|
||||
return stripeApiOptions;
|
||||
}
|
||||
}
|
||||
@@ -62,7 +62,7 @@ public class SelfHostedSyncSponsorshipsCommand : BaseIdentityClientService, ISel
|
||||
.ToDictionary(i => i.SponsoringOrganizationUserId);
|
||||
if (!organizationSponsorshipsDict.Any())
|
||||
{
|
||||
_logger.LogInformation($"No existing sponsorships to sync for organization {organizationId}");
|
||||
_logger.LogInformation("No existing sponsorships to sync for organization {organizationId}", organizationId);
|
||||
return;
|
||||
}
|
||||
var syncedSponsorships = new List<OrganizationSponsorshipData>();
|
||||
|
||||
@@ -167,18 +167,17 @@ public interface IPushNotificationService
|
||||
ExcludeCurrentContext = false,
|
||||
});
|
||||
|
||||
Task PushLogOutAsync(Guid userId, bool excludeCurrentContextFromPush = false)
|
||||
=> PushAsync(new PushNotification<UserPushNotification>
|
||||
Task PushLogOutAsync(Guid userId, bool excludeCurrentContextFromPush = false,
|
||||
PushNotificationLogOutReason? reason = null)
|
||||
=> PushAsync(new PushNotification<LogOutPushNotification>
|
||||
{
|
||||
Type = PushType.LogOut,
|
||||
Target = NotificationTarget.User,
|
||||
TargetId = userId,
|
||||
Payload = new UserPushNotification
|
||||
Payload = new LogOutPushNotification
|
||||
{
|
||||
UserId = userId,
|
||||
#pragma warning disable BWP0001 // Type or member is obsolete
|
||||
Date = TimeProvider.GetUtcNow().UtcDateTime,
|
||||
#pragma warning restore BWP0001 // Type or member is obsolete
|
||||
Reason = reason
|
||||
},
|
||||
ExcludeCurrentContext = excludeCurrentContextFromPush,
|
||||
});
|
||||
|
||||
@@ -55,7 +55,7 @@ public enum PushType : byte
|
||||
[NotificationInfo("not-specified", typeof(Models.UserPushNotification))]
|
||||
SyncSettings = 10,
|
||||
|
||||
[NotificationInfo("not-specified", typeof(Models.UserPushNotification))]
|
||||
[NotificationInfo("not-specified", typeof(Models.LogOutPushNotification))]
|
||||
LogOut = 11,
|
||||
|
||||
[NotificationInfo("@bitwarden/team-tools-dev", typeof(Models.SyncSendPushNotification))]
|
||||
|
||||
28
src/Core/SecretsManager/Entities/SecretVersion.cs
Normal file
28
src/Core/SecretsManager/Entities/SecretVersion.cs
Normal file
@@ -0,0 +1,28 @@
|
||||
#nullable enable
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Utilities;
|
||||
|
||||
namespace Bit.Core.SecretsManager.Entities;
|
||||
|
||||
public class SecretVersion : ITableObject<Guid>
|
||||
{
|
||||
public Guid Id { get; set; }
|
||||
|
||||
public Guid SecretId { get; set; }
|
||||
|
||||
public string Value { get; set; } = string.Empty;
|
||||
|
||||
public DateTime VersionDate { get; set; }
|
||||
|
||||
public Guid? EditorServiceAccountId { get; set; }
|
||||
|
||||
public Guid? EditorOrganizationUserId { get; set; }
|
||||
|
||||
public void SetNewId()
|
||||
{
|
||||
if (Id == default(Guid))
|
||||
{
|
||||
Id = CoreHelpers.GenerateComb();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,58 +3,47 @@
|
||||
|
||||
using Bit.Core.Models.BitStripe;
|
||||
using Stripe;
|
||||
using Stripe.Tax;
|
||||
|
||||
namespace Bit.Core.Services;
|
||||
|
||||
public interface IStripeAdapter
|
||||
{
|
||||
Task<Stripe.Customer> CustomerCreateAsync(Stripe.CustomerCreateOptions customerCreateOptions);
|
||||
Task<Stripe.Customer> CustomerGetAsync(string id, Stripe.CustomerGetOptions options = null);
|
||||
Task<Stripe.Customer> CustomerUpdateAsync(string id, Stripe.CustomerUpdateOptions options = null);
|
||||
Task<Stripe.Customer> CustomerDeleteAsync(string id);
|
||||
Task<List<PaymentMethod>> CustomerListPaymentMethods(string id, CustomerListPaymentMethodsOptions options = null);
|
||||
Task<Customer> CustomerCreateAsync(CustomerCreateOptions customerCreateOptions);
|
||||
Task CustomerDeleteDiscountAsync(string customerId, CustomerDeleteDiscountOptions options = null);
|
||||
Task<Customer> CustomerGetAsync(string id, CustomerGetOptions options = null);
|
||||
Task<Customer> CustomerUpdateAsync(string id, CustomerUpdateOptions options = null);
|
||||
Task<Customer> CustomerDeleteAsync(string id);
|
||||
Task<List<PaymentMethod>> CustomerListPaymentMethods(string id, CustomerPaymentMethodListOptions options = null);
|
||||
Task<CustomerBalanceTransaction> CustomerBalanceTransactionCreate(string customerId,
|
||||
CustomerBalanceTransactionCreateOptions options);
|
||||
Task<Stripe.Subscription> SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions subscriptionCreateOptions);
|
||||
Task<Stripe.Subscription> SubscriptionGetAsync(string id, Stripe.SubscriptionGetOptions options = null);
|
||||
|
||||
/// <summary>
|
||||
/// Retrieves a subscription object for a provider.
|
||||
/// </summary>
|
||||
/// <param name="id">The subscription ID.</param>
|
||||
/// <param name="providerId">The provider ID.</param>
|
||||
/// <param name="options">Additional options.</param>
|
||||
/// <returns>The subscription object.</returns>
|
||||
/// <exception cref="InvalidOperationException">Thrown when the subscription doesn't belong to the provider.</exception>
|
||||
Task<Stripe.Subscription> ProviderSubscriptionGetAsync(string id, Guid providerId, Stripe.SubscriptionGetOptions options = null);
|
||||
|
||||
Task<List<Stripe.Subscription>> SubscriptionListAsync(StripeSubscriptionListOptions subscriptionSearchOptions);
|
||||
Task<Stripe.Subscription> SubscriptionUpdateAsync(string id, Stripe.SubscriptionUpdateOptions options = null);
|
||||
Task<Stripe.Subscription> SubscriptionCancelAsync(string Id, Stripe.SubscriptionCancelOptions options = null);
|
||||
Task<Stripe.Invoice> InvoiceUpcomingAsync(Stripe.UpcomingInvoiceOptions options);
|
||||
Task<Stripe.Invoice> InvoiceGetAsync(string id, Stripe.InvoiceGetOptions options);
|
||||
Task<List<Stripe.Invoice>> InvoiceListAsync(StripeInvoiceListOptions options);
|
||||
Task<Stripe.Invoice> InvoiceCreatePreviewAsync(InvoiceCreatePreviewOptions options);
|
||||
Task<List<Stripe.Invoice>> InvoiceSearchAsync(InvoiceSearchOptions options);
|
||||
Task<Stripe.Invoice> InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options);
|
||||
Task<Stripe.Invoice> InvoiceFinalizeInvoiceAsync(string id, Stripe.InvoiceFinalizeOptions options);
|
||||
Task<Stripe.Invoice> InvoiceSendInvoiceAsync(string id, Stripe.InvoiceSendOptions options);
|
||||
Task<Stripe.Invoice> InvoicePayAsync(string id, Stripe.InvoicePayOptions options = null);
|
||||
Task<Stripe.Invoice> InvoiceDeleteAsync(string id, Stripe.InvoiceDeleteOptions options = null);
|
||||
Task<Stripe.Invoice> InvoiceVoidInvoiceAsync(string id, Stripe.InvoiceVoidOptions options = null);
|
||||
IEnumerable<Stripe.PaymentMethod> PaymentMethodListAutoPaging(Stripe.PaymentMethodListOptions options);
|
||||
IAsyncEnumerable<Stripe.PaymentMethod> PaymentMethodListAutoPagingAsync(Stripe.PaymentMethodListOptions options);
|
||||
Task<Stripe.PaymentMethod> PaymentMethodAttachAsync(string id, Stripe.PaymentMethodAttachOptions options = null);
|
||||
Task<Stripe.PaymentMethod> PaymentMethodDetachAsync(string id, Stripe.PaymentMethodDetachOptions options = null);
|
||||
Task<Stripe.TaxId> TaxIdCreateAsync(string id, Stripe.TaxIdCreateOptions options);
|
||||
Task<Stripe.TaxId> TaxIdDeleteAsync(string customerId, string taxIdId, Stripe.TaxIdDeleteOptions options = null);
|
||||
Task<Stripe.StripeList<Stripe.Tax.Registration>> TaxRegistrationsListAsync(Stripe.Tax.RegistrationListOptions options = null);
|
||||
Task<Stripe.StripeList<Stripe.Charge>> ChargeListAsync(Stripe.ChargeListOptions options);
|
||||
Task<Stripe.Refund> RefundCreateAsync(Stripe.RefundCreateOptions options);
|
||||
Task<Stripe.Card> CardDeleteAsync(string customerId, string cardId, Stripe.CardDeleteOptions options = null);
|
||||
Task<Stripe.BankAccount> BankAccountCreateAsync(string customerId, Stripe.BankAccountCreateOptions options = null);
|
||||
Task<Stripe.BankAccount> BankAccountDeleteAsync(string customerId, string bankAccount, Stripe.BankAccountDeleteOptions options = null);
|
||||
Task<Stripe.StripeList<Stripe.Price>> PriceListAsync(Stripe.PriceListOptions options = null);
|
||||
Task<Subscription> SubscriptionCreateAsync(SubscriptionCreateOptions subscriptionCreateOptions);
|
||||
Task<Subscription> SubscriptionGetAsync(string id, SubscriptionGetOptions options = null);
|
||||
Task<Subscription> SubscriptionUpdateAsync(string id, SubscriptionUpdateOptions options = null);
|
||||
Task<Subscription> SubscriptionCancelAsync(string Id, SubscriptionCancelOptions options = null);
|
||||
Task<Invoice> InvoiceGetAsync(string id, InvoiceGetOptions options);
|
||||
Task<List<Invoice>> InvoiceListAsync(StripeInvoiceListOptions options);
|
||||
Task<Invoice> InvoiceCreatePreviewAsync(InvoiceCreatePreviewOptions options);
|
||||
Task<List<Invoice>> InvoiceSearchAsync(InvoiceSearchOptions options);
|
||||
Task<Invoice> InvoiceUpdateAsync(string id, InvoiceUpdateOptions options);
|
||||
Task<Invoice> InvoiceFinalizeInvoiceAsync(string id, InvoiceFinalizeOptions options);
|
||||
Task<Invoice> InvoiceSendInvoiceAsync(string id, InvoiceSendOptions options);
|
||||
Task<Invoice> InvoicePayAsync(string id, InvoicePayOptions options = null);
|
||||
Task<Invoice> InvoiceDeleteAsync(string id, InvoiceDeleteOptions options = null);
|
||||
Task<Invoice> InvoiceVoidInvoiceAsync(string id, InvoiceVoidOptions options = null);
|
||||
IEnumerable<PaymentMethod> PaymentMethodListAutoPaging(PaymentMethodListOptions options);
|
||||
IAsyncEnumerable<PaymentMethod> PaymentMethodListAutoPagingAsync(PaymentMethodListOptions options);
|
||||
Task<PaymentMethod> PaymentMethodAttachAsync(string id, PaymentMethodAttachOptions options = null);
|
||||
Task<PaymentMethod> PaymentMethodDetachAsync(string id, PaymentMethodDetachOptions options = null);
|
||||
Task<TaxId> TaxIdCreateAsync(string id, TaxIdCreateOptions options);
|
||||
Task<TaxId> TaxIdDeleteAsync(string customerId, string taxIdId, TaxIdDeleteOptions options = null);
|
||||
Task<StripeList<Registration>> TaxRegistrationsListAsync(RegistrationListOptions options = null);
|
||||
Task<StripeList<Charge>> ChargeListAsync(ChargeListOptions options);
|
||||
Task<Refund> RefundCreateAsync(RefundCreateOptions options);
|
||||
Task<Card> CardDeleteAsync(string customerId, string cardId, CardDeleteOptions options = null);
|
||||
Task<BankAccount> BankAccountCreateAsync(string customerId, BankAccountCreateOptions options = null);
|
||||
Task<BankAccount> BankAccountDeleteAsync(string customerId, string bankAccount, BankAccountDeleteOptions options = null);
|
||||
Task<StripeList<Price>> PriceListAsync(PriceListOptions options = null);
|
||||
Task<SetupIntent> SetupIntentCreate(SetupIntentCreateOptions options);
|
||||
Task<List<SetupIntent>> SetupIntentList(SetupIntentListOptions options);
|
||||
Task SetupIntentCancel(string id, SetupIntentCancelOptions options = null);
|
||||
|
||||
@@ -9,18 +9,18 @@ namespace Bit.Core.Services;
|
||||
|
||||
public class StripeAdapter : IStripeAdapter
|
||||
{
|
||||
private readonly Stripe.CustomerService _customerService;
|
||||
private readonly Stripe.SubscriptionService _subscriptionService;
|
||||
private readonly Stripe.InvoiceService _invoiceService;
|
||||
private readonly Stripe.PaymentMethodService _paymentMethodService;
|
||||
private readonly Stripe.TaxIdService _taxIdService;
|
||||
private readonly Stripe.ChargeService _chargeService;
|
||||
private readonly Stripe.RefundService _refundService;
|
||||
private readonly Stripe.CardService _cardService;
|
||||
private readonly Stripe.BankAccountService _bankAccountService;
|
||||
private readonly Stripe.PlanService _planService;
|
||||
private readonly Stripe.PriceService _priceService;
|
||||
private readonly Stripe.SetupIntentService _setupIntentService;
|
||||
private readonly CustomerService _customerService;
|
||||
private readonly SubscriptionService _subscriptionService;
|
||||
private readonly InvoiceService _invoiceService;
|
||||
private readonly PaymentMethodService _paymentMethodService;
|
||||
private readonly TaxIdService _taxIdService;
|
||||
private readonly ChargeService _chargeService;
|
||||
private readonly RefundService _refundService;
|
||||
private readonly CardService _cardService;
|
||||
private readonly BankAccountService _bankAccountService;
|
||||
private readonly PlanService _planService;
|
||||
private readonly PriceService _priceService;
|
||||
private readonly SetupIntentService _setupIntentService;
|
||||
private readonly Stripe.TestHelpers.TestClockService _testClockService;
|
||||
private readonly CustomerBalanceTransactionService _customerBalanceTransactionService;
|
||||
private readonly Stripe.Tax.RegistrationService _taxRegistrationService;
|
||||
@@ -28,17 +28,17 @@ public class StripeAdapter : IStripeAdapter
|
||||
|
||||
public StripeAdapter()
|
||||
{
|
||||
_customerService = new Stripe.CustomerService();
|
||||
_subscriptionService = new Stripe.SubscriptionService();
|
||||
_invoiceService = new Stripe.InvoiceService();
|
||||
_paymentMethodService = new Stripe.PaymentMethodService();
|
||||
_taxIdService = new Stripe.TaxIdService();
|
||||
_chargeService = new Stripe.ChargeService();
|
||||
_refundService = new Stripe.RefundService();
|
||||
_cardService = new Stripe.CardService();
|
||||
_bankAccountService = new Stripe.BankAccountService();
|
||||
_priceService = new Stripe.PriceService();
|
||||
_planService = new Stripe.PlanService();
|
||||
_customerService = new CustomerService();
|
||||
_subscriptionService = new SubscriptionService();
|
||||
_invoiceService = new InvoiceService();
|
||||
_paymentMethodService = new PaymentMethodService();
|
||||
_taxIdService = new TaxIdService();
|
||||
_chargeService = new ChargeService();
|
||||
_refundService = new RefundService();
|
||||
_cardService = new CardService();
|
||||
_bankAccountService = new BankAccountService();
|
||||
_priceService = new PriceService();
|
||||
_planService = new PlanService();
|
||||
_setupIntentService = new SetupIntentService();
|
||||
_testClockService = new Stripe.TestHelpers.TestClockService();
|
||||
_customerBalanceTransactionService = new CustomerBalanceTransactionService();
|
||||
@@ -46,28 +46,31 @@ public class StripeAdapter : IStripeAdapter
|
||||
_calculationService = new CalculationService();
|
||||
}
|
||||
|
||||
public Task<Stripe.Customer> CustomerCreateAsync(Stripe.CustomerCreateOptions options)
|
||||
public Task<Customer> CustomerCreateAsync(CustomerCreateOptions options)
|
||||
{
|
||||
return _customerService.CreateAsync(options);
|
||||
}
|
||||
|
||||
public Task<Stripe.Customer> CustomerGetAsync(string id, Stripe.CustomerGetOptions options = null)
|
||||
public Task CustomerDeleteDiscountAsync(string customerId, CustomerDeleteDiscountOptions options = null) =>
|
||||
_customerService.DeleteDiscountAsync(customerId, options);
|
||||
|
||||
public Task<Customer> CustomerGetAsync(string id, CustomerGetOptions options = null)
|
||||
{
|
||||
return _customerService.GetAsync(id, options);
|
||||
}
|
||||
|
||||
public Task<Stripe.Customer> CustomerUpdateAsync(string id, Stripe.CustomerUpdateOptions options = null)
|
||||
public Task<Customer> CustomerUpdateAsync(string id, CustomerUpdateOptions options = null)
|
||||
{
|
||||
return _customerService.UpdateAsync(id, options);
|
||||
}
|
||||
|
||||
public Task<Stripe.Customer> CustomerDeleteAsync(string id)
|
||||
public Task<Customer> CustomerDeleteAsync(string id)
|
||||
{
|
||||
return _customerService.DeleteAsync(id);
|
||||
}
|
||||
|
||||
public async Task<List<PaymentMethod>> CustomerListPaymentMethods(string id,
|
||||
CustomerListPaymentMethodsOptions options = null)
|
||||
CustomerPaymentMethodListOptions options = null)
|
||||
{
|
||||
var paymentMethods = await _customerService.ListPaymentMethodsAsync(id, options);
|
||||
return paymentMethods.Data;
|
||||
@@ -77,12 +80,12 @@ public class StripeAdapter : IStripeAdapter
|
||||
CustomerBalanceTransactionCreateOptions options)
|
||||
=> await _customerBalanceTransactionService.CreateAsync(customerId, options);
|
||||
|
||||
public Task<Stripe.Subscription> SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions options)
|
||||
public Task<Subscription> SubscriptionCreateAsync(SubscriptionCreateOptions options)
|
||||
{
|
||||
return _subscriptionService.CreateAsync(options);
|
||||
}
|
||||
|
||||
public Task<Stripe.Subscription> SubscriptionGetAsync(string id, Stripe.SubscriptionGetOptions options = null)
|
||||
public Task<Subscription> SubscriptionGetAsync(string id, SubscriptionGetOptions options = null)
|
||||
{
|
||||
return _subscriptionService.GetAsync(id, options);
|
||||
}
|
||||
@@ -101,28 +104,23 @@ public class StripeAdapter : IStripeAdapter
|
||||
throw new InvalidOperationException("Subscription does not belong to the provider.");
|
||||
}
|
||||
|
||||
public Task<Stripe.Subscription> SubscriptionUpdateAsync(string id,
|
||||
Stripe.SubscriptionUpdateOptions options = null)
|
||||
public Task<Subscription> SubscriptionUpdateAsync(string id,
|
||||
SubscriptionUpdateOptions options = null)
|
||||
{
|
||||
return _subscriptionService.UpdateAsync(id, options);
|
||||
}
|
||||
|
||||
public Task<Stripe.Subscription> SubscriptionCancelAsync(string Id, Stripe.SubscriptionCancelOptions options = null)
|
||||
public Task<Subscription> SubscriptionCancelAsync(string Id, SubscriptionCancelOptions options = null)
|
||||
{
|
||||
return _subscriptionService.CancelAsync(Id, options);
|
||||
}
|
||||
|
||||
public Task<Stripe.Invoice> InvoiceUpcomingAsync(Stripe.UpcomingInvoiceOptions options)
|
||||
{
|
||||
return _invoiceService.UpcomingAsync(options);
|
||||
}
|
||||
|
||||
public Task<Stripe.Invoice> InvoiceGetAsync(string id, Stripe.InvoiceGetOptions options)
|
||||
public Task<Invoice> InvoiceGetAsync(string id, InvoiceGetOptions options)
|
||||
{
|
||||
return _invoiceService.GetAsync(id, options);
|
||||
}
|
||||
|
||||
public async Task<List<Stripe.Invoice>> InvoiceListAsync(StripeInvoiceListOptions options)
|
||||
public async Task<List<Invoice>> InvoiceListAsync(StripeInvoiceListOptions options)
|
||||
{
|
||||
if (!options.SelectAll)
|
||||
{
|
||||
@@ -131,7 +129,7 @@ public class StripeAdapter : IStripeAdapter
|
||||
|
||||
options.Limit = 100;
|
||||
|
||||
var invoices = new List<Stripe.Invoice>();
|
||||
var invoices = new List<Invoice>();
|
||||
|
||||
await foreach (var invoice in _invoiceService.ListAutoPagingAsync(options.ToInvoiceListOptions()))
|
||||
{
|
||||
@@ -146,120 +144,104 @@ public class StripeAdapter : IStripeAdapter
|
||||
return _invoiceService.CreatePreviewAsync(options);
|
||||
}
|
||||
|
||||
public async Task<List<Stripe.Invoice>> InvoiceSearchAsync(InvoiceSearchOptions options)
|
||||
public async Task<List<Invoice>> InvoiceSearchAsync(InvoiceSearchOptions options)
|
||||
=> (await _invoiceService.SearchAsync(options)).Data;
|
||||
|
||||
public Task<Stripe.Invoice> InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options)
|
||||
public Task<Invoice> InvoiceUpdateAsync(string id, InvoiceUpdateOptions options)
|
||||
{
|
||||
return _invoiceService.UpdateAsync(id, options);
|
||||
}
|
||||
|
||||
public Task<Stripe.Invoice> InvoiceFinalizeInvoiceAsync(string id, Stripe.InvoiceFinalizeOptions options)
|
||||
public Task<Invoice> InvoiceFinalizeInvoiceAsync(string id, InvoiceFinalizeOptions options)
|
||||
{
|
||||
return _invoiceService.FinalizeInvoiceAsync(id, options);
|
||||
}
|
||||
|
||||
public Task<Stripe.Invoice> InvoiceSendInvoiceAsync(string id, Stripe.InvoiceSendOptions options)
|
||||
public Task<Invoice> InvoiceSendInvoiceAsync(string id, InvoiceSendOptions options)
|
||||
{
|
||||
return _invoiceService.SendInvoiceAsync(id, options);
|
||||
}
|
||||
|
||||
public Task<Stripe.Invoice> InvoicePayAsync(string id, Stripe.InvoicePayOptions options = null)
|
||||
public Task<Invoice> InvoicePayAsync(string id, InvoicePayOptions options = null)
|
||||
{
|
||||
return _invoiceService.PayAsync(id, options);
|
||||
}
|
||||
|
||||
public Task<Stripe.Invoice> InvoiceDeleteAsync(string id, Stripe.InvoiceDeleteOptions options = null)
|
||||
public Task<Invoice> InvoiceDeleteAsync(string id, InvoiceDeleteOptions options = null)
|
||||
{
|
||||
return _invoiceService.DeleteAsync(id, options);
|
||||
}
|
||||
|
||||
public Task<Stripe.Invoice> InvoiceVoidInvoiceAsync(string id, Stripe.InvoiceVoidOptions options = null)
|
||||
public Task<Invoice> InvoiceVoidInvoiceAsync(string id, InvoiceVoidOptions options = null)
|
||||
{
|
||||
return _invoiceService.VoidInvoiceAsync(id, options);
|
||||
}
|
||||
|
||||
public IEnumerable<Stripe.PaymentMethod> PaymentMethodListAutoPaging(Stripe.PaymentMethodListOptions options)
|
||||
public IEnumerable<PaymentMethod> PaymentMethodListAutoPaging(PaymentMethodListOptions options)
|
||||
{
|
||||
return _paymentMethodService.ListAutoPaging(options);
|
||||
}
|
||||
|
||||
public IAsyncEnumerable<Stripe.PaymentMethod> PaymentMethodListAutoPagingAsync(Stripe.PaymentMethodListOptions options)
|
||||
public IAsyncEnumerable<PaymentMethod> PaymentMethodListAutoPagingAsync(PaymentMethodListOptions options)
|
||||
=> _paymentMethodService.ListAutoPagingAsync(options);
|
||||
|
||||
public Task<Stripe.PaymentMethod> PaymentMethodAttachAsync(string id, Stripe.PaymentMethodAttachOptions options = null)
|
||||
public Task<PaymentMethod> PaymentMethodAttachAsync(string id, PaymentMethodAttachOptions options = null)
|
||||
{
|
||||
return _paymentMethodService.AttachAsync(id, options);
|
||||
}
|
||||
|
||||
public Task<Stripe.PaymentMethod> PaymentMethodDetachAsync(string id, Stripe.PaymentMethodDetachOptions options = null)
|
||||
public Task<PaymentMethod> PaymentMethodDetachAsync(string id, PaymentMethodDetachOptions options = null)
|
||||
{
|
||||
return _paymentMethodService.DetachAsync(id, options);
|
||||
}
|
||||
|
||||
public Task<Stripe.Plan> PlanGetAsync(string id, Stripe.PlanGetOptions options = null)
|
||||
public Task<Plan> PlanGetAsync(string id, PlanGetOptions options = null)
|
||||
{
|
||||
return _planService.GetAsync(id, options);
|
||||
}
|
||||
|
||||
public Task<Stripe.TaxId> TaxIdCreateAsync(string id, Stripe.TaxIdCreateOptions options)
|
||||
public Task<TaxId> TaxIdCreateAsync(string id, TaxIdCreateOptions options)
|
||||
{
|
||||
return _taxIdService.CreateAsync(id, options);
|
||||
}
|
||||
|
||||
public Task<Stripe.TaxId> TaxIdDeleteAsync(string customerId, string taxIdId,
|
||||
Stripe.TaxIdDeleteOptions options = null)
|
||||
public Task<TaxId> TaxIdDeleteAsync(string customerId, string taxIdId,
|
||||
TaxIdDeleteOptions options = null)
|
||||
{
|
||||
return _taxIdService.DeleteAsync(customerId, taxIdId);
|
||||
}
|
||||
|
||||
public Task<Stripe.StripeList<Stripe.Tax.Registration>> TaxRegistrationsListAsync(Stripe.Tax.RegistrationListOptions options = null)
|
||||
public Task<StripeList<Registration>> TaxRegistrationsListAsync(RegistrationListOptions options = null)
|
||||
{
|
||||
return _taxRegistrationService.ListAsync(options);
|
||||
}
|
||||
|
||||
public Task<Stripe.StripeList<Stripe.Charge>> ChargeListAsync(Stripe.ChargeListOptions options)
|
||||
public Task<StripeList<Charge>> ChargeListAsync(ChargeListOptions options)
|
||||
{
|
||||
return _chargeService.ListAsync(options);
|
||||
}
|
||||
|
||||
public Task<Stripe.Refund> RefundCreateAsync(Stripe.RefundCreateOptions options)
|
||||
public Task<Refund> RefundCreateAsync(RefundCreateOptions options)
|
||||
{
|
||||
return _refundService.CreateAsync(options);
|
||||
}
|
||||
|
||||
public Task<Stripe.Card> CardDeleteAsync(string customerId, string cardId, Stripe.CardDeleteOptions options = null)
|
||||
public Task<Card> CardDeleteAsync(string customerId, string cardId, CardDeleteOptions options = null)
|
||||
{
|
||||
return _cardService.DeleteAsync(customerId, cardId, options);
|
||||
}
|
||||
|
||||
public Task<Stripe.BankAccount> BankAccountCreateAsync(string customerId, Stripe.BankAccountCreateOptions options = null)
|
||||
public Task<BankAccount> BankAccountCreateAsync(string customerId, BankAccountCreateOptions options = null)
|
||||
{
|
||||
return _bankAccountService.CreateAsync(customerId, options);
|
||||
}
|
||||
|
||||
public Task<Stripe.BankAccount> BankAccountDeleteAsync(string customerId, string bankAccount, Stripe.BankAccountDeleteOptions options = null)
|
||||
public Task<BankAccount> BankAccountDeleteAsync(string customerId, string bankAccount, BankAccountDeleteOptions options = null)
|
||||
{
|
||||
return _bankAccountService.DeleteAsync(customerId, bankAccount, options);
|
||||
}
|
||||
|
||||
public async Task<List<Stripe.Subscription>> SubscriptionListAsync(StripeSubscriptionListOptions options)
|
||||
{
|
||||
if (!options.SelectAll)
|
||||
{
|
||||
return (await _subscriptionService.ListAsync(options.ToStripeApiOptions())).Data;
|
||||
}
|
||||
|
||||
options.Limit = 100;
|
||||
var items = new List<Stripe.Subscription>();
|
||||
await foreach (var i in _subscriptionService.ListAutoPagingAsync(options.ToStripeApiOptions()))
|
||||
{
|
||||
items.Add(i);
|
||||
}
|
||||
return items;
|
||||
}
|
||||
|
||||
public async Task<Stripe.StripeList<Stripe.Price>> PriceListAsync(Stripe.PriceListOptions options = null)
|
||||
public async Task<StripeList<Price>> PriceListAsync(PriceListOptions options = null)
|
||||
{
|
||||
return await _priceService.ListAsync(options);
|
||||
}
|
||||
|
||||
@@ -65,19 +65,20 @@ public class StripePaymentService : IPaymentService
|
||||
bool applySponsorship)
|
||||
{
|
||||
var existingPlan = await _pricingClient.GetPlanOrThrow(org.PlanType);
|
||||
var sponsoredPlan = sponsorship?.PlanSponsorshipType != null ?
|
||||
Utilities.StaticStore.GetSponsoredPlan(sponsorship.PlanSponsorshipType.Value) :
|
||||
null;
|
||||
var subscriptionUpdate = new SponsorOrganizationSubscriptionUpdate(existingPlan, sponsoredPlan, applySponsorship);
|
||||
var sponsoredPlan = sponsorship?.PlanSponsorshipType != null
|
||||
? Utilities.StaticStore.GetSponsoredPlan(sponsorship.PlanSponsorshipType.Value)
|
||||
: null;
|
||||
var subscriptionUpdate =
|
||||
new SponsorOrganizationSubscriptionUpdate(existingPlan, sponsoredPlan, applySponsorship);
|
||||
|
||||
await FinalizeSubscriptionChangeAsync(org, subscriptionUpdate, true);
|
||||
|
||||
var sub = await _stripeAdapter.SubscriptionGetAsync(org.GatewaySubscriptionId);
|
||||
org.ExpirationDate = sub.CurrentPeriodEnd;
|
||||
org.ExpirationDate = sub.GetCurrentPeriodEnd();
|
||||
|
||||
if (sponsorship is not null)
|
||||
{
|
||||
sponsorship.ValidUntil = sub.CurrentPeriodEnd;
|
||||
sponsorship.ValidUntil = sub.GetCurrentPeriodEnd();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,7 +101,8 @@ public class StripePaymentService : IPaymentService
|
||||
|
||||
if (sub.Status == SubscriptionStatuses.Canceled)
|
||||
{
|
||||
throw new BadRequestException("You do not have an active subscription. Reinstate your subscription to make changes.");
|
||||
throw new BadRequestException(
|
||||
"You do not have an active subscription. Reinstate your subscription to make changes.");
|
||||
}
|
||||
|
||||
var existingCoupon = sub.Customer.Discount?.Coupon?.Id;
|
||||
@@ -191,24 +193,24 @@ public class StripePaymentService : IPaymentService
|
||||
throw;
|
||||
}
|
||||
}
|
||||
else if (!invoice.Paid)
|
||||
else if (invoice.Status != StripeConstants.InvoiceStatus.Paid)
|
||||
{
|
||||
// Pay invoice with no charge to the customer this completes the invoice immediately without waiting the scheduled 1h
|
||||
invoice = await _stripeAdapter.InvoicePayAsync(subResponse.LatestInvoiceId);
|
||||
paymentIntentClientSecret = null;
|
||||
}
|
||||
|
||||
}
|
||||
finally
|
||||
{
|
||||
// Change back the subscription collection method and/or days until due
|
||||
if (collectionMethod != "send_invoice" || daysUntilDue == null)
|
||||
{
|
||||
await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, new SubscriptionUpdateOptions
|
||||
{
|
||||
CollectionMethod = collectionMethod,
|
||||
DaysUntilDue = daysUntilDue,
|
||||
});
|
||||
await _stripeAdapter.SubscriptionUpdateAsync(sub.Id,
|
||||
new SubscriptionUpdateOptions
|
||||
{
|
||||
CollectionMethod = collectionMethod,
|
||||
DaysUntilDue = daysUntilDue,
|
||||
});
|
||||
}
|
||||
|
||||
var customer = await _stripeAdapter.CustomerGetAsync(sub.CustomerId);
|
||||
@@ -218,9 +220,15 @@ public class StripePaymentService : IPaymentService
|
||||
if (!string.IsNullOrEmpty(existingCoupon) && string.IsNullOrEmpty(newCoupon))
|
||||
{
|
||||
// Re-add the lost coupon due to the update.
|
||||
await _stripeAdapter.CustomerUpdateAsync(sub.CustomerId, new CustomerUpdateOptions
|
||||
await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, new SubscriptionUpdateOptions
|
||||
{
|
||||
Coupon = existingCoupon
|
||||
Discounts =
|
||||
[
|
||||
new SubscriptionDiscountOptions
|
||||
{
|
||||
Coupon = existingCoupon
|
||||
}
|
||||
]
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -352,7 +360,7 @@ public class StripePaymentService : IPaymentService
|
||||
{
|
||||
var hasDefaultCardPaymentMethod = customer.InvoiceSettings?.DefaultPaymentMethod?.Type == "card";
|
||||
var hasDefaultValidSource = customer.DefaultSource != null &&
|
||||
(customer.DefaultSource is Card || customer.DefaultSource is BankAccount);
|
||||
(customer.DefaultSource is Card || customer.DefaultSource is BankAccount);
|
||||
if (!hasDefaultCardPaymentMethod && !hasDefaultValidSource)
|
||||
{
|
||||
cardPaymentMethodId = GetLatestCardPaymentMethod(customer.Id)?.Id;
|
||||
@@ -365,12 +373,11 @@ public class StripePaymentService : IPaymentService
|
||||
}
|
||||
catch
|
||||
{
|
||||
await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id, new InvoiceFinalizeOptions
|
||||
{
|
||||
AutoAdvance = false
|
||||
});
|
||||
await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id,
|
||||
new InvoiceFinalizeOptions { AutoAdvance = false });
|
||||
await _stripeAdapter.InvoiceVoidInvoiceAsync(invoice.Id);
|
||||
}
|
||||
|
||||
throw new BadRequestException("No payment method is available.");
|
||||
}
|
||||
}
|
||||
@@ -381,14 +388,9 @@ public class StripePaymentService : IPaymentService
|
||||
{
|
||||
// Finalize the invoice (from Draft) w/o auto-advance so we
|
||||
// can attempt payment manually.
|
||||
invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id, new InvoiceFinalizeOptions
|
||||
{
|
||||
AutoAdvance = false,
|
||||
});
|
||||
var invoicePayOptions = new InvoicePayOptions
|
||||
{
|
||||
PaymentMethod = cardPaymentMethodId,
|
||||
};
|
||||
invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id,
|
||||
new InvoiceFinalizeOptions { AutoAdvance = false, });
|
||||
var invoicePayOptions = new InvoicePayOptions { PaymentMethod = cardPaymentMethodId, };
|
||||
if (customer?.Metadata?.ContainsKey("btCustomerId") ?? false)
|
||||
{
|
||||
invoicePayOptions.PaidOutOfBand = true;
|
||||
@@ -403,13 +405,15 @@ public class StripePaymentService : IPaymentService
|
||||
SubmitForSettlement = true,
|
||||
PayPal = new Braintree.TransactionOptionsPayPalRequest
|
||||
{
|
||||
CustomField = $"{subscriber.BraintreeIdField()}:{subscriber.Id},{subscriber.BraintreeCloudRegionField()}:{_globalSettings.BaseServiceUri.CloudRegion}"
|
||||
CustomField =
|
||||
$"{subscriber.BraintreeIdField()}:{subscriber.Id},{subscriber.BraintreeCloudRegionField()}:{_globalSettings.BaseServiceUri.CloudRegion}"
|
||||
}
|
||||
},
|
||||
CustomFields = new Dictionary<string, string>
|
||||
{
|
||||
[subscriber.BraintreeIdField()] = subscriber.Id.ToString(),
|
||||
[subscriber.BraintreeCloudRegionField()] = _globalSettings.BaseServiceUri.CloudRegion
|
||||
[subscriber.BraintreeCloudRegionField()] =
|
||||
_globalSettings.BaseServiceUri.CloudRegion
|
||||
}
|
||||
});
|
||||
|
||||
@@ -442,9 +446,9 @@ public class StripePaymentService : IPaymentService
|
||||
{
|
||||
// SCA required, get intent client secret
|
||||
var invoiceGetOptions = new InvoiceGetOptions();
|
||||
invoiceGetOptions.AddExpand("payment_intent");
|
||||
invoiceGetOptions.AddExpand("confirmation_secret");
|
||||
invoice = await _stripeAdapter.InvoiceGetAsync(invoice.Id, invoiceGetOptions);
|
||||
paymentIntentClientSecret = invoice?.PaymentIntent?.ClientSecret;
|
||||
paymentIntentClientSecret = invoice?.ConfirmationSecret?.ClientSecret;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -458,6 +462,7 @@ public class StripePaymentService : IPaymentService
|
||||
{
|
||||
await _btGateway.Transaction.RefundAsync(braintreeTransaction.Id);
|
||||
}
|
||||
|
||||
if (invoice != null)
|
||||
{
|
||||
if (invoice.Status == "paid")
|
||||
@@ -479,10 +484,8 @@ public class StripePaymentService : IPaymentService
|
||||
// Assumption: Customer balance should now be $0, otherwise payment would not have failed.
|
||||
if (customer.Balance == 0)
|
||||
{
|
||||
await _stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions
|
||||
{
|
||||
Balance = invoice.StartingBalance
|
||||
});
|
||||
await _stripeAdapter.CustomerUpdateAsync(customer.Id,
|
||||
new CustomerUpdateOptions { Balance = invoice.StartingBalance });
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -496,6 +499,7 @@ public class StripePaymentService : IPaymentService
|
||||
// Let the caller perform any subscription change cleanup
|
||||
throw;
|
||||
}
|
||||
|
||||
return paymentIntentClientSecret;
|
||||
}
|
||||
|
||||
@@ -526,10 +530,10 @@ public class StripePaymentService : IPaymentService
|
||||
|
||||
try
|
||||
{
|
||||
var canceledSub = endOfPeriod ?
|
||||
await _stripeAdapter.SubscriptionUpdateAsync(sub.Id,
|
||||
new SubscriptionUpdateOptions { CancelAtPeriodEnd = true }) :
|
||||
await _stripeAdapter.SubscriptionCancelAsync(sub.Id, new SubscriptionCancelOptions());
|
||||
var canceledSub = endOfPeriod
|
||||
? await _stripeAdapter.SubscriptionUpdateAsync(sub.Id,
|
||||
new SubscriptionUpdateOptions { CancelAtPeriodEnd = true })
|
||||
: await _stripeAdapter.SubscriptionCancelAsync(sub.Id, new SubscriptionCancelOptions());
|
||||
if (!canceledSub.CanceledAt.HasValue)
|
||||
{
|
||||
throw new GatewayException("Unable to cancel subscription.");
|
||||
@@ -580,7 +584,7 @@ public class StripePaymentService : IPaymentService
|
||||
{
|
||||
Customer customer = null;
|
||||
var customerExists = subscriber.Gateway == GatewayType.Stripe &&
|
||||
!string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId);
|
||||
!string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId);
|
||||
if (customerExists)
|
||||
{
|
||||
customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId);
|
||||
@@ -595,10 +599,10 @@ public class StripePaymentService : IPaymentService
|
||||
subscriber.Gateway = GatewayType.Stripe;
|
||||
subscriber.GatewayCustomerId = customer.Id;
|
||||
}
|
||||
await _stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions
|
||||
{
|
||||
Balance = customer.Balance - (long)(creditAmount * 100)
|
||||
});
|
||||
|
||||
await _stripeAdapter.CustomerUpdateAsync(customer.Id,
|
||||
new CustomerUpdateOptions { Balance = customer.Balance - (long)(creditAmount * 100) });
|
||||
|
||||
return !customerExists;
|
||||
}
|
||||
|
||||
@@ -630,50 +634,45 @@ public class StripePaymentService : IPaymentService
|
||||
{
|
||||
var subscriptionInfo = new SubscriptionInfo();
|
||||
|
||||
if (!string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId))
|
||||
{
|
||||
var customerGetOptions = new CustomerGetOptions();
|
||||
customerGetOptions.AddExpand("discount.coupon.applies_to");
|
||||
var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerGetOptions);
|
||||
|
||||
if (customer.Discount != null)
|
||||
{
|
||||
subscriptionInfo.CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount(customer.Discount);
|
||||
}
|
||||
}
|
||||
|
||||
if (string.IsNullOrWhiteSpace(subscriber.GatewaySubscriptionId))
|
||||
if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId))
|
||||
{
|
||||
return subscriptionInfo;
|
||||
}
|
||||
|
||||
var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, new SubscriptionGetOptions
|
||||
var subscription = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId,
|
||||
new SubscriptionGetOptions { Expand = ["customer", "discounts", "test_clock"] });
|
||||
|
||||
subscriptionInfo.Subscription = new SubscriptionInfo.BillingSubscription(subscription);
|
||||
|
||||
var discount = subscription.Customer.Discount ?? subscription.Discounts.FirstOrDefault();
|
||||
|
||||
if (discount != null)
|
||||
{
|
||||
Expand = ["test_clock"]
|
||||
});
|
||||
|
||||
if (sub != null)
|
||||
{
|
||||
subscriptionInfo.Subscription = new SubscriptionInfo.BillingSubscription(sub);
|
||||
|
||||
var (suspensionDate, unpaidPeriodEndDate) = await GetSuspensionDateAsync(sub);
|
||||
|
||||
if (suspensionDate.HasValue && unpaidPeriodEndDate.HasValue)
|
||||
{
|
||||
subscriptionInfo.Subscription.SuspensionDate = suspensionDate;
|
||||
subscriptionInfo.Subscription.UnpaidPeriodEndDate = unpaidPeriodEndDate;
|
||||
}
|
||||
subscriptionInfo.CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount(discount);
|
||||
}
|
||||
|
||||
if (sub is { CanceledAt: not null } || string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId))
|
||||
var (suspensionDate, unpaidPeriodEndDate) = await GetSuspensionDateAsync(subscription);
|
||||
|
||||
if (suspensionDate.HasValue && unpaidPeriodEndDate.HasValue)
|
||||
{
|
||||
subscriptionInfo.Subscription.SuspensionDate = suspensionDate;
|
||||
subscriptionInfo.Subscription.UnpaidPeriodEndDate = unpaidPeriodEndDate;
|
||||
}
|
||||
|
||||
if (subscription is { CanceledAt: not null } || string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId))
|
||||
{
|
||||
return subscriptionInfo;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
var upcomingInvoiceOptions = new UpcomingInvoiceOptions { Customer = subscriber.GatewayCustomerId };
|
||||
var upcomingInvoice = await _stripeAdapter.InvoiceUpcomingAsync(upcomingInvoiceOptions);
|
||||
var invoiceCreatePreviewOptions = new InvoiceCreatePreviewOptions
|
||||
{
|
||||
Customer = subscriber.GatewayCustomerId,
|
||||
Subscription = subscriber.GatewaySubscriptionId
|
||||
};
|
||||
|
||||
var upcomingInvoice = await _stripeAdapter.InvoiceCreatePreviewAsync(invoiceCreatePreviewOptions);
|
||||
|
||||
if (upcomingInvoice != null)
|
||||
{
|
||||
@@ -682,7 +681,12 @@ public class StripePaymentService : IPaymentService
|
||||
}
|
||||
catch (StripeException ex)
|
||||
{
|
||||
_logger.LogWarning(ex, "Encountered an unexpected Stripe error");
|
||||
_logger.LogWarning(
|
||||
ex,
|
||||
"Failed to retrieve upcoming invoice for customer {CustomerId}, subscription {SubscriptionId}. Error Code: {ErrorCode}",
|
||||
subscriber.GatewayCustomerId,
|
||||
subscriber.GatewaySubscriptionId,
|
||||
ex.StripeError?.Code);
|
||||
}
|
||||
|
||||
return subscriptionInfo;
|
||||
@@ -788,7 +792,11 @@ public class StripePaymentService : IPaymentService
|
||||
if (taxInfo.TaxIdType == StripeConstants.TaxIdType.SpanishNIF)
|
||||
{
|
||||
await _stripeAdapter.TaxIdCreateAsync(customer.Id,
|
||||
new TaxIdCreateOptions { Type = StripeConstants.TaxIdType.EUVAT, Value = $"ES{taxInfo.TaxIdNumber}" });
|
||||
new TaxIdCreateOptions
|
||||
{
|
||||
Type = StripeConstants.TaxIdType.EUVAT,
|
||||
Value = $"ES{taxInfo.TaxIdNumber}"
|
||||
});
|
||||
}
|
||||
}
|
||||
catch (StripeException e)
|
||||
@@ -829,7 +837,8 @@ public class StripePaymentService : IPaymentService
|
||||
await HasSecretsManagerStandaloneAsync(gatewayCustomerId: organization.GatewayCustomerId,
|
||||
organizationHasSecretsManager: organization.UseSecretsManager);
|
||||
|
||||
private async Task<bool> HasSecretsManagerStandaloneAsync(string gatewayCustomerId, bool organizationHasSecretsManager)
|
||||
private async Task<bool> HasSecretsManagerStandaloneAsync(string gatewayCustomerId,
|
||||
bool organizationHasSecretsManager)
|
||||
{
|
||||
if (string.IsNullOrEmpty(gatewayCustomerId))
|
||||
{
|
||||
@@ -894,26 +903,14 @@ public class StripePaymentService : IPaymentService
|
||||
{
|
||||
var options = new InvoiceCreatePreviewOptions
|
||||
{
|
||||
AutomaticTax = new InvoiceAutomaticTaxOptions
|
||||
{
|
||||
Enabled = true,
|
||||
},
|
||||
AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true, },
|
||||
Currency = "usd",
|
||||
SubscriptionDetails = new InvoiceSubscriptionDetailsOptions
|
||||
{
|
||||
Items =
|
||||
[
|
||||
new()
|
||||
{
|
||||
Quantity = 1,
|
||||
Plan = StripeConstants.Prices.PremiumAnnually
|
||||
},
|
||||
|
||||
new()
|
||||
{
|
||||
Quantity = parameters.PasswordManager.AdditionalStorage,
|
||||
Plan = "storage-gb-annually"
|
||||
}
|
||||
new InvoiceSubscriptionDetailsItemOptions { Quantity = 1, Plan = StripeConstants.Prices.PremiumAnnually },
|
||||
new InvoiceSubscriptionDetailsItemOptions { Quantity = parameters.PasswordManager.AdditionalStorage, Plan = StripeConstants.Prices.StoragePlanPersonal }
|
||||
]
|
||||
},
|
||||
CustomerDetails = new InvoiceCustomerDetailsOptions
|
||||
@@ -940,12 +937,9 @@ public class StripePaymentService : IPaymentService
|
||||
throw new BadRequestException("billingPreviewInvalidTaxIdError");
|
||||
}
|
||||
|
||||
options.CustomerDetails.TaxIds = [
|
||||
new InvoiceCustomerDetailsTaxIdOptions
|
||||
{
|
||||
Type = taxIdType,
|
||||
Value = parameters.TaxInformation.TaxId
|
||||
}
|
||||
options.CustomerDetails.TaxIds =
|
||||
[
|
||||
new InvoiceCustomerDetailsTaxIdOptions { Type = taxIdType, Value = parameters.TaxInformation.TaxId }
|
||||
];
|
||||
|
||||
if (taxIdType == StripeConstants.TaxIdType.SpanishNIF)
|
||||
@@ -964,7 +958,7 @@ public class StripePaymentService : IPaymentService
|
||||
|
||||
if (gatewayCustomer.Discount != null)
|
||||
{
|
||||
options.Coupon = gatewayCustomer.Discount.Coupon.Id;
|
||||
options.Discounts = [new InvoiceDiscountOptions { Coupon = gatewayCustomer.Discount.Coupon.Id }];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -972,24 +966,31 @@ public class StripePaymentService : IPaymentService
|
||||
{
|
||||
var gatewaySubscription = await _stripeAdapter.SubscriptionGetAsync(gatewaySubscriptionId);
|
||||
|
||||
if (gatewaySubscription?.Discount != null)
|
||||
if (gatewaySubscription?.Discounts is { Count: > 0 })
|
||||
{
|
||||
options.Coupon ??= gatewaySubscription.Discount.Coupon.Id;
|
||||
options.Discounts = gatewaySubscription.Discounts.Select(x => new InvoiceDiscountOptions { Coupon = x.Coupon.Id }).ToList();
|
||||
}
|
||||
}
|
||||
|
||||
if (options.Discounts is { Count: > 0 })
|
||||
{
|
||||
options.Discounts = options.Discounts.DistinctBy(invoiceDiscountOptions => invoiceDiscountOptions.Coupon).ToList();
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
var invoice = await _stripeAdapter.InvoiceCreatePreviewAsync(options);
|
||||
|
||||
var effectiveTaxRate = invoice.Tax != null && invoice.TotalExcludingTax != null && invoice.TotalExcludingTax.Value != 0
|
||||
? invoice.Tax.Value.ToMajor() / invoice.TotalExcludingTax.Value.ToMajor()
|
||||
var tax = invoice.TotalTaxes.Sum(invoiceTotalTax => invoiceTotalTax.Amount);
|
||||
|
||||
var effectiveTaxRate = invoice.TotalExcludingTax != null && invoice.TotalExcludingTax.Value != 0
|
||||
? tax.ToMajor() / invoice.TotalExcludingTax.Value.ToMajor()
|
||||
: 0M;
|
||||
|
||||
var result = new PreviewInvoiceResponseModel(
|
||||
effectiveTaxRate,
|
||||
invoice.TotalExcludingTax.ToMajor() ?? 0,
|
||||
invoice.Tax.ToMajor() ?? 0,
|
||||
tax.ToMajor(),
|
||||
invoice.Total.ToMajor());
|
||||
return result;
|
||||
}
|
||||
@@ -1003,7 +1004,8 @@ public class StripePaymentService : IPaymentService
|
||||
parameters.TaxInformation.Country);
|
||||
throw new BadRequestException("billingPreviewInvalidTaxIdError");
|
||||
default:
|
||||
_logger.LogError(e, "Unexpected error previewing invoice with tax ID '{TaxId}' in country '{Country}'.",
|
||||
_logger.LogError(e,
|
||||
"Unexpected error previewing invoice with tax ID '{TaxId}' in country '{Country}'.",
|
||||
parameters.TaxInformation.TaxId,
|
||||
parameters.TaxInformation.Country);
|
||||
throw new BadRequestException("billingPreviewInvoiceError");
|
||||
@@ -1101,12 +1103,9 @@ public class StripePaymentService : IPaymentService
|
||||
throw new BadRequestException("billingTaxIdTypeInferenceError");
|
||||
}
|
||||
|
||||
options.CustomerDetails.TaxIds = [
|
||||
new InvoiceCustomerDetailsTaxIdOptions
|
||||
{
|
||||
Type = taxIdType,
|
||||
Value = parameters.TaxInformation.TaxId
|
||||
}
|
||||
options.CustomerDetails.TaxIds =
|
||||
[
|
||||
new InvoiceCustomerDetailsTaxIdOptions { Type = taxIdType, Value = parameters.TaxInformation.TaxId }
|
||||
];
|
||||
|
||||
if (taxIdType == StripeConstants.TaxIdType.SpanishNIF)
|
||||
@@ -1127,7 +1126,10 @@ public class StripePaymentService : IPaymentService
|
||||
|
||||
if (gatewayCustomer.Discount != null)
|
||||
{
|
||||
options.Coupon = gatewayCustomer.Discount.Coupon.Id;
|
||||
options.Discounts =
|
||||
[
|
||||
new InvoiceDiscountOptions { Coupon = gatewayCustomer.Discount.Coupon.Id }
|
||||
];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1135,9 +1137,10 @@ public class StripePaymentService : IPaymentService
|
||||
{
|
||||
var gatewaySubscription = await _stripeAdapter.SubscriptionGetAsync(gatewaySubscriptionId);
|
||||
|
||||
if (gatewaySubscription?.Discount != null)
|
||||
if (gatewaySubscription?.Discounts != null)
|
||||
{
|
||||
options.Coupon ??= gatewaySubscription.Discount.Coupon.Id;
|
||||
options.Discounts = gatewaySubscription.Discounts
|
||||
.Select(discount => new InvoiceDiscountOptions { Coupon = discount.Coupon.Id }).ToList();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1152,14 +1155,16 @@ public class StripePaymentService : IPaymentService
|
||||
{
|
||||
var invoice = await _stripeAdapter.InvoiceCreatePreviewAsync(options);
|
||||
|
||||
var effectiveTaxRate = invoice.Tax != null && invoice.TotalExcludingTax != null && invoice.TotalExcludingTax.Value != 0
|
||||
? invoice.Tax.Value.ToMajor() / invoice.TotalExcludingTax.Value.ToMajor()
|
||||
var tax = invoice.TotalTaxes.Sum(invoiceTotalTax => invoiceTotalTax.Amount);
|
||||
|
||||
var effectiveTaxRate = invoice.TotalExcludingTax != null && invoice.TotalExcludingTax.Value != 0
|
||||
? tax.ToMajor() / invoice.TotalExcludingTax.Value.ToMajor()
|
||||
: 0M;
|
||||
|
||||
var result = new PreviewInvoiceResponseModel(
|
||||
effectiveTaxRate,
|
||||
invoice.TotalExcludingTax.ToMajor() ?? 0,
|
||||
invoice.Tax.ToMajor() ?? 0,
|
||||
tax.ToMajor(),
|
||||
invoice.Total.ToMajor());
|
||||
return result;
|
||||
}
|
||||
@@ -1173,7 +1178,8 @@ public class StripePaymentService : IPaymentService
|
||||
parameters.TaxInformation.Country);
|
||||
throw new BadRequestException("billingPreviewInvalidTaxIdError");
|
||||
default:
|
||||
_logger.LogError(e, "Unexpected error previewing invoice with tax ID '{TaxId}' in country '{Country}'.",
|
||||
_logger.LogError(e,
|
||||
"Unexpected error previewing invoice with tax ID '{TaxId}' in country '{Country}'.",
|
||||
parameters.TaxInformation.TaxId,
|
||||
parameters.TaxInformation.Country);
|
||||
throw new BadRequestException("billingPreviewInvoiceError");
|
||||
@@ -1207,7 +1213,9 @@ public class StripePaymentService : IPaymentService
|
||||
braintreeCustomer.DefaultPaymentMethod);
|
||||
}
|
||||
}
|
||||
catch (Braintree.Exceptions.NotFoundException) { }
|
||||
catch (Braintree.Exceptions.NotFoundException)
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
if (customer.InvoiceSettings?.DefaultPaymentMethod?.Type == "card")
|
||||
@@ -1246,12 +1254,15 @@ public class StripePaymentService : IPaymentService
|
||||
{
|
||||
customer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId, options);
|
||||
}
|
||||
catch (StripeException) { }
|
||||
catch (StripeException)
|
||||
{
|
||||
}
|
||||
|
||||
return customer;
|
||||
}
|
||||
|
||||
private async Task<IEnumerable<BillingHistoryInfo.BillingTransaction>> GetBillingTransactionsAsync(ISubscriber subscriber, int? limit = null)
|
||||
private async Task<IEnumerable<BillingHistoryInfo.BillingTransaction>> GetBillingTransactionsAsync(
|
||||
ISubscriber subscriber, int? limit = null)
|
||||
{
|
||||
var transactions = subscriber switch
|
||||
{
|
||||
|
||||
@@ -17,6 +17,6 @@ public class LoggingExceptionHandlerFilterAttribute : ExceptionFilterAttribute
|
||||
|
||||
var logger = context.HttpContext.RequestServices
|
||||
.GetRequiredService<ILogger<LoggingExceptionHandlerFilterAttribute>>();
|
||||
logger.LogError(0, exception, exception.Message);
|
||||
logger.LogError(0, exception, "Unhandled exception");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user