1
0
mirror of https://github.com/bitwarden/server synced 2026-01-04 17:43:53 +00:00

Merge branch 'main' into SM-1571-DisableSMAdsForUsers

This commit is contained in:
cd-bitwarden
2025-10-22 12:28:04 -04:00
committed by GitHub
295 changed files with 42165 additions and 4881 deletions

View File

@@ -129,11 +129,15 @@ public class Organization : ITableObject<Guid>, IStorableSubscriber, IRevisable
/// </summary>
public bool SyncSeats { get; set; }
/// If set to true, user accounts created within the organization are automatically confirmed without requiring additional verification steps.
/// </summary>
public bool UseAutomaticUserConfirmation { get; set; }
/// <summary>
/// If set to true, disables Secrets Manager ads for users in the organization
/// </summary>
public bool UseDisableSMAdsForUsers { get; set; }
public void SetNewId()
{
if (Id == default(Guid))

View File

@@ -28,6 +28,7 @@ public class OrganizationAbility
UseRiskInsights = organization.UseRiskInsights;
UseOrganizationDomains = organization.UseOrganizationDomains;
UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies;
UseAutomaticUserConfirmation = organization.UseAutomaticUserConfirmation;
UseDisableSMAdsForUsers = organization.UseDisableSMAdsForUsers;
}
@@ -50,5 +51,6 @@ public class OrganizationAbility
public bool UseRiskInsights { get; set; }
public bool UseOrganizationDomains { get; set; }
public bool UseAdminSponsoredFamilies { get; set; }
public bool UseAutomaticUserConfirmation { get; set; }
public bool UseDisableSMAdsForUsers { get; set; }
}

View File

@@ -66,4 +66,5 @@ public class OrganizationUserOrganizationDetails
public bool UseOrganizationDomains { get; set; }
public bool UseAdminSponsoredFamilies { get; set; }
public bool? IsAdminInitiated { get; set; }
public bool UseAutomaticUserConfirmation { get; set; }
}

View File

@@ -51,4 +51,5 @@ public class ProviderUserOrganizationDetails
public bool UseOrganizationDomains { get; set; }
public bool UseAdminSponsoredFamilies { get; set; }
public ProviderType ProviderType { get; set; }
public bool UseAutomaticUserConfirmation { get; set; }
}

View File

@@ -89,7 +89,7 @@ public class UpdateOrganizationAuthRequestCommand : IUpdateOrganizationAuthReque
AuthRequestExpiresAfter = _globalSettings.PasswordlessAuth.AdminRequestExpiration
}
);
processor.Process((Exception e) => _logger.LogError(e.Message));
processor.Process((Exception e) => _logger.LogError("Error processing organization auth request: {Message}", e.Message));
await processor.Save((IEnumerable<OrganizationAdminAuthRequest> authRequests) => _authRequestRepository.UpdateManyAsync(authRequests));
await processor.SendPushNotifications((ar) => _pushNotificationService.PushAuthRequestResponseAsync(ar));
await processor.SendApprovalEmailsForProcessedRequests(SendApprovalEmail);
@@ -114,7 +114,7 @@ public class UpdateOrganizationAuthRequestCommand : IUpdateOrganizationAuthReque
// This should be impossible
if (user == null)
{
_logger.LogError($"User {authRequest.UserId} not found. Trusted device admin approval email not sent.");
_logger.LogError("User {UserId} not found. Trusted device admin approval email not sent.", authRequest.UserId);
return;
}

View File

@@ -1,6 +1,4 @@
#nullable enable
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.AdminConsole.Repositories;
@@ -20,7 +18,7 @@ public class PolicyRequirementQuery(
throw new NotImplementedException("No Requirement Factory found for " + typeof(T));
}
var policyDetails = await GetPolicyDetails(userId);
var policyDetails = await GetPolicyDetails(userId, factory.PolicyType);
var filteredPolicies = policyDetails
.Where(p => p.PolicyType == factory.PolicyType)
.Where(factory.Enforce);
@@ -48,8 +46,8 @@ public class PolicyRequirementQuery(
return eligibleOrganizationUserIds;
}
private Task<IEnumerable<PolicyDetails>> GetPolicyDetails(Guid userId)
=> policyRepository.GetPolicyDetailsByUserId(userId);
private async Task<IEnumerable<OrganizationPolicyDetails>> GetPolicyDetails(Guid userId, PolicyType policyType)
=> await policyRepository.GetPolicyDetailsByUserIdsAndPolicyType([userId], policyType);
private async Task<IEnumerable<OrganizationPolicyDetails>> GetOrganizationPolicyDetails(Guid organizationId, PolicyType policyType)
=> await policyRepository.GetPolicyDetailsByOrganizationIdAsync(organizationId, policyType);

View File

@@ -13,25 +13,11 @@ public class VNextSavePolicyCommand(
IApplicationCacheService applicationCacheService,
IEventService eventService,
IPolicyRepository policyRepository,
IEnumerable<IEnforceDependentPoliciesEvent> policyValidationEventHandlers,
IEnumerable<IPolicyUpdateEvent> policyUpdateEventHandlers,
TimeProvider timeProvider,
IPolicyEventHandlerFactory policyEventHandlerFactory)
: IVNextSavePolicyCommand
{
private readonly IReadOnlyDictionary<PolicyType, IEnforceDependentPoliciesEvent> _policyValidationEvents = MapToDictionary(policyValidationEventHandlers);
private static Dictionary<PolicyType, IEnforceDependentPoliciesEvent> MapToDictionary(IEnumerable<IEnforceDependentPoliciesEvent> policyValidationEventHandlers)
{
var policyValidationEventsDict = new Dictionary<PolicyType, IEnforceDependentPoliciesEvent>();
foreach (var policyValidationEvent in policyValidationEventHandlers)
{
if (!policyValidationEventsDict.TryAdd(policyValidationEvent.Type, policyValidationEvent))
{
throw new Exception($"Duplicate PolicyValidationEvent for {policyValidationEvent.Type} policy.");
}
}
return policyValidationEventsDict;
}
public async Task<Policy> SaveAsync(SavePolicyModel policyRequest)
{
@@ -112,32 +98,26 @@ public class VNextSavePolicyCommand(
Policy? currentPolicy,
Dictionary<PolicyType, Policy> savedPoliciesDict)
{
var result = policyEventHandlerFactory.GetHandler<IEnforceDependentPoliciesEvent>(policyUpdateRequest.Type);
var isCurrentlyEnabled = currentPolicy?.Enabled == true;
var isBeingEnabled = policyUpdateRequest.Enabled && !isCurrentlyEnabled;
var isBeingDisabled = !policyUpdateRequest.Enabled && isCurrentlyEnabled;
result.Switch(
validator =>
{
var isCurrentlyEnabled = currentPolicy?.Enabled == true;
switch (policyUpdateRequest.Enabled)
{
case true when !isCurrentlyEnabled:
ValidateEnablingRequirements(validator, savedPoliciesDict);
return;
case false when isCurrentlyEnabled:
ValidateDisablingRequirements(validator, policyUpdateRequest.Type, savedPoliciesDict);
break;
}
},
_ => { });
if (isBeingEnabled)
{
ValidateEnablingRequirements(policyUpdateRequest.Type, savedPoliciesDict);
}
else if (isBeingDisabled)
{
ValidateDisablingRequirements(policyUpdateRequest.Type, savedPoliciesDict);
}
}
private void ValidateDisablingRequirements(
IEnforceDependentPoliciesEvent validator,
PolicyType policyType,
Dictionary<PolicyType, Policy> savedPoliciesDict)
{
var dependentPolicyTypes = _policyValidationEvents.Values
var dependentPolicyTypes = policyUpdateEventHandlers
.OfType<IEnforceDependentPoliciesEvent>()
.Where(otherValidator => otherValidator.RequiredPolicies.Contains(policyType))
.Select(otherValidator => otherValidator.Type)
.Where(otherPolicyType => savedPoliciesDict.TryGetValue(otherPolicyType, out var savedPolicy) &&
@@ -147,24 +127,31 @@ public class VNextSavePolicyCommand(
switch (dependentPolicyTypes)
{
case { Count: 1 }:
throw new BadRequestException($"Turn off the {dependentPolicyTypes.First().GetName()} policy because it requires the {validator.Type.GetName()} policy.");
throw new BadRequestException($"Turn off the {dependentPolicyTypes.First().GetName()} policy because it requires the {policyType.GetName()} policy.");
case { Count: > 1 }:
throw new BadRequestException($"Turn off all of the policies that require the {validator.Type.GetName()} policy.");
throw new BadRequestException($"Turn off all of the policies that require the {policyType.GetName()} policy.");
}
}
private static void ValidateEnablingRequirements(
IEnforceDependentPoliciesEvent validator,
private void ValidateEnablingRequirements(
PolicyType policyType,
Dictionary<PolicyType, Policy> savedPoliciesDict)
{
var missingRequiredPolicyTypes = validator.RequiredPolicies
.Where(requiredPolicyType => savedPoliciesDict.GetValueOrDefault(requiredPolicyType) is not { Enabled: true })
.ToList();
var result = policyEventHandlerFactory.GetHandler<IEnforceDependentPoliciesEvent>(policyType);
if (missingRequiredPolicyTypes.Count != 0)
{
throw new BadRequestException($"Turn on the {missingRequiredPolicyTypes.First().GetName()} policy because it is required for the {validator.Type.GetName()} policy.");
}
result.Switch(
validator =>
{
var missingRequiredPolicyTypes = validator.RequiredPolicies
.Where(requiredPolicyType => savedPoliciesDict.GetValueOrDefault(requiredPolicyType) is not { Enabled: true })
.ToList();
if (missingRequiredPolicyTypes.Count != 0)
{
throw new BadRequestException($"Turn on the {missingRequiredPolicyTypes.First().GetName()} policy because it is required for the {policyType.GetName()} policy.");
}
},
_ => { /* Policy has no required dependencies */ });
}
private async Task ExecutePreUpsertSideEffectAsync(

View File

@@ -22,8 +22,10 @@ public static class PolicyServiceCollectionExtensions
services.AddPolicyValidators();
services.AddPolicyRequirements();
services.AddPolicySideEffects();
services.AddPolicyUpdateEvents();
}
[Obsolete("Use AddPolicyUpdateEvents instead.")]
private static void AddPolicyValidators(this IServiceCollection services)
{
services.AddScoped<IPolicyValidator, TwoFactorAuthenticationPolicyValidator>();
@@ -34,11 +36,23 @@ public static class PolicyServiceCollectionExtensions
services.AddScoped<IPolicyValidator, FreeFamiliesForEnterprisePolicyValidator>();
}
[Obsolete("Use AddPolicyUpdateEvents instead.")]
private static void AddPolicySideEffects(this IServiceCollection services)
{
services.AddScoped<IPostSavePolicySideEffect, OrganizationDataOwnershipPolicyValidator>();
}
private static void AddPolicyUpdateEvents(this IServiceCollection services)
{
services.AddScoped<IPolicyUpdateEvent, RequireSsoPolicyValidator>();
services.AddScoped<IPolicyUpdateEvent, TwoFactorAuthenticationPolicyValidator>();
services.AddScoped<IPolicyUpdateEvent, SingleOrgPolicyValidator>();
services.AddScoped<IPolicyUpdateEvent, ResetPasswordPolicyValidator>();
services.AddScoped<IPolicyUpdateEvent, MaximumVaultTimeoutPolicyValidator>();
services.AddScoped<IPolicyUpdateEvent, FreeFamiliesForEnterprisePolicyValidator>();
services.AddScoped<IPolicyUpdateEvent, OrganizationDataOwnershipPolicyValidator>();
}
private static void AddPolicyRequirements(this IServiceCollection services)
{
services.AddScoped<IPolicyRequirementFactory<IPolicyRequirement>, DisableSendPolicyRequirementFactory>();

View File

@@ -3,6 +3,7 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
using Bit.Core.Repositories;
using Bit.Core.Services;
@@ -12,11 +13,16 @@ public class FreeFamiliesForEnterprisePolicyValidator(
IOrganizationSponsorshipRepository organizationSponsorshipRepository,
IMailService mailService,
IOrganizationRepository organizationRepository)
: IPolicyValidator
: IPolicyValidator, IOnPolicyPreUpdateEvent
{
public PolicyType Type => PolicyType.FreeFamiliesSponsorshipPolicy;
public IEnumerable<PolicyType> RequiredPolicies => [];
public async Task ExecutePreUpsertSideEffectAsync(SavePolicyModel policyRequest, Policy? currentPolicy)
{
await OnSaveSideEffectsAsync(policyRequest.PolicyUpdate, currentPolicy);
}
public async Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy)
{
if (currentPolicy is not { Enabled: true } && policyUpdate is { Enabled: true })

View File

@@ -3,10 +3,11 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
public class MaximumVaultTimeoutPolicyValidator : IPolicyValidator
public class MaximumVaultTimeoutPolicyValidator : IPolicyValidator, IEnforceDependentPoliciesEvent
{
public PolicyType Type => PolicyType.MaximumVaultTimeout;
public IEnumerable<PolicyType> RequiredPolicies => [PolicyType.SingleOrg];

View File

@@ -1,24 +1,32 @@

using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Repositories;
using Bit.Core.Services;
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
/// <summary>
/// Please do not extend or expand this validator. We're currently in the process of refactoring our policy validator pattern.
/// This is a stop-gap solution for post-policy-save side effects, but it is not the long-term solution.
/// </summary>
public class OrganizationDataOwnershipPolicyValidator(
IPolicyRepository policyRepository,
ICollectionRepository collectionRepository,
IEnumerable<IPolicyRequirementFactory<IPolicyRequirement>> factories,
IFeatureService featureService)
: OrganizationPolicyValidator(policyRepository, factories), IPostSavePolicySideEffect
: OrganizationPolicyValidator(policyRepository, factories), IPostSavePolicySideEffect, IOnPolicyPostUpdateEvent
{
public PolicyType Type => PolicyType.OrganizationDataOwnership;
public async Task ExecutePostUpsertSideEffectAsync(
SavePolicyModel policyRequest,
Policy postUpsertedPolicyState,
Policy? previousPolicyState)
{
await ExecuteSideEffectsAsync(policyRequest, postUpsertedPolicyState, previousPolicyState);
}
public async Task ExecuteSideEffectsAsync(
SavePolicyModel policyRequest,
Policy postUpdatedPolicy,
@@ -68,5 +76,4 @@ public class OrganizationDataOwnershipPolicyValidator(
userOrgIds,
defaultCollectionName);
}
}

View File

@@ -3,12 +3,13 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Repositories;
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
public class RequireSsoPolicyValidator : IPolicyValidator
public class RequireSsoPolicyValidator : IPolicyValidator, IPolicyValidationEvent, IEnforceDependentPoliciesEvent
{
private readonly ISsoConfigRepository _ssoConfigRepository;
@@ -20,6 +21,11 @@ public class RequireSsoPolicyValidator : IPolicyValidator
public PolicyType Type => PolicyType.RequireSso;
public IEnumerable<PolicyType> RequiredPolicies => [PolicyType.SingleOrg];
public async Task<string> ValidateAsync(SavePolicyModel policyRequest, Policy? currentPolicy)
{
return await ValidateAsync(policyRequest.PolicyUpdate, currentPolicy);
}
public async Task<string> ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy)
{
if (policyUpdate is not { Enabled: true })

View File

@@ -4,12 +4,13 @@ using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Repositories;
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
public class ResetPasswordPolicyValidator : IPolicyValidator
public class ResetPasswordPolicyValidator : IPolicyValidator, IPolicyValidationEvent, IEnforceDependentPoliciesEvent
{
private readonly ISsoConfigRepository _ssoConfigRepository;
public PolicyType Type => PolicyType.ResetPassword;
@@ -20,6 +21,11 @@ public class ResetPasswordPolicyValidator : IPolicyValidator
_ssoConfigRepository = ssoConfigRepository;
}
public async Task<string> ValidateAsync(SavePolicyModel policyRequest, Policy? currentPolicy)
{
return await ValidateAsync(policyRequest.PolicyUpdate, currentPolicy);
}
public async Task<string> ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy)
{
if (policyUpdate is not { Enabled: true } ||

View File

@@ -7,6 +7,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Repositories;
using Bit.Core.Context;
@@ -17,7 +18,7 @@ using Bit.Core.Services;
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
public class SingleOrgPolicyValidator : IPolicyValidator
public class SingleOrgPolicyValidator : IPolicyValidator, IPolicyValidationEvent, IOnPolicyPreUpdateEvent
{
public PolicyType Type => PolicyType.SingleOrg;
private const string OrganizationNotFoundErrorMessage = "Organization not found.";
@@ -57,6 +58,16 @@ public class SingleOrgPolicyValidator : IPolicyValidator
public IEnumerable<PolicyType> RequiredPolicies => [];
public async Task<string> ValidateAsync(SavePolicyModel policyRequest, Policy? currentPolicy)
{
return await ValidateAsync(policyRequest.PolicyUpdate, currentPolicy);
}
public async Task ExecutePreUpsertSideEffectAsync(SavePolicyModel policyRequest, Policy? currentPolicy)
{
await OnSaveSideEffectsAsync(policyRequest.PolicyUpdate, currentPolicy);
}
public async Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy)
{
if (currentPolicy is not { Enabled: true } && policyUpdate is { Enabled: true })

View File

@@ -6,6 +6,7 @@ using Bit.Core.AdminConsole.Models.Data;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces;
using Bit.Core.Context;
using Bit.Core.Enums;
@@ -16,7 +17,7 @@ using Bit.Core.Services;
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
public class TwoFactorAuthenticationPolicyValidator : IPolicyValidator
public class TwoFactorAuthenticationPolicyValidator : IPolicyValidator, IOnPolicyPreUpdateEvent
{
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IMailService _mailService;
@@ -46,6 +47,11 @@ public class TwoFactorAuthenticationPolicyValidator : IPolicyValidator
_revokeNonCompliantOrganizationUserCommand = revokeNonCompliantOrganizationUserCommand;
}
public async Task ExecutePreUpsertSideEffectAsync(SavePolicyModel policyRequest, Policy? currentPolicy)
{
await OnSaveSideEffectsAsync(policyRequest.PolicyUpdate, currentPolicy);
}
public async Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy)
{
if (currentPolicy is not { Enabled: true } && policyUpdate is { Enabled: true })

View File

@@ -87,4 +87,13 @@ public interface IOrganizationUserRepository : IRepository<OrganizationUser, Gui
Task<IEnumerable<OrganizationUserUserDetails>> GetManyDetailsByRoleAsync(Guid organizationId, OrganizationUserType role);
Task CreateManyAsync(IEnumerable<CreateOrganizationUser> organizationUserCollection);
/// <summary>
/// It will only confirm if the user is in the `Accepted` state.
///
/// This is an idempotent operation.
/// </summary>
/// <param name="organizationUser">Accepted OrganizationUser to confirm</param>
/// <returns>True, if the user was updated. False, if not performed.</returns>
Task<bool> ConfirmOrganizationUserAsync(OrganizationUser organizationUser);
}

View File

@@ -20,17 +20,6 @@ public interface IPolicyRepository : IRepository<Policy, Guid>
Task<Policy?> GetByOrganizationIdTypeAsync(Guid organizationId, PolicyType type);
Task<ICollection<Policy>> GetManyByOrganizationIdAsync(Guid organizationId);
Task<ICollection<Policy>> GetManyByUserIdAsync(Guid userId);
/// <summary>
/// Gets all PolicyDetails for a user for all policy types.
/// </summary>
/// <remarks>
/// Each PolicyDetail represents an OrganizationUser and a Policy which *may* be enforced
/// against them. It only returns PolicyDetails for policies that are enabled and where the organization's plan
/// supports policies. It also excludes "revoked invited" users who are not subject to policy enforcement.
/// This is consumed by <see cref="IPolicyRequirementQuery"/> to create requirements for specific policy types.
/// You probably do not want to call it directly.
/// </remarks>
Task<IEnumerable<PolicyDetails>> GetPolicyDetailsByUserId(Guid userId);
/// <summary>
/// Retrieves <see cref="OrganizationPolicyDetails"/> of the specified <paramref name="policyType"/>

View File

@@ -61,8 +61,9 @@ public static class OrganizationFactory
claimsPrincipal.GetValue<bool>(OrganizationLicenseConstants.UseOrganizationDomains),
UseAdminSponsoredFamilies =
claimsPrincipal.GetValue<bool>(OrganizationLicenseConstants.UseAdminSponsoredFamilies),
UseAutomaticUserConfirmation = claimsPrincipal.GetValue<bool>(OrganizationLicenseConstants.UseAutomaticUserConfirmation),
UseDisableSMAdsForUsers =
claimsPrincipal.GetValue<bool>(OrganizationLicenseConstants.UseDisableSMAdsForUsers),
claimsPrincipal.GetValue<bool>(OrganizationLicenseConstants.UseDisableSMAdsForUsers),
};
public static Organization Create(

View File

@@ -1,5 +1,5 @@
using System.Text.Json.Serialization;
using Bit.Core.KeyManagement.Models.Response;
using Bit.Core.KeyManagement.Models.Api.Response;
using Bit.Core.Models.Api;
namespace Bit.Core.Auth.Models.Api.Response;

View File

@@ -64,10 +64,12 @@ public static class InvoiceExtensions
}
}
var tax = invoice.TotalTaxes?.Sum(invoiceTotalTax => invoiceTotalTax.Amount) ?? 0;
// Add fallback tax from invoice-level tax if present and not already included
if (invoice.Tax.HasValue && invoice.Tax.Value > 0)
if (tax > 0)
{
var taxAmount = invoice.Tax.Value / 100m;
var taxAmount = tax / 100m;
items.Add($"1 × Tax (at ${taxAmount:F2} / month)");
}

View File

@@ -0,0 +1,25 @@
using Stripe;
namespace Bit.Core.Billing.Extensions;
public static class SubscriptionExtensions
{
/*
* For the time being, this is the simplest migration approach from v45 to v48 as
* we do not support multi-cadence subscriptions. Each subscription item should be on the
* same billing cycle. If this changes, we'll need a significantly more robust approach.
*
* Because we can't guarantee a subscription will have items, this has to be nullable.
*/
public static (DateTime? Start, DateTime? End)? GetCurrentPeriod(this Subscription subscription)
{
var item = subscription.Items?.FirstOrDefault();
return item is null ? null : (item.CurrentPeriodStart, item.CurrentPeriodEnd);
}
public static DateTime? GetCurrentPeriodStart(this Subscription subscription) =>
subscription.Items?.FirstOrDefault()?.CurrentPeriodStart;
public static DateTime? GetCurrentPeriodEnd(this Subscription subscription) =>
subscription.Items?.FirstOrDefault()?.CurrentPeriodEnd;
}

View File

@@ -1,35 +0,0 @@
using Stripe;
namespace Bit.Core.Billing.Extensions;
public static class UpcomingInvoiceOptionsExtensions
{
/// <summary>
/// Attempts to enable automatic tax for given upcoming invoice options.
/// </summary>
/// <param name="options"></param>
/// <param name="customer">The existing customer to which the upcoming invoice belongs.</param>
/// <param name="subscription">The existing subscription to which the upcoming invoice belongs.</param>
/// <returns>Returns true when successful, false when conditions are not met.</returns>
public static bool EnableAutomaticTax(
this UpcomingInvoiceOptions options,
Customer customer,
Subscription subscription)
{
if (subscription != null && subscription.AutomaticTax.Enabled)
{
return false;
}
// We might only need to check the automatic tax status.
if (!customer.HasRecognizedTaxLocation() && string.IsNullOrWhiteSpace(customer.Address?.Country))
{
return false;
}
options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true };
options.SubscriptionDefaultTaxRates = [];
return true;
}
}

View File

@@ -43,6 +43,7 @@ public static class OrganizationLicenseConstants
public const string Trial = nameof(Trial);
public const string UseAdminSponsoredFamilies = nameof(UseAdminSponsoredFamilies);
public const string UseOrganizationDomains = nameof(UseOrganizationDomains);
public const string UseAutomaticUserConfirmation = nameof(UseAutomaticUserConfirmation);
public const string UseDisableSMAdsForUsers = nameof(UseDisableSMAdsForUsers);
}

View File

@@ -56,6 +56,7 @@ public class OrganizationLicenseClaimsFactory : ILicenseClaimsFactory<Organizati
new(nameof(OrganizationLicenseConstants.Trial), trial.ToString()),
new(nameof(OrganizationLicenseConstants.UseAdminSponsoredFamilies), entity.UseAdminSponsoredFamilies.ToString()),
new(nameof(OrganizationLicenseConstants.UseOrganizationDomains), entity.UseOrganizationDomains.ToString()),
new(nameof(OrganizationLicenseConstants.UseAutomaticUserConfirmation), entity.UseAutomaticUserConfirmation.ToString()),
new(nameof(OrganizationLicenseConstants.UseDisableSMAdsForUsers), entity.UseDisableSMAdsForUsers.ToString()),
};

View File

@@ -1,6 +1,7 @@
// FIXME: Update this file to be null safe and then delete the line below
#nullable disable
using Bit.Core.Billing.Constants;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Stripe;
@@ -46,7 +47,7 @@ public class BillingHistoryInfo
Url = inv.HostedInvoiceUrl;
PdfUrl = inv.InvoicePdf;
Number = inv.Number;
Paid = inv.Paid;
Paid = inv.Status == StripeConstants.InvoiceStatus.Paid;
Amount = inv.Total / 100M;
}

View File

@@ -43,6 +43,8 @@ public abstract record Plan
public SecretsManagerPlanFeatures SecretsManager { get; protected init; }
public bool SupportsSecretsManager => SecretsManager != null;
public bool AutomaticUserConfirmation { get; init; }
public bool HasNonSeatBasedPasswordManagerPlan() =>
PasswordManager is { StripePlanId: not null and not "", StripeSeatPlanId: null or "" };

View File

@@ -75,7 +75,13 @@ public class PreviewOrganizationTaxCommand(
Quantity = purchase.SecretsManager.Seats
}
]);
options.Coupon = CouponIDs.SecretsManagerStandalone;
options.Discounts =
[
new InvoiceDiscountOptions
{
Coupon = CouponIDs.SecretsManagerStandalone
}
];
break;
default:
@@ -180,7 +186,10 @@ public class PreviewOrganizationTaxCommand(
if (subscription.Customer.Discount != null)
{
options.Coupon = subscription.Customer.Discount.Coupon.Id;
options.Discounts =
[
new InvoiceDiscountOptions { Coupon = subscription.Customer.Discount.Coupon.Id }
];
}
var currentPlan = await pricingClient.GetPlanOrThrow(organization.PlanType);
@@ -277,7 +286,10 @@ public class PreviewOrganizationTaxCommand(
if (subscription.Customer.Discount != null)
{
options.Coupon = subscription.Customer.Discount.Coupon.Id;
options.Discounts =
[
new InvoiceDiscountOptions { Coupon = subscription.Customer.Discount.Coupon.Id }
];
}
var currentPlan = await pricingClient.GetPlanOrThrow(organization.PlanType);
@@ -329,7 +341,7 @@ public class PreviewOrganizationTaxCommand(
});
private static (decimal, decimal) GetAmounts(Invoice invoice) => (
Convert.ToDecimal(invoice.Tax) / 100,
Convert.ToDecimal(invoice.TotalTaxes.Sum(invoiceTotalTax => invoiceTotalTax.Amount)) / 100,
Convert.ToDecimal(invoice.Total) / 100);
private static InvoiceCreatePreviewOptions GetBaseOptions(

View File

@@ -153,6 +153,7 @@ public class OrganizationLicense : ILicense
public LicenseType? LicenseType { get; set; }
public bool UseOrganizationDomains { get; set; }
public bool UseAdminSponsoredFamilies { get; set; }
public bool UseAutomaticUserConfirmation { get; set; }
public bool UseDisableSMAdsForUsers { get; set; }
public string Hash { get; set; }
public string Signature { get; set; }
@@ -228,6 +229,7 @@ public class OrganizationLicense : ILicense
!p.Name.Equals(nameof(UseRiskInsights)) &&
!p.Name.Equals(nameof(UseAdminSponsoredFamilies)) &&
!p.Name.Equals(nameof(UseOrganizationDomains)) &&
!p.Name.Equals(nameof(UseAutomaticUserConfirmation))) &&
!p.Name.Equals(nameof(UseDisableSMAdsForUsers)))
.OrderBy(p => p.Name)
.Select(p => $"{p.Name}:{Core.Utilities.CoreHelpers.FormatLicenseSignatureValue(p.GetValue(this, null))}")
@@ -423,6 +425,7 @@ public class OrganizationLicense : ILicense
var smServiceAccounts = claimsPrincipal.GetValue<int?>(nameof(SmServiceAccounts));
var useAdminSponsoredFamilies = claimsPrincipal.GetValue<bool>(nameof(UseAdminSponsoredFamilies));
var useOrganizationDomains = claimsPrincipal.GetValue<bool>(nameof(UseOrganizationDomains));
var useAutomaticUserConfirmation = claimsPrincipal.GetValue<bool>(nameof(UseAutomaticUserConfirmation));
var UseDisableSMAdsForUsers = claimsPrincipal.GetValue<bool>(nameof(UseDisableSMAdsForUsers));
return issued <= DateTime.UtcNow &&
@@ -454,6 +457,7 @@ public class OrganizationLicense : ILicense
smServiceAccounts == organization.SmServiceAccounts &&
useAdminSponsoredFamilies == organization.UseAdminSponsoredFamilies &&
useOrganizationDomains == organization.UseOrganizationDomains &&
useAutomaticUserConfirmation == organization.UseAutomaticUserConfirmation;
UseDisableSMAdsForUsers == organization.UseDisableSMAdsForUsers;
}

View File

@@ -9,7 +9,7 @@ namespace Bit.Core.Billing.Organizations.Models;
public class OrganizationSale
{
private OrganizationSale() { }
internal OrganizationSale() { }
public void Deconstruct(
out Organization organization,

View File

@@ -162,17 +162,23 @@ public class GetOrganizationWarningsQuery(
if (subscription is
{
Status: SubscriptionStatus.Trialing or SubscriptionStatus.Active,
LatestInvoice: null or { Status: InvoiceStatus.Paid }
} && (subscription.CurrentPeriodEnd - now).TotalDays <= 14)
LatestInvoice: null or { Status: InvoiceStatus.Paid },
Items.Data.Count: > 0
})
{
return new ResellerRenewalWarning
var currentPeriodEnd = subscription.GetCurrentPeriodEnd();
if (currentPeriodEnd != null && (currentPeriodEnd.Value - now).TotalDays <= 14)
{
Type = "upcoming",
Upcoming = new ResellerRenewalWarning.UpcomingRenewal
return new ResellerRenewalWarning
{
RenewalDate = subscription.CurrentPeriodEnd
}
};
Type = "upcoming",
Upcoming = new ResellerRenewalWarning.UpcomingRenewal
{
RenewalDate = currentPeriodEnd.Value
}
};
}
}
if (subscription is

View File

@@ -45,12 +45,12 @@ public class OrganizationBillingService(
? await CreateCustomerAsync(organization, customerSetup, subscriptionSetup.PlanType)
: await GetCustomerWhileEnsuringCorrectTaxExemptionAsync(organization, subscriptionSetup);
var subscription = await CreateSubscriptionAsync(organization, customer, subscriptionSetup);
var subscription = await CreateSubscriptionAsync(organization, customer, subscriptionSetup, customerSetup?.Coupon);
if (subscription.Status is StripeConstants.SubscriptionStatus.Trialing or StripeConstants.SubscriptionStatus.Active)
{
organization.Enabled = true;
organization.ExpirationDate = subscription.CurrentPeriodEnd;
organization.ExpirationDate = subscription.GetCurrentPeriodEnd();
await organizationRepository.ReplaceAsync(organization);
}
}
@@ -187,7 +187,6 @@ public class OrganizationBillingService(
var customerCreateOptions = new CustomerCreateOptions
{
Coupon = customerSetup.Coupon,
Description = organization.DisplayBusinessName(),
Email = organization.BillingEmail,
Expand = ["tax", "tax_ids"],
@@ -273,7 +272,7 @@ public class OrganizationBillingService(
customerCreateOptions.TaxIdData =
[
new() { Type = taxIdType, Value = customerSetup.TaxInformation.TaxId }
new CustomerTaxIdDataOptions { Type = taxIdType, Value = customerSetup.TaxInformation.TaxId }
];
if (taxIdType == StripeConstants.TaxIdType.SpanishNIF)
@@ -381,7 +380,8 @@ public class OrganizationBillingService(
private async Task<Subscription> CreateSubscriptionAsync(
Organization organization,
Customer customer,
SubscriptionSetup subscriptionSetup)
SubscriptionSetup subscriptionSetup,
string? coupon)
{
var plan = await pricingClient.GetPlanOrThrow(subscriptionSetup.PlanType);
@@ -444,6 +444,7 @@ public class OrganizationBillingService(
{
CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically,
Customer = customer.Id,
Discounts = !string.IsNullOrEmpty(coupon) ? [new SubscriptionDiscountOptions { Coupon = coupon }] : null,
Items = subscriptionItemOptionsList,
Metadata = new Dictionary<string, string>
{
@@ -459,8 +460,9 @@ public class OrganizationBillingService(
var hasPaymentMethod = await hasPaymentMethodQuery.Run(organization);
// Only set trial_settings.end_behavior.missing_payment_method to "cancel" if there is no payment method
if (!hasPaymentMethod)
// Only set trial_settings.end_behavior.missing_payment_method to "cancel"
// if there is no payment method AND there's an actual trial period
if (!hasPaymentMethod && subscriptionCreateOptions.TrialPeriodDays > 0)
{
subscriptionCreateOptions.TrialSettings = new SubscriptionTrialSettingsOptions
{

View File

@@ -1,6 +1,7 @@
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
@@ -87,7 +88,7 @@ public class CreatePremiumCloudHostedSubscriptionCommand(
when subscription.Status == StripeConstants.SubscriptionStatus.Active:
{
user.Premium = true;
user.PremiumExpirationDate = subscription.CurrentPeriodEnd;
user.PremiumExpirationDate = subscription.GetCurrentPeriodEnd();
break;
}
}

View File

@@ -60,6 +60,6 @@ public class PreviewPremiumTaxCommand(
});
private static (decimal, decimal) GetAmounts(Invoice invoice) => (
Convert.ToDecimal(invoice.Tax) / 100,
Convert.ToDecimal(invoice.TotalTaxes.Sum(invoiceTotalTax => invoiceTotalTax.Amount)) / 100,
Convert.ToDecimal(invoice.Total) / 100);
}

View File

@@ -1,26 +0,0 @@
// FIXME: Update this file to be null safe and then delete the line below
#nullable disable
namespace Bit.Core.Billing.Providers.Migration.Models;
public enum ClientMigrationProgress
{
Started = 1,
MigrationRecordCreated = 2,
SubscriptionEnded = 3,
Completed = 4,
Reversing = 5,
ResetOrganization = 6,
RecreatedSubscription = 7,
RemovedMigrationRecord = 8,
Reversed = 9
}
public class ClientMigrationTracker
{
public Guid ProviderId { get; set; }
public Guid OrganizationId { get; set; }
public string OrganizationName { get; set; }
public ClientMigrationProgress Progress { get; set; } = ClientMigrationProgress.Started;
}

View File

@@ -1,48 +0,0 @@
// FIXME: Update this file to be null safe and then delete the line below
#nullable disable
using Bit.Core.Billing.Providers.Entities;
namespace Bit.Core.Billing.Providers.Migration.Models;
public class ProviderMigrationResult
{
public Guid ProviderId { get; set; }
public string ProviderName { get; set; }
public string Result { get; set; }
public List<ClientMigrationResult> Clients { get; set; }
}
public class ClientMigrationResult
{
public Guid OrganizationId { get; set; }
public string OrganizationName { get; set; }
public string Result { get; set; }
public ClientPreviousState PreviousState { get; set; }
}
public class ClientPreviousState
{
public ClientPreviousState() { }
public ClientPreviousState(ClientOrganizationMigrationRecord migrationRecord)
{
PlanType = migrationRecord.PlanType.ToString();
Seats = migrationRecord.Seats;
MaxStorageGb = migrationRecord.MaxStorageGb;
GatewayCustomerId = migrationRecord.GatewayCustomerId;
GatewaySubscriptionId = migrationRecord.GatewaySubscriptionId;
ExpirationDate = migrationRecord.ExpirationDate;
MaxAutoscaleSeats = migrationRecord.MaxAutoscaleSeats;
Status = migrationRecord.Status.ToString();
}
public string PlanType { get; set; }
public int Seats { get; set; }
public short? MaxStorageGb { get; set; }
public string GatewayCustomerId { get; set; } = null!;
public string GatewaySubscriptionId { get; set; } = null!;
public DateTime? ExpirationDate { get; set; }
public int? MaxAutoscaleSeats { get; set; }
public string Status { get; set; }
}

View File

@@ -1,25 +0,0 @@
// FIXME: Update this file to be null safe and then delete the line below
#nullable disable
namespace Bit.Core.Billing.Providers.Migration.Models;
public enum ProviderMigrationProgress
{
Started = 1,
NoClients = 2,
ClientsMigrated = 3,
TeamsPlanConfigured = 4,
EnterprisePlanConfigured = 5,
CustomerSetup = 6,
SubscriptionSetup = 7,
CreditApplied = 8,
Completed = 9,
}
public class ProviderMigrationTracker
{
public Guid ProviderId { get; set; }
public string ProviderName { get; set; }
public List<Guid> OrganizationIds { get; set; }
public ProviderMigrationProgress Progress { get; set; } = ProviderMigrationProgress.Started;
}

View File

@@ -1,15 +0,0 @@
using Bit.Core.Billing.Providers.Migration.Services;
using Bit.Core.Billing.Providers.Migration.Services.Implementations;
using Microsoft.Extensions.DependencyInjection;
namespace Bit.Core.Billing.Providers.Migration;
public static class ServiceCollectionExtensions
{
public static void AddProviderMigration(this IServiceCollection services)
{
services.AddTransient<IMigrationTrackerCache, MigrationTrackerDistributedCache>();
services.AddTransient<IOrganizationMigrator, OrganizationMigrator>();
services.AddTransient<IProviderMigrator, ProviderMigrator>();
}
}

View File

@@ -1,17 +0,0 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Providers.Migration.Models;
namespace Bit.Core.Billing.Providers.Migration.Services;
public interface IMigrationTrackerCache
{
Task StartTracker(Provider provider);
Task SetOrganizationIds(Guid providerId, IEnumerable<Guid> organizationIds);
Task<ProviderMigrationTracker> GetTracker(Guid providerId);
Task UpdateTrackingStatus(Guid providerId, ProviderMigrationProgress status);
Task StartTracker(Guid providerId, Organization organization);
Task<ClientMigrationTracker> GetTracker(Guid providerId, Guid organizationId);
Task UpdateTrackingStatus(Guid providerId, Guid organizationId, ClientMigrationProgress status);
}

View File

@@ -1,8 +0,0 @@
using Bit.Core.AdminConsole.Entities;
namespace Bit.Core.Billing.Providers.Migration.Services;
public interface IOrganizationMigrator
{
Task Migrate(Guid providerId, Organization organization);
}

View File

@@ -1,10 +0,0 @@
using Bit.Core.Billing.Providers.Migration.Models;
namespace Bit.Core.Billing.Providers.Migration.Services;
public interface IProviderMigrator
{
Task Migrate(Guid providerId);
Task<ProviderMigrationResult> GetResult(Guid providerId);
}

View File

@@ -1,110 +0,0 @@
// FIXME: Update this file to be null safe and then delete the line below
#nullable disable
using System.Text.Json;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Providers.Migration.Models;
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.DependencyInjection;
namespace Bit.Core.Billing.Providers.Migration.Services.Implementations;
public class MigrationTrackerDistributedCache(
[FromKeyedServices("persistent")]
IDistributedCache distributedCache) : IMigrationTrackerCache
{
public async Task StartTracker(Provider provider) =>
await SetAsync(new ProviderMigrationTracker
{
ProviderId = provider.Id,
ProviderName = provider.Name
});
public async Task SetOrganizationIds(Guid providerId, IEnumerable<Guid> organizationIds)
{
var tracker = await GetAsync(providerId);
tracker.OrganizationIds = organizationIds.ToList();
await SetAsync(tracker);
}
public Task<ProviderMigrationTracker> GetTracker(Guid providerId) => GetAsync(providerId);
public async Task UpdateTrackingStatus(Guid providerId, ProviderMigrationProgress status)
{
var tracker = await GetAsync(providerId);
tracker.Progress = status;
await SetAsync(tracker);
}
public async Task StartTracker(Guid providerId, Organization organization) =>
await SetAsync(new ClientMigrationTracker
{
ProviderId = providerId,
OrganizationId = organization.Id,
OrganizationName = organization.Name
});
public Task<ClientMigrationTracker> GetTracker(Guid providerId, Guid organizationId) =>
GetAsync(providerId, organizationId);
public async Task UpdateTrackingStatus(Guid providerId, Guid organizationId, ClientMigrationProgress status)
{
var tracker = await GetAsync(providerId, organizationId);
tracker.Progress = status;
await SetAsync(tracker);
}
private static string GetProviderCacheKey(Guid providerId) => $"provider_{providerId}_migration";
private static string GetClientCacheKey(Guid providerId, Guid clientId) =>
$"provider_{providerId}_client_{clientId}_migration";
private async Task<ProviderMigrationTracker> GetAsync(Guid providerId)
{
var cacheKey = GetProviderCacheKey(providerId);
var json = await distributedCache.GetStringAsync(cacheKey);
return string.IsNullOrEmpty(json) ? null : JsonSerializer.Deserialize<ProviderMigrationTracker>(json);
}
private async Task<ClientMigrationTracker> GetAsync(Guid providerId, Guid organizationId)
{
var cacheKey = GetClientCacheKey(providerId, organizationId);
var json = await distributedCache.GetStringAsync(cacheKey);
return string.IsNullOrEmpty(json) ? null : JsonSerializer.Deserialize<ClientMigrationTracker>(json);
}
private async Task SetAsync(ProviderMigrationTracker tracker)
{
var cacheKey = GetProviderCacheKey(tracker.ProviderId);
var json = JsonSerializer.Serialize(tracker);
await distributedCache.SetStringAsync(cacheKey, json, new DistributedCacheEntryOptions
{
SlidingExpiration = TimeSpan.FromMinutes(30)
});
}
private async Task SetAsync(ClientMigrationTracker tracker)
{
var cacheKey = GetClientCacheKey(tracker.ProviderId, tracker.OrganizationId);
var json = JsonSerializer.Serialize(tracker);
await distributedCache.SetStringAsync(cacheKey, json, new DistributedCacheEntryOptions
{
SlidingExpiration = TimeSpan.FromMinutes(30)
});
}
}

View File

@@ -1,331 +0,0 @@
// FIXME: Update this file to be null safe and then delete the line below
#nullable disable
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Providers.Entities;
using Bit.Core.Billing.Providers.Migration.Models;
using Bit.Core.Billing.Providers.Repositories;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Microsoft.Extensions.Logging;
using Stripe;
using Plan = Bit.Core.Models.StaticStore.Plan;
namespace Bit.Core.Billing.Providers.Migration.Services.Implementations;
public class OrganizationMigrator(
IClientOrganizationMigrationRecordRepository clientOrganizationMigrationRecordRepository,
ILogger<OrganizationMigrator> logger,
IMigrationTrackerCache migrationTrackerCache,
IOrganizationRepository organizationRepository,
IPricingClient pricingClient,
IStripeAdapter stripeAdapter) : IOrganizationMigrator
{
private const string _cancellationComment = "Cancelled as part of provider migration to Consolidated Billing";
public async Task Migrate(Guid providerId, Organization organization)
{
logger.LogInformation("CB: Starting migration for organization ({OrganizationID})", organization.Id);
await migrationTrackerCache.StartTracker(providerId, organization);
await CreateMigrationRecordAsync(providerId, organization);
await CancelSubscriptionAsync(providerId, organization);
await UpdateOrganizationAsync(providerId, organization);
}
#region Steps
private async Task CreateMigrationRecordAsync(Guid providerId, Organization organization)
{
logger.LogInformation("CB: Creating ClientOrganizationMigrationRecord for organization ({OrganizationID})", organization.Id);
var migrationRecord = await clientOrganizationMigrationRecordRepository.GetByOrganizationId(organization.Id);
if (migrationRecord != null)
{
logger.LogInformation(
"CB: ClientOrganizationMigrationRecord already exists for organization ({OrganizationID}), deleting record",
organization.Id);
await clientOrganizationMigrationRecordRepository.DeleteAsync(migrationRecord);
}
await clientOrganizationMigrationRecordRepository.CreateAsync(new ClientOrganizationMigrationRecord
{
OrganizationId = organization.Id,
ProviderId = providerId,
PlanType = organization.PlanType,
Seats = organization.Seats ?? 0,
MaxStorageGb = organization.MaxStorageGb,
GatewayCustomerId = organization.GatewayCustomerId!,
GatewaySubscriptionId = organization.GatewaySubscriptionId!,
ExpirationDate = organization.ExpirationDate,
MaxAutoscaleSeats = organization.MaxAutoscaleSeats,
Status = organization.Status
});
logger.LogInformation("CB: Created migration record for organization ({OrganizationID})", organization.Id);
await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id,
ClientMigrationProgress.MigrationRecordCreated);
}
private async Task CancelSubscriptionAsync(Guid providerId, Organization organization)
{
logger.LogInformation("CB: Cancelling subscription for organization ({OrganizationID})", organization.Id);
var subscription = await stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId);
if (subscription is
{
Status:
StripeConstants.SubscriptionStatus.Active or
StripeConstants.SubscriptionStatus.PastDue or
StripeConstants.SubscriptionStatus.Trialing
})
{
await stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId,
new SubscriptionUpdateOptions { CancelAtPeriodEnd = false });
subscription = await stripeAdapter.SubscriptionCancelAsync(organization.GatewaySubscriptionId,
new SubscriptionCancelOptions
{
CancellationDetails = new SubscriptionCancellationDetailsOptions
{
Comment = _cancellationComment
},
InvoiceNow = true,
Prorate = true,
Expand = ["latest_invoice", "test_clock"]
});
logger.LogInformation("CB: Cancelled subscription for organization ({OrganizationID})", organization.Id);
var now = subscription.TestClock?.FrozenTime ?? DateTime.UtcNow;
var trialing = subscription.TrialEnd.HasValue && subscription.TrialEnd.Value > now;
if (!trialing && subscription is { Status: StripeConstants.SubscriptionStatus.Canceled, CancellationDetails.Comment: _cancellationComment })
{
var latestInvoice = subscription.LatestInvoice;
if (latestInvoice.Status == "draft")
{
await stripeAdapter.InvoiceFinalizeInvoiceAsync(latestInvoice.Id,
new InvoiceFinalizeOptions { AutoAdvance = true });
logger.LogInformation("CB: Finalized prorated invoice for organization ({OrganizationID})", organization.Id);
}
}
}
else
{
logger.LogInformation(
"CB: Did not need to cancel subscription for organization ({OrganizationID}) as it was inactive",
organization.Id);
}
await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id,
ClientMigrationProgress.SubscriptionEnded);
}
private async Task UpdateOrganizationAsync(Guid providerId, Organization organization)
{
logger.LogInformation("CB: Bringing organization ({OrganizationID}) under provider management",
organization.Id);
var plan = await pricingClient.GetPlanOrThrow(organization.Plan.Contains("Teams") ? PlanType.TeamsMonthly : PlanType.EnterpriseMonthly);
ResetOrganizationPlan(organization, plan);
organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb;
organization.GatewaySubscriptionId = null;
organization.ExpirationDate = null;
organization.MaxAutoscaleSeats = null;
organization.Status = OrganizationStatusType.Managed;
await organizationRepository.ReplaceAsync(organization);
logger.LogInformation("CB: Brought organization ({OrganizationID}) under provider management",
organization.Id);
await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id,
ClientMigrationProgress.Completed);
}
#endregion
#region Reverse
private async Task RemoveMigrationRecordAsync(Guid providerId, Organization organization)
{
logger.LogInformation("CB: Removing migration record for organization ({OrganizationID})", organization.Id);
var migrationRecord = await clientOrganizationMigrationRecordRepository.GetByOrganizationId(organization.Id);
if (migrationRecord != null)
{
await clientOrganizationMigrationRecordRepository.DeleteAsync(migrationRecord);
logger.LogInformation(
"CB: Removed migration record for organization ({OrganizationID})",
organization.Id);
}
else
{
logger.LogInformation("CB: Did not remove migration record for organization ({OrganizationID}) as it does not exist", organization.Id);
}
await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id, ClientMigrationProgress.Reversed);
}
private async Task RecreateSubscriptionAsync(Guid providerId, Organization organization)
{
logger.LogInformation("CB: Recreating subscription for organization ({OrganizationID})", organization.Id);
if (!string.IsNullOrEmpty(organization.GatewaySubscriptionId))
{
if (string.IsNullOrEmpty(organization.GatewayCustomerId))
{
logger.LogError(
"CB: Cannot recreate subscription for organization ({OrganizationID}) as it does not have a Stripe customer",
organization.Id);
throw new Exception();
}
var customer = await stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId,
new CustomerGetOptions { Expand = ["default_source", "invoice_settings.default_payment_method"] });
var collectionMethod =
customer.DefaultSource != null ||
customer.InvoiceSettings?.DefaultPaymentMethod != null ||
customer.Metadata.ContainsKey(Utilities.BraintreeCustomerIdKey)
? StripeConstants.CollectionMethod.ChargeAutomatically
: StripeConstants.CollectionMethod.SendInvoice;
var plan = await pricingClient.GetPlanOrThrow(organization.PlanType);
var items = new List<SubscriptionItemOptions>
{
new ()
{
Price = plan.PasswordManager.StripeSeatPlanId,
Quantity = organization.Seats
}
};
if (organization.MaxStorageGb.HasValue && plan.PasswordManager.BaseStorageGb.HasValue && organization.MaxStorageGb.Value > plan.PasswordManager.BaseStorageGb.Value)
{
var additionalStorage = organization.MaxStorageGb.Value - plan.PasswordManager.BaseStorageGb.Value;
items.Add(new SubscriptionItemOptions
{
Price = plan.PasswordManager.StripeStoragePlanId,
Quantity = additionalStorage
});
}
var subscriptionCreateOptions = new SubscriptionCreateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
},
Customer = customer.Id,
CollectionMethod = collectionMethod,
DaysUntilDue = collectionMethod == StripeConstants.CollectionMethod.SendInvoice ? 30 : null,
Items = items,
Metadata = new Dictionary<string, string>
{
[organization.GatewayIdField()] = organization.Id.ToString()
},
OffSession = true,
ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations,
TrialPeriodDays = plan.TrialPeriodDays
};
var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
organization.GatewaySubscriptionId = subscription.Id;
await organizationRepository.ReplaceAsync(organization);
logger.LogInformation("CB: Recreated subscription for organization ({OrganizationID})", organization.Id);
}
else
{
logger.LogInformation(
"CB: Did not recreate subscription for organization ({OrganizationID}) as it already exists",
organization.Id);
}
await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id,
ClientMigrationProgress.RecreatedSubscription);
}
private async Task ReverseOrganizationUpdateAsync(Guid providerId, Organization organization)
{
var migrationRecord = await clientOrganizationMigrationRecordRepository.GetByOrganizationId(organization.Id);
if (migrationRecord == null)
{
logger.LogError(
"CB: Cannot reverse migration for organization ({OrganizationID}) as it does not have a migration record",
organization.Id);
throw new Exception();
}
var plan = await pricingClient.GetPlanOrThrow(migrationRecord.PlanType);
ResetOrganizationPlan(organization, plan);
organization.MaxStorageGb = migrationRecord.MaxStorageGb;
organization.ExpirationDate = migrationRecord.ExpirationDate;
organization.MaxAutoscaleSeats = migrationRecord.MaxAutoscaleSeats;
organization.Status = migrationRecord.Status;
await organizationRepository.ReplaceAsync(organization);
logger.LogInformation("CB: Reversed organization ({OrganizationID}) updates",
organization.Id);
await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id,
ClientMigrationProgress.ResetOrganization);
}
#endregion
#region Shared
private static void ResetOrganizationPlan(Organization organization, Plan plan)
{
organization.Plan = plan.Name;
organization.PlanType = plan.Type;
organization.MaxCollections = plan.PasswordManager.MaxCollections;
organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb;
organization.UsePolicies = plan.HasPolicies;
organization.UseSso = plan.HasSso;
organization.UseOrganizationDomains = plan.HasOrganizationDomains;
organization.UseGroups = plan.HasGroups;
organization.UseEvents = plan.HasEvents;
organization.UseDirectory = plan.HasDirectory;
organization.UseTotp = plan.HasTotp;
organization.Use2fa = plan.Has2fa;
organization.UseApi = plan.HasApi;
organization.UseResetPassword = plan.HasResetPassword;
organization.SelfHost = plan.HasSelfHost;
organization.UsersGetPremium = plan.UsersGetPremium;
organization.UseCustomPermissions = plan.HasCustomPermissions;
organization.UseScim = plan.HasScim;
organization.UseKeyConnector = plan.HasKeyConnector;
}
#endregion
}

View File

@@ -1,436 +0,0 @@
// FIXME: Update this file to be null safe and then delete the line below
#nullable disable
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Providers.Entities;
using Bit.Core.Billing.Providers.Migration.Models;
using Bit.Core.Billing.Providers.Models;
using Bit.Core.Billing.Providers.Repositories;
using Bit.Core.Billing.Providers.Services;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Microsoft.Extensions.Logging;
using Stripe;
namespace Bit.Core.Billing.Providers.Migration.Services.Implementations;
public class ProviderMigrator(
IClientOrganizationMigrationRecordRepository clientOrganizationMigrationRecordRepository,
IOrganizationMigrator organizationMigrator,
ILogger<ProviderMigrator> logger,
IMigrationTrackerCache migrationTrackerCache,
IOrganizationRepository organizationRepository,
IPaymentService paymentService,
IProviderBillingService providerBillingService,
IProviderOrganizationRepository providerOrganizationRepository,
IProviderRepository providerRepository,
IProviderPlanRepository providerPlanRepository,
IStripeAdapter stripeAdapter) : IProviderMigrator
{
public async Task Migrate(Guid providerId)
{
var provider = await GetProviderAsync(providerId);
if (provider == null)
{
return;
}
logger.LogInformation("CB: Starting migration for provider ({ProviderID})", providerId);
await migrationTrackerCache.StartTracker(provider);
var organizations = await GetClientsAsync(provider.Id);
if (organizations.Count == 0)
{
logger.LogInformation("CB: Skipping migration for provider ({ProviderID}) with no clients", providerId);
await migrationTrackerCache.UpdateTrackingStatus(providerId, ProviderMigrationProgress.NoClients);
return;
}
await MigrateClientsAsync(providerId, organizations);
await ConfigureTeamsPlanAsync(providerId);
await ConfigureEnterprisePlanAsync(providerId);
await SetupCustomerAsync(provider);
await SetupSubscriptionAsync(provider);
await ApplyCreditAsync(provider);
await UpdateProviderAsync(provider);
}
public async Task<ProviderMigrationResult> GetResult(Guid providerId)
{
var providerTracker = await migrationTrackerCache.GetTracker(providerId);
if (providerTracker == null)
{
return null;
}
if (providerTracker.Progress == ProviderMigrationProgress.NoClients)
{
return new ProviderMigrationResult
{
ProviderId = providerTracker.ProviderId,
ProviderName = providerTracker.ProviderName,
Result = providerTracker.Progress.ToString()
};
}
var clientTrackers = await Task.WhenAll(providerTracker.OrganizationIds.Select(organizationId =>
migrationTrackerCache.GetTracker(providerId, organizationId)));
var migrationRecordLookup = new Dictionary<Guid, ClientOrganizationMigrationRecord>();
foreach (var clientTracker in clientTrackers)
{
var migrationRecord =
await clientOrganizationMigrationRecordRepository.GetByOrganizationId(clientTracker.OrganizationId);
migrationRecordLookup.Add(clientTracker.OrganizationId, migrationRecord);
}
return new ProviderMigrationResult
{
ProviderId = providerTracker.ProviderId,
ProviderName = providerTracker.ProviderName,
Result = providerTracker.Progress.ToString(),
Clients = clientTrackers.Select(tracker =>
{
var foundMigrationRecord = migrationRecordLookup.TryGetValue(tracker.OrganizationId, out var migrationRecord);
return new ClientMigrationResult
{
OrganizationId = tracker.OrganizationId,
OrganizationName = tracker.OrganizationName,
Result = tracker.Progress.ToString(),
PreviousState = foundMigrationRecord ? new ClientPreviousState(migrationRecord) : null
};
}).ToList(),
};
}
#region Steps
private async Task MigrateClientsAsync(Guid providerId, List<Organization> organizations)
{
logger.LogInformation("CB: Migrating clients for provider ({ProviderID})", providerId);
var organizationIds = organizations.Select(organization => organization.Id);
await migrationTrackerCache.SetOrganizationIds(providerId, organizationIds);
foreach (var organization in organizations)
{
var tracker = await migrationTrackerCache.GetTracker(providerId, organization.Id);
if (tracker is not { Progress: ClientMigrationProgress.Completed })
{
await organizationMigrator.Migrate(providerId, organization);
}
}
logger.LogInformation("CB: Migrated clients for provider ({ProviderID})", providerId);
await migrationTrackerCache.UpdateTrackingStatus(providerId,
ProviderMigrationProgress.ClientsMigrated);
}
private async Task ConfigureTeamsPlanAsync(Guid providerId)
{
logger.LogInformation("CB: Configuring Teams plan for provider ({ProviderID})", providerId);
var organizations = await GetClientsAsync(providerId);
var teamsSeats = organizations
.Where(IsTeams)
.Sum(client => client.Seats) ?? 0;
var teamsProviderPlan = (await providerPlanRepository.GetByProviderId(providerId))
.FirstOrDefault(providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly);
if (teamsProviderPlan == null)
{
await providerPlanRepository.CreateAsync(new ProviderPlan
{
ProviderId = providerId,
PlanType = PlanType.TeamsMonthly,
SeatMinimum = teamsSeats,
PurchasedSeats = 0,
AllocatedSeats = teamsSeats
});
logger.LogInformation("CB: Created Teams plan for provider ({ProviderID}) with a seat minimum of {Seats}",
providerId, teamsSeats);
}
else
{
logger.LogInformation("CB: Teams plan already exists for provider ({ProviderID}), updating seat minimum", providerId);
teamsProviderPlan.SeatMinimum = teamsSeats;
teamsProviderPlan.AllocatedSeats = teamsSeats;
await providerPlanRepository.ReplaceAsync(teamsProviderPlan);
logger.LogInformation("CB: Updated Teams plan for provider ({ProviderID}) to seat minimum of {Seats}",
providerId, teamsProviderPlan.SeatMinimum);
}
await migrationTrackerCache.UpdateTrackingStatus(providerId, ProviderMigrationProgress.TeamsPlanConfigured);
}
private async Task ConfigureEnterprisePlanAsync(Guid providerId)
{
logger.LogInformation("CB: Configuring Enterprise plan for provider ({ProviderID})", providerId);
var organizations = await GetClientsAsync(providerId);
var enterpriseSeats = organizations
.Where(IsEnterprise)
.Sum(client => client.Seats) ?? 0;
var enterpriseProviderPlan = (await providerPlanRepository.GetByProviderId(providerId))
.FirstOrDefault(providerPlan => providerPlan.PlanType == PlanType.EnterpriseMonthly);
if (enterpriseProviderPlan == null)
{
await providerPlanRepository.CreateAsync(new ProviderPlan
{
ProviderId = providerId,
PlanType = PlanType.EnterpriseMonthly,
SeatMinimum = enterpriseSeats,
PurchasedSeats = 0,
AllocatedSeats = enterpriseSeats
});
logger.LogInformation("CB: Created Enterprise plan for provider ({ProviderID}) with a seat minimum of {Seats}",
providerId, enterpriseSeats);
}
else
{
logger.LogInformation("CB: Enterprise plan already exists for provider ({ProviderID}), updating seat minimum", providerId);
enterpriseProviderPlan.SeatMinimum = enterpriseSeats;
enterpriseProviderPlan.AllocatedSeats = enterpriseSeats;
await providerPlanRepository.ReplaceAsync(enterpriseProviderPlan);
logger.LogInformation("CB: Updated Enterprise plan for provider ({ProviderID}) to seat minimum of {Seats}",
providerId, enterpriseProviderPlan.SeatMinimum);
}
await migrationTrackerCache.UpdateTrackingStatus(providerId, ProviderMigrationProgress.EnterprisePlanConfigured);
}
private async Task SetupCustomerAsync(Provider provider)
{
if (string.IsNullOrEmpty(provider.GatewayCustomerId))
{
var organizations = await GetClientsAsync(provider.Id);
var sampleOrganization = organizations.FirstOrDefault(organization => !string.IsNullOrEmpty(organization.GatewayCustomerId));
if (sampleOrganization == null)
{
logger.LogInformation(
"CB: Could not find sample organization for provider ({ProviderID}) that has a Stripe customer",
provider.Id);
return;
}
var taxInfo = await paymentService.GetTaxInfoAsync(sampleOrganization);
// Create dummy payment source for legacy migration - this migrator is deprecated and will be removed
var dummyPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "migration_dummy_token");
var customer = await providerBillingService.SetupCustomer(provider, null, null);
await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions
{
Coupon = StripeConstants.CouponIDs.LegacyMSPDiscount
});
provider.GatewayCustomerId = customer.Id;
await providerRepository.ReplaceAsync(provider);
logger.LogInformation("CB: Setup Stripe customer for provider ({ProviderID})", provider.Id);
}
else
{
logger.LogInformation("CB: Stripe customer already exists for provider ({ProviderID})", provider.Id);
}
await migrationTrackerCache.UpdateTrackingStatus(provider.Id, ProviderMigrationProgress.CustomerSetup);
}
private async Task SetupSubscriptionAsync(Provider provider)
{
if (string.IsNullOrEmpty(provider.GatewaySubscriptionId))
{
if (!string.IsNullOrEmpty(provider.GatewayCustomerId))
{
var subscription = await providerBillingService.SetupSubscription(provider);
provider.GatewaySubscriptionId = subscription.Id;
await providerRepository.ReplaceAsync(provider);
logger.LogInformation("CB: Setup Stripe subscription for provider ({ProviderID})", provider.Id);
}
else
{
logger.LogInformation(
"CB: Could not set up Stripe subscription for provider ({ProviderID}) with no Stripe customer",
provider.Id);
return;
}
}
else
{
logger.LogInformation("CB: Stripe subscription already exists for provider ({ProviderID})", provider.Id);
var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id);
var enterpriseSeatMinimum = providerPlans
.FirstOrDefault(providerPlan => providerPlan.PlanType == PlanType.EnterpriseMonthly)?
.SeatMinimum ?? 0;
var teamsSeatMinimum = providerPlans
.FirstOrDefault(providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly)?
.SeatMinimum ?? 0;
var updateSeatMinimumsCommand = new UpdateProviderSeatMinimumsCommand(
provider,
[
(Plan: PlanType.EnterpriseMonthly, SeatsMinimum: enterpriseSeatMinimum),
(Plan: PlanType.TeamsMonthly, SeatsMinimum: teamsSeatMinimum)
]);
await providerBillingService.UpdateSeatMinimums(updateSeatMinimumsCommand);
logger.LogInformation(
"CB: Updated Stripe subscription for provider ({ProviderID}) with current seat minimums", provider.Id);
}
await migrationTrackerCache.UpdateTrackingStatus(provider.Id, ProviderMigrationProgress.SubscriptionSetup);
}
private async Task ApplyCreditAsync(Provider provider)
{
var organizations = await GetClientsAsync(provider.Id);
var organizationCustomers =
await Task.WhenAll(organizations.Select(organization => stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId)));
var organizationCancellationCredit = organizationCustomers.Sum(customer => customer.Balance);
if (organizationCancellationCredit != 0)
{
await stripeAdapter.CustomerBalanceTransactionCreate(provider.GatewayCustomerId,
new CustomerBalanceTransactionCreateOptions
{
Amount = organizationCancellationCredit,
Currency = "USD",
Description = "Unused, prorated time for client organization subscriptions."
});
}
var migrationRecords = await Task.WhenAll(organizations.Select(organization =>
clientOrganizationMigrationRecordRepository.GetByOrganizationId(organization.Id)));
var legacyOrganizationMigrationRecords = migrationRecords.Where(migrationRecord =>
migrationRecord.PlanType is
PlanType.EnterpriseAnnually2020 or
PlanType.TeamsAnnually2020);
var legacyOrganizationCredit = legacyOrganizationMigrationRecords.Sum(migrationRecord => migrationRecord.Seats) * 12 * -100;
if (legacyOrganizationCredit < 0)
{
await stripeAdapter.CustomerBalanceTransactionCreate(provider.GatewayCustomerId,
new CustomerBalanceTransactionCreateOptions
{
Amount = legacyOrganizationCredit,
Currency = "USD",
Description = "1 year rebate for legacy client organizations."
});
}
logger.LogInformation("CB: Applied {Credit} credit to provider ({ProviderID})", organizationCancellationCredit + legacyOrganizationCredit, provider.Id);
await migrationTrackerCache.UpdateTrackingStatus(provider.Id, ProviderMigrationProgress.CreditApplied);
}
private async Task UpdateProviderAsync(Provider provider)
{
provider.Status = ProviderStatusType.Billable;
await providerRepository.ReplaceAsync(provider);
logger.LogInformation("CB: Completed migration for provider ({ProviderID})", provider.Id);
await migrationTrackerCache.UpdateTrackingStatus(provider.Id, ProviderMigrationProgress.Completed);
}
#endregion
#region Utilities
private async Task<List<Organization>> GetClientsAsync(Guid providerId)
{
var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId);
return (await Task.WhenAll(providerOrganizations.Select(providerOrganization =>
organizationRepository.GetByIdAsync(providerOrganization.OrganizationId))))
.ToList();
}
private async Task<Provider> GetProviderAsync(Guid providerId)
{
var provider = await providerRepository.GetByIdAsync(providerId);
if (provider == null)
{
logger.LogWarning("CB: Cannot migrate provider ({ProviderID}) as it does not exist", providerId);
return null;
}
if (provider.Type != ProviderType.Msp)
{
logger.LogWarning("CB: Cannot migrate provider ({ProviderID}) as it is not an MSP", providerId);
return null;
}
if (provider.Status == ProviderStatusType.Created)
{
return provider;
}
logger.LogWarning("CB: Cannot migrate provider ({ProviderID}) as it is not in the 'Created' state", providerId);
return null;
}
private static bool IsEnterprise(Organization organization) => organization.Plan.Contains("Enterprise");
private static bool IsTeams(Organization organization) => organization.Plan.Contains("Teams");
#endregion
}

View File

@@ -3,6 +3,7 @@
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Tax.Models;
@@ -108,7 +109,7 @@ public class PremiumUserBillingService(
when subscription.Status == StripeConstants.SubscriptionStatus.Active:
{
user.Premium = true;
user.PremiumExpirationDate = subscription.CurrentPeriodEnd;
user.PremiumExpirationDate = subscription.GetCurrentPeriodEnd();
break;
}
}

View File

@@ -3,6 +3,7 @@ using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
using Bit.Core.Repositories;
@@ -65,7 +66,7 @@ public class RestartSubscriptionCommand(
{
organization.GatewaySubscriptionId = subscription.Id;
organization.Enabled = true;
organization.ExpirationDate = subscription.CurrentPeriodEnd;
organization.ExpirationDate = subscription.GetCurrentPeriodEnd();
organization.RevisionDate = DateTime.UtcNow;
await organizationRepository.ReplaceAsync(organization);
break;
@@ -82,7 +83,7 @@ public class RestartSubscriptionCommand(
{
user.GatewaySubscriptionId = subscription.Id;
user.Premium = true;
user.PremiumExpirationDate = subscription.CurrentPeriodEnd;
user.PremiumExpirationDate = subscription.GetCurrentPeriodEnd();
user.RevisionDate = DateTime.UtcNow;
await userRepository.ReplaceAsync(user);
break;

View File

@@ -140,6 +140,7 @@ public static class FeatureFlagKeys
public const string EventBasedOrganizationIntegrations = "event-based-organization-integrations";
public const string SeparateCustomRolePermissions = "pm-19917-separate-custom-role-permissions";
public const string CreateDefaultLocation = "pm-19467-create-default-location";
public const string AutomaticConfirmUsers = "pm-19934-auto-confirm-organization-users";
public const string PM23845_VNextApplicationCache = "pm-24957-refactor-memory-application-cache";
/* Auth Team */
@@ -160,6 +161,7 @@ public static class FeatureFlagKeys
public const string InlineMenuFieldQualification = "inline-menu-field-qualification";
public const string InlineMenuPositioningImprovements = "inline-menu-positioning-improvements";
public const string SSHAgent = "ssh-agent";
public const string SSHAgentV2 = "ssh-agent-v2";
public const string SSHVersionCheckQAOverride = "ssh-version-check-qa-override";
public const string GenerateIdentityFillScriptRefactor = "generate-identity-fill-script-refactor";
public const string DelayFido2PageScriptInitWithinMv2 = "delay-fido2-page-script-init-within-mv2";
@@ -192,6 +194,7 @@ public static class FeatureFlagKeys
public const string UserkeyRotationV2 = "userkey-rotation-v2";
public const string SSHKeyItemVaultItem = "ssh-key-vault-item";
public const string UserSdkForDecryption = "use-sdk-for-decryption";
public const string EnrollAeadOnKeyRotation = "enroll-aead-on-key-rotation";
public const string PM17987_BlockType0 = "pm-17987-block-type-0";
public const string ForceUpdateKDFSettings = "pm-18021-force-update-kdf-settings";
public const string UnlockWithMasterPasswordUnlockData = "pm-23246-unlock-with-master-password-unlock-data";
@@ -199,24 +202,18 @@ public static class FeatureFlagKeys
public const string LinuxBiometricsV2 = "pm-26340-linux-biometrics-v2";
public const string NoLogoutOnKdfChange = "pm-23995-no-logout-on-kdf-change";
public const string DisableType0Decryption = "pm-25174-disable-type-0-decryption";
public const string ConsolidatedSessionTimeoutComponent = "pm-26056-consolidated-session-timeout-component";
/* Mobile Team */
public const string NativeCarouselFlow = "native-carousel-flow";
public const string NativeCreateAccountFlow = "native-create-account-flow";
public const string AndroidImportLoginsFlow = "import-logins-flow";
public const string AppReviewPrompt = "app-review-prompt";
public const string AndroidMutualTls = "mutual-tls";
public const string SingleTapPasskeyCreation = "single-tap-passkey-creation";
public const string SingleTapPasskeyAuthentication = "single-tap-passkey-authentication";
public const string EnablePMAuthenticatorSync = "enable-pm-bwa-sync";
public const string PM3503_MobileAnonAddySelfHostAlias = "anon-addy-self-host-alias";
public const string PM3553_MobileSimpleLoginSelfHostAlias = "simple-login-self-host-alias";
public const string EnablePMFlightRecorder = "enable-pm-flight-recorder";
public const string MobileErrorReporting = "mobile-error-reporting";
public const string AndroidChromeAutofill = "android-chrome-autofill";
public const string UserManagedPrivilegedApps = "pm-18970-user-managed-privileged-apps";
public const string EnablePMPreloginSettings = "enable-pm-prelogin-settings";
public const string AppIntents = "app-intents";
public const string SendAccess = "pm-19394-send-access-control";
public const string CxpImportMobile = "cxp-import-mobile";
public const string CxpExportMobile = "cxp-export-mobile";
@@ -229,6 +226,7 @@ public static class FeatureFlagKeys
/* Tools Team */
public const string DesktopSendUIRefresh = "desktop-send-ui-refresh";
public const string UseSdkPasswordGenerators = "pm-19976-use-sdk-password-generators";
public const string ChromiumImporterWithABE = "pm-25855-chromium-importer-abe";
/* Vault Team */
public const string PM8851_BrowserOnboardingNudge = "pm-8851-browser-onboarding-nudge";
@@ -247,6 +245,7 @@ public static class FeatureFlagKeys
/* DIRT Team */
public const string PM22887_RiskInsightsActivityTab = "pm-22887-risk-insights-activity-tab";
public const string EventManagementForDataDogAndCrowdStrike = "event-management-for-datadog-and-crowdstrike";
public static List<string> GetAllKeys()
{

View File

@@ -57,7 +57,7 @@
<PackageReference Include="Serilog.Sinks.SyslogMessages" Version="4.0.0" />
<PackageReference Include="AspNetCoreRateLimit" Version="5.0.0" />
<PackageReference Include="Braintree" Version="5.28.0" />
<PackageReference Include="Stripe.net" Version="45.14.0" />
<PackageReference Include="Stripe.net" Version="48.5.0" />
<PackageReference Include="Otp.NET" Version="1.4.0" />
<PackageReference Include="YubicoDotNetClient" Version="1.2.0" />
<PackageReference Include="Microsoft.Extensions.Caching.StackExchangeRedis" Version="8.0.10" />

View File

@@ -3,6 +3,7 @@ using System.Text.Json;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models;
using Bit.Core.Enums;
using Bit.Core.KeyManagement.Models.Data;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Identity;
@@ -21,6 +22,9 @@ public class User : ITableObject<Guid>, IStorableSubscriber, IRevisable, ITwoFac
[MaxLength(256)]
public string Email { get; set; } = null!;
public bool EmailVerified { get; set; }
/// <summary>
/// The server-side master-password hash
/// </summary>
[MaxLength(300)]
public string? MasterPassword { get; set; }
[MaxLength(50)]
@@ -41,9 +45,30 @@ public class User : ITableObject<Guid>, IStorableSubscriber, IRevisable, ITwoFac
/// organization membership.
/// </summary>
public DateTime AccountRevisionDate { get; set; } = DateTime.UtcNow;
/// <summary>
/// The master-password-sealed user key.
/// </summary>
public string? Key { get; set; }
/// <summary>
/// The raw public key, without a signature from the user's signature key.
/// </summary>
public string? PublicKey { get; set; }
/// <summary>
/// User key wrapped private key.
/// </summary>
public string? PrivateKey { get; set; }
/// <summary>
/// The public key, signed by the user's signature key.
/// </summary>
public string? SignedPublicKey { get; set; }
/// <summary>
/// The security version is included in the security state, but needs COSE parsing
/// </summary>
public int? SecurityVersion { get; set; }
/// <summary>
/// The security state is a signed object attesting to the version of the user's account.
/// </summary>
public string? SecurityState { get; set; }
public bool Premium { get; set; }
public DateTime? PremiumExpirationDate { get; set; }
public DateTime? RenewalReminderDate { get; set; }
@@ -180,6 +205,12 @@ public class User : ITableObject<Guid>, IStorableSubscriber, IRevisable, ITwoFac
return Premium;
}
public int GetSecurityVersion()
{
// If no security version is set, it is version 1. The minimum initialized version is 2.
return SecurityVersion ?? 1;
}
/// <summary>
/// Serializes the C# object to the User.TwoFactorProviders property in JSON format.
/// </summary>
@@ -243,4 +274,14 @@ public class User : ITableObject<Guid>, IStorableSubscriber, IRevisable, ITwoFac
{
return MasterPassword != null;
}
public PublicKeyEncryptionKeyPairData GetPublicKeyEncryptionKeyPair()
{
if (string.IsNullOrWhiteSpace(PrivateKey) || string.IsNullOrWhiteSpace(PublicKey))
{
throw new InvalidOperationException("User public key encryption key pair is not fully initialized.");
}
return new PublicKeyEncryptionKeyPairData(PrivateKey, PublicKey, SignedPublicKey);
}
}

View File

@@ -0,0 +1,6 @@
namespace Bit.Core.Enums;
public enum PushNotificationLogOutReason : byte
{
KdfChange = 0
}

View File

@@ -107,7 +107,7 @@ public abstract class BaseJobsHostedService : IHostedService, IDisposable
throw new Exception("Job failed to start after 10 retries.");
}
_logger.LogWarning($"Exception while trying to schedule job: {job.FullName}, {e}");
_logger.LogWarning(e, "Exception while trying to schedule job: {JobName}", job.FullName);
var random = new Random();
await Task.Delay(random.Next(50, 250));
}
@@ -125,7 +125,7 @@ public abstract class BaseJobsHostedService : IHostedService, IDisposable
continue;
}
_logger.LogInformation($"Deleting old job with key {key}");
_logger.LogInformation("Deleting old job with key {Key}", key);
await _scheduler.DeleteJob(key);
}
@@ -138,7 +138,7 @@ public abstract class BaseJobsHostedService : IHostedService, IDisposable
continue;
}
_logger.LogInformation($"Unscheduling old trigger with key {key}");
_logger.LogInformation("Unscheduling old trigger with key {Key}", key);
await _scheduler.UnscheduleJob(key);
}
}

View File

@@ -0,0 +1,30 @@
using Bit.Core.Entities;
using Bit.Core.KeyManagement.Enums;
using Bit.Core.KeyManagement.Models.Data;
using Bit.Core.Utilities;
namespace Bit.Core.KeyManagement.Entities;
public class UserSignatureKeyPair : ITableObject<Guid>, IRevisable
{
public Guid Id { get; set; }
public Guid UserId { get; set; }
public SignatureAlgorithm SignatureAlgorithm { get; set; }
public required string VerifyingKey { get; set; }
public required string SigningKey { get; set; }
public DateTime CreationDate { get; set; } = DateTime.UtcNow;
public DateTime RevisionDate { get; set; } = DateTime.UtcNow;
public void SetNewId()
{
Id = CoreHelpers.GenerateComb();
}
public SignatureKeyPairData ToSignatureKeyPairData()
{
return new SignatureKeyPairData(SignatureAlgorithm, SigningKey, VerifyingKey);
}
}

View File

@@ -0,0 +1,9 @@
namespace Bit.Core.KeyManagement.Enums;
// <summary>
// Represents the algorithm / digital signature scheme used for a signature key pair.
// </summary>
public enum SignatureAlgorithm : byte
{
Ed25519 = 0
}

View File

@@ -1,4 +1,5 @@
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.KeyManagement.Models.Data;
using Bit.Core.Platform.Push;
@@ -18,17 +19,22 @@ public class ChangeKdfCommand : IChangeKdfCommand
private readonly IUserRepository _userRepository;
private readonly IdentityErrorDescriber _identityErrorDescriber;
private readonly ILogger<ChangeKdfCommand> _logger;
private readonly IFeatureService _featureService;
public ChangeKdfCommand(IUserService userService, IPushNotificationService pushService, IUserRepository userRepository, IdentityErrorDescriber describer, ILogger<ChangeKdfCommand> logger)
public ChangeKdfCommand(IUserService userService, IPushNotificationService pushService,
IUserRepository userRepository, IdentityErrorDescriber describer, ILogger<ChangeKdfCommand> logger,
IFeatureService featureService)
{
_userService = userService;
_pushService = pushService;
_userRepository = userRepository;
_identityErrorDescriber = describer;
_logger = logger;
_featureService = featureService;
}
public async Task<IdentityResult> ChangeKdfAsync(User user, string masterPasswordAuthenticationHash, MasterPasswordAuthenticationData authenticationData, MasterPasswordUnlockData unlockData)
public async Task<IdentityResult> ChangeKdfAsync(User user, string masterPasswordAuthenticationHash,
MasterPasswordAuthenticationData authenticationData, MasterPasswordUnlockData unlockData)
{
ArgumentNullException.ThrowIfNull(user);
if (!await _userService.CheckPasswordAsync(user, masterPasswordAuthenticationHash))
@@ -37,8 +43,8 @@ public class ChangeKdfCommand : IChangeKdfCommand
}
// Validate to prevent user account from becoming un-decryptable from invalid parameters
//
// Prevent a de-synced salt value from creating an un-decryptable unlock method
//
// Prevent a de-synced salt value from creating an un-decryptable unlock method
authenticationData.ValidateSaltUnchangedForUser(user);
unlockData.ValidateSaltUnchangedForUser(user);
@@ -47,12 +53,15 @@ public class ChangeKdfCommand : IChangeKdfCommand
{
throw new BadRequestException("KDF settings must be equal for authentication and unlock.");
}
var validationErrors = KdfSettingsValidator.Validate(unlockData.Kdf);
if (validationErrors.Any())
{
throw new BadRequestException("KDF settings are invalid.");
}
var logoutOnKdfChange = !_featureService.IsEnabled(FeatureFlagKeys.NoLogoutOnKdfChange);
// Update the user with the new KDF settings
// This updates the authentication data and unlock data for the user separately. Currently these still
// use shared values for KDF settings and salt.
@@ -68,7 +77,8 @@ public class ChangeKdfCommand : IChangeKdfCommand
// This entire operation MUST be atomic to prevent a user from being locked out of their account.
// Salt is ensured to be the same as unlock data, and the value stored in the account and not updated.
// KDF is ensured to be the same as unlock data above and updated below.
var result = await _userService.UpdatePasswordHash(user, authenticationData.MasterPasswordAuthenticationHash);
var result = await _userService.UpdatePasswordHash(user, authenticationData.MasterPasswordAuthenticationHash,
refreshStamp: logoutOnKdfChange);
if (!result.Succeeded)
{
_logger.LogWarning("Change KDF failed for user {userId}.", user.Id);
@@ -88,7 +98,17 @@ public class ChangeKdfCommand : IChangeKdfCommand
user.LastKdfChangeDate = now;
await _userRepository.ReplaceAsync(user);
await _pushService.PushLogOutAsync(user.Id);
if (logoutOnKdfChange)
{
await _pushService.PushLogOutAsync(user.Id);
}
else
{
// Clients that support the new feature flag will ignore the logout when it matches the reason and the feature flag is enabled.
await _pushService.PushLogOutAsync(user.Id, reason: PushNotificationLogOutReason.KdfChange);
await _pushService.PushSyncSettingsAsync(user.Id);
}
return IdentityResult.Success;
}
}

View File

@@ -2,6 +2,8 @@
using Bit.Core.KeyManagement.Commands.Interfaces;
using Bit.Core.KeyManagement.Kdf;
using Bit.Core.KeyManagement.Kdf.Implementations;
using Bit.Core.KeyManagement.Queries;
using Bit.Core.KeyManagement.Queries.Interfaces;
using Microsoft.Extensions.DependencyInjection;
namespace Bit.Core.KeyManagement;
@@ -11,6 +13,7 @@ public static class KeyManagementServiceCollectionExtensions
public static void AddKeyManagementServices(this IServiceCollection services)
{
services.AddKeyManagementCommands();
services.AddKeyManagementQueries();
services.AddSendPasswordServices();
}
@@ -19,4 +22,9 @@ public static class KeyManagementServiceCollectionExtensions
services.AddScoped<IRegenerateUserAsymmetricKeysCommand, RegenerateUserAsymmetricKeysCommand>();
services.AddScoped<IChangeKdfCommand, ChangeKdfCommand>();
}
private static void AddKeyManagementQueries(this IServiceCollection services)
{
services.AddScoped<IUserAccountKeysQuery, UserAccountKeysQuery>();
}
}

View File

@@ -0,0 +1,32 @@
using System.ComponentModel.DataAnnotations;
using System.Text.Json.Serialization;
using Bit.Core.KeyManagement.Models.Data;
namespace Bit.Core.KeyManagement.Models.Api.Request;
public class SecurityStateModel
{
[StringLength(1000)]
[JsonPropertyName("securityState")]
public required string SecurityState { get; set; }
[JsonPropertyName("securityVersion")]
public required int SecurityVersion { get; set; }
public SecurityStateData ToSecurityState()
{
return new SecurityStateData
{
SecurityState = SecurityState,
SecurityVersion = SecurityVersion
};
}
public static SecurityStateModel FromSecurityStateData(SecurityStateData data)
{
return new SecurityStateModel
{
SecurityState = data.SecurityState,
SecurityVersion = data.SecurityVersion
};
}
}

View File

@@ -2,7 +2,7 @@
using Bit.Core.Enums;
using Bit.Core.Utilities;
namespace Bit.Core.KeyManagement.Models.Response;
namespace Bit.Core.KeyManagement.Models.Api.Response;
public class MasterPasswordUnlockResponseModel
{

View File

@@ -0,0 +1,48 @@
using System.Text.Json.Serialization;
using Bit.Core.KeyManagement.Models.Api.Request;
using Bit.Core.KeyManagement.Models.Data;
using Bit.Core.Models.Api;
namespace Bit.Core.KeyManagement.Models.Api.Response;
/// <summary>
/// This response model is used to return the asymmetric encryption keys,
/// and signature keys of an entity. This includes the private keys of the key pairs,
/// (private key, signing key), and the public keys of the key pairs (unsigned public key,
/// signed public key, verification key).
/// </summary>
public class PrivateKeysResponseModel : ResponseModel
{
// Not all accounts have signature keys, but all accounts have public encryption keys.
[JsonPropertyName("signatureKeyPair")]
public SignatureKeyPairResponseModel? SignatureKeyPair { get; set; }
[JsonPropertyName("publicKeyEncryptionKeyPair")]
public required PublicKeyEncryptionKeyPairResponseModel PublicKeyEncryptionKeyPair { get; set; }
[JsonPropertyName("securityState")]
public SecurityStateModel? SecurityState { get; set; }
[System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute]
public PrivateKeysResponseModel(UserAccountKeysData accountKeys) : base("privateKeys")
{
ArgumentNullException.ThrowIfNull(accountKeys);
PublicKeyEncryptionKeyPair = new PublicKeyEncryptionKeyPairResponseModel(accountKeys.PublicKeyEncryptionKeyPairData);
if (accountKeys.SignatureKeyPairData != null && accountKeys.SecurityStateData != null)
{
SignatureKeyPair = new SignatureKeyPairResponseModel(accountKeys.SignatureKeyPairData);
SecurityState = SecurityStateModel.FromSecurityStateData(accountKeys.SecurityStateData!);
}
}
[JsonConstructor]
public PrivateKeysResponseModel(SignatureKeyPairResponseModel? signatureKeyPair, PublicKeyEncryptionKeyPairResponseModel publicKeyEncryptionKeyPair, SecurityStateModel? securityState)
: base("privateKeys")
{
SignatureKeyPair = signatureKeyPair;
PublicKeyEncryptionKeyPair = publicKeyEncryptionKeyPair ?? throw new ArgumentNullException(nameof(publicKeyEncryptionKeyPair));
SecurityState = securityState;
}
}

View File

@@ -0,0 +1,34 @@
using System.Text.Json.Serialization;
using Bit.Core.KeyManagement.Models.Data;
using Bit.Core.Models.Api;
namespace Bit.Core.KeyManagement.Models.Api.Response;
public class PublicKeyEncryptionKeyPairResponseModel : ResponseModel
{
[JsonPropertyName("wrappedPrivateKey")]
public required string WrappedPrivateKey { get; set; }
[JsonPropertyName("publicKey")]
public required string PublicKey { get; set; }
[JsonPropertyName("signedPublicKey")]
public string? SignedPublicKey { get; set; }
[System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute]
public PublicKeyEncryptionKeyPairResponseModel(PublicKeyEncryptionKeyPairData keyPair)
: base("publicKeyEncryptionKeyPair")
{
WrappedPrivateKey = keyPair.WrappedPrivateKey;
PublicKey = keyPair.PublicKey;
SignedPublicKey = keyPair.SignedPublicKey;
}
[JsonConstructor]
public PublicKeyEncryptionKeyPairResponseModel(string wrappedPrivateKey, string publicKey, string? signedPublicKey)
: base("publicKeyEncryptionKeyPair")
{
WrappedPrivateKey = wrappedPrivateKey ?? throw new ArgumentNullException(nameof(wrappedPrivateKey));
PublicKey = publicKey ?? throw new ArgumentNullException(nameof(publicKey));
SignedPublicKey = signedPublicKey;
}
}

View File

@@ -0,0 +1,30 @@
using Bit.Core.KeyManagement.Models.Data;
using Bit.Core.Models.Api;
namespace Bit.Core.KeyManagement.Models.Api.Response;
/// <summary>
/// This response model is used to return the public keys of a user, to any other registered user or entity on the server.
/// It can contain public keys (signature/encryption), and proofs between the two. It does not contain (encrypted) private keys.
/// </summary>
public class PublicKeysResponseModel : ResponseModel
{
[System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute]
public PublicKeysResponseModel(UserAccountKeysData accountKeys)
: base("publicKeys")
{
ArgumentNullException.ThrowIfNull(accountKeys);
PublicKey = accountKeys.PublicKeyEncryptionKeyPairData.PublicKey;
if (accountKeys.SignatureKeyPairData != null)
{
SignedPublicKey = accountKeys.PublicKeyEncryptionKeyPairData.SignedPublicKey;
VerifyingKey = accountKeys.SignatureKeyPairData.VerifyingKey;
}
}
public string? VerifyingKey { get; set; }
public string? SignedPublicKey { get; set; }
public required string PublicKey { get; set; }
}

View File

@@ -0,0 +1,32 @@
using System.Text.Json.Serialization;
using Bit.Core.KeyManagement.Models.Data;
using Bit.Core.Models.Api;
namespace Bit.Core.KeyManagement.Models.Api.Response;
public class SignatureKeyPairResponseModel : ResponseModel
{
[JsonPropertyName("wrappedSigningKey")]
public required string WrappedSigningKey { get; set; }
[JsonPropertyName("verifyingKey")]
public required string VerifyingKey { get; set; }
[System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute]
public SignatureKeyPairResponseModel(SignatureKeyPairData signatureKeyPair)
: base("signatureKeyPair")
{
ArgumentNullException.ThrowIfNull(signatureKeyPair);
WrappedSigningKey = signatureKeyPair.WrappedSigningKey;
VerifyingKey = signatureKeyPair.VerifyingKey;
}
[JsonConstructor]
public SignatureKeyPairResponseModel(string wrappedSigningKey, string verifyingKey)
: base("signatureKeyPair")
{
WrappedSigningKey = wrappedSigningKey ?? throw new ArgumentNullException(nameof(wrappedSigningKey));
VerifyingKey = verifyingKey ?? throw new ArgumentNullException(nameof(verifyingKey));
}
}

View File

@@ -1,4 +1,4 @@
namespace Bit.Core.KeyManagement.Models.Response;
namespace Bit.Core.KeyManagement.Models.Api.Response;
public class UserDecryptionResponseModel
{

View File

@@ -1,4 +1,5 @@
using Bit.Core.Entities;
using Bit.Core.Exceptions;
namespace Bit.Core.KeyManagement.Models.Data;
@@ -12,7 +13,7 @@ public class MasterPasswordAuthenticationData
{
if (user.GetMasterPasswordSalt() != Salt)
{
throw new ArgumentException("Invalid master password salt.");
throw new BadRequestException("Invalid master password salt.");
}
}
}

View File

@@ -1,6 +1,5 @@
#nullable enable
using Bit.Core.Entities;
using Bit.Core.Entities;
using Bit.Core.Exceptions;
namespace Bit.Core.KeyManagement.Models.Data;
@@ -14,7 +13,7 @@ public class MasterPasswordUnlockData
{
if (user.GetMasterPasswordSalt() != Salt)
{
throw new ArgumentException("Invalid master password salt.");
throw new BadRequestException("Invalid master password salt.");
}
}
}

View File

@@ -0,0 +1,20 @@
using System.Text.Json.Serialization;
namespace Bit.Core.KeyManagement.Models.Data;
public class PublicKeyEncryptionKeyPairData
{
public required string WrappedPrivateKey { get; set; }
public string? SignedPublicKey { get; set; }
public required string PublicKey { get; set; }
[JsonConstructor]
[System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute]
public PublicKeyEncryptionKeyPairData(string wrappedPrivateKey, string publicKey, string? signedPublicKey = null)
{
WrappedPrivateKey = wrappedPrivateKey ?? throw new ArgumentNullException(nameof(wrappedPrivateKey));
PublicKey = publicKey ?? throw new ArgumentNullException(nameof(publicKey));
SignedPublicKey = signedPublicKey;
}
}

View File

@@ -1,6 +1,4 @@
// FIXME: Update this file to be null safe and then delete the line below
#nullable disable

using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Models.Data;
using Bit.Core.Entities;
@@ -12,21 +10,19 @@ namespace Bit.Core.KeyManagement.Models.Data;
public class RotateUserAccountKeysData
{
// Authentication for this requests
public string OldMasterKeyAuthenticationHash { get; set; }
public required string OldMasterKeyAuthenticationHash { get; set; }
// Other keys encrypted by the userkey
public string UserKeyEncryptedAccountPrivateKey { get; set; }
public string AccountPublicKey { get; set; }
public required UserAccountKeysData AccountKeys { get; set; }
// All methods to get to the userkey
public MasterPasswordUnlockAndAuthenticationData MasterPasswordUnlockData { get; set; }
public IEnumerable<EmergencyAccess> EmergencyAccesses { get; set; }
public IReadOnlyList<OrganizationUser> OrganizationUsers { get; set; }
public IEnumerable<WebAuthnLoginRotateKeyData> WebAuthnKeys { get; set; }
public IEnumerable<Device> DeviceKeys { get; set; }
public required MasterPasswordUnlockAndAuthenticationData MasterPasswordUnlockData { get; set; }
public required IEnumerable<EmergencyAccess> EmergencyAccesses { get; set; }
public required IReadOnlyList<OrganizationUser> OrganizationUsers { get; set; }
public required IEnumerable<WebAuthnLoginRotateKeyData> WebAuthnKeys { get; set; }
public required IEnumerable<Device> DeviceKeys { get; set; }
// User vault data encrypted by the userkey
public IEnumerable<Cipher> Ciphers { get; set; }
public IEnumerable<Folder> Folders { get; set; }
public IReadOnlyList<Send> Sends { get; set; }
public required IEnumerable<Cipher> Ciphers { get; set; }
public required IEnumerable<Folder> Folders { get; set; }
public required IReadOnlyList<Send> Sends { get; set; }
}

View File

@@ -0,0 +1,10 @@

namespace Bit.Core.KeyManagement.Models.Data;
public class SecurityStateData
{
public required string SecurityState { get; set; }
// The security version is included in the security state, but needs COSE parsing,
// so this is a separate copy that can be used directly.
public required int SecurityVersion { get; set; }
}

View File

@@ -0,0 +1,21 @@

using System.Text.Json.Serialization;
using Bit.Core.KeyManagement.Enums;
namespace Bit.Core.KeyManagement.Models.Data;
public class SignatureKeyPairData
{
public required SignatureAlgorithm SignatureAlgorithm { get; set; }
public required string WrappedSigningKey { get; set; }
public required string VerifyingKey { get; set; }
[JsonConstructor]
[System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute]
public SignatureKeyPairData(SignatureAlgorithm signatureAlgorithm, string wrappedSigningKey, string verifyingKey)
{
SignatureAlgorithm = signatureAlgorithm;
WrappedSigningKey = wrappedSigningKey ?? throw new ArgumentNullException(nameof(wrappedSigningKey));
VerifyingKey = verifyingKey ?? throw new ArgumentNullException(nameof(verifyingKey));
}
}

View File

@@ -0,0 +1,9 @@
namespace Bit.Core.KeyManagement.Models.Data;
public class UserAccountKeysData
{
public required PublicKeyEncryptionKeyPairData PublicKeyEncryptionKeyPairData { get; set; }
public SignatureKeyPairData? SignatureKeyPairData { get; set; }
public SecurityStateData? SecurityStateData { get; set; }
}

View File

@@ -0,0 +1,10 @@

using Bit.Core.Entities;
using Bit.Core.KeyManagement.Models.Data;
namespace Bit.Core.KeyManagement.Queries.Interfaces;
public interface IUserAccountKeysQuery
{
Task<UserAccountKeysData> Run(User user);
}

View File

@@ -0,0 +1,35 @@

using Bit.Core.Entities;
using Bit.Core.KeyManagement.Models.Data;
using Bit.Core.KeyManagement.Queries.Interfaces;
using Bit.Core.KeyManagement.Repositories;
namespace Bit.Core.KeyManagement.Queries;
public class UserAccountKeysQuery(IUserSignatureKeyPairRepository signatureKeyPairRepository) : IUserAccountKeysQuery
{
public async Task<UserAccountKeysData> Run(User user)
{
if (user.GetSecurityVersion() < 2)
{
return new UserAccountKeysData
{
PublicKeyEncryptionKeyPairData = user.GetPublicKeyEncryptionKeyPair(),
};
}
else
{
return new UserAccountKeysData
{
PublicKeyEncryptionKeyPairData = user.GetPublicKeyEncryptionKeyPair(),
SignatureKeyPairData = await signatureKeyPairRepository.GetByUserIdAsync(user.Id),
SecurityStateData = new SecurityStateData
{
SecurityState = user.SecurityState!,
SecurityVersion = user.GetSecurityVersion(),
}
};
}
}
}

View File

@@ -0,0 +1,14 @@

using Bit.Core.KeyManagement.Entities;
using Bit.Core.KeyManagement.Models.Data;
using Bit.Core.KeyManagement.UserKey;
using Bit.Core.Repositories;
namespace Bit.Core.KeyManagement.Repositories;
public interface IUserSignatureKeyPairRepository : IRepository<UserSignatureKeyPair, Guid>
{
public Task<SignatureKeyPairData?> GetByUserIdAsync(Guid userId);
public UpdateEncryptedDataForKeyRotation UpdateForKeyRotation(Guid grantorId, SignatureKeyPairData signatureKeyPair);
public UpdateEncryptedDataForKeyRotation SetUserSignatureKeyPair(Guid userId, SignatureKeyPairData signatureKeyPair);
}

View File

@@ -1,6 +1,11 @@
using Bit.Core.Auth.Repositories;
// FIXME: Update this file to be null safe and then delete the line below
#nullable disable
using Bit.Core.Auth.Repositories;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.KeyManagement.Models.Data;
using Bit.Core.KeyManagement.Repositories;
using Bit.Core.Platform.Push;
using Bit.Core.Repositories;
using Bit.Core.Services;
@@ -25,6 +30,8 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand
private readonly IdentityErrorDescriber _identityErrorDescriber;
private readonly IWebAuthnCredentialRepository _credentialRepository;
private readonly IPasswordHasher<User> _passwordHasher;
private readonly IUserSignatureKeyPairRepository _userSignatureKeyPairRepository;
private readonly IFeatureService _featureService;
/// <summary>
/// Instantiates a new <see cref="RotateUserAccountKeysCommand"/>
@@ -36,16 +43,19 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand
/// <param name="sendRepository">Provides a method to update re-encrypted send data</param>
/// <param name="emergencyAccessRepository">Provides a method to update re-encrypted emergency access data</param>
/// <param name="organizationUserRepository">Provides a method to update re-encrypted organization user data</param>
/// <param name="deviceRepository">Provides a method to update re-encrypted device keys</param>
/// <param name="passwordHasher">Hashes the new master password</param>
/// <param name="pushService">Logs out user from other devices after successful rotation</param>
/// <param name="errors">Provides a password mismatch error if master password hash validation fails</param>
/// <param name="credentialRepository">Provides a method to update re-encrypted WebAuthn keys</param>
/// <param name="userSignatureKeyPairRepository">Provides a method to update re-encrypted signature keys</param>
public RotateUserAccountKeysCommand(IUserService userService, IUserRepository userRepository,
ICipherRepository cipherRepository, IFolderRepository folderRepository, ISendRepository sendRepository,
IEmergencyAccessRepository emergencyAccessRepository, IOrganizationUserRepository organizationUserRepository,
IDeviceRepository deviceRepository,
IPasswordHasher<User> passwordHasher,
IPushNotificationService pushService, IdentityErrorDescriber errors, IWebAuthnCredentialRepository credentialRepository,
IUserSignatureKeyPairRepository userSignatureKeyPairRepository,
IFeatureService featureService)
{
_userService = userService;
@@ -60,6 +70,8 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand
_identityErrorDescriber = errors;
_credentialRepository = credentialRepository;
_passwordHasher = passwordHasher;
_userSignatureKeyPairRepository = userSignatureKeyPairRepository;
_featureService = featureService;
}
/// <inheritdoc />
@@ -80,50 +92,106 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand
user.LastKeyRotationDate = now;
user.SecurityStamp = Guid.NewGuid().ToString();
if (
!model.MasterPasswordUnlockData.ValidateForUser(user)
)
List<UpdateEncryptedDataForKeyRotation> saveEncryptedDataActions = [];
await UpdateAccountKeysAsync(model, user, saveEncryptedDataActions);
UpdateUnlockMethods(model, user, saveEncryptedDataActions);
UpdateUserData(model, user, saveEncryptedDataActions);
await _userRepository.UpdateUserKeyAndEncryptedDataV2Async(user, saveEncryptedDataActions);
await _pushService.PushLogOutAsync(user.Id);
return IdentityResult.Success;
}
public async Task RotateV2AccountKeysAsync(RotateUserAccountKeysData model, User user, List<UpdateEncryptedDataForKeyRotation> saveEncryptedDataActions)
{
ValidateV2Encryption(model);
await ValidateVerifyingKeyUnchangedAsync(model, user);
saveEncryptedDataActions.Add(_userSignatureKeyPairRepository.UpdateForKeyRotation(user.Id, model.AccountKeys.SignatureKeyPairData));
user.SignedPublicKey = model.AccountKeys.PublicKeyEncryptionKeyPairData.SignedPublicKey;
user.SecurityState = model.AccountKeys.SecurityStateData!.SecurityState;
user.SecurityVersion = model.AccountKeys.SecurityStateData.SecurityVersion;
}
public void UpgradeV1ToV2Keys(RotateUserAccountKeysData model, User user, List<UpdateEncryptedDataForKeyRotation> saveEncryptedDataActions)
{
ValidateV2Encryption(model);
saveEncryptedDataActions.Add(_userSignatureKeyPairRepository.SetUserSignatureKeyPair(user.Id, model.AccountKeys.SignatureKeyPairData));
user.SignedPublicKey = model.AccountKeys.PublicKeyEncryptionKeyPairData.SignedPublicKey;
user.SecurityState = model.AccountKeys.SecurityStateData!.SecurityState;
user.SecurityVersion = model.AccountKeys.SecurityStateData.SecurityVersion;
}
public async Task UpdateAccountKeysAsync(RotateUserAccountKeysData model, User user, List<UpdateEncryptedDataForKeyRotation> saveEncryptedDataActions)
{
ValidatePublicKeyEncryptionKeyPairUnchanged(model, user);
if (IsV2EncryptionUserAsync(user))
{
throw new InvalidOperationException("The provided master password unlock data is not valid for this user.");
await RotateV2AccountKeysAsync(model, user, saveEncryptedDataActions);
}
if (
model.AccountPublicKey != user.PublicKey
)
else if (model.AccountKeys.SignatureKeyPairData != null)
{
throw new InvalidOperationException("The provided account public key does not match the user's current public key, and changing the account asymmetric keypair is currently not supported during key rotation.");
UpgradeV1ToV2Keys(model, user, saveEncryptedDataActions);
}
else
{
if (GetEncryptionType(model.AccountKeys.PublicKeyEncryptionKeyPairData.WrappedPrivateKey) != EncryptionType.AesCbc256_HmacSha256_B64)
{
throw new InvalidOperationException("The provided account private key was not wrapped with AES-256-CBC-HMAC");
}
// V1 user to V1 user rotation needs to further changes, the private key was re-encrypted.
}
user.Key = model.MasterPasswordUnlockData.MasterKeyEncryptedUserKey;
user.PrivateKey = model.UserKeyEncryptedAccountPrivateKey;
user.MasterPassword = _passwordHasher.HashPassword(user, model.MasterPasswordUnlockData.MasterKeyAuthenticationHash);
user.MasterPasswordHint = model.MasterPasswordUnlockData.MasterPasswordHint;
// Private key is re-wrapped with new user key by client
user.PrivateKey = model.AccountKeys.PublicKeyEncryptionKeyPairData.WrappedPrivateKey;
}
public void UpdateUserData(RotateUserAccountKeysData model, User user, List<UpdateEncryptedDataForKeyRotation> saveEncryptedDataActions)
{
// The revision date has to be updated so that de-synced clients don't accidentally post over the re-encrypted data
// with an old-user key-encrypted copy
var now = DateTime.UtcNow;
List<UpdateEncryptedDataForKeyRotation> saveEncryptedDataActions = new();
if (model.Ciphers.Any())
{
saveEncryptedDataActions.Add(_cipherRepository.UpdateForKeyRotation(user.Id, model.Ciphers));
var ciphersWithUpdatedDate = model.Ciphers.ToList().Select(c => { c.RevisionDate = now; return c; });
saveEncryptedDataActions.Add(_cipherRepository.UpdateForKeyRotation(user.Id, ciphersWithUpdatedDate));
}
if (model.Folders.Any())
{
saveEncryptedDataActions.Add(_folderRepository.UpdateForKeyRotation(user.Id, model.Folders));
var foldersWithUpdatedDate = model.Folders.ToList().Select(f => { f.RevisionDate = now; return f; });
saveEncryptedDataActions.Add(_folderRepository.UpdateForKeyRotation(user.Id, foldersWithUpdatedDate));
}
if (model.Sends.Any())
{
saveEncryptedDataActions.Add(_sendRepository.UpdateForKeyRotation(user.Id, model.Sends));
var sendsWithUpdatedDate = model.Sends.ToList().Select(s => { s.RevisionDate = now; return s; });
saveEncryptedDataActions.Add(_sendRepository.UpdateForKeyRotation(user.Id, sendsWithUpdatedDate));
}
}
void UpdateUnlockMethods(RotateUserAccountKeysData model, User user, List<UpdateEncryptedDataForKeyRotation> saveEncryptedDataActions)
{
if (!model.MasterPasswordUnlockData.ValidateForUser(user))
{
throw new InvalidOperationException("The provided master password unlock data is not valid for this user.");
}
// Update master password authentication & unlock
user.Key = model.MasterPasswordUnlockData.MasterKeyEncryptedUserKey;
user.MasterPassword = _passwordHasher.HashPassword(user, model.MasterPasswordUnlockData.MasterKeyAuthenticationHash);
user.MasterPasswordHint = model.MasterPasswordUnlockData.MasterPasswordHint;
if (model.EmergencyAccesses.Any())
{
saveEncryptedDataActions.Add(
_emergencyAccessRepository.UpdateForKeyRotation(user.Id, model.EmergencyAccesses));
saveEncryptedDataActions.Add(_emergencyAccessRepository.UpdateForKeyRotation(user.Id, model.EmergencyAccesses));
}
if (model.OrganizationUsers.Any())
{
saveEncryptedDataActions.Add(
_organizationUserRepository.UpdateForKeyRotation(user.Id, model.OrganizationUsers));
saveEncryptedDataActions.Add(_organizationUserRepository.UpdateForKeyRotation(user.Id, model.OrganizationUsers));
}
if (model.WebAuthnKeys.Any())
@@ -135,9 +203,80 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand
{
saveEncryptedDataActions.Add(_deviceRepository.UpdateKeysForRotationAsync(user.Id, model.DeviceKeys));
}
}
await _userRepository.UpdateUserKeyAndEncryptedDataV2Async(user, saveEncryptedDataActions);
await _pushService.PushLogOutAsync(user.Id);
return IdentityResult.Success;
private bool IsV2EncryptionUserAsync(User user)
{
// Returns whether the user is a V2 user based on the private key's encryption type.
ArgumentNullException.ThrowIfNull(user);
var isPrivateKeyEncryptionV2 = GetEncryptionType(user.PrivateKey) == EncryptionType.XChaCha20Poly1305_B64;
return isPrivateKeyEncryptionV2;
}
private async Task ValidateVerifyingKeyUnchangedAsync(RotateUserAccountKeysData model, User user)
{
var currentSignatureKeyPair = await _userSignatureKeyPairRepository.GetByUserIdAsync(user.Id) ?? throw new InvalidOperationException("User does not have a signature key pair.");
if (model.AccountKeys.SignatureKeyPairData.VerifyingKey != currentSignatureKeyPair!.VerifyingKey)
{
throw new InvalidOperationException("The provided verifying key does not match the user's current verifying key.");
}
}
private static void ValidatePublicKeyEncryptionKeyPairUnchanged(RotateUserAccountKeysData model, User user)
{
var publicKey = model.AccountKeys.PublicKeyEncryptionKeyPairData.PublicKey;
if (publicKey != user.PublicKey)
{
throw new InvalidOperationException("The provided account public key does not match the user's current public key, and changing the account asymmetric key pair is currently not supported during key rotation.");
}
}
private static void ValidateV2Encryption(RotateUserAccountKeysData model)
{
if (model.AccountKeys.SignatureKeyPairData == null)
{
throw new InvalidOperationException("Signature key pair data is required for V2 encryption.");
}
if (GetEncryptionType(model.AccountKeys.SignatureKeyPairData.WrappedSigningKey) != EncryptionType.XChaCha20Poly1305_B64)
{
throw new InvalidOperationException("The provided signing key data is not wrapped with XChaCha20-Poly1305.");
}
if (string.IsNullOrEmpty(model.AccountKeys.SignatureKeyPairData.VerifyingKey))
{
throw new InvalidOperationException("The provided signature key pair data does not contain a valid verifying key.");
}
if (GetEncryptionType(model.AccountKeys.PublicKeyEncryptionKeyPairData.WrappedPrivateKey) != EncryptionType.XChaCha20Poly1305_B64)
{
throw new InvalidOperationException("The provided private key encryption key is not wrapped with XChaCha20-Poly1305.");
}
if (string.IsNullOrEmpty(model.AccountKeys.PublicKeyEncryptionKeyPairData.SignedPublicKey))
{
throw new InvalidOperationException("No signed public key provided, but the user already has a signature key pair.");
}
if (model.AccountKeys.SecurityStateData == null || string.IsNullOrEmpty(model.AccountKeys.SecurityStateData.SecurityState))
{
throw new InvalidOperationException("No signed security state provider for V2 user");
}
}
/// <summary>
/// Helper method to convert an encryption type string to an enum value.
/// </summary>
private static EncryptionType GetEncryptionType(string encString)
{
var parts = encString.Split('.');
if (parts.Length == 1)
{
throw new ArgumentException("Invalid encryption type string.");
}
if (byte.TryParse(parts[0], out var encryptionTypeNumber))
{
if (Enum.IsDefined(typeof(EncryptionType), encryptionTypeNumber))
{
return (EncryptionType)encryptionTypeNumber;
}
}
throw new ArgumentException("Invalid encryption type string.");
}
}

View File

@@ -1,6 +1,7 @@
// FIXME: Update this file to be null safe and then delete the line below
#nullable disable
using Bit.Core.Billing.Extensions;
using Stripe;
namespace Bit.Core.Models.Business;
@@ -36,8 +37,13 @@ public class SubscriptionInfo
Status = sub.Status;
TrialStartDate = sub.TrialStart;
TrialEndDate = sub.TrialEnd;
PeriodStartDate = sub.CurrentPeriodStart;
PeriodEndDate = sub.CurrentPeriodEnd;
var currentPeriod = sub.GetCurrentPeriod();
if (currentPeriod != null)
{
var (start, end) = currentPeriod.Value;
PeriodStartDate = start;
PeriodEndDate = end;
}
CancelledDate = sub.CanceledAt;
CancelAtEndDate = sub.CancelAtPeriodEnd;
Cancelled = sub.Status == "canceled" || sub.Status == "unpaid" || sub.Status == "incomplete_expired";

View File

@@ -97,3 +97,9 @@ public class ProviderBankAccountVerifiedPushNotification
public Guid ProviderId { get; set; }
public Guid AdminId { get; set; }
}
public class LogOutPushNotification
{
public Guid UserId { get; set; }
public PushNotificationLogOutReason? Reason { get; set; }
}

View File

@@ -1,51 +0,0 @@
// FIXME: Update this file to be null safe and then delete the line below
#nullable disable
namespace Bit.Core.Models.BitStripe;
// Stripe's SubscriptionListOptions model has a complex input for date filters.
// It expects a dictionary, and has lots of validation rules around what can have a value and what can't.
// To simplify this a bit we are extending Stripe's model and using our own date inputs, and building the dictionary they expect JiT.
// ___
// Our model also facilitates selecting all elements in a list, which is unsupported by Stripe's model.
public class StripeSubscriptionListOptions : Stripe.SubscriptionListOptions
{
public DateTime? CurrentPeriodEndDate { get; set; }
public string CurrentPeriodEndRange { get; set; } = "lt";
public bool SelectAll { get; set; }
public new Stripe.DateRangeOptions CurrentPeriodEnd
{
get
{
return CurrentPeriodEndDate.HasValue ?
new Stripe.DateRangeOptions()
{
LessThan = CurrentPeriodEndRange == "lt" ? CurrentPeriodEndDate : null,
GreaterThan = CurrentPeriodEndRange == "gt" ? CurrentPeriodEndDate : null
} :
null;
}
}
public Stripe.SubscriptionListOptions ToStripeApiOptions()
{
var stripeApiOptions = (Stripe.SubscriptionListOptions)this;
if (SelectAll)
{
stripeApiOptions.EndingBefore = null;
stripeApiOptions.StartingAfter = null;
}
if (CurrentPeriodEndDate.HasValue)
{
stripeApiOptions.CurrentPeriodEnd = new Stripe.DateRangeOptions()
{
LessThan = CurrentPeriodEndRange == "lt" ? CurrentPeriodEndDate : null,
GreaterThan = CurrentPeriodEndRange == "gt" ? CurrentPeriodEndDate : null
};
}
return stripeApiOptions;
}
}

View File

@@ -62,7 +62,7 @@ public class SelfHostedSyncSponsorshipsCommand : BaseIdentityClientService, ISel
.ToDictionary(i => i.SponsoringOrganizationUserId);
if (!organizationSponsorshipsDict.Any())
{
_logger.LogInformation($"No existing sponsorships to sync for organization {organizationId}");
_logger.LogInformation("No existing sponsorships to sync for organization {organizationId}", organizationId);
return;
}
var syncedSponsorships = new List<OrganizationSponsorshipData>();

View File

@@ -167,18 +167,17 @@ public interface IPushNotificationService
ExcludeCurrentContext = false,
});
Task PushLogOutAsync(Guid userId, bool excludeCurrentContextFromPush = false)
=> PushAsync(new PushNotification<UserPushNotification>
Task PushLogOutAsync(Guid userId, bool excludeCurrentContextFromPush = false,
PushNotificationLogOutReason? reason = null)
=> PushAsync(new PushNotification<LogOutPushNotification>
{
Type = PushType.LogOut,
Target = NotificationTarget.User,
TargetId = userId,
Payload = new UserPushNotification
Payload = new LogOutPushNotification
{
UserId = userId,
#pragma warning disable BWP0001 // Type or member is obsolete
Date = TimeProvider.GetUtcNow().UtcDateTime,
#pragma warning restore BWP0001 // Type or member is obsolete
Reason = reason
},
ExcludeCurrentContext = excludeCurrentContextFromPush,
});

View File

@@ -55,7 +55,7 @@ public enum PushType : byte
[NotificationInfo("not-specified", typeof(Models.UserPushNotification))]
SyncSettings = 10,
[NotificationInfo("not-specified", typeof(Models.UserPushNotification))]
[NotificationInfo("not-specified", typeof(Models.LogOutPushNotification))]
LogOut = 11,
[NotificationInfo("@bitwarden/team-tools-dev", typeof(Models.SyncSendPushNotification))]

View File

@@ -0,0 +1,28 @@
#nullable enable
using Bit.Core.Entities;
using Bit.Core.Utilities;
namespace Bit.Core.SecretsManager.Entities;
public class SecretVersion : ITableObject<Guid>
{
public Guid Id { get; set; }
public Guid SecretId { get; set; }
public string Value { get; set; } = string.Empty;
public DateTime VersionDate { get; set; }
public Guid? EditorServiceAccountId { get; set; }
public Guid? EditorOrganizationUserId { get; set; }
public void SetNewId()
{
if (Id == default(Guid))
{
Id = CoreHelpers.GenerateComb();
}
}
}

View File

@@ -3,58 +3,47 @@
using Bit.Core.Models.BitStripe;
using Stripe;
using Stripe.Tax;
namespace Bit.Core.Services;
public interface IStripeAdapter
{
Task<Stripe.Customer> CustomerCreateAsync(Stripe.CustomerCreateOptions customerCreateOptions);
Task<Stripe.Customer> CustomerGetAsync(string id, Stripe.CustomerGetOptions options = null);
Task<Stripe.Customer> CustomerUpdateAsync(string id, Stripe.CustomerUpdateOptions options = null);
Task<Stripe.Customer> CustomerDeleteAsync(string id);
Task<List<PaymentMethod>> CustomerListPaymentMethods(string id, CustomerListPaymentMethodsOptions options = null);
Task<Customer> CustomerCreateAsync(CustomerCreateOptions customerCreateOptions);
Task CustomerDeleteDiscountAsync(string customerId, CustomerDeleteDiscountOptions options = null);
Task<Customer> CustomerGetAsync(string id, CustomerGetOptions options = null);
Task<Customer> CustomerUpdateAsync(string id, CustomerUpdateOptions options = null);
Task<Customer> CustomerDeleteAsync(string id);
Task<List<PaymentMethod>> CustomerListPaymentMethods(string id, CustomerPaymentMethodListOptions options = null);
Task<CustomerBalanceTransaction> CustomerBalanceTransactionCreate(string customerId,
CustomerBalanceTransactionCreateOptions options);
Task<Stripe.Subscription> SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions subscriptionCreateOptions);
Task<Stripe.Subscription> SubscriptionGetAsync(string id, Stripe.SubscriptionGetOptions options = null);
/// <summary>
/// Retrieves a subscription object for a provider.
/// </summary>
/// <param name="id">The subscription ID.</param>
/// <param name="providerId">The provider ID.</param>
/// <param name="options">Additional options.</param>
/// <returns>The subscription object.</returns>
/// <exception cref="InvalidOperationException">Thrown when the subscription doesn't belong to the provider.</exception>
Task<Stripe.Subscription> ProviderSubscriptionGetAsync(string id, Guid providerId, Stripe.SubscriptionGetOptions options = null);
Task<List<Stripe.Subscription>> SubscriptionListAsync(StripeSubscriptionListOptions subscriptionSearchOptions);
Task<Stripe.Subscription> SubscriptionUpdateAsync(string id, Stripe.SubscriptionUpdateOptions options = null);
Task<Stripe.Subscription> SubscriptionCancelAsync(string Id, Stripe.SubscriptionCancelOptions options = null);
Task<Stripe.Invoice> InvoiceUpcomingAsync(Stripe.UpcomingInvoiceOptions options);
Task<Stripe.Invoice> InvoiceGetAsync(string id, Stripe.InvoiceGetOptions options);
Task<List<Stripe.Invoice>> InvoiceListAsync(StripeInvoiceListOptions options);
Task<Stripe.Invoice> InvoiceCreatePreviewAsync(InvoiceCreatePreviewOptions options);
Task<List<Stripe.Invoice>> InvoiceSearchAsync(InvoiceSearchOptions options);
Task<Stripe.Invoice> InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options);
Task<Stripe.Invoice> InvoiceFinalizeInvoiceAsync(string id, Stripe.InvoiceFinalizeOptions options);
Task<Stripe.Invoice> InvoiceSendInvoiceAsync(string id, Stripe.InvoiceSendOptions options);
Task<Stripe.Invoice> InvoicePayAsync(string id, Stripe.InvoicePayOptions options = null);
Task<Stripe.Invoice> InvoiceDeleteAsync(string id, Stripe.InvoiceDeleteOptions options = null);
Task<Stripe.Invoice> InvoiceVoidInvoiceAsync(string id, Stripe.InvoiceVoidOptions options = null);
IEnumerable<Stripe.PaymentMethod> PaymentMethodListAutoPaging(Stripe.PaymentMethodListOptions options);
IAsyncEnumerable<Stripe.PaymentMethod> PaymentMethodListAutoPagingAsync(Stripe.PaymentMethodListOptions options);
Task<Stripe.PaymentMethod> PaymentMethodAttachAsync(string id, Stripe.PaymentMethodAttachOptions options = null);
Task<Stripe.PaymentMethod> PaymentMethodDetachAsync(string id, Stripe.PaymentMethodDetachOptions options = null);
Task<Stripe.TaxId> TaxIdCreateAsync(string id, Stripe.TaxIdCreateOptions options);
Task<Stripe.TaxId> TaxIdDeleteAsync(string customerId, string taxIdId, Stripe.TaxIdDeleteOptions options = null);
Task<Stripe.StripeList<Stripe.Tax.Registration>> TaxRegistrationsListAsync(Stripe.Tax.RegistrationListOptions options = null);
Task<Stripe.StripeList<Stripe.Charge>> ChargeListAsync(Stripe.ChargeListOptions options);
Task<Stripe.Refund> RefundCreateAsync(Stripe.RefundCreateOptions options);
Task<Stripe.Card> CardDeleteAsync(string customerId, string cardId, Stripe.CardDeleteOptions options = null);
Task<Stripe.BankAccount> BankAccountCreateAsync(string customerId, Stripe.BankAccountCreateOptions options = null);
Task<Stripe.BankAccount> BankAccountDeleteAsync(string customerId, string bankAccount, Stripe.BankAccountDeleteOptions options = null);
Task<Stripe.StripeList<Stripe.Price>> PriceListAsync(Stripe.PriceListOptions options = null);
Task<Subscription> SubscriptionCreateAsync(SubscriptionCreateOptions subscriptionCreateOptions);
Task<Subscription> SubscriptionGetAsync(string id, SubscriptionGetOptions options = null);
Task<Subscription> SubscriptionUpdateAsync(string id, SubscriptionUpdateOptions options = null);
Task<Subscription> SubscriptionCancelAsync(string Id, SubscriptionCancelOptions options = null);
Task<Invoice> InvoiceGetAsync(string id, InvoiceGetOptions options);
Task<List<Invoice>> InvoiceListAsync(StripeInvoiceListOptions options);
Task<Invoice> InvoiceCreatePreviewAsync(InvoiceCreatePreviewOptions options);
Task<List<Invoice>> InvoiceSearchAsync(InvoiceSearchOptions options);
Task<Invoice> InvoiceUpdateAsync(string id, InvoiceUpdateOptions options);
Task<Invoice> InvoiceFinalizeInvoiceAsync(string id, InvoiceFinalizeOptions options);
Task<Invoice> InvoiceSendInvoiceAsync(string id, InvoiceSendOptions options);
Task<Invoice> InvoicePayAsync(string id, InvoicePayOptions options = null);
Task<Invoice> InvoiceDeleteAsync(string id, InvoiceDeleteOptions options = null);
Task<Invoice> InvoiceVoidInvoiceAsync(string id, InvoiceVoidOptions options = null);
IEnumerable<PaymentMethod> PaymentMethodListAutoPaging(PaymentMethodListOptions options);
IAsyncEnumerable<PaymentMethod> PaymentMethodListAutoPagingAsync(PaymentMethodListOptions options);
Task<PaymentMethod> PaymentMethodAttachAsync(string id, PaymentMethodAttachOptions options = null);
Task<PaymentMethod> PaymentMethodDetachAsync(string id, PaymentMethodDetachOptions options = null);
Task<TaxId> TaxIdCreateAsync(string id, TaxIdCreateOptions options);
Task<TaxId> TaxIdDeleteAsync(string customerId, string taxIdId, TaxIdDeleteOptions options = null);
Task<StripeList<Registration>> TaxRegistrationsListAsync(RegistrationListOptions options = null);
Task<StripeList<Charge>> ChargeListAsync(ChargeListOptions options);
Task<Refund> RefundCreateAsync(RefundCreateOptions options);
Task<Card> CardDeleteAsync(string customerId, string cardId, CardDeleteOptions options = null);
Task<BankAccount> BankAccountCreateAsync(string customerId, BankAccountCreateOptions options = null);
Task<BankAccount> BankAccountDeleteAsync(string customerId, string bankAccount, BankAccountDeleteOptions options = null);
Task<StripeList<Price>> PriceListAsync(PriceListOptions options = null);
Task<SetupIntent> SetupIntentCreate(SetupIntentCreateOptions options);
Task<List<SetupIntent>> SetupIntentList(SetupIntentListOptions options);
Task SetupIntentCancel(string id, SetupIntentCancelOptions options = null);

View File

@@ -9,18 +9,18 @@ namespace Bit.Core.Services;
public class StripeAdapter : IStripeAdapter
{
private readonly Stripe.CustomerService _customerService;
private readonly Stripe.SubscriptionService _subscriptionService;
private readonly Stripe.InvoiceService _invoiceService;
private readonly Stripe.PaymentMethodService _paymentMethodService;
private readonly Stripe.TaxIdService _taxIdService;
private readonly Stripe.ChargeService _chargeService;
private readonly Stripe.RefundService _refundService;
private readonly Stripe.CardService _cardService;
private readonly Stripe.BankAccountService _bankAccountService;
private readonly Stripe.PlanService _planService;
private readonly Stripe.PriceService _priceService;
private readonly Stripe.SetupIntentService _setupIntentService;
private readonly CustomerService _customerService;
private readonly SubscriptionService _subscriptionService;
private readonly InvoiceService _invoiceService;
private readonly PaymentMethodService _paymentMethodService;
private readonly TaxIdService _taxIdService;
private readonly ChargeService _chargeService;
private readonly RefundService _refundService;
private readonly CardService _cardService;
private readonly BankAccountService _bankAccountService;
private readonly PlanService _planService;
private readonly PriceService _priceService;
private readonly SetupIntentService _setupIntentService;
private readonly Stripe.TestHelpers.TestClockService _testClockService;
private readonly CustomerBalanceTransactionService _customerBalanceTransactionService;
private readonly Stripe.Tax.RegistrationService _taxRegistrationService;
@@ -28,17 +28,17 @@ public class StripeAdapter : IStripeAdapter
public StripeAdapter()
{
_customerService = new Stripe.CustomerService();
_subscriptionService = new Stripe.SubscriptionService();
_invoiceService = new Stripe.InvoiceService();
_paymentMethodService = new Stripe.PaymentMethodService();
_taxIdService = new Stripe.TaxIdService();
_chargeService = new Stripe.ChargeService();
_refundService = new Stripe.RefundService();
_cardService = new Stripe.CardService();
_bankAccountService = new Stripe.BankAccountService();
_priceService = new Stripe.PriceService();
_planService = new Stripe.PlanService();
_customerService = new CustomerService();
_subscriptionService = new SubscriptionService();
_invoiceService = new InvoiceService();
_paymentMethodService = new PaymentMethodService();
_taxIdService = new TaxIdService();
_chargeService = new ChargeService();
_refundService = new RefundService();
_cardService = new CardService();
_bankAccountService = new BankAccountService();
_priceService = new PriceService();
_planService = new PlanService();
_setupIntentService = new SetupIntentService();
_testClockService = new Stripe.TestHelpers.TestClockService();
_customerBalanceTransactionService = new CustomerBalanceTransactionService();
@@ -46,28 +46,31 @@ public class StripeAdapter : IStripeAdapter
_calculationService = new CalculationService();
}
public Task<Stripe.Customer> CustomerCreateAsync(Stripe.CustomerCreateOptions options)
public Task<Customer> CustomerCreateAsync(CustomerCreateOptions options)
{
return _customerService.CreateAsync(options);
}
public Task<Stripe.Customer> CustomerGetAsync(string id, Stripe.CustomerGetOptions options = null)
public Task CustomerDeleteDiscountAsync(string customerId, CustomerDeleteDiscountOptions options = null) =>
_customerService.DeleteDiscountAsync(customerId, options);
public Task<Customer> CustomerGetAsync(string id, CustomerGetOptions options = null)
{
return _customerService.GetAsync(id, options);
}
public Task<Stripe.Customer> CustomerUpdateAsync(string id, Stripe.CustomerUpdateOptions options = null)
public Task<Customer> CustomerUpdateAsync(string id, CustomerUpdateOptions options = null)
{
return _customerService.UpdateAsync(id, options);
}
public Task<Stripe.Customer> CustomerDeleteAsync(string id)
public Task<Customer> CustomerDeleteAsync(string id)
{
return _customerService.DeleteAsync(id);
}
public async Task<List<PaymentMethod>> CustomerListPaymentMethods(string id,
CustomerListPaymentMethodsOptions options = null)
CustomerPaymentMethodListOptions options = null)
{
var paymentMethods = await _customerService.ListPaymentMethodsAsync(id, options);
return paymentMethods.Data;
@@ -77,12 +80,12 @@ public class StripeAdapter : IStripeAdapter
CustomerBalanceTransactionCreateOptions options)
=> await _customerBalanceTransactionService.CreateAsync(customerId, options);
public Task<Stripe.Subscription> SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions options)
public Task<Subscription> SubscriptionCreateAsync(SubscriptionCreateOptions options)
{
return _subscriptionService.CreateAsync(options);
}
public Task<Stripe.Subscription> SubscriptionGetAsync(string id, Stripe.SubscriptionGetOptions options = null)
public Task<Subscription> SubscriptionGetAsync(string id, SubscriptionGetOptions options = null)
{
return _subscriptionService.GetAsync(id, options);
}
@@ -101,28 +104,23 @@ public class StripeAdapter : IStripeAdapter
throw new InvalidOperationException("Subscription does not belong to the provider.");
}
public Task<Stripe.Subscription> SubscriptionUpdateAsync(string id,
Stripe.SubscriptionUpdateOptions options = null)
public Task<Subscription> SubscriptionUpdateAsync(string id,
SubscriptionUpdateOptions options = null)
{
return _subscriptionService.UpdateAsync(id, options);
}
public Task<Stripe.Subscription> SubscriptionCancelAsync(string Id, Stripe.SubscriptionCancelOptions options = null)
public Task<Subscription> SubscriptionCancelAsync(string Id, SubscriptionCancelOptions options = null)
{
return _subscriptionService.CancelAsync(Id, options);
}
public Task<Stripe.Invoice> InvoiceUpcomingAsync(Stripe.UpcomingInvoiceOptions options)
{
return _invoiceService.UpcomingAsync(options);
}
public Task<Stripe.Invoice> InvoiceGetAsync(string id, Stripe.InvoiceGetOptions options)
public Task<Invoice> InvoiceGetAsync(string id, InvoiceGetOptions options)
{
return _invoiceService.GetAsync(id, options);
}
public async Task<List<Stripe.Invoice>> InvoiceListAsync(StripeInvoiceListOptions options)
public async Task<List<Invoice>> InvoiceListAsync(StripeInvoiceListOptions options)
{
if (!options.SelectAll)
{
@@ -131,7 +129,7 @@ public class StripeAdapter : IStripeAdapter
options.Limit = 100;
var invoices = new List<Stripe.Invoice>();
var invoices = new List<Invoice>();
await foreach (var invoice in _invoiceService.ListAutoPagingAsync(options.ToInvoiceListOptions()))
{
@@ -146,120 +144,104 @@ public class StripeAdapter : IStripeAdapter
return _invoiceService.CreatePreviewAsync(options);
}
public async Task<List<Stripe.Invoice>> InvoiceSearchAsync(InvoiceSearchOptions options)
public async Task<List<Invoice>> InvoiceSearchAsync(InvoiceSearchOptions options)
=> (await _invoiceService.SearchAsync(options)).Data;
public Task<Stripe.Invoice> InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options)
public Task<Invoice> InvoiceUpdateAsync(string id, InvoiceUpdateOptions options)
{
return _invoiceService.UpdateAsync(id, options);
}
public Task<Stripe.Invoice> InvoiceFinalizeInvoiceAsync(string id, Stripe.InvoiceFinalizeOptions options)
public Task<Invoice> InvoiceFinalizeInvoiceAsync(string id, InvoiceFinalizeOptions options)
{
return _invoiceService.FinalizeInvoiceAsync(id, options);
}
public Task<Stripe.Invoice> InvoiceSendInvoiceAsync(string id, Stripe.InvoiceSendOptions options)
public Task<Invoice> InvoiceSendInvoiceAsync(string id, InvoiceSendOptions options)
{
return _invoiceService.SendInvoiceAsync(id, options);
}
public Task<Stripe.Invoice> InvoicePayAsync(string id, Stripe.InvoicePayOptions options = null)
public Task<Invoice> InvoicePayAsync(string id, InvoicePayOptions options = null)
{
return _invoiceService.PayAsync(id, options);
}
public Task<Stripe.Invoice> InvoiceDeleteAsync(string id, Stripe.InvoiceDeleteOptions options = null)
public Task<Invoice> InvoiceDeleteAsync(string id, InvoiceDeleteOptions options = null)
{
return _invoiceService.DeleteAsync(id, options);
}
public Task<Stripe.Invoice> InvoiceVoidInvoiceAsync(string id, Stripe.InvoiceVoidOptions options = null)
public Task<Invoice> InvoiceVoidInvoiceAsync(string id, InvoiceVoidOptions options = null)
{
return _invoiceService.VoidInvoiceAsync(id, options);
}
public IEnumerable<Stripe.PaymentMethod> PaymentMethodListAutoPaging(Stripe.PaymentMethodListOptions options)
public IEnumerable<PaymentMethod> PaymentMethodListAutoPaging(PaymentMethodListOptions options)
{
return _paymentMethodService.ListAutoPaging(options);
}
public IAsyncEnumerable<Stripe.PaymentMethod> PaymentMethodListAutoPagingAsync(Stripe.PaymentMethodListOptions options)
public IAsyncEnumerable<PaymentMethod> PaymentMethodListAutoPagingAsync(PaymentMethodListOptions options)
=> _paymentMethodService.ListAutoPagingAsync(options);
public Task<Stripe.PaymentMethod> PaymentMethodAttachAsync(string id, Stripe.PaymentMethodAttachOptions options = null)
public Task<PaymentMethod> PaymentMethodAttachAsync(string id, PaymentMethodAttachOptions options = null)
{
return _paymentMethodService.AttachAsync(id, options);
}
public Task<Stripe.PaymentMethod> PaymentMethodDetachAsync(string id, Stripe.PaymentMethodDetachOptions options = null)
public Task<PaymentMethod> PaymentMethodDetachAsync(string id, PaymentMethodDetachOptions options = null)
{
return _paymentMethodService.DetachAsync(id, options);
}
public Task<Stripe.Plan> PlanGetAsync(string id, Stripe.PlanGetOptions options = null)
public Task<Plan> PlanGetAsync(string id, PlanGetOptions options = null)
{
return _planService.GetAsync(id, options);
}
public Task<Stripe.TaxId> TaxIdCreateAsync(string id, Stripe.TaxIdCreateOptions options)
public Task<TaxId> TaxIdCreateAsync(string id, TaxIdCreateOptions options)
{
return _taxIdService.CreateAsync(id, options);
}
public Task<Stripe.TaxId> TaxIdDeleteAsync(string customerId, string taxIdId,
Stripe.TaxIdDeleteOptions options = null)
public Task<TaxId> TaxIdDeleteAsync(string customerId, string taxIdId,
TaxIdDeleteOptions options = null)
{
return _taxIdService.DeleteAsync(customerId, taxIdId);
}
public Task<Stripe.StripeList<Stripe.Tax.Registration>> TaxRegistrationsListAsync(Stripe.Tax.RegistrationListOptions options = null)
public Task<StripeList<Registration>> TaxRegistrationsListAsync(RegistrationListOptions options = null)
{
return _taxRegistrationService.ListAsync(options);
}
public Task<Stripe.StripeList<Stripe.Charge>> ChargeListAsync(Stripe.ChargeListOptions options)
public Task<StripeList<Charge>> ChargeListAsync(ChargeListOptions options)
{
return _chargeService.ListAsync(options);
}
public Task<Stripe.Refund> RefundCreateAsync(Stripe.RefundCreateOptions options)
public Task<Refund> RefundCreateAsync(RefundCreateOptions options)
{
return _refundService.CreateAsync(options);
}
public Task<Stripe.Card> CardDeleteAsync(string customerId, string cardId, Stripe.CardDeleteOptions options = null)
public Task<Card> CardDeleteAsync(string customerId, string cardId, CardDeleteOptions options = null)
{
return _cardService.DeleteAsync(customerId, cardId, options);
}
public Task<Stripe.BankAccount> BankAccountCreateAsync(string customerId, Stripe.BankAccountCreateOptions options = null)
public Task<BankAccount> BankAccountCreateAsync(string customerId, BankAccountCreateOptions options = null)
{
return _bankAccountService.CreateAsync(customerId, options);
}
public Task<Stripe.BankAccount> BankAccountDeleteAsync(string customerId, string bankAccount, Stripe.BankAccountDeleteOptions options = null)
public Task<BankAccount> BankAccountDeleteAsync(string customerId, string bankAccount, BankAccountDeleteOptions options = null)
{
return _bankAccountService.DeleteAsync(customerId, bankAccount, options);
}
public async Task<List<Stripe.Subscription>> SubscriptionListAsync(StripeSubscriptionListOptions options)
{
if (!options.SelectAll)
{
return (await _subscriptionService.ListAsync(options.ToStripeApiOptions())).Data;
}
options.Limit = 100;
var items = new List<Stripe.Subscription>();
await foreach (var i in _subscriptionService.ListAutoPagingAsync(options.ToStripeApiOptions()))
{
items.Add(i);
}
return items;
}
public async Task<Stripe.StripeList<Stripe.Price>> PriceListAsync(Stripe.PriceListOptions options = null)
public async Task<StripeList<Price>> PriceListAsync(PriceListOptions options = null)
{
return await _priceService.ListAsync(options);
}

View File

@@ -65,19 +65,20 @@ public class StripePaymentService : IPaymentService
bool applySponsorship)
{
var existingPlan = await _pricingClient.GetPlanOrThrow(org.PlanType);
var sponsoredPlan = sponsorship?.PlanSponsorshipType != null ?
Utilities.StaticStore.GetSponsoredPlan(sponsorship.PlanSponsorshipType.Value) :
null;
var subscriptionUpdate = new SponsorOrganizationSubscriptionUpdate(existingPlan, sponsoredPlan, applySponsorship);
var sponsoredPlan = sponsorship?.PlanSponsorshipType != null
? Utilities.StaticStore.GetSponsoredPlan(sponsorship.PlanSponsorshipType.Value)
: null;
var subscriptionUpdate =
new SponsorOrganizationSubscriptionUpdate(existingPlan, sponsoredPlan, applySponsorship);
await FinalizeSubscriptionChangeAsync(org, subscriptionUpdate, true);
var sub = await _stripeAdapter.SubscriptionGetAsync(org.GatewaySubscriptionId);
org.ExpirationDate = sub.CurrentPeriodEnd;
org.ExpirationDate = sub.GetCurrentPeriodEnd();
if (sponsorship is not null)
{
sponsorship.ValidUntil = sub.CurrentPeriodEnd;
sponsorship.ValidUntil = sub.GetCurrentPeriodEnd();
}
}
@@ -100,7 +101,8 @@ public class StripePaymentService : IPaymentService
if (sub.Status == SubscriptionStatuses.Canceled)
{
throw new BadRequestException("You do not have an active subscription. Reinstate your subscription to make changes.");
throw new BadRequestException(
"You do not have an active subscription. Reinstate your subscription to make changes.");
}
var existingCoupon = sub.Customer.Discount?.Coupon?.Id;
@@ -191,24 +193,24 @@ public class StripePaymentService : IPaymentService
throw;
}
}
else if (!invoice.Paid)
else if (invoice.Status != StripeConstants.InvoiceStatus.Paid)
{
// Pay invoice with no charge to the customer this completes the invoice immediately without waiting the scheduled 1h
invoice = await _stripeAdapter.InvoicePayAsync(subResponse.LatestInvoiceId);
paymentIntentClientSecret = null;
}
}
finally
{
// Change back the subscription collection method and/or days until due
if (collectionMethod != "send_invoice" || daysUntilDue == null)
{
await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, new SubscriptionUpdateOptions
{
CollectionMethod = collectionMethod,
DaysUntilDue = daysUntilDue,
});
await _stripeAdapter.SubscriptionUpdateAsync(sub.Id,
new SubscriptionUpdateOptions
{
CollectionMethod = collectionMethod,
DaysUntilDue = daysUntilDue,
});
}
var customer = await _stripeAdapter.CustomerGetAsync(sub.CustomerId);
@@ -218,9 +220,15 @@ public class StripePaymentService : IPaymentService
if (!string.IsNullOrEmpty(existingCoupon) && string.IsNullOrEmpty(newCoupon))
{
// Re-add the lost coupon due to the update.
await _stripeAdapter.CustomerUpdateAsync(sub.CustomerId, new CustomerUpdateOptions
await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, new SubscriptionUpdateOptions
{
Coupon = existingCoupon
Discounts =
[
new SubscriptionDiscountOptions
{
Coupon = existingCoupon
}
]
});
}
}
@@ -352,7 +360,7 @@ public class StripePaymentService : IPaymentService
{
var hasDefaultCardPaymentMethod = customer.InvoiceSettings?.DefaultPaymentMethod?.Type == "card";
var hasDefaultValidSource = customer.DefaultSource != null &&
(customer.DefaultSource is Card || customer.DefaultSource is BankAccount);
(customer.DefaultSource is Card || customer.DefaultSource is BankAccount);
if (!hasDefaultCardPaymentMethod && !hasDefaultValidSource)
{
cardPaymentMethodId = GetLatestCardPaymentMethod(customer.Id)?.Id;
@@ -365,12 +373,11 @@ public class StripePaymentService : IPaymentService
}
catch
{
await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id, new InvoiceFinalizeOptions
{
AutoAdvance = false
});
await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id,
new InvoiceFinalizeOptions { AutoAdvance = false });
await _stripeAdapter.InvoiceVoidInvoiceAsync(invoice.Id);
}
throw new BadRequestException("No payment method is available.");
}
}
@@ -381,14 +388,9 @@ public class StripePaymentService : IPaymentService
{
// Finalize the invoice (from Draft) w/o auto-advance so we
// can attempt payment manually.
invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id, new InvoiceFinalizeOptions
{
AutoAdvance = false,
});
var invoicePayOptions = new InvoicePayOptions
{
PaymentMethod = cardPaymentMethodId,
};
invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id,
new InvoiceFinalizeOptions { AutoAdvance = false, });
var invoicePayOptions = new InvoicePayOptions { PaymentMethod = cardPaymentMethodId, };
if (customer?.Metadata?.ContainsKey("btCustomerId") ?? false)
{
invoicePayOptions.PaidOutOfBand = true;
@@ -403,13 +405,15 @@ public class StripePaymentService : IPaymentService
SubmitForSettlement = true,
PayPal = new Braintree.TransactionOptionsPayPalRequest
{
CustomField = $"{subscriber.BraintreeIdField()}:{subscriber.Id},{subscriber.BraintreeCloudRegionField()}:{_globalSettings.BaseServiceUri.CloudRegion}"
CustomField =
$"{subscriber.BraintreeIdField()}:{subscriber.Id},{subscriber.BraintreeCloudRegionField()}:{_globalSettings.BaseServiceUri.CloudRegion}"
}
},
CustomFields = new Dictionary<string, string>
{
[subscriber.BraintreeIdField()] = subscriber.Id.ToString(),
[subscriber.BraintreeCloudRegionField()] = _globalSettings.BaseServiceUri.CloudRegion
[subscriber.BraintreeCloudRegionField()] =
_globalSettings.BaseServiceUri.CloudRegion
}
});
@@ -442,9 +446,9 @@ public class StripePaymentService : IPaymentService
{
// SCA required, get intent client secret
var invoiceGetOptions = new InvoiceGetOptions();
invoiceGetOptions.AddExpand("payment_intent");
invoiceGetOptions.AddExpand("confirmation_secret");
invoice = await _stripeAdapter.InvoiceGetAsync(invoice.Id, invoiceGetOptions);
paymentIntentClientSecret = invoice?.PaymentIntent?.ClientSecret;
paymentIntentClientSecret = invoice?.ConfirmationSecret?.ClientSecret;
}
else
{
@@ -458,6 +462,7 @@ public class StripePaymentService : IPaymentService
{
await _btGateway.Transaction.RefundAsync(braintreeTransaction.Id);
}
if (invoice != null)
{
if (invoice.Status == "paid")
@@ -479,10 +484,8 @@ public class StripePaymentService : IPaymentService
// Assumption: Customer balance should now be $0, otherwise payment would not have failed.
if (customer.Balance == 0)
{
await _stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions
{
Balance = invoice.StartingBalance
});
await _stripeAdapter.CustomerUpdateAsync(customer.Id,
new CustomerUpdateOptions { Balance = invoice.StartingBalance });
}
}
}
@@ -496,6 +499,7 @@ public class StripePaymentService : IPaymentService
// Let the caller perform any subscription change cleanup
throw;
}
return paymentIntentClientSecret;
}
@@ -526,10 +530,10 @@ public class StripePaymentService : IPaymentService
try
{
var canceledSub = endOfPeriod ?
await _stripeAdapter.SubscriptionUpdateAsync(sub.Id,
new SubscriptionUpdateOptions { CancelAtPeriodEnd = true }) :
await _stripeAdapter.SubscriptionCancelAsync(sub.Id, new SubscriptionCancelOptions());
var canceledSub = endOfPeriod
? await _stripeAdapter.SubscriptionUpdateAsync(sub.Id,
new SubscriptionUpdateOptions { CancelAtPeriodEnd = true })
: await _stripeAdapter.SubscriptionCancelAsync(sub.Id, new SubscriptionCancelOptions());
if (!canceledSub.CanceledAt.HasValue)
{
throw new GatewayException("Unable to cancel subscription.");
@@ -580,7 +584,7 @@ public class StripePaymentService : IPaymentService
{
Customer customer = null;
var customerExists = subscriber.Gateway == GatewayType.Stripe &&
!string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId);
!string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId);
if (customerExists)
{
customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId);
@@ -595,10 +599,10 @@ public class StripePaymentService : IPaymentService
subscriber.Gateway = GatewayType.Stripe;
subscriber.GatewayCustomerId = customer.Id;
}
await _stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions
{
Balance = customer.Balance - (long)(creditAmount * 100)
});
await _stripeAdapter.CustomerUpdateAsync(customer.Id,
new CustomerUpdateOptions { Balance = customer.Balance - (long)(creditAmount * 100) });
return !customerExists;
}
@@ -630,50 +634,45 @@ public class StripePaymentService : IPaymentService
{
var subscriptionInfo = new SubscriptionInfo();
if (!string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId))
{
var customerGetOptions = new CustomerGetOptions();
customerGetOptions.AddExpand("discount.coupon.applies_to");
var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerGetOptions);
if (customer.Discount != null)
{
subscriptionInfo.CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount(customer.Discount);
}
}
if (string.IsNullOrWhiteSpace(subscriber.GatewaySubscriptionId))
if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId))
{
return subscriptionInfo;
}
var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, new SubscriptionGetOptions
var subscription = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId,
new SubscriptionGetOptions { Expand = ["customer", "discounts", "test_clock"] });
subscriptionInfo.Subscription = new SubscriptionInfo.BillingSubscription(subscription);
var discount = subscription.Customer.Discount ?? subscription.Discounts.FirstOrDefault();
if (discount != null)
{
Expand = ["test_clock"]
});
if (sub != null)
{
subscriptionInfo.Subscription = new SubscriptionInfo.BillingSubscription(sub);
var (suspensionDate, unpaidPeriodEndDate) = await GetSuspensionDateAsync(sub);
if (suspensionDate.HasValue && unpaidPeriodEndDate.HasValue)
{
subscriptionInfo.Subscription.SuspensionDate = suspensionDate;
subscriptionInfo.Subscription.UnpaidPeriodEndDate = unpaidPeriodEndDate;
}
subscriptionInfo.CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount(discount);
}
if (sub is { CanceledAt: not null } || string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId))
var (suspensionDate, unpaidPeriodEndDate) = await GetSuspensionDateAsync(subscription);
if (suspensionDate.HasValue && unpaidPeriodEndDate.HasValue)
{
subscriptionInfo.Subscription.SuspensionDate = suspensionDate;
subscriptionInfo.Subscription.UnpaidPeriodEndDate = unpaidPeriodEndDate;
}
if (subscription is { CanceledAt: not null } || string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId))
{
return subscriptionInfo;
}
try
{
var upcomingInvoiceOptions = new UpcomingInvoiceOptions { Customer = subscriber.GatewayCustomerId };
var upcomingInvoice = await _stripeAdapter.InvoiceUpcomingAsync(upcomingInvoiceOptions);
var invoiceCreatePreviewOptions = new InvoiceCreatePreviewOptions
{
Customer = subscriber.GatewayCustomerId,
Subscription = subscriber.GatewaySubscriptionId
};
var upcomingInvoice = await _stripeAdapter.InvoiceCreatePreviewAsync(invoiceCreatePreviewOptions);
if (upcomingInvoice != null)
{
@@ -682,7 +681,12 @@ public class StripePaymentService : IPaymentService
}
catch (StripeException ex)
{
_logger.LogWarning(ex, "Encountered an unexpected Stripe error");
_logger.LogWarning(
ex,
"Failed to retrieve upcoming invoice for customer {CustomerId}, subscription {SubscriptionId}. Error Code: {ErrorCode}",
subscriber.GatewayCustomerId,
subscriber.GatewaySubscriptionId,
ex.StripeError?.Code);
}
return subscriptionInfo;
@@ -788,7 +792,11 @@ public class StripePaymentService : IPaymentService
if (taxInfo.TaxIdType == StripeConstants.TaxIdType.SpanishNIF)
{
await _stripeAdapter.TaxIdCreateAsync(customer.Id,
new TaxIdCreateOptions { Type = StripeConstants.TaxIdType.EUVAT, Value = $"ES{taxInfo.TaxIdNumber}" });
new TaxIdCreateOptions
{
Type = StripeConstants.TaxIdType.EUVAT,
Value = $"ES{taxInfo.TaxIdNumber}"
});
}
}
catch (StripeException e)
@@ -829,7 +837,8 @@ public class StripePaymentService : IPaymentService
await HasSecretsManagerStandaloneAsync(gatewayCustomerId: organization.GatewayCustomerId,
organizationHasSecretsManager: organization.UseSecretsManager);
private async Task<bool> HasSecretsManagerStandaloneAsync(string gatewayCustomerId, bool organizationHasSecretsManager)
private async Task<bool> HasSecretsManagerStandaloneAsync(string gatewayCustomerId,
bool organizationHasSecretsManager)
{
if (string.IsNullOrEmpty(gatewayCustomerId))
{
@@ -894,26 +903,14 @@ public class StripePaymentService : IPaymentService
{
var options = new InvoiceCreatePreviewOptions
{
AutomaticTax = new InvoiceAutomaticTaxOptions
{
Enabled = true,
},
AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true, },
Currency = "usd",
SubscriptionDetails = new InvoiceSubscriptionDetailsOptions
{
Items =
[
new()
{
Quantity = 1,
Plan = StripeConstants.Prices.PremiumAnnually
},
new()
{
Quantity = parameters.PasswordManager.AdditionalStorage,
Plan = "storage-gb-annually"
}
new InvoiceSubscriptionDetailsItemOptions { Quantity = 1, Plan = StripeConstants.Prices.PremiumAnnually },
new InvoiceSubscriptionDetailsItemOptions { Quantity = parameters.PasswordManager.AdditionalStorage, Plan = StripeConstants.Prices.StoragePlanPersonal }
]
},
CustomerDetails = new InvoiceCustomerDetailsOptions
@@ -940,12 +937,9 @@ public class StripePaymentService : IPaymentService
throw new BadRequestException("billingPreviewInvalidTaxIdError");
}
options.CustomerDetails.TaxIds = [
new InvoiceCustomerDetailsTaxIdOptions
{
Type = taxIdType,
Value = parameters.TaxInformation.TaxId
}
options.CustomerDetails.TaxIds =
[
new InvoiceCustomerDetailsTaxIdOptions { Type = taxIdType, Value = parameters.TaxInformation.TaxId }
];
if (taxIdType == StripeConstants.TaxIdType.SpanishNIF)
@@ -964,7 +958,7 @@ public class StripePaymentService : IPaymentService
if (gatewayCustomer.Discount != null)
{
options.Coupon = gatewayCustomer.Discount.Coupon.Id;
options.Discounts = [new InvoiceDiscountOptions { Coupon = gatewayCustomer.Discount.Coupon.Id }];
}
}
@@ -972,24 +966,31 @@ public class StripePaymentService : IPaymentService
{
var gatewaySubscription = await _stripeAdapter.SubscriptionGetAsync(gatewaySubscriptionId);
if (gatewaySubscription?.Discount != null)
if (gatewaySubscription?.Discounts is { Count: > 0 })
{
options.Coupon ??= gatewaySubscription.Discount.Coupon.Id;
options.Discounts = gatewaySubscription.Discounts.Select(x => new InvoiceDiscountOptions { Coupon = x.Coupon.Id }).ToList();
}
}
if (options.Discounts is { Count: > 0 })
{
options.Discounts = options.Discounts.DistinctBy(invoiceDiscountOptions => invoiceDiscountOptions.Coupon).ToList();
}
try
{
var invoice = await _stripeAdapter.InvoiceCreatePreviewAsync(options);
var effectiveTaxRate = invoice.Tax != null && invoice.TotalExcludingTax != null && invoice.TotalExcludingTax.Value != 0
? invoice.Tax.Value.ToMajor() / invoice.TotalExcludingTax.Value.ToMajor()
var tax = invoice.TotalTaxes.Sum(invoiceTotalTax => invoiceTotalTax.Amount);
var effectiveTaxRate = invoice.TotalExcludingTax != null && invoice.TotalExcludingTax.Value != 0
? tax.ToMajor() / invoice.TotalExcludingTax.Value.ToMajor()
: 0M;
var result = new PreviewInvoiceResponseModel(
effectiveTaxRate,
invoice.TotalExcludingTax.ToMajor() ?? 0,
invoice.Tax.ToMajor() ?? 0,
tax.ToMajor(),
invoice.Total.ToMajor());
return result;
}
@@ -1003,7 +1004,8 @@ public class StripePaymentService : IPaymentService
parameters.TaxInformation.Country);
throw new BadRequestException("billingPreviewInvalidTaxIdError");
default:
_logger.LogError(e, "Unexpected error previewing invoice with tax ID '{TaxId}' in country '{Country}'.",
_logger.LogError(e,
"Unexpected error previewing invoice with tax ID '{TaxId}' in country '{Country}'.",
parameters.TaxInformation.TaxId,
parameters.TaxInformation.Country);
throw new BadRequestException("billingPreviewInvoiceError");
@@ -1101,12 +1103,9 @@ public class StripePaymentService : IPaymentService
throw new BadRequestException("billingTaxIdTypeInferenceError");
}
options.CustomerDetails.TaxIds = [
new InvoiceCustomerDetailsTaxIdOptions
{
Type = taxIdType,
Value = parameters.TaxInformation.TaxId
}
options.CustomerDetails.TaxIds =
[
new InvoiceCustomerDetailsTaxIdOptions { Type = taxIdType, Value = parameters.TaxInformation.TaxId }
];
if (taxIdType == StripeConstants.TaxIdType.SpanishNIF)
@@ -1127,7 +1126,10 @@ public class StripePaymentService : IPaymentService
if (gatewayCustomer.Discount != null)
{
options.Coupon = gatewayCustomer.Discount.Coupon.Id;
options.Discounts =
[
new InvoiceDiscountOptions { Coupon = gatewayCustomer.Discount.Coupon.Id }
];
}
}
@@ -1135,9 +1137,10 @@ public class StripePaymentService : IPaymentService
{
var gatewaySubscription = await _stripeAdapter.SubscriptionGetAsync(gatewaySubscriptionId);
if (gatewaySubscription?.Discount != null)
if (gatewaySubscription?.Discounts != null)
{
options.Coupon ??= gatewaySubscription.Discount.Coupon.Id;
options.Discounts = gatewaySubscription.Discounts
.Select(discount => new InvoiceDiscountOptions { Coupon = discount.Coupon.Id }).ToList();
}
}
@@ -1152,14 +1155,16 @@ public class StripePaymentService : IPaymentService
{
var invoice = await _stripeAdapter.InvoiceCreatePreviewAsync(options);
var effectiveTaxRate = invoice.Tax != null && invoice.TotalExcludingTax != null && invoice.TotalExcludingTax.Value != 0
? invoice.Tax.Value.ToMajor() / invoice.TotalExcludingTax.Value.ToMajor()
var tax = invoice.TotalTaxes.Sum(invoiceTotalTax => invoiceTotalTax.Amount);
var effectiveTaxRate = invoice.TotalExcludingTax != null && invoice.TotalExcludingTax.Value != 0
? tax.ToMajor() / invoice.TotalExcludingTax.Value.ToMajor()
: 0M;
var result = new PreviewInvoiceResponseModel(
effectiveTaxRate,
invoice.TotalExcludingTax.ToMajor() ?? 0,
invoice.Tax.ToMajor() ?? 0,
tax.ToMajor(),
invoice.Total.ToMajor());
return result;
}
@@ -1173,7 +1178,8 @@ public class StripePaymentService : IPaymentService
parameters.TaxInformation.Country);
throw new BadRequestException("billingPreviewInvalidTaxIdError");
default:
_logger.LogError(e, "Unexpected error previewing invoice with tax ID '{TaxId}' in country '{Country}'.",
_logger.LogError(e,
"Unexpected error previewing invoice with tax ID '{TaxId}' in country '{Country}'.",
parameters.TaxInformation.TaxId,
parameters.TaxInformation.Country);
throw new BadRequestException("billingPreviewInvoiceError");
@@ -1207,7 +1213,9 @@ public class StripePaymentService : IPaymentService
braintreeCustomer.DefaultPaymentMethod);
}
}
catch (Braintree.Exceptions.NotFoundException) { }
catch (Braintree.Exceptions.NotFoundException)
{
}
}
if (customer.InvoiceSettings?.DefaultPaymentMethod?.Type == "card")
@@ -1246,12 +1254,15 @@ public class StripePaymentService : IPaymentService
{
customer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId, options);
}
catch (StripeException) { }
catch (StripeException)
{
}
return customer;
}
private async Task<IEnumerable<BillingHistoryInfo.BillingTransaction>> GetBillingTransactionsAsync(ISubscriber subscriber, int? limit = null)
private async Task<IEnumerable<BillingHistoryInfo.BillingTransaction>> GetBillingTransactionsAsync(
ISubscriber subscriber, int? limit = null)
{
var transactions = subscriber switch
{

View File

@@ -17,6 +17,6 @@ public class LoggingExceptionHandlerFilterAttribute : ExceptionFilterAttribute
var logger = context.HttpContext.RequestServices
.GetRequiredService<ILogger<LoggingExceptionHandlerFilterAttribute>>();
logger.LogError(0, exception, exception.Message);
logger.LogError(0, exception, "Unhandled exception");
}
}