1
0
mirror of https://github.com/bitwarden/server synced 2025-12-26 05:03:18 +00:00

Merge branch 'main' of https://github.com/bitwarden/server into vault/pm-25957/sharing-cipher-to-org

This commit is contained in:
Nick Krantz
2025-10-08 10:56:59 -05:00
126 changed files with 25603 additions and 1549 deletions

View File

@@ -484,7 +484,7 @@ jobs:
uses: bitwarden/gh-actions/azure-logout@main
- name: Trigger self-host build
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ steps.retrieve-secret-pat.outputs.github-pat-bitwarden-devops-bot-repo-scope }}
script: |
@@ -525,7 +525,7 @@ jobs:
uses: bitwarden/gh-actions/azure-logout@main
- name: Trigger k8s deploy
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ steps.retrieve-secret-pat.outputs.github-pat-bitwarden-devops-bot-repo-scope }}
script: |

View File

@@ -3,7 +3,7 @@
<PropertyGroup>
<TargetFramework>net8.0</TargetFramework>
<Version>2025.9.2</Version>
<Version>2025.10.0</Version>
<RootNamespace>Bit.$(MSBuildProjectName)</RootNamespace>
<ImplicitUsings>enable</ImplicitUsings>

View File

@@ -12,7 +12,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Organizations;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Providers.Services;
using Bit.Core.Context;
@@ -90,7 +90,7 @@ public class ProviderService : IProviderService
_providerClientOrganizationSignUpCommand = providerClientOrganizationSignUpCommand;
}
public async Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo, TokenizedPaymentSource tokenizedPaymentSource = null)
public async Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TokenizedPaymentMethod paymentMethod, BillingAddress billingAddress)
{
var owner = await _userService.GetUserByIdAsync(ownerUserId);
if (owner == null)
@@ -115,21 +115,7 @@ public class ProviderService : IProviderService
throw new BadRequestException("Invalid owner.");
}
if (taxInfo == null || string.IsNullOrEmpty(taxInfo.BillingAddressCountry) || string.IsNullOrEmpty(taxInfo.BillingAddressPostalCode))
{
throw new BadRequestException("Both address and postal code are required to set up your provider.");
}
if (tokenizedPaymentSource is not
{
Type: PaymentMethodType.BankAccount or PaymentMethodType.Card or PaymentMethodType.PayPal,
Token: not null and not ""
})
{
throw new BadRequestException("A payment method is required to set up your provider.");
}
var customer = await _providerBillingService.SetupCustomer(provider, taxInfo, tokenizedPaymentSource);
var customer = await _providerBillingService.SetupCustomer(provider, paymentMethod, billingAddress);
provider.GatewayCustomerId = customer.Id;
var subscription = await _providerBillingService.SetupSubscription(provider);
provider.GatewaySubscriptionId = subscription.Id;

View File

@@ -14,6 +14,7 @@ using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Providers.Entities;
using Bit.Core.Billing.Providers.Models;
@@ -21,10 +22,8 @@ using Bit.Core.Billing.Providers.Repositories;
using Bit.Core.Billing.Providers.Services;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Models.Business;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
@@ -38,6 +37,9 @@ using Subscription = Stripe.Subscription;
namespace Bit.Commercial.Core.Billing.Providers.Services;
using static Constants;
using static StripeConstants;
public class ProviderBillingService(
IBraintreeGateway braintreeGateway,
IEventService eventService,
@@ -51,8 +53,7 @@ public class ProviderBillingService(
IProviderUserRepository providerUserRepository,
ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter,
ISubscriberService subscriberService,
ITaxService taxService)
ISubscriberService subscriberService)
: IProviderBillingService
{
public async Task AddExistingOrganization(
@@ -61,10 +62,7 @@ public class ProviderBillingService(
string key)
{
await stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId,
new SubscriptionUpdateOptions
{
CancelAtPeriodEnd = false
});
new SubscriptionUpdateOptions { CancelAtPeriodEnd = false });
var subscription =
await stripeAdapter.SubscriptionCancelAsync(organization.GatewaySubscriptionId,
@@ -83,7 +81,7 @@ public class ProviderBillingService(
var wasTrialing = subscription.TrialEnd.HasValue && subscription.TrialEnd.Value > now;
if (!wasTrialing && subscription.LatestInvoice.Status == StripeConstants.InvoiceStatus.Draft)
if (!wasTrialing && subscription.LatestInvoice.Status == InvoiceStatus.Draft)
{
await stripeAdapter.InvoiceFinalizeInvoiceAsync(subscription.LatestInvoiceId,
new InvoiceFinalizeOptions { AutoAdvance = true });
@@ -184,16 +182,8 @@ public class ProviderBillingService(
{
Items =
[
new SubscriptionItemOptions
{
Price = newPriceId,
Quantity = oldSubscriptionItem!.Quantity
},
new SubscriptionItemOptions
{
Id = oldSubscriptionItem.Id,
Deleted = true
}
new SubscriptionItemOptions { Price = newPriceId, Quantity = oldSubscriptionItem!.Quantity },
new SubscriptionItemOptions { Id = oldSubscriptionItem.Id, Deleted = true }
]
};
@@ -202,7 +192,8 @@ public class ProviderBillingService(
// Refactor later to ?ChangeClientPlanCommand? (ProviderPlanId, ProviderId, OrganizationId)
// 1. Retrieve PlanType and PlanName for ProviderPlan
// 2. Assign PlanType & PlanName to Organization
var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerPlan.ProviderId);
var providerOrganizations =
await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerPlan.ProviderId);
var newPlan = await pricingClient.GetPlanOrThrow(newPlanType);
@@ -213,6 +204,7 @@ public class ProviderBillingService(
{
throw new ConflictException($"Organization '{providerOrganization.Id}' not found.");
}
organization.PlanType = newPlanType;
organization.Plan = newPlan.Name;
await organizationRepository.ReplaceAsync(organization);
@@ -228,15 +220,15 @@ public class ProviderBillingService(
if (!string.IsNullOrEmpty(organization.GatewayCustomerId))
{
logger.LogWarning("Client organization ({ID}) already has a populated {FieldName}", organization.Id, nameof(organization.GatewayCustomerId));
logger.LogWarning("Client organization ({ID}) already has a populated {FieldName}", organization.Id,
nameof(organization.GatewayCustomerId));
return;
}
var providerCustomer = await subscriberService.GetCustomerOrThrow(provider, new CustomerGetOptions
{
Expand = ["tax", "tax_ids"]
});
var providerCustomer =
await subscriberService.GetCustomerOrThrow(provider,
new CustomerGetOptions { Expand = ["tax", "tax_ids"] });
var providerTaxId = providerCustomer.TaxIds.FirstOrDefault();
@@ -269,23 +261,18 @@ public class ProviderBillingService(
}
]
},
Metadata = new Dictionary<string, string>
{
{ "region", globalSettings.BaseServiceUri.CloudRegion }
},
TaxIdData = providerTaxId == null ? null :
[
new CustomerTaxIdDataOptions
{
Type = providerTaxId.Type,
Value = providerTaxId.Value
}
]
Metadata = new Dictionary<string, string> { { "region", globalSettings.BaseServiceUri.CloudRegion } },
TaxIdData = providerTaxId == null
? null
:
[
new CustomerTaxIdDataOptions { Type = providerTaxId.Type, Value = providerTaxId.Value }
]
};
if (providerCustomer.Address is not { Country: Constants.CountryAbbreviations.UnitedStates })
if (providerCustomer.Address is not { Country: CountryAbbreviations.UnitedStates })
{
customerCreateOptions.TaxExempt = StripeConstants.TaxExempt.Reverse;
customerCreateOptions.TaxExempt = TaxExempt.Reverse;
}
var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions);
@@ -347,9 +334,9 @@ public class ProviderBillingService(
.Where(pair => pair.subscription is
{
Status:
StripeConstants.SubscriptionStatus.Active or
StripeConstants.SubscriptionStatus.Trialing or
StripeConstants.SubscriptionStatus.PastDue
SubscriptionStatus.Active or
SubscriptionStatus.Trialing or
SubscriptionStatus.PastDue
}).ToList();
if (active.Count == 0)
@@ -474,37 +461,27 @@ public class ProviderBillingService(
// Below the limit to above the limit
(currentlyAssignedSeatTotal <= seatMinimum && newlyAssignedSeatTotal > seatMinimum) ||
// Above the limit to further above the limit
(currentlyAssignedSeatTotal > seatMinimum && newlyAssignedSeatTotal > seatMinimum && newlyAssignedSeatTotal > currentlyAssignedSeatTotal);
(currentlyAssignedSeatTotal > seatMinimum && newlyAssignedSeatTotal > seatMinimum &&
newlyAssignedSeatTotal > currentlyAssignedSeatTotal);
}
public async Task<Customer> SetupCustomer(
Provider provider,
TaxInfo taxInfo,
TokenizedPaymentSource tokenizedPaymentSource)
TokenizedPaymentMethod paymentMethod,
BillingAddress billingAddress)
{
ArgumentNullException.ThrowIfNull(tokenizedPaymentSource);
if (taxInfo is not
{
BillingAddressCountry: not null and not "",
BillingAddressPostalCode: not null and not ""
})
{
logger.LogError("Cannot create customer for provider ({ProviderID}) without both a country and postal code", provider.Id);
throw new BillingException();
}
var options = new CustomerCreateOptions
{
Address = new AddressOptions
{
Country = taxInfo.BillingAddressCountry,
PostalCode = taxInfo.BillingAddressPostalCode,
Line1 = taxInfo.BillingAddressLine1,
Line2 = taxInfo.BillingAddressLine2,
City = taxInfo.BillingAddressCity,
State = taxInfo.BillingAddressState
Country = billingAddress.Country,
PostalCode = billingAddress.PostalCode,
Line1 = billingAddress.Line1,
Line2 = billingAddress.Line2,
City = billingAddress.City,
State = billingAddress.State
},
Coupon = !string.IsNullOrEmpty(provider.DiscountId) ? provider.DiscountId : null,
Description = provider.DisplayBusinessName(),
Email = provider.BillingEmail,
InvoiceSettings = new CustomerInvoiceSettingsOptions
@@ -520,93 +497,61 @@ public class ProviderBillingService(
}
]
},
Metadata = new Dictionary<string, string>
{
{ "region", globalSettings.BaseServiceUri.CloudRegion }
}
Metadata = new Dictionary<string, string> { { "region", globalSettings.BaseServiceUri.CloudRegion } },
TaxExempt = billingAddress.Country != CountryAbbreviations.UnitedStates ? TaxExempt.Reverse : TaxExempt.None
};
if (taxInfo.BillingAddressCountry is not Constants.CountryAbbreviations.UnitedStates)
if (billingAddress.TaxId != null)
{
options.TaxExempt = StripeConstants.TaxExempt.Reverse;
}
if (!string.IsNullOrEmpty(taxInfo.TaxIdNumber))
{
var taxIdType = taxService.GetStripeTaxCode(
taxInfo.BillingAddressCountry,
taxInfo.TaxIdNumber);
if (taxIdType == null)
{
logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.",
taxInfo.BillingAddressCountry,
taxInfo.TaxIdNumber);
throw new BadRequestException("billingTaxIdTypeInferenceError");
}
options.TaxIdData =
[
new CustomerTaxIdDataOptions { Type = taxIdType, Value = taxInfo.TaxIdNumber }
new CustomerTaxIdDataOptions { Type = billingAddress.TaxId.Code, Value = billingAddress.TaxId.Value }
];
if (taxIdType == StripeConstants.TaxIdType.SpanishNIF)
if (billingAddress.TaxId.Code == TaxIdType.SpanishNIF)
{
options.TaxIdData.Add(new CustomerTaxIdDataOptions
{
Type = StripeConstants.TaxIdType.EUVAT,
Value = $"ES{taxInfo.TaxIdNumber}"
Type = TaxIdType.EUVAT,
Value = $"ES{billingAddress.TaxId.Value}"
});
}
}
if (!string.IsNullOrEmpty(provider.DiscountId))
{
options.Coupon = provider.DiscountId;
}
var braintreeCustomerId = "";
if (tokenizedPaymentSource is not
{
Type: PaymentMethodType.BankAccount or PaymentMethodType.Card or PaymentMethodType.PayPal,
Token: not null and not ""
})
{
logger.LogError("Cannot create customer for provider ({ProviderID}) with invalid payment method", provider.Id);
throw new BillingException();
}
var (type, token) = tokenizedPaymentSource;
// ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault
switch (type)
switch (paymentMethod.Type)
{
case PaymentMethodType.BankAccount:
case TokenizablePaymentMethodType.BankAccount:
{
var setupIntent =
(await stripeAdapter.SetupIntentList(new SetupIntentListOptions { PaymentMethod = token }))
(await stripeAdapter.SetupIntentList(new SetupIntentListOptions
{
PaymentMethod = paymentMethod.Token
}))
.FirstOrDefault();
if (setupIntent == null)
{
logger.LogError("Cannot create customer for provider ({ProviderID}) without a setup intent for their bank account", provider.Id);
logger.LogError(
"Cannot create customer for provider ({ProviderID}) without a setup intent for their bank account",
provider.Id);
throw new BillingException();
}
await setupIntentCache.Set(provider.Id, setupIntent.Id);
break;
}
case PaymentMethodType.Card:
case TokenizablePaymentMethodType.Card:
{
options.PaymentMethod = token;
options.InvoiceSettings.DefaultPaymentMethod = token;
options.PaymentMethod = paymentMethod.Token;
options.InvoiceSettings.DefaultPaymentMethod = paymentMethod.Token;
break;
}
case PaymentMethodType.PayPal:
case TokenizablePaymentMethodType.PayPal:
{
braintreeCustomerId = await subscriberService.CreateBraintreeCustomer(provider, token);
braintreeCustomerId = await subscriberService.CreateBraintreeCustomer(provider, paymentMethod.Token);
options.Metadata[BraintreeCustomerIdKey] = braintreeCustomerId;
break;
}
@@ -616,8 +561,7 @@ public class ProviderBillingService(
{
return await stripeAdapter.CustomerCreateAsync(options);
}
catch (StripeException stripeException) when (stripeException.StripeError?.Code ==
StripeConstants.ErrorCodes.TaxIdInvalid)
catch (StripeException stripeException) when (stripeException.StripeError?.Code == ErrorCodes.TaxIdInvalid)
{
await Revert();
throw new BadRequestException(
@@ -632,9 +576,9 @@ public class ProviderBillingService(
async Task Revert()
{
// ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault
switch (tokenizedPaymentSource.Type)
switch (paymentMethod.Type)
{
case PaymentMethodType.BankAccount:
case TokenizablePaymentMethodType.BankAccount:
{
var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(provider.Id);
await stripeAdapter.SetupIntentCancel(setupIntentId,
@@ -642,7 +586,7 @@ public class ProviderBillingService(
await setupIntentCache.RemoveSetupIntentForSubscriber(provider.Id);
break;
}
case PaymentMethodType.PayPal when !string.IsNullOrEmpty(braintreeCustomerId):
case TokenizablePaymentMethodType.PayPal when !string.IsNullOrEmpty(braintreeCustomerId):
{
await braintreeGateway.Customer.DeleteAsync(braintreeCustomerId);
break;
@@ -661,9 +605,10 @@ public class ProviderBillingService(
var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id);
if (providerPlans == null || providerPlans.Count == 0)
if (providerPlans.Count == 0)
{
logger.LogError("Cannot start subscription for provider ({ProviderID}) that has no configured plans", provider.Id);
logger.LogError("Cannot start subscription for provider ({ProviderID}) that has no configured plans",
provider.Id);
throw new BillingException();
}
@@ -676,7 +621,9 @@ public class ProviderBillingService(
if (!providerPlan.IsConfigured())
{
logger.LogError("Cannot start subscription for provider ({ProviderID}) that has no configured {ProviderName} plan", provider.Id, plan.Name);
logger.LogError(
"Cannot start subscription for provider ({ProviderID}) that has no configured {ProviderName} plan",
provider.Id, plan.Name);
throw new BillingException();
}
@@ -692,16 +639,14 @@ public class ProviderBillingService(
var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(provider.Id);
var setupIntent = !string.IsNullOrEmpty(setupIntentId)
? await stripeAdapter.SetupIntentGet(setupIntentId, new SetupIntentGetOptions
{
Expand = ["payment_method"]
})
? await stripeAdapter.SetupIntentGet(setupIntentId,
new SetupIntentGetOptions { Expand = ["payment_method"] })
: null;
var usePaymentMethod =
!string.IsNullOrEmpty(customer.InvoiceSettings?.DefaultPaymentMethodId) ||
(customer.Metadata?.ContainsKey(BraintreeCustomerIdKey) == true) ||
(setupIntent?.IsUnverifiedBankAccount() == true);
customer.Metadata?.ContainsKey(BraintreeCustomerIdKey) == true ||
setupIntent?.IsUnverifiedBankAccount() == true;
int? trialPeriodDays = provider.Type switch
{
@@ -712,30 +657,28 @@ public class ProviderBillingService(
var subscriptionCreateOptions = new SubscriptionCreateOptions
{
CollectionMethod = usePaymentMethod ?
StripeConstants.CollectionMethod.ChargeAutomatically : StripeConstants.CollectionMethod.SendInvoice,
CollectionMethod =
usePaymentMethod
? CollectionMethod.ChargeAutomatically
: CollectionMethod.SendInvoice,
Customer = customer.Id,
DaysUntilDue = usePaymentMethod ? null : 30,
Items = subscriptionItemOptionsList,
Metadata = new Dictionary<string, string>
{
{ "providerId", provider.Id.ToString() }
},
Metadata = new Dictionary<string, string> { { "providerId", provider.Id.ToString() } },
OffSession = true,
ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations,
TrialPeriodDays = trialPeriodDays
ProrationBehavior = ProrationBehavior.CreateProrations,
TrialPeriodDays = trialPeriodDays,
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
};
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
try
{
var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
if (subscription is
{
Status: StripeConstants.SubscriptionStatus.Active or StripeConstants.SubscriptionStatus.Trialing
Status: SubscriptionStatus.Active or SubscriptionStatus.Trialing
})
{
return subscription;
@@ -749,9 +692,11 @@ public class ProviderBillingService(
throw new BillingException();
}
catch (StripeException stripeException) when (stripeException.StripeError?.Code == StripeConstants.ErrorCodes.CustomerTaxLocationInvalid)
catch (StripeException stripeException) when (stripeException.StripeError?.Code ==
ErrorCodes.CustomerTaxLocationInvalid)
{
throw new BadRequestException("Your location wasn't recognized. Please ensure your country and postal code are valid.");
throw new BadRequestException(
"Your location wasn't recognized. Please ensure your country and postal code are valid.");
}
}
@@ -765,7 +710,7 @@ public class ProviderBillingService(
subscriberService.UpdateTaxInformation(provider, taxInformation));
await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId,
new SubscriptionUpdateOptions { CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically });
new SubscriptionUpdateOptions { CollectionMethod = CollectionMethod.ChargeAutomatically });
}
public async Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command)
@@ -865,13 +810,9 @@ public class ProviderBillingService(
await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions
{
Items = [
new SubscriptionItemOptions
{
Id = item.Id,
Price = priceId,
Quantity = newlySubscribedSeats
}
Items =
[
new SubscriptionItemOptions { Id = item.Id, Price = priceId, Quantity = newlySubscribedSeats }
]
});
@@ -894,7 +835,8 @@ public class ProviderBillingService(
var plan = await pricingClient.GetPlanOrThrow(planType);
return providerOrganizations
.Where(providerOrganization => providerOrganization.Plan == plan.Name && providerOrganization.Status == OrganizationStatusType.Managed)
.Where(providerOrganization => providerOrganization.Plan == plan.Name &&
providerOrganization.Status == OrganizationStatusType.Managed)
.Sum(providerOrganization => providerOrganization.Seats ?? 0);
}

View File

@@ -1,10 +1,13 @@
// FIXME: Update this file to be null safe and then delete the line below
#nullable disable
using Bit.Core.Context;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Bit.Core.SecretsManager.Commands.ServiceAccounts.Interfaces;
using Bit.Core.SecretsManager.Entities;
using Bit.Core.SecretsManager.Repositories;
using Bit.Core.Services;
namespace Bit.Commercial.Core.SecretsManager.Commands.ServiceAccounts;
@@ -13,15 +16,21 @@ public class CreateServiceAccountCommand : ICreateServiceAccountCommand
private readonly IAccessPolicyRepository _accessPolicyRepository;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IServiceAccountRepository _serviceAccountRepository;
private readonly IEventService _eventService;
private readonly ICurrentContext _currentContext;
public CreateServiceAccountCommand(
IAccessPolicyRepository accessPolicyRepository,
IOrganizationUserRepository organizationUserRepository,
IServiceAccountRepository serviceAccountRepository)
IServiceAccountRepository serviceAccountRepository,
IEventService eventService,
ICurrentContext currentContext)
{
_accessPolicyRepository = accessPolicyRepository;
_organizationUserRepository = organizationUserRepository;
_serviceAccountRepository = serviceAccountRepository;
_eventService = eventService;
_currentContext = currentContext;
}
public async Task<ServiceAccount> CreateAsync(ServiceAccount serviceAccount, Guid userId)
@@ -38,6 +47,7 @@ public class CreateServiceAccountCommand : ICreateServiceAccountCommand
Write = true,
};
await _accessPolicyRepository.CreateManyAsync(new List<BaseAccessPolicy> { accessPolicy });
await _eventService.LogServiceAccountPeopleEventAsync(user.Id, accessPolicy, EventType.ServiceAccount_UserAdded, _currentContext.IdentityClientType);
return createdServiceAccount;
}
}

View File

@@ -9,7 +9,7 @@ using Bit.Core.AdminConsole.Models.Data.Provider;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Providers.Services;
using Bit.Core.Context;
@@ -41,7 +41,7 @@ public class ProviderServiceTests
public async Task CompleteSetupAsync_UserIdIsInvalid_Throws(SutProvider<ProviderService> sutProvider)
{
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.CompleteSetupAsync(default, default, default, default, null));
() => sutProvider.Sut.CompleteSetupAsync(default, default, default, default, null, null));
Assert.Contains("Invalid owner.", exception.Message);
}
@@ -53,83 +53,12 @@ public class ProviderServiceTests
userService.GetUserByIdAsync(user.Id).Returns(user);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.CompleteSetupAsync(provider, user.Id, default, default, null));
() => sutProvider.Sut.CompleteSetupAsync(provider, user.Id, default, default, null, null));
Assert.Contains("Invalid token.", exception.Message);
}
[Theory, BitAutoData]
public async Task CompleteSetupAsync_InvalidTaxInfo_ThrowsBadRequestException(
User user,
Provider provider,
string key,
TaxInfo taxInfo,
TokenizedPaymentSource tokenizedPaymentSource,
[ProviderUser] ProviderUser providerUser,
SutProvider<ProviderService> sutProvider)
{
providerUser.ProviderId = provider.Id;
providerUser.UserId = user.Id;
var userService = sutProvider.GetDependency<IUserService>();
userService.GetUserByIdAsync(user.Id).Returns(user);
var providerUserRepository = sutProvider.GetDependency<IProviderUserRepository>();
providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id).Returns(providerUser);
var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName");
var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector");
sutProvider.GetDependency<IDataProtectionProvider>().CreateProtector("ProviderServiceDataProtector")
.Returns(protector);
sutProvider.Create();
var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}");
taxInfo.BillingAddressCountry = null;
var exception = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, taxInfo, tokenizedPaymentSource));
Assert.Equal("Both address and postal code are required to set up your provider.", exception.Message);
}
[Theory, BitAutoData]
public async Task CompleteSetupAsync_InvalidTokenizedPaymentSource_ThrowsBadRequestException(
User user,
Provider provider,
string key,
TaxInfo taxInfo,
TokenizedPaymentSource tokenizedPaymentSource,
[ProviderUser] ProviderUser providerUser,
SutProvider<ProviderService> sutProvider)
{
providerUser.ProviderId = provider.Id;
providerUser.UserId = user.Id;
var userService = sutProvider.GetDependency<IUserService>();
userService.GetUserByIdAsync(user.Id).Returns(user);
var providerUserRepository = sutProvider.GetDependency<IProviderUserRepository>();
providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id).Returns(providerUser);
var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName");
var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector");
sutProvider.GetDependency<IDataProtectionProvider>().CreateProtector("ProviderServiceDataProtector")
.Returns(protector);
sutProvider.Create();
var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}");
tokenizedPaymentSource = tokenizedPaymentSource with { Type = PaymentMethodType.BitPay };
var exception = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, taxInfo, tokenizedPaymentSource));
Assert.Equal("A payment method is required to set up your provider.", exception.Message);
}
[Theory, BitAutoData]
public async Task CompleteSetupAsync_Success(User user, Provider provider, string key, TaxInfo taxInfo, TokenizedPaymentSource tokenizedPaymentSource,
public async Task CompleteSetupAsync_Success(User user, Provider provider, string key, TokenizedPaymentMethod tokenizedPaymentMethod, BillingAddress billingAddress,
[ProviderUser] ProviderUser providerUser,
SutProvider<ProviderService> sutProvider)
{
@@ -149,7 +78,7 @@ public class ProviderServiceTests
var providerBillingService = sutProvider.GetDependency<IProviderBillingService>();
var customer = new Customer { Id = "customer_id" };
providerBillingService.SetupCustomer(provider, taxInfo, tokenizedPaymentSource).Returns(customer);
providerBillingService.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress).Returns(customer);
var subscription = new Subscription { Id = "subscription_id" };
providerBillingService.SetupSubscription(provider).Returns(subscription);
@@ -158,7 +87,7 @@ public class ProviderServiceTests
var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}");
await sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, taxInfo, tokenizedPaymentSource);
await sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, tokenizedPaymentMethod, billingAddress);
await sutProvider.GetDependency<IProviderRepository>().Received().UpsertAsync(Arg.Is<Provider>(
p =>

View File

@@ -1,5 +1,4 @@
using System.Globalization;
using System.Net;
using Bit.Commercial.Core.Billing.Providers.Models;
using Bit.Commercial.Core.Billing.Providers.Services;
using Bit.Core.AdminConsole.Entities;
@@ -10,18 +9,16 @@ using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Providers.Entities;
using Bit.Core.Billing.Providers.Models;
using Bit.Core.Billing.Providers.Repositories;
using Bit.Core.Billing.Providers.Services;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Models.Business;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
@@ -895,118 +892,53 @@ public class ProviderBillingServiceTests
#region SetupCustomer
[Theory, BitAutoData]
public async Task SetupCustomer_MissingCountry_ContactSupport(
public async Task SetupCustomer_NullPaymentMethod_ThrowsNullReferenceException(
SutProvider<ProviderBillingService> sutProvider,
Provider provider,
TaxInfo taxInfo,
TokenizedPaymentSource tokenizedPaymentSource)
BillingAddress billingAddress)
{
taxInfo.BillingAddressCountry = null;
await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource));
await sutProvider.GetDependency<IStripeAdapter>()
.DidNotReceiveWithAnyArgs()
.CustomerGetAsync(Arg.Any<string>(), Arg.Any<CustomerGetOptions>());
}
[Theory, BitAutoData]
public async Task SetupCustomer_MissingPostalCode_ContactSupport(
SutProvider<ProviderBillingService> sutProvider,
Provider provider,
TaxInfo taxInfo,
TokenizedPaymentSource tokenizedPaymentSource)
{
taxInfo.BillingAddressCountry = null;
await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource));
await sutProvider.GetDependency<IStripeAdapter>()
.DidNotReceiveWithAnyArgs()
.CustomerGetAsync(Arg.Any<string>(), Arg.Any<CustomerGetOptions>());
}
[Theory, BitAutoData]
public async Task SetupCustomer_NullPaymentSource_ThrowsArgumentNullException(
SutProvider<ProviderBillingService> sutProvider,
Provider provider,
TaxInfo taxInfo)
{
await Assert.ThrowsAsync<ArgumentNullException>(() =>
sutProvider.Sut.SetupCustomer(provider, taxInfo, null));
}
[Theory, BitAutoData]
public async Task SetupCustomer_InvalidRequiredPaymentMethod_ThrowsBillingException(
SutProvider<ProviderBillingService> sutProvider,
Provider provider,
TaxInfo taxInfo,
TokenizedPaymentSource tokenizedPaymentSource)
{
provider.Name = "MSP";
sutProvider.GetDependency<ITaxService>()
.GetStripeTaxCode(Arg.Is<string>(
p => p == taxInfo.BillingAddressCountry),
Arg.Is<string>(p => p == taxInfo.TaxIdNumber))
.Returns(taxInfo.TaxIdType);
taxInfo.BillingAddressCountry = "AD";
tokenizedPaymentSource = tokenizedPaymentSource with { Type = PaymentMethodType.BitPay };
await ThrowsBillingExceptionAsync(() =>
sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource));
await Assert.ThrowsAsync<NullReferenceException>(() =>
sutProvider.Sut.SetupCustomer(provider, null, billingAddress));
}
[Theory, BitAutoData]
public async Task SetupCustomer_WithBankAccount_Error_Reverts(
SutProvider<ProviderBillingService> sutProvider,
Provider provider,
TaxInfo taxInfo)
BillingAddress billingAddress)
{
provider.Name = "MSP";
sutProvider.GetDependency<ITaxService>()
.GetStripeTaxCode(Arg.Is<string>(
p => p == taxInfo.BillingAddressCountry),
Arg.Is<string>(p => p == taxInfo.TaxIdNumber))
.Returns(taxInfo.TaxIdType);
taxInfo.BillingAddressCountry = "AD";
billingAddress.Country = "AD";
billingAddress.TaxId = new TaxID("es_nif", "12345678Z");
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.BankAccount, "token");
var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.BankAccount, Token = "token" };
stripeAdapter.SetupIntentList(Arg.Is<SetupIntentListOptions>(options =>
options.PaymentMethod == tokenizedPaymentSource.Token)).Returns([
options.PaymentMethod == tokenizedPaymentMethod.Token)).Returns([
new SetupIntent { Id = "setup_intent_id" }
]);
stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o =>
o.Address.Country == taxInfo.BillingAddressCountry &&
o.Address.PostalCode == taxInfo.BillingAddressPostalCode &&
o.Address.Line1 == taxInfo.BillingAddressLine1 &&
o.Address.Line2 == taxInfo.BillingAddressLine2 &&
o.Address.City == taxInfo.BillingAddressCity &&
o.Address.State == taxInfo.BillingAddressState &&
o.Description == WebUtility.HtmlDecode(provider.BusinessName) &&
o.Address.Country == billingAddress.Country &&
o.Address.PostalCode == billingAddress.PostalCode &&
o.Address.Line1 == billingAddress.Line1 &&
o.Address.Line2 == billingAddress.Line2 &&
o.Address.City == billingAddress.City &&
o.Address.State == billingAddress.State &&
o.Description == provider.DisplayBusinessName() &&
o.Email == provider.BillingEmail &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() &&
o.Metadata["region"] == "" &&
o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType &&
o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber))
o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code &&
o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value))
.Throws<StripeException>();
sutProvider.GetDependency<ISetupIntentCache>().GetSetupIntentIdForSubscriber(provider.Id).Returns("setup_intent_id");
await Assert.ThrowsAsync<StripeException>(() =>
sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource));
sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress));
await sutProvider.GetDependency<ISetupIntentCache>().Received(1).Set(provider.Id, "setup_intent_id");
@@ -1020,45 +952,37 @@ public class ProviderBillingServiceTests
public async Task SetupCustomer_WithPayPal_Error_Reverts(
SutProvider<ProviderBillingService> sutProvider,
Provider provider,
TaxInfo taxInfo)
BillingAddress billingAddress)
{
provider.Name = "MSP";
sutProvider.GetDependency<ITaxService>()
.GetStripeTaxCode(Arg.Is<string>(
p => p == taxInfo.BillingAddressCountry),
Arg.Is<string>(p => p == taxInfo.TaxIdNumber))
.Returns(taxInfo.TaxIdType);
taxInfo.BillingAddressCountry = "AD";
billingAddress.Country = "AD";
billingAddress.TaxId = new TaxID("es_nif", "12345678Z");
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.PayPal, Token = "token" };
var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.PayPal, "token");
sutProvider.GetDependency<ISubscriberService>().CreateBraintreeCustomer(provider, tokenizedPaymentSource.Token)
sutProvider.GetDependency<ISubscriberService>().CreateBraintreeCustomer(provider, tokenizedPaymentMethod.Token)
.Returns("braintree_customer_id");
stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o =>
o.Address.Country == taxInfo.BillingAddressCountry &&
o.Address.PostalCode == taxInfo.BillingAddressPostalCode &&
o.Address.Line1 == taxInfo.BillingAddressLine1 &&
o.Address.Line2 == taxInfo.BillingAddressLine2 &&
o.Address.City == taxInfo.BillingAddressCity &&
o.Address.State == taxInfo.BillingAddressState &&
o.Description == WebUtility.HtmlDecode(provider.BusinessName) &&
o.Address.Country == billingAddress.Country &&
o.Address.PostalCode == billingAddress.PostalCode &&
o.Address.Line1 == billingAddress.Line1 &&
o.Address.Line2 == billingAddress.Line2 &&
o.Address.City == billingAddress.City &&
o.Address.State == billingAddress.State &&
o.Description == provider.DisplayBusinessName() &&
o.Email == provider.BillingEmail &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() &&
o.Metadata["region"] == "" &&
o.Metadata["btCustomerId"] == "braintree_customer_id" &&
o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType &&
o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber))
o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code &&
o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value))
.Throws<StripeException>();
await Assert.ThrowsAsync<StripeException>(() =>
sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource));
sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress));
await sutProvider.GetDependency<IBraintreeGateway>().Customer.Received(1).DeleteAsync("braintree_customer_id");
}
@@ -1067,17 +991,11 @@ public class ProviderBillingServiceTests
public async Task SetupCustomer_WithBankAccount_Success(
SutProvider<ProviderBillingService> sutProvider,
Provider provider,
TaxInfo taxInfo)
BillingAddress billingAddress)
{
provider.Name = "MSP";
sutProvider.GetDependency<ITaxService>()
.GetStripeTaxCode(Arg.Is<string>(
p => p == taxInfo.BillingAddressCountry),
Arg.Is<string>(p => p == taxInfo.TaxIdNumber))
.Returns(taxInfo.TaxIdType);
taxInfo.BillingAddressCountry = "AD";
billingAddress.Country = "AD";
billingAddress.TaxId = new TaxID("es_nif", "12345678Z");
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
@@ -1087,31 +1005,30 @@ public class ProviderBillingServiceTests
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
};
var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.BankAccount, "token");
var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.BankAccount, Token = "token" };
stripeAdapter.SetupIntentList(Arg.Is<SetupIntentListOptions>(options =>
options.PaymentMethod == tokenizedPaymentSource.Token)).Returns([
options.PaymentMethod == tokenizedPaymentMethod.Token)).Returns([
new SetupIntent { Id = "setup_intent_id" }
]);
stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o =>
o.Address.Country == taxInfo.BillingAddressCountry &&
o.Address.PostalCode == taxInfo.BillingAddressPostalCode &&
o.Address.Line1 == taxInfo.BillingAddressLine1 &&
o.Address.Line2 == taxInfo.BillingAddressLine2 &&
o.Address.City == taxInfo.BillingAddressCity &&
o.Address.State == taxInfo.BillingAddressState &&
o.Description == WebUtility.HtmlDecode(provider.BusinessName) &&
o.Address.Country == billingAddress.Country &&
o.Address.PostalCode == billingAddress.PostalCode &&
o.Address.Line1 == billingAddress.Line1 &&
o.Address.Line2 == billingAddress.Line2 &&
o.Address.City == billingAddress.City &&
o.Address.State == billingAddress.State &&
o.Description == provider.DisplayBusinessName() &&
o.Email == provider.BillingEmail &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() &&
o.Metadata["region"] == "" &&
o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType &&
o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber))
o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code &&
o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value))
.Returns(expected);
var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource);
var actual = await sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress);
Assert.Equivalent(expected, actual);
@@ -1122,17 +1039,11 @@ public class ProviderBillingServiceTests
public async Task SetupCustomer_WithPayPal_Success(
SutProvider<ProviderBillingService> sutProvider,
Provider provider,
TaxInfo taxInfo)
BillingAddress billingAddress)
{
provider.Name = "MSP";
sutProvider.GetDependency<ITaxService>()
.GetStripeTaxCode(Arg.Is<string>(
p => p == taxInfo.BillingAddressCountry),
Arg.Is<string>(p => p == taxInfo.TaxIdNumber))
.Returns(taxInfo.TaxIdType);
taxInfo.BillingAddressCountry = "AD";
billingAddress.Country = "AD";
billingAddress.TaxId = new TaxID("es_nif", "12345678Z");
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
@@ -1142,30 +1053,29 @@ public class ProviderBillingServiceTests
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
};
var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.PayPal, "token");
var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.PayPal, Token = "token" };
sutProvider.GetDependency<ISubscriberService>().CreateBraintreeCustomer(provider, tokenizedPaymentSource.Token)
sutProvider.GetDependency<ISubscriberService>().CreateBraintreeCustomer(provider, tokenizedPaymentMethod.Token)
.Returns("braintree_customer_id");
stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o =>
o.Address.Country == taxInfo.BillingAddressCountry &&
o.Address.PostalCode == taxInfo.BillingAddressPostalCode &&
o.Address.Line1 == taxInfo.BillingAddressLine1 &&
o.Address.Line2 == taxInfo.BillingAddressLine2 &&
o.Address.City == taxInfo.BillingAddressCity &&
o.Address.State == taxInfo.BillingAddressState &&
o.Description == WebUtility.HtmlDecode(provider.BusinessName) &&
o.Address.Country == billingAddress.Country &&
o.Address.PostalCode == billingAddress.PostalCode &&
o.Address.Line1 == billingAddress.Line1 &&
o.Address.Line2 == billingAddress.Line2 &&
o.Address.City == billingAddress.City &&
o.Address.State == billingAddress.State &&
o.Description == provider.DisplayBusinessName() &&
o.Email == provider.BillingEmail &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() &&
o.Metadata["region"] == "" &&
o.Metadata["btCustomerId"] == "braintree_customer_id" &&
o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType &&
o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber))
o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code &&
o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value))
.Returns(expected);
var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource);
var actual = await sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress);
Assert.Equivalent(expected, actual);
}
@@ -1174,17 +1084,11 @@ public class ProviderBillingServiceTests
public async Task SetupCustomer_WithCard_Success(
SutProvider<ProviderBillingService> sutProvider,
Provider provider,
TaxInfo taxInfo)
BillingAddress billingAddress)
{
provider.Name = "MSP";
sutProvider.GetDependency<ITaxService>()
.GetStripeTaxCode(Arg.Is<string>(
p => p == taxInfo.BillingAddressCountry),
Arg.Is<string>(p => p == taxInfo.TaxIdNumber))
.Returns(taxInfo.TaxIdType);
taxInfo.BillingAddressCountry = "AD";
billingAddress.Country = "AD";
billingAddress.TaxId = new TaxID("es_nif", "12345678Z");
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
@@ -1194,28 +1098,26 @@ public class ProviderBillingServiceTests
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
};
var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "token");
var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = "token" };
stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o =>
o.Address.Country == taxInfo.BillingAddressCountry &&
o.Address.PostalCode == taxInfo.BillingAddressPostalCode &&
o.Address.Line1 == taxInfo.BillingAddressLine1 &&
o.Address.Line2 == taxInfo.BillingAddressLine2 &&
o.Address.City == taxInfo.BillingAddressCity &&
o.Address.State == taxInfo.BillingAddressState &&
o.Description == WebUtility.HtmlDecode(provider.BusinessName) &&
o.Address.Country == billingAddress.Country &&
o.Address.PostalCode == billingAddress.PostalCode &&
o.Address.Line1 == billingAddress.Line1 &&
o.Address.Line2 == billingAddress.Line2 &&
o.Address.City == billingAddress.City &&
o.Address.State == billingAddress.State &&
o.Description == provider.DisplayBusinessName() &&
o.Email == provider.BillingEmail &&
o.PaymentMethod == tokenizedPaymentSource.Token &&
o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentSource.Token &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" &&
o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentMethod.Token &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() &&
o.Metadata["region"] == "" &&
o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType &&
o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber))
o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code &&
o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value))
.Returns(expected);
var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource);
var actual = await sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress);
Assert.Equivalent(expected, actual);
}
@@ -1224,17 +1126,11 @@ public class ProviderBillingServiceTests
public async Task SetupCustomer_WithCard_ReverseCharge_Success(
SutProvider<ProviderBillingService> sutProvider,
Provider provider,
TaxInfo taxInfo)
BillingAddress billingAddress)
{
provider.Name = "MSP";
sutProvider.GetDependency<ITaxService>()
.GetStripeTaxCode(Arg.Is<string>(
p => p == taxInfo.BillingAddressCountry),
Arg.Is<string>(p => p == taxInfo.TaxIdNumber))
.Returns(taxInfo.TaxIdType);
taxInfo.BillingAddressCountry = "AD";
billingAddress.Country = "FR"; // Non-US country to trigger reverse charge
billingAddress.TaxId = new TaxID("fr_siren", "123456789");
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
@@ -1244,55 +1140,51 @@ public class ProviderBillingServiceTests
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
};
var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "token");
var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = "token" };
stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o =>
o.Address.Country == taxInfo.BillingAddressCountry &&
o.Address.PostalCode == taxInfo.BillingAddressPostalCode &&
o.Address.Line1 == taxInfo.BillingAddressLine1 &&
o.Address.Line2 == taxInfo.BillingAddressLine2 &&
o.Address.City == taxInfo.BillingAddressCity &&
o.Address.State == taxInfo.BillingAddressState &&
o.Description == WebUtility.HtmlDecode(provider.BusinessName) &&
o.Address.Country == billingAddress.Country &&
o.Address.PostalCode == billingAddress.PostalCode &&
o.Address.Line1 == billingAddress.Line1 &&
o.Address.Line2 == billingAddress.Line2 &&
o.Address.City == billingAddress.City &&
o.Address.State == billingAddress.State &&
o.Description == provider.DisplayBusinessName() &&
o.Email == provider.BillingEmail &&
o.PaymentMethod == tokenizedPaymentSource.Token &&
o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentSource.Token &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" &&
o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentMethod.Token &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() &&
o.Metadata["region"] == "" &&
o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType &&
o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber &&
o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code &&
o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value &&
o.TaxExempt == StripeConstants.TaxExempt.Reverse))
.Returns(expected);
var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource);
var actual = await sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress);
Assert.Equivalent(expected, actual);
}
[Theory, BitAutoData]
public async Task SetupCustomer_Throws_BadRequestException_WhenTaxIdIsInvalid(
public async Task SetupCustomer_WithInvalidTaxId_ThrowsBadRequestException(
SutProvider<ProviderBillingService> sutProvider,
Provider provider,
TaxInfo taxInfo,
TokenizedPaymentSource tokenizedPaymentSource)
BillingAddress billingAddress)
{
provider.Name = "MSP";
billingAddress.Country = "AD";
billingAddress.TaxId = new TaxID("es_nif", "invalid_tax_id");
taxInfo.BillingAddressCountry = "AD";
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = "token" };
sutProvider.GetDependency<ITaxService>()
.GetStripeTaxCode(Arg.Is<string>(
p => p == taxInfo.BillingAddressCountry),
Arg.Is<string>(p => p == taxInfo.TaxIdNumber))
.Returns((string)null);
stripeAdapter.CustomerCreateAsync(Arg.Any<CustomerCreateOptions>())
.Throws(new StripeException("Invalid tax ID") { StripeError = new StripeError { Code = "tax_id_invalid" } });
var actual = await Assert.ThrowsAsync<BadRequestException>(async () =>
await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource));
await sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress));
Assert.IsType<BadRequestException>(actual);
Assert.Equal("billingTaxIdTypeInferenceError", actual.Message);
Assert.Equal("Your tax ID wasn't recognized for your selected country. Please ensure your country and tax ID are valid.", actual.Message);
}
#endregion

View File

@@ -53,6 +53,7 @@ services:
- ./.data/postgres/log:/var/log/postgresql
profiles:
- postgres
- ef
mysql:
image: mysql:8.0
@@ -69,6 +70,7 @@ services:
- mysql_dev_data:/var/lib/mysql
profiles:
- mysql
- ef
mariadb:
image: mariadb:10
@@ -76,13 +78,13 @@ services:
- 4306:3306
environment:
MARIADB_USER: maria
MARIADB_PASSWORD: ${MARIADB_ROOT_PASSWORD}
MARIADB_DATABASE: vault_dev
MARIADB_RANDOM_ROOT_PASSWORD: "true"
volumes:
- mariadb_dev_data:/var/lib/mysql
profiles:
- mariadb
- ef
idp:
image: kenchan0130/simplesamlphp:1.19.8
@@ -99,7 +101,7 @@ services:
- idp
rabbitmq:
image: rabbitmq:4.1.0-management
image: rabbitmq:4.1.3-management
container_name: rabbitmq
ports:
- "5672:5672"
@@ -153,5 +155,6 @@ volumes:
mssql_dev_data:
postgres_dev_data:
mysql_dev_data:
mariadb_dev_data:
rabbitmq_data:
redis_data:

View File

@@ -70,7 +70,7 @@ Foreach ($item in @(
@($mysql, "MySQL", "MySqlMigrations", "mySql", 2),
# MariaDB shares the MySQL connection string in the server config so they are mutually exclusive in that context.
# However they can still be run independently for integration tests.
@($mariadb, "MariaDB", "MySqlMigrations", "mySql", 3)
@($mariadb, "MariaDB", "MySqlMigrations", "mySql", 4)
)) {
if (!$item[0] -and !$all) {
continue

View File

@@ -30,6 +30,8 @@ public class EventsController : Controller
private readonly ICurrentContext _currentContext;
private readonly ISecretRepository _secretRepository;
private readonly IProjectRepository _projectRepository;
private readonly IServiceAccountRepository _serviceAccountRepository;
public EventsController(
IUserService userService,
@@ -39,7 +41,8 @@ public class EventsController : Controller
IEventRepository eventRepository,
ICurrentContext currentContext,
ISecretRepository secretRepository,
IProjectRepository projectRepository)
IProjectRepository projectRepository,
IServiceAccountRepository serviceAccountRepository)
{
_userService = userService;
_cipherRepository = cipherRepository;
@@ -49,6 +52,7 @@ public class EventsController : Controller
_currentContext = currentContext;
_secretRepository = secretRepository;
_projectRepository = projectRepository;
_serviceAccountRepository = serviceAccountRepository;
}
[HttpGet("")]
@@ -184,6 +188,57 @@ public class EventsController : Controller
return new ListResponseModel<EventResponseModel>(responses, result.ContinuationToken);
}
[HttpGet("~/organization/{orgId}/service-account/{id}/events")]
public async Task<ListResponseModel<EventResponseModel>> GetServiceAccounts(
Guid orgId,
Guid id,
[FromQuery] DateTime? start = null,
[FromQuery] DateTime? end = null,
[FromQuery] string continuationToken = null)
{
if (id == Guid.Empty || orgId == Guid.Empty)
{
throw new NotFoundException();
}
var serviceAccount = await GetServiceAccount(id, orgId);
var org = _currentContext.GetOrganization(orgId);
if (org == null || !await _currentContext.AccessEventLogs(org.Id))
{
throw new NotFoundException();
}
var (fromDate, toDate) = ApiHelpers.GetDateRange(start, end);
var result = await _eventRepository.GetManyByOrganizationServiceAccountAsync(
serviceAccount.OrganizationId,
serviceAccount.Id,
fromDate,
toDate,
new PageOptions { ContinuationToken = continuationToken });
var responses = result.Data.Select(e => new EventResponseModel(e));
return new ListResponseModel<EventResponseModel>(responses, result.ContinuationToken);
}
[ApiExplorerSettings(IgnoreApi = true)]
private async Task<ServiceAccount> GetServiceAccount(Guid serviceAccountId, Guid orgId)
{
var serviceAccount = await _serviceAccountRepository.GetByIdAsync(serviceAccountId);
if (serviceAccount != null)
{
return serviceAccount;
}
var fallbackServiceAccount = new ServiceAccount
{
Id = serviceAccountId,
OrganizationId = orgId
};
return fallbackServiceAccount;
}
[HttpGet("~/organizations/{orgId}/users/{id}/events")]
public async Task<ListResponseModel<EventResponseModel>> GetOrganizationUser(string orgId, string id,
[FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null)

View File

@@ -7,7 +7,6 @@ using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Context;
using Bit.Core.Exceptions;
using Bit.Core.Models.Business;
using Bit.Core.Services;
using Bit.Core.Settings;
using Microsoft.AspNetCore.Authorization;
@@ -93,22 +92,12 @@ public class ProvidersController : Controller
var userId = _userService.GetProperUserId(User).Value;
var taxInfo = new TaxInfo
{
BillingAddressCountry = model.TaxInfo.Country,
BillingAddressPostalCode = model.TaxInfo.PostalCode,
TaxIdNumber = model.TaxInfo.TaxId,
BillingAddressLine1 = model.TaxInfo.Line1,
BillingAddressLine2 = model.TaxInfo.Line2,
BillingAddressCity = model.TaxInfo.City,
BillingAddressState = model.TaxInfo.State
};
var tokenizedPaymentSource = model.PaymentSource?.ToDomain();
var paymentMethod = model.PaymentMethod.ToDomain();
var billingAddress = model.BillingAddress.ToDomain();
var response =
await _providerService.CompleteSetupAsync(model.ToProvider(provider), userId, model.Token, model.Key,
taxInfo, tokenizedPaymentSource);
paymentMethod, billingAddress);
return new ProviderResponseModel(response);
}

View File

@@ -1,7 +1,4 @@
// FIXME: Update this file to be null safe and then delete the line below
#nullable disable
using System.Text.Json;
using System.Text.Json;
using Bit.Api.AdminConsole.Models.Response.Organizations;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
@@ -18,25 +15,58 @@ using Microsoft.AspNetCore.Mvc;
namespace Bit.Api.AdminConsole.Controllers;
[RequireFeature(FeatureFlagKeys.EventBasedOrganizationIntegrations)]
[Route("organizations/{organizationId:guid}/integrations/slack")]
[Route("organizations")]
[Authorize("Application")]
public class SlackIntegrationController(
ICurrentContext currentContext,
IOrganizationIntegrationRepository integrationRepository,
ISlackService slackService) : Controller
ISlackService slackService,
TimeProvider timeProvider) : Controller
{
[HttpGet("redirect")]
[HttpGet("{organizationId:guid}/integrations/slack/redirect")]
public async Task<IActionResult> RedirectAsync(Guid organizationId)
{
if (!await currentContext.OrganizationOwner(organizationId))
{
throw new NotFoundException();
}
string callbackUrl = Url.RouteUrl(
nameof(CreateAsync),
new { organizationId },
currentContext.HttpContext.Request.Scheme);
var redirectUrl = slackService.GetRedirectUrl(callbackUrl);
string? callbackUrl = Url.RouteUrl(
routeName: nameof(CreateAsync),
values: null,
protocol: currentContext.HttpContext.Request.Scheme,
host: currentContext.HttpContext.Request.Host.ToUriComponent()
);
if (string.IsNullOrEmpty(callbackUrl))
{
throw new BadRequestException("Unable to build callback Url");
}
var integrations = await integrationRepository.GetManyByOrganizationAsync(organizationId);
var integration = integrations.FirstOrDefault(i => i.Type == IntegrationType.Slack);
if (integration is null)
{
// No slack integration exists, create Initiated version
integration = await integrationRepository.CreateAsync(new OrganizationIntegration
{
OrganizationId = organizationId,
Type = IntegrationType.Slack,
Configuration = null,
});
}
else if (integration.Configuration is not null)
{
// A Completed (fully configured) Slack integration already exists, throw to prevent overriding
throw new BadRequestException("There already exists a Slack integration for this organization");
} // An Initiated slack integration exits, re-use it and kick off a new OAuth flow
var state = IntegrationOAuthState.FromIntegration(integration, timeProvider);
var redirectUrl = slackService.GetRedirectUrl(
callbackUrl: callbackUrl,
state: state.ToString()
);
if (string.IsNullOrEmpty(redirectUrl))
{
@@ -46,23 +76,42 @@ public class SlackIntegrationController(
return Redirect(redirectUrl);
}
[HttpGet("create", Name = nameof(CreateAsync))]
public async Task<IActionResult> CreateAsync(Guid organizationId, [FromQuery] string code)
[HttpGet("integrations/slack/create", Name = nameof(CreateAsync))]
[AllowAnonymous]
public async Task<IActionResult> CreateAsync([FromQuery] string code, [FromQuery] string state)
{
if (!await currentContext.OrganizationOwner(organizationId))
var oAuthState = IntegrationOAuthState.FromString(state: state, timeProvider: timeProvider);
if (oAuthState is null)
{
throw new NotFoundException();
}
if (string.IsNullOrEmpty(code))
// Fetch existing Initiated record
var integration = await integrationRepository.GetByIdAsync(oAuthState.IntegrationId);
if (integration is null ||
integration.Type != IntegrationType.Slack ||
integration.Configuration is not null)
{
throw new BadRequestException("Missing code from Slack.");
throw new NotFoundException();
}
string callbackUrl = Url.RouteUrl(
nameof(CreateAsync),
new { organizationId },
currentContext.HttpContext.Request.Scheme);
// Verify Organization matches hash
if (!oAuthState.ValidateOrg(integration.OrganizationId))
{
throw new NotFoundException();
}
// Fetch token from Slack and store to DB
string? callbackUrl = Url.RouteUrl(
routeName: nameof(CreateAsync),
values: null,
protocol: currentContext.HttpContext.Request.Scheme,
host: currentContext.HttpContext.Request.Host.ToUriComponent()
);
if (string.IsNullOrEmpty(callbackUrl))
{
throw new BadRequestException("Unable to build callback Url");
}
var token = await slackService.ObtainTokenViaOAuth(code, callbackUrl);
if (string.IsNullOrEmpty(token))
@@ -70,14 +119,10 @@ public class SlackIntegrationController(
throw new BadRequestException("Invalid response from Slack.");
}
var integration = await integrationRepository.CreateAsync(new OrganizationIntegration
{
OrganizationId = organizationId,
Type = IntegrationType.Slack,
Configuration = JsonSerializer.Serialize(new SlackIntegration(token)),
});
var location = $"/organizations/{organizationId}/integrations/{integration.Id}";
integration.Configuration = JsonSerializer.Serialize(new SlackIntegration(token));
await integrationRepository.UpsertAsync(integration);
var location = $"/organizations/{integration.OrganizationId}/integrations/{integration.Id}";
return Created(location, new OrganizationIntegrationResponseModel(integration));
}
}

View File

@@ -3,8 +3,7 @@
using System.ComponentModel.DataAnnotations;
using System.Text.Json.Serialization;
using Bit.Api.Billing.Models.Requests;
using Bit.Api.Models.Request;
using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Utilities;
@@ -28,8 +27,9 @@ public class ProviderSetupRequestModel
[Required]
public string Key { get; set; }
[Required]
public ExpandedTaxInfoUpdateRequestModel TaxInfo { get; set; }
public TokenizedPaymentSourceRequestBody PaymentSource { get; set; }
public MinimalTokenizedPaymentMethodRequest PaymentMethod { get; set; }
[Required]
public BillingAddressRequest BillingAddress { get; set; }
public virtual Provider ToProvider(Provider provider)
{

View File

@@ -35,6 +35,7 @@ public class EventResponseModel : ResponseModel
SecretId = ev.SecretId;
ProjectId = ev.ProjectId;
ServiceAccountId = ev.ServiceAccountId;
GrantedServiceAccountId = ev.GrantedServiceAccountId;
}
public EventType Type { get; set; }
@@ -58,4 +59,5 @@ public class EventResponseModel : ResponseModel
public Guid? SecretId { get; set; }
public Guid? ProjectId { get; set; }
public Guid? ServiceAccountId { get; set; }
public Guid? GrantedServiceAccountId { get; set; }
}

View File

@@ -2,8 +2,6 @@
using Bit.Core.Enums;
using Bit.Core.Models.Api;
#nullable enable
namespace Bit.Api.AdminConsole.Models.Response.Organizations;
public class OrganizationIntegrationResponseModel : ResponseModel
@@ -21,4 +19,29 @@ public class OrganizationIntegrationResponseModel : ResponseModel
public Guid Id { get; set; }
public IntegrationType Type { get; set; }
public string? Configuration { get; set; }
public OrganizationIntegrationStatus Status => Type switch
{
// Not yet implemented, shouldn't be present, NotApplicable
IntegrationType.CloudBillingSync => OrganizationIntegrationStatus.NotApplicable,
IntegrationType.Scim => OrganizationIntegrationStatus.NotApplicable,
// Webhook is allowed to be null. If it's present, it's Completed
IntegrationType.Webhook => OrganizationIntegrationStatus.Completed,
// If present and the configuration is null, OAuth has been initiated, and we are
// waiting on the return call
IntegrationType.Slack => string.IsNullOrWhiteSpace(Configuration)
? OrganizationIntegrationStatus.Initiated
: OrganizationIntegrationStatus.Completed,
// HEC and Datadog should only be allowed to be created non-null.
// If they are null, they are Invalid
IntegrationType.Hec => string.IsNullOrWhiteSpace(Configuration)
? OrganizationIntegrationStatus.Invalid
: OrganizationIntegrationStatus.Completed,
IntegrationType.Datadog => string.IsNullOrWhiteSpace(Configuration)
? OrganizationIntegrationStatus.Invalid
: OrganizationIntegrationStatus.Completed,
};
}

View File

@@ -121,7 +121,7 @@ public class SsoConfigurationDataRequest : IValidatableObject
new[] { nameof(IdpEntityId) });
}
if (!Uri.IsWellFormedUriString(IdpEntityId, UriKind.Absolute) && string.IsNullOrWhiteSpace(IdpSingleSignOnServiceUrl))
if (string.IsNullOrWhiteSpace(IdpSingleSignOnServiceUrl))
{
yield return new ValidationResult(i18nService.GetLocalizedHtmlString("IdpSingleSignOnServiceUrlValidationError"),
new[] { nameof(IdpSingleSignOnServiceUrl) });
@@ -139,6 +139,7 @@ public class SsoConfigurationDataRequest : IValidatableObject
new[] { nameof(IdpSingleLogoutServiceUrl) });
}
// TODO: On server, make public certificate required for SAML2 SSO: https://bitwarden.atlassian.net/browse/PM-26028
if (!string.IsNullOrWhiteSpace(IdpX509PublicCert))
{
// Validate the certificate is in a valid format

View File

@@ -1,16 +1,8 @@
#nullable enable
using System.Diagnostics;
using Bit.Api.AdminConsole.Models.Request.Organizations;
using Bit.Api.Billing.Models.Requests;
using Bit.Api.Billing.Models.Requests;
using Bit.Api.Billing.Models.Responses;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Organizations.Models;
using Bit.Core.Billing.Organizations.Services;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Providers.Services;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Context;
using Bit.Core.Repositories;
using Bit.Core.Services;
@@ -28,10 +20,8 @@ public class OrganizationBillingController(
IOrganizationBillingService organizationBillingService,
IOrganizationRepository organizationRepository,
IPaymentService paymentService,
IPricingClient pricingClient,
ISubscriberService subscriberService,
IPaymentHistoryService paymentHistoryService,
IUserService userService) : BaseBillingController
IPaymentHistoryService paymentHistoryService) : BaseBillingController
{
[HttpGet("metadata")]
public async Task<IResult> GetMetadataAsync([FromRoute] Guid organizationId)
@@ -264,71 +254,6 @@ public class OrganizationBillingController(
return TypedResults.Ok();
}
[HttpPost("restart-subscription")]
public async Task<IResult> RestartSubscriptionAsync([FromRoute] Guid organizationId,
[FromBody] OrganizationCreateRequestModel model)
{
var user = await userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
if (!await currentContext.EditPaymentMethods(organizationId))
{
return Error.Unauthorized();
}
var organization = await organizationRepository.GetByIdAsync(organizationId);
if (organization == null)
{
return Error.NotFound();
}
var existingPlan = organization.PlanType;
var organizationSignup = model.ToOrganizationSignup(user);
var sale = OrganizationSale.From(organization, organizationSignup);
var plan = await pricingClient.GetPlanOrThrow(model.PlanType);
sale.Organization.PlanType = plan.Type;
sale.Organization.Plan = plan.Name;
sale.SubscriptionSetup.SkipTrial = true;
if (existingPlan == PlanType.Free && organization.GatewaySubscriptionId is not null)
{
sale.Organization.UseTotp = plan.HasTotp;
sale.Organization.UseGroups = plan.HasGroups;
sale.Organization.UseDirectory = plan.HasDirectory;
sale.Organization.SelfHost = plan.HasSelfHost;
sale.Organization.UsersGetPremium = plan.UsersGetPremium;
sale.Organization.UseEvents = plan.HasEvents;
sale.Organization.Use2fa = plan.Has2fa;
sale.Organization.UseApi = plan.HasApi;
sale.Organization.UsePolicies = plan.HasPolicies;
sale.Organization.UseSso = plan.HasSso;
sale.Organization.UseResetPassword = plan.HasResetPassword;
sale.Organization.UseKeyConnector = plan.HasKeyConnector ? organization.UseKeyConnector : false;
sale.Organization.UseScim = plan.HasScim;
sale.Organization.UseCustomPermissions = plan.HasCustomPermissions;
sale.Organization.UseOrganizationDomains = plan.HasOrganizationDomains;
sale.Organization.MaxCollections = plan.PasswordManager.MaxCollections;
}
if (organizationSignup.PaymentMethodType == null || string.IsNullOrEmpty(organizationSignup.PaymentToken))
{
return Error.BadRequest("A payment method is required to restart the subscription.");
}
var org = await organizationRepository.GetByIdAsync(organizationId);
Debug.Assert(org is not null, "This organization has already been found via this same ID, this should be fine.");
var paymentSource = new TokenizedPaymentSource(organizationSignup.PaymentMethodType.Value, organizationSignup.PaymentToken);
var taxInformation = TaxInformation.From(organizationSignup.TaxInfo);
await organizationBillingService.Finalize(sale);
var updatedOrg = await organizationRepository.GetByIdAsync(organizationId);
if (updatedOrg != null)
{
await organizationBillingService.UpdatePaymentMethod(updatedOrg, paymentSource, taxInformation);
}
return TypedResults.Ok();
}
[HttpPost("setup-business-unit")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IResult> SetupBusinessUnitAsync(

View File

@@ -1,33 +1,73 @@
using Bit.Api.Billing.Models.Requests;
using Bit.Core.Billing.Tax.Commands;
using Bit.Api.Billing.Attributes;
using Bit.Api.Billing.Models.Requests.Tax;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Organizations.Commands;
using Bit.Core.Billing.Premium.Commands;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.Mvc.ModelBinding;
namespace Bit.Api.Billing.Controllers;
[Authorize("Application")]
[Route("tax")]
[Route("billing/tax")]
public class TaxController(
IPreviewTaxAmountCommand previewTaxAmountCommand) : BaseBillingController
IPreviewOrganizationTaxCommand previewOrganizationTaxCommand,
IPreviewPremiumTaxCommand previewPremiumTaxCommand) : BaseBillingController
{
[HttpPost("preview-amount/organization-trial")]
public async Task<IResult> PreviewTaxAmountForOrganizationTrialAsync(
[FromBody] PreviewTaxAmountForOrganizationTrialRequestBody requestBody)
[HttpPost("organizations/subscriptions/purchase")]
public async Task<IResult> PreviewOrganizationSubscriptionPurchaseTaxAsync(
[FromBody] PreviewOrganizationSubscriptionPurchaseTaxRequest request)
{
var parameters = new OrganizationTrialParameters
var (purchase, billingAddress) = request.ToDomain();
var result = await previewOrganizationTaxCommand.Run(purchase, billingAddress);
return Handle(result.Map(pair => new
{
PlanType = requestBody.PlanType,
ProductType = requestBody.ProductType,
TaxInformation = new OrganizationTrialParameters.TaxInformationDTO
{
Country = requestBody.TaxInformation.Country,
PostalCode = requestBody.TaxInformation.PostalCode,
TaxId = requestBody.TaxInformation.TaxId
}
};
pair.Tax,
pair.Total
}));
}
var result = await previewTaxAmountCommand.Run(parameters);
[HttpPost("organizations/{organizationId:guid}/subscription/plan-change")]
[InjectOrganization]
public async Task<IResult> PreviewOrganizationSubscriptionPlanChangeTaxAsync(
[BindNever] Organization organization,
[FromBody] PreviewOrganizationSubscriptionPlanChangeTaxRequest request)
{
var (planChange, billingAddress) = request.ToDomain();
var result = await previewOrganizationTaxCommand.Run(organization, planChange, billingAddress);
return Handle(result.Map(pair => new
{
pair.Tax,
pair.Total
}));
}
return Handle(result);
[HttpPut("organizations/{organizationId:guid}/subscription/update")]
[InjectOrganization]
public async Task<IResult> PreviewOrganizationSubscriptionUpdateTaxAsync(
[BindNever] Organization organization,
[FromBody] PreviewOrganizationSubscriptionUpdateTaxRequest request)
{
var update = request.ToDomain();
var result = await previewOrganizationTaxCommand.Run(organization, update);
return Handle(result.Map(pair => new
{
pair.Tax,
pair.Total
}));
}
[HttpPost("premium/subscriptions/purchase")]
public async Task<IResult> PreviewPremiumSubscriptionPurchaseTaxAsync(
[FromBody] PreviewPremiumSubscriptionPurchaseTaxRequest request)
{
var (purchase, billingAddress) = request.ToDomain();
var result = await previewPremiumTaxCommand.Run(purchase, billingAddress);
return Handle(result.Map(pair => new
{
pair.Tax,
pair.Total
}));
}
}

View File

@@ -1,5 +1,4 @@
#nullable enable
using Bit.Api.Billing.Attributes;
using Bit.Api.Billing.Attributes;
using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Api.Billing.Models.Requests.Premium;
using Bit.Core;
@@ -67,7 +66,7 @@ public class AccountBillingVNextController(
}
[HttpPost("subscription")]
[RequireFeature(FeatureFlagKeys.PM23385_UseNewPremiumFlow)]
[RequireFeature(FeatureFlagKeys.PM24996ImplementUpgradeFromFreeDialog)]
[InjectUser]
public async Task<IResult> CreateSubscriptionAsync(
[BindNever] User user,

View File

@@ -2,11 +2,14 @@
using Bit.Api.AdminConsole.Authorization.Requirements;
using Bit.Api.Billing.Attributes;
using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Api.Billing.Models.Requests.Subscriptions;
using Bit.Api.Billing.Models.Requirements;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Organizations.Queries;
using Bit.Core.Billing.Payment.Commands;
using Bit.Core.Billing.Payment.Queries;
using Bit.Core.Billing.Subscriptions.Commands;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;
@@ -24,6 +27,7 @@ public class OrganizationBillingVNextController(
IGetCreditQuery getCreditQuery,
IGetOrganizationWarningsQuery getOrganizationWarningsQuery,
IGetPaymentMethodQuery getPaymentMethodQuery,
IRestartSubscriptionCommand restartSubscriptionCommand,
IUpdateBillingAddressCommand updateBillingAddressCommand,
IUpdatePaymentMethodCommand updatePaymentMethodCommand) : BaseBillingController
{
@@ -95,6 +99,20 @@ public class OrganizationBillingVNextController(
return Handle(result);
}
[Authorize<ManageOrganizationBillingRequirement>]
[HttpPost("subscription/restart")]
[InjectOrganization]
public async Task<IResult> RestartSubscriptionAsync(
[BindNever] Organization organization,
[FromBody] RestartSubscriptionRequest request)
{
var (paymentMethod, billingAddress) = request.ToDomain();
var result = await updatePaymentMethodCommand.Run(organization, paymentMethod, null)
.AndThenAsync(_ => updateBillingAddressCommand.Run(organization, billingAddress))
.AndThenAsync(_ => restartSubscriptionCommand.Run(organization));
return Handle(result);
}
[Authorize<MemberOrProviderRequirement>]
[HttpGet("warnings")]
[InjectOrganization]

View File

@@ -21,7 +21,7 @@ public class SelfHostedAccountBillingController(
ICreatePremiumSelfHostedSubscriptionCommand createPremiumSelfHostedSubscriptionCommand) : BaseBillingController
{
[HttpPost("license")]
[RequireFeature(FeatureFlagKeys.PM23385_UseNewPremiumFlow)]
[RequireFeature(FeatureFlagKeys.PM24996ImplementUpgradeFromFreeDialog)]
[InjectUser]
public async Task<IResult> UploadLicenseAsync(
[BindNever] User user,

View File

@@ -0,0 +1,31 @@
using System.ComponentModel.DataAnnotations;
using System.Text.Json.Serialization;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Organizations.Models;
namespace Bit.Api.Billing.Models.Requests.Organizations;
public record OrganizationSubscriptionPlanChangeRequest : IValidatableObject
{
[Required]
[JsonConverter(typeof(JsonStringEnumConverter))]
public ProductTierType Tier { get; set; }
[Required]
[JsonConverter(typeof(JsonStringEnumConverter))]
public PlanCadenceType Cadence { get; set; }
public OrganizationSubscriptionPlanChange ToDomain() => new()
{
Tier = Tier,
Cadence = Cadence
};
public IEnumerable<ValidationResult> Validate(ValidationContext validationContext)
{
if (Tier == ProductTierType.Families && Cadence == PlanCadenceType.Monthly)
{
yield return new ValidationResult("Monthly billing cadence is not available for the Families plan.");
}
}
}

View File

@@ -0,0 +1,84 @@
using System.ComponentModel.DataAnnotations;
using System.Text.Json.Serialization;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Organizations.Models;
namespace Bit.Api.Billing.Models.Requests.Organizations;
public record OrganizationSubscriptionPurchaseRequest : IValidatableObject
{
[Required]
[JsonConverter(typeof(JsonStringEnumConverter))]
public ProductTierType Tier { get; set; }
[Required]
[JsonConverter(typeof(JsonStringEnumConverter))]
public PlanCadenceType Cadence { get; set; }
[Required]
public required PasswordManagerPurchaseSelections PasswordManager { get; set; }
public SecretsManagerPurchaseSelections? SecretsManager { get; set; }
public OrganizationSubscriptionPurchase ToDomain() => new()
{
Tier = Tier,
Cadence = Cadence,
PasswordManager = new OrganizationSubscriptionPurchase.PasswordManagerSelections
{
Seats = PasswordManager.Seats,
AdditionalStorage = PasswordManager.AdditionalStorage,
Sponsored = PasswordManager.Sponsored
},
SecretsManager = SecretsManager != null ? new OrganizationSubscriptionPurchase.SecretsManagerSelections
{
Seats = SecretsManager.Seats,
AdditionalServiceAccounts = SecretsManager.AdditionalServiceAccounts,
Standalone = SecretsManager.Standalone
} : null
};
public IEnumerable<ValidationResult> Validate(ValidationContext validationContext)
{
if (Tier != ProductTierType.Families)
{
yield break;
}
if (Cadence == PlanCadenceType.Monthly)
{
yield return new ValidationResult("Monthly cadence is not available on the Families plan.");
}
if (SecretsManager != null)
{
yield return new ValidationResult("Secrets Manager is not available on the Families plan.");
}
}
public record PasswordManagerPurchaseSelections
{
[Required]
[Range(1, 100000, ErrorMessage = "Password Manager seats must be between 1 and 100,000")]
public int Seats { get; set; }
[Required]
[Range(0, 99, ErrorMessage = "Additional storage must be between 0 and 99 GB")]
public int AdditionalStorage { get; set; }
public bool Sponsored { get; set; } = false;
}
public record SecretsManagerPurchaseSelections
{
[Required]
[Range(1, 100000, ErrorMessage = "Secrets Manager seats must be between 1 and 100,000")]
public int Seats { get; set; }
[Required]
[Range(0, 100000, ErrorMessage = "Additional service accounts must be between 0 and 100,000")]
public int AdditionalServiceAccounts { get; set; }
public bool Standalone { get; set; } = false;
}
}

View File

@@ -0,0 +1,48 @@
using System.ComponentModel.DataAnnotations;
using Bit.Core.Billing.Organizations.Models;
namespace Bit.Api.Billing.Models.Requests.Organizations;
public record OrganizationSubscriptionUpdateRequest
{
public PasswordManagerUpdateSelections? PasswordManager { get; set; }
public SecretsManagerUpdateSelections? SecretsManager { get; set; }
public OrganizationSubscriptionUpdate ToDomain() => new()
{
PasswordManager =
PasswordManager != null
? new OrganizationSubscriptionUpdate.PasswordManagerSelections
{
Seats = PasswordManager.Seats,
AdditionalStorage = PasswordManager.AdditionalStorage
}
: null,
SecretsManager =
SecretsManager != null
? new OrganizationSubscriptionUpdate.SecretsManagerSelections
{
Seats = SecretsManager.Seats,
AdditionalServiceAccounts = SecretsManager.AdditionalServiceAccounts
}
: null
};
public record PasswordManagerUpdateSelections
{
[Range(1, 100000, ErrorMessage = "Password Manager seats must be between 1 and 100,000")]
public int? Seats { get; set; }
[Range(0, 99, ErrorMessage = "Additional storage must be between 0 and 99 GB")]
public int? AdditionalStorage { get; set; }
}
public record SecretsManagerUpdateSelections
{
[Range(0, 100000, ErrorMessage = "Secrets Manager seats must be between 0 and 100,000")]
public int? Seats { get; set; }
[Range(0, 100000, ErrorMessage = "Additional service accounts must be between 0 and 100,000")]
public int? AdditionalServiceAccounts { get; set; }
}
}

View File

@@ -1,5 +1,4 @@
#nullable enable
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Payment.Models;
namespace Bit.Api.Billing.Models.Requests.Payment;

View File

@@ -1,5 +1,4 @@
#nullable enable
using System.ComponentModel.DataAnnotations;
using System.ComponentModel.DataAnnotations;
namespace Bit.Api.Billing.Models.Requests.Payment;

View File

@@ -1,5 +1,4 @@
#nullable enable
using System.ComponentModel.DataAnnotations;
using System.ComponentModel.DataAnnotations;
using Bit.Core.Billing.Payment.Models;
namespace Bit.Api.Billing.Models.Requests.Payment;

View File

@@ -1,5 +1,4 @@
#nullable enable
using System.ComponentModel.DataAnnotations;
using System.ComponentModel.DataAnnotations;
using Bit.Core.Billing.Payment.Models;
namespace Bit.Api.Billing.Models.Requests.Payment;

View File

@@ -1,5 +1,4 @@
#nullable enable
using System.ComponentModel.DataAnnotations;
using System.ComponentModel.DataAnnotations;
using Bit.Api.Billing.Attributes;
using Bit.Core.Billing.Payment.Models;
@@ -14,12 +13,9 @@ public class MinimalTokenizedPaymentMethodRequest
[Required]
public required string Token { get; set; }
public TokenizedPaymentMethod ToDomain()
public TokenizedPaymentMethod ToDomain() => new()
{
return new TokenizedPaymentMethod
{
Type = TokenizablePaymentMethodTypeExtensions.From(Type),
Token = Token
};
}
Type = TokenizablePaymentMethodTypeExtensions.From(Type),
Token = Token
};
}

View File

@@ -1,31 +1,15 @@
#nullable enable
using System.ComponentModel.DataAnnotations;
using Bit.Api.Billing.Attributes;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Payment.Models;
namespace Bit.Api.Billing.Models.Requests.Payment;
public class TokenizedPaymentMethodRequest
public class TokenizedPaymentMethodRequest : MinimalTokenizedPaymentMethodRequest
{
[Required]
[PaymentMethodTypeValidation]
public required string Type { get; set; }
[Required]
public required string Token { get; set; }
public MinimalBillingAddressRequest? BillingAddress { get; set; }
public (TokenizedPaymentMethod, BillingAddress?) ToDomain()
public new (TokenizedPaymentMethod, BillingAddress?) ToDomain()
{
var paymentMethod = new TokenizedPaymentMethod
{
Type = TokenizablePaymentMethodTypeExtensions.From(Type),
Token = Token
};
var paymentMethod = base.ToDomain();
var billingAddress = BillingAddress?.ToDomain();
return (paymentMethod, billingAddress);
}
}

View File

@@ -1,5 +1,4 @@
#nullable enable
using System.ComponentModel.DataAnnotations;
using System.ComponentModel.DataAnnotations;
using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Core.Billing.Payment.Models;

View File

@@ -1,27 +0,0 @@
#nullable enable
using System.ComponentModel.DataAnnotations;
using Bit.Core.Billing.Enums;
namespace Bit.Api.Billing.Models.Requests;
public class PreviewTaxAmountForOrganizationTrialRequestBody
{
[Required]
public PlanType PlanType { get; set; }
[Required]
public ProductType ProductType { get; set; }
[Required] public TaxInformationDTO TaxInformation { get; set; } = null!;
public class TaxInformationDTO
{
[Required]
public string Country { get; set; } = null!;
[Required]
public string PostalCode { get; set; } = null!;
public string? TaxId { get; set; }
}
}

View File

@@ -0,0 +1,16 @@
using System.ComponentModel.DataAnnotations;
using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Core.Billing.Payment.Models;
namespace Bit.Api.Billing.Models.Requests.Subscriptions;
public class RestartSubscriptionRequest
{
[Required]
public required MinimalTokenizedPaymentMethodRequest PaymentMethod { get; set; }
[Required]
public required CheckoutBillingAddressRequest BillingAddress { get; set; }
public (TokenizedPaymentMethod, BillingAddress) ToDomain()
=> (PaymentMethod.ToDomain(), BillingAddress.ToDomain());
}

View File

@@ -0,0 +1,19 @@
using System.ComponentModel.DataAnnotations;
using Bit.Api.Billing.Models.Requests.Organizations;
using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Core.Billing.Organizations.Models;
using Bit.Core.Billing.Payment.Models;
namespace Bit.Api.Billing.Models.Requests.Tax;
public record PreviewOrganizationSubscriptionPlanChangeTaxRequest
{
[Required]
public required OrganizationSubscriptionPlanChangeRequest Plan { get; set; }
[Required]
public required CheckoutBillingAddressRequest BillingAddress { get; set; }
public (OrganizationSubscriptionPlanChange, BillingAddress) ToDomain() =>
(Plan.ToDomain(), BillingAddress.ToDomain());
}

View File

@@ -0,0 +1,19 @@
using System.ComponentModel.DataAnnotations;
using Bit.Api.Billing.Models.Requests.Organizations;
using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Core.Billing.Organizations.Models;
using Bit.Core.Billing.Payment.Models;
namespace Bit.Api.Billing.Models.Requests.Tax;
public record PreviewOrganizationSubscriptionPurchaseTaxRequest
{
[Required]
public required OrganizationSubscriptionPurchaseRequest Purchase { get; set; }
[Required]
public required CheckoutBillingAddressRequest BillingAddress { get; set; }
public (OrganizationSubscriptionPurchase, BillingAddress) ToDomain() =>
(Purchase.ToDomain(), BillingAddress.ToDomain());
}

View File

@@ -0,0 +1,11 @@
using Bit.Api.Billing.Models.Requests.Organizations;
using Bit.Core.Billing.Organizations.Models;
namespace Bit.Api.Billing.Models.Requests.Tax;
public class PreviewOrganizationSubscriptionUpdateTaxRequest
{
public required OrganizationSubscriptionUpdateRequest Update { get; set; }
public OrganizationSubscriptionUpdate ToDomain() => Update.ToDomain();
}

View File

@@ -0,0 +1,17 @@
using System.ComponentModel.DataAnnotations;
using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Core.Billing.Payment.Models;
namespace Bit.Api.Billing.Models.Requests.Tax;
public record PreviewPremiumSubscriptionPurchaseTaxRequest
{
[Required]
[Range(0, 99, ErrorMessage = "Additional storage must be between 0 and 99 GB.")]
public short AdditionalStorage { get; set; }
[Required]
public required MinimalBillingAddressRequest BillingAddress { get; set; }
public (short, BillingAddress) ToDomain() => (AdditionalStorage, BillingAddress.ToDomain());
}

View File

@@ -29,6 +29,7 @@ public class AccessPoliciesController : Controller
private readonly IServiceAccountRepository _serviceAccountRepository;
private readonly IUpdateServiceAccountGrantedPoliciesCommand _updateServiceAccountGrantedPoliciesCommand;
private readonly IUserService _userService;
private readonly IEventService _eventService;
private readonly IProjectServiceAccountsAccessPoliciesUpdatesQuery
_projectServiceAccountsAccessPoliciesUpdatesQuery;
private readonly IUpdateProjectServiceAccountsAccessPoliciesCommand
@@ -47,7 +48,8 @@ public class AccessPoliciesController : Controller
IServiceAccountGrantedPolicyUpdatesQuery serviceAccountGrantedPolicyUpdatesQuery,
IProjectServiceAccountsAccessPoliciesUpdatesQuery projectServiceAccountsAccessPoliciesUpdatesQuery,
IUpdateServiceAccountGrantedPoliciesCommand updateServiceAccountGrantedPoliciesCommand,
IUpdateProjectServiceAccountsAccessPoliciesCommand updateProjectServiceAccountsAccessPoliciesCommand)
IUpdateProjectServiceAccountsAccessPoliciesCommand updateProjectServiceAccountsAccessPoliciesCommand,
IEventService eventService)
{
_authorizationService = authorizationService;
_userService = userService;
@@ -61,6 +63,7 @@ public class AccessPoliciesController : Controller
_serviceAccountGrantedPolicyUpdatesQuery = serviceAccountGrantedPolicyUpdatesQuery;
_projectServiceAccountsAccessPoliciesUpdatesQuery = projectServiceAccountsAccessPoliciesUpdatesQuery;
_updateProjectServiceAccountsAccessPoliciesCommand = updateProjectServiceAccountsAccessPoliciesCommand;
_eventService = eventService;
}
[HttpGet("/organizations/{id}/access-policies/people/potential-grantees")]
@@ -186,7 +189,9 @@ public class AccessPoliciesController : Controller
}
var userId = _userService.GetProperUserId(User)!.Value;
var currentPolicies = await _accessPolicyRepository.GetPeoplePoliciesByGrantedServiceAccountIdAsync(peopleAccessPolicies.Id, userId);
var results = await _accessPolicyRepository.ReplaceServiceAccountPeopleAsync(peopleAccessPolicies, userId);
await LogAccessPolicyServiceAccountChanges(currentPolicies, results, userId);
return new ServiceAccountPeopleAccessPoliciesResponseModel(results, userId);
}
@@ -336,4 +341,39 @@ public class AccessPoliciesController : Controller
userId, accessClient);
return new ServiceAccountGrantedPoliciesPermissionDetailsResponseModel(results);
}
public async Task LogAccessPolicyServiceAccountChanges(IEnumerable<BaseAccessPolicy> currentPolicies, IEnumerable<BaseAccessPolicy> updatedPolicies, Guid userId)
{
foreach (var current in currentPolicies.OfType<GroupServiceAccountAccessPolicy>())
{
if (!updatedPolicies.Any(r => r.Id == current.Id))
{
await _eventService.LogServiceAccountGroupEventAsync(userId, current, EventType.ServiceAccount_GroupRemoved, _currentContext.IdentityClientType);
}
}
foreach (var policy in updatedPolicies.OfType<GroupServiceAccountAccessPolicy>())
{
if (!currentPolicies.Any(e => e.Id == policy.Id))
{
await _eventService.LogServiceAccountGroupEventAsync(userId, policy, EventType.ServiceAccount_GroupAdded, _currentContext.IdentityClientType);
}
}
foreach (var current in currentPolicies.OfType<UserServiceAccountAccessPolicy>())
{
if (!updatedPolicies.Any(r => r.Id == current.Id))
{
await _eventService.LogServiceAccountPeopleEventAsync(userId, current, EventType.ServiceAccount_UserRemoved, _currentContext.IdentityClientType);
}
}
foreach (var policy in updatedPolicies.OfType<UserServiceAccountAccessPolicy>())
{
if (!currentPolicies.Any(e => e.Id == policy.Id))
{
await _eventService.LogServiceAccountPeopleEventAsync(userId, policy, EventType.ServiceAccount_UserAdded, _currentContext.IdentityClientType);
}
}
}
}

View File

@@ -42,6 +42,8 @@ public class ServiceAccountsController : Controller
private readonly IDeleteServiceAccountsCommand _deleteServiceAccountsCommand;
private readonly IRevokeAccessTokensCommand _revokeAccessTokensCommand;
private readonly IPricingClient _pricingClient;
private readonly IEventService _eventService;
private readonly IOrganizationUserRepository _organizationUserRepository;
public ServiceAccountsController(
ICurrentContext currentContext,
@@ -58,7 +60,9 @@ public class ServiceAccountsController : Controller
IUpdateServiceAccountCommand updateServiceAccountCommand,
IDeleteServiceAccountsCommand deleteServiceAccountsCommand,
IRevokeAccessTokensCommand revokeAccessTokensCommand,
IPricingClient pricingClient)
IPricingClient pricingClient,
IEventService eventService,
IOrganizationUserRepository organizationUserRepository)
{
_currentContext = currentContext;
_userService = userService;
@@ -75,6 +79,8 @@ public class ServiceAccountsController : Controller
_pricingClient = pricingClient;
_createAccessTokenCommand = createAccessTokenCommand;
_updateSecretsManagerSubscriptionCommand = updateSecretsManagerSubscriptionCommand;
_eventService = eventService;
_organizationUserRepository = organizationUserRepository;
}
[HttpGet("/organizations/{organizationId}/service-accounts")]
@@ -139,8 +145,15 @@ public class ServiceAccountsController : Controller
}
var userId = _userService.GetProperUserId(User).Value;
var result =
await _createServiceAccountCommand.CreateAsync(createRequest.ToServiceAccount(organizationId), userId);
await _createServiceAccountCommand.CreateAsync(serviceAccount, userId);
if (result != null)
{
await _eventService.LogServiceAccountEventAsync(userId, [serviceAccount], EventType.ServiceAccount_Created, _currentContext.IdentityClientType);
}
return new ServiceAccountResponseModel(result);
}
@@ -197,6 +210,9 @@ public class ServiceAccountsController : Controller
}
await _deleteServiceAccountsCommand.DeleteServiceAccounts(serviceAccountsToDelete);
var userId = _userService.GetProperUserId(User)!.Value;
await _eventService.LogServiceAccountEventAsync(userId, serviceAccountsToDelete, EventType.ServiceAccount_Deleted, _currentContext.IdentityClientType);
var responses = results.Select(r => new BulkDeleteResponseModel(r.ServiceAccount.Id, r.Error));
return new ListResponseModel<BulkDeleteResponseModel>(responses);
}

View File

@@ -34,6 +34,7 @@ public class Event : ITableObject<Guid>, IEvent
SecretId = e.SecretId;
ProjectId = e.ProjectId;
ServiceAccountId = e.ServiceAccountId;
GrantedServiceAccountId = e.GrantedServiceAccountId;
}
public Guid Id { get; set; }
@@ -59,7 +60,7 @@ public class Event : ITableObject<Guid>, IEvent
public Guid? SecretId { get; set; }
public Guid? ProjectId { get; set; }
public Guid? ServiceAccountId { get; set; }
public Guid? GrantedServiceAccountId { get; set; }
public void SetNewId()
{
Id = CoreHelpers.GenerateComb();

View File

@@ -70,8 +70,8 @@ public enum EventType : int
Organization_EnabledKeyConnector = 1606,
Organization_DisabledKeyConnector = 1607,
Organization_SponsorshipsSynced = 1608,
[Obsolete("Use other specific Organization_CollectionManagement events instead")]
Organization_CollectionManagement_Updated = 1609, // TODO: Will be removed in PM-25315
[Obsolete("Kept for historical data. Use specific Organization_CollectionManagement events instead.")]
Organization_CollectionManagement_Updated = 1609,
Organization_CollectionManagement_LimitCollectionCreationEnabled = 1610,
Organization_CollectionManagement_LimitCollectionCreationDisabled = 1611,
Organization_CollectionManagement_LimitCollectionDeletionEnabled = 1612,
@@ -109,4 +109,11 @@ public enum EventType : int
Project_Created = 2201,
Project_Edited = 2202,
Project_Deleted = 2203,
ServiceAccount_UserAdded = 2300,
ServiceAccount_UserRemoved = 2301,
ServiceAccount_GroupAdded = 2302,
ServiceAccount_GroupRemoved = 2303,
ServiceAccount_Created = 2304,
ServiceAccount_Deleted = 2305,
}

View File

@@ -0,0 +1,10 @@
namespace Bit.Api.AdminConsole.Models.Response.Organizations;
public enum OrganizationIntegrationStatus : int
{
NotApplicable,
Invalid,
Initiated,
InProgress,
Completed
}

View File

@@ -20,6 +20,7 @@ public enum PolicyType : byte
RestrictedItemTypesPolicy = 15,
UriMatchDefaults = 16,
AutotypeDefaultSetting = 17,
AutomaticUserConfirmation = 18,
}
public static class PolicyTypeExtensions
@@ -50,6 +51,7 @@ public static class PolicyTypeExtensions
PolicyType.RestrictedItemTypesPolicy => "Restricted item types",
PolicyType.UriMatchDefaults => "URI match defaults",
PolicyType.AutotypeDefaultSetting => "Autotype default setting",
PolicyType.AutomaticUserConfirmation => "Automatically confirm invited users",
};
}
}

View File

@@ -5,4 +5,6 @@ public interface IEventListenerConfiguration
public string EventQueueName { get; }
public string EventSubscriptionName { get; }
public string EventTopicName { get; }
public int EventPrefetchCount { get; }
public int EventMaxConcurrentCalls { get; }
}

View File

@@ -10,6 +10,8 @@ public interface IIntegrationListenerConfiguration : IEventListenerConfiguration
public string IntegrationSubscriptionName { get; }
public string IntegrationTopicName { get; }
public int MaxRetries { get; }
public int IntegrationPrefetchCount { get; }
public int IntegrationMaxConcurrentCalls { get; }
public string RoutingKey
{

View File

@@ -0,0 +1,71 @@
using System.Security.Cryptography;
using System.Text;
using Bit.Core.AdminConsole.Entities;
namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations;
public class IntegrationOAuthState
{
private const int _orgHashLength = 12;
private static readonly TimeSpan _maxAge = TimeSpan.FromMinutes(20);
public Guid IntegrationId { get; }
private DateTimeOffset Issued { get; }
private string OrganizationIdHash { get; }
private IntegrationOAuthState(Guid integrationId, string organizationIdHash, DateTimeOffset issued)
{
IntegrationId = integrationId;
OrganizationIdHash = organizationIdHash;
Issued = issued;
}
public static IntegrationOAuthState FromIntegration(OrganizationIntegration integration, TimeProvider timeProvider)
{
var integrationId = integration.Id;
var issuedUtc = timeProvider.GetUtcNow();
var organizationIdHash = ComputeOrgHash(integration.OrganizationId, issuedUtc.ToUnixTimeSeconds());
return new IntegrationOAuthState(integrationId, organizationIdHash, issuedUtc);
}
public static IntegrationOAuthState? FromString(string state, TimeProvider timeProvider)
{
if (string.IsNullOrWhiteSpace(state)) return null;
var parts = state.Split('.');
if (parts.Length != 3) return null;
// Verify timestamp
if (!long.TryParse(parts[2], out var unixSeconds)) return null;
var issuedUtc = DateTimeOffset.FromUnixTimeSeconds(unixSeconds);
var now = timeProvider.GetUtcNow();
var age = now - issuedUtc;
if (age > _maxAge) return null;
// Parse integration id and store org
if (!Guid.TryParse(parts[0], out var integrationId)) return null;
var organizationIdHash = parts[1];
return new IntegrationOAuthState(integrationId, organizationIdHash, issuedUtc);
}
public bool ValidateOrg(Guid orgId)
{
var expected = ComputeOrgHash(orgId, Issued.ToUnixTimeSeconds());
return expected == OrganizationIdHash;
}
public override string ToString()
{
return $"{IntegrationId}.{OrganizationIdHash}.{Issued.ToUnixTimeSeconds()}";
}
private static string ComputeOrgHash(Guid orgId, long timestamp)
{
var bytes = SHA256.HashData(Encoding.UTF8.GetBytes($"{orgId:N}:{timestamp}"));
return Convert.ToHexString(bytes)[.._orgHashLength];
}
}

View File

@@ -25,4 +25,24 @@ public abstract class ListenerConfiguration
{
get => _globalSettings.EventLogging.AzureServiceBus.IntegrationTopicName;
}
public int EventPrefetchCount
{
get => _globalSettings.EventLogging.AzureServiceBus.DefaultPrefetchCount;
}
public int EventMaxConcurrentCalls
{
get => _globalSettings.EventLogging.AzureServiceBus.DefaultMaxConcurrentCalls;
}
public int IntegrationPrefetchCount
{
get => _globalSettings.EventLogging.AzureServiceBus.DefaultPrefetchCount;
}
public int IntegrationMaxConcurrentCalls
{
get => _globalSettings.EventLogging.AzureServiceBus.DefaultMaxConcurrentCalls;
}
}

View File

@@ -39,4 +39,5 @@ public class EventMessage : IEvent
public Guid? SecretId { get; set; }
public Guid? ProjectId { get; set; }
public Guid? ServiceAccountId { get; set; }
public Guid? GrantedServiceAccountId { get; set; }
}

View File

@@ -37,6 +37,7 @@ public class AzureEvent : ITableEntity
public Guid? SecretId { get; set; }
public Guid? ProjectId { get; set; }
public Guid? ServiceAccountId { get; set; }
public Guid? GrantedServiceAccountId { get; set; }
public EventTableEntity ToEventTableEntity()
{
@@ -68,6 +69,7 @@ public class AzureEvent : ITableEntity
SecretId = SecretId,
ServiceAccountId = ServiceAccountId,
ProjectId = ProjectId,
GrantedServiceAccountId = GrantedServiceAccountId
};
}
}
@@ -99,6 +101,7 @@ public class EventTableEntity : IEvent
SecretId = e.SecretId;
ProjectId = e.ProjectId;
ServiceAccountId = e.ServiceAccountId;
GrantedServiceAccountId = e.GrantedServiceAccountId;
}
public string PartitionKey { get; set; }
@@ -127,6 +130,7 @@ public class EventTableEntity : IEvent
public Guid? SecretId { get; set; }
public Guid? ProjectId { get; set; }
public Guid? ServiceAccountId { get; set; }
public Guid? GrantedServiceAccountId { get; set; }
public AzureEvent ToAzureEvent()
{
@@ -157,7 +161,8 @@ public class EventTableEntity : IEvent
DomainName = DomainName,
SecretId = SecretId,
ProjectId = ProjectId,
ServiceAccountId = ServiceAccountId
ServiceAccountId = ServiceAccountId,
GrantedServiceAccountId = GrantedServiceAccountId
};
}
@@ -232,6 +237,15 @@ public class EventTableEntity : IEvent
});
}
if (e.GrantedServiceAccountId.HasValue)
{
entities.Add(new EventTableEntity(e)
{
PartitionKey = pKey,
RowKey = $"GrantedServiceAccountId={e.GrantedServiceAccountId}__Date={dateKey}__Uniquifier={uniquifier}"
});
}
return entities;
}

View File

@@ -28,4 +28,5 @@ public interface IEvent
Guid? SecretId { get; set; }
Guid? ProjectId { get; set; }
Guid? ServiceAccountId { get; set; }
Guid? GrantedServiceAccountId { get; set; }
}

View File

@@ -27,6 +27,7 @@ public interface IEventRepository
DateTime startDate, DateTime endDate, PageOptions pageOptions);
Task<PagedResult<IEvent>> GetManyByCipherAsync(Cipher cipher, DateTime startDate, DateTime endDate,
PageOptions pageOptions);
Task CreateAsync(IEvent e);
Task CreateManyAsync(IEnumerable<IEvent> e);
Task<PagedResult<IEvent>> GetManyByOrganizationServiceAccountAsync(Guid organizationId, Guid serviceAccountId,

View File

@@ -77,12 +77,18 @@ public class EventRepository : IEventRepository
return await GetManyAsync(partitionKey, $"CipherId={cipher.Id}__Date={{0}}", startDate, endDate, pageOptions);
}
public async Task<PagedResult<IEvent>> GetManyByOrganizationServiceAccountAsync(Guid organizationId,
Guid serviceAccountId, DateTime startDate, DateTime endDate, PageOptions pageOptions)
public async Task<PagedResult<IEvent>> GetManyByOrganizationServiceAccountAsync(
Guid organizationId,
Guid serviceAccountId,
DateTime startDate,
DateTime endDate,
PageOptions pageOptions)
{
return await GetManyServiceAccountAsync(
$"OrganizationId={organizationId}",
serviceAccountId.ToString(),
startDate, endDate, pageOptions);
return await GetManyAsync($"OrganizationId={organizationId}",
$"ServiceAccountId={serviceAccountId}__Date={{0}}", startDate, endDate, pageOptions);
}
public async Task CreateAsync(IEvent e)
@@ -141,6 +147,40 @@ public class EventRepository : IEventRepository
}
}
public async Task<PagedResult<IEvent>> GetManyServiceAccountAsync(
string partitionKey,
string serviceAccountId,
DateTime startDate,
DateTime endDate,
PageOptions pageOptions)
{
var start = CoreHelpers.DateTimeToTableStorageKey(startDate);
var end = CoreHelpers.DateTimeToTableStorageKey(endDate);
var filter = MakeFilterForServiceAccount(partitionKey, serviceAccountId, startDate, endDate);
var result = new PagedResult<IEvent>();
var query = _tableClient.QueryAsync<AzureEvent>(filter, pageOptions.PageSize);
await using (var enumerator = query.AsPages(pageOptions.ContinuationToken,
pageOptions.PageSize).GetAsyncEnumerator())
{
if (await enumerator.MoveNextAsync())
{
result.ContinuationToken = enumerator.Current.ContinuationToken;
var events = enumerator.Current.Values
.Select(e => e.ToEventTableEntity())
.ToList();
events = events.OrderByDescending(e => e.Date).ToList();
result.Data.AddRange(events);
}
}
return result;
}
public async Task<PagedResult<IEvent>> GetManyAsync(string partitionKey, string rowKey,
DateTime startDate, DateTime endDate, PageOptions pageOptions)
{
@@ -172,4 +212,27 @@ public class EventRepository : IEventRepository
{
return $"PartitionKey eq '{partitionKey}' and RowKey le '{rowStart}' and RowKey ge '{rowEnd}'";
}
private string MakeFilterForServiceAccount(
string partitionKey,
string machineAccountId,
DateTime startDate,
DateTime endDate)
{
var start = CoreHelpers.DateTimeToTableStorageKey(startDate);
var end = CoreHelpers.DateTimeToTableStorageKey(endDate);
var rowKey1Start = $"ServiceAccountId={machineAccountId}__Date={start}";
var rowKey1End = $"ServiceAccountId={machineAccountId}__Date={end}";
var rowKey2Start = $"GrantedServiceAccountId={machineAccountId}__Date={start}";
var rowKey2End = $"GrantedServiceAccountId={machineAccountId}__Date={end}";
var left = $"PartitionKey eq '{partitionKey}' and RowKey le '{rowKey1Start}' and RowKey ge '{rowKey1End}'";
var right = $"PartitionKey eq '{partitionKey}' and RowKey le '{rowKey2Start}' and RowKey ge '{rowKey2End}'";
return $"({left}) or ({right})";
}
}

View File

@@ -4,6 +4,7 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Interfaces;
using Bit.Core.Auth.Identity;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.SecretsManager.Entities;
@@ -37,4 +38,7 @@ public interface IEventService
Task LogServiceAccountSecretsEventAsync(Guid serviceAccountId, IEnumerable<Secret> secrets, EventType type, DateTime? date = null);
Task LogUserProjectsEventAsync(Guid userId, IEnumerable<Project> projects, EventType type, DateTime? date = null);
Task LogServiceAccountProjectsEventAsync(Guid serviceAccountId, IEnumerable<Project> projects, EventType type, DateTime? date = null);
Task LogServiceAccountPeopleEventAsync(Guid userId, UserServiceAccountAccessPolicy policy, EventType type, IdentityClientType identityClientType, DateTime? date = null);
Task LogServiceAccountGroupEventAsync(Guid userId, GroupServiceAccountAccessPolicy policy, EventType type, IdentityClientType identityClientType, DateTime? date = null);
Task LogServiceAccountEventAsync(Guid userId, List<ServiceAccount> serviceAccount, EventType type, IdentityClientType identityClientType, DateTime? date = null);
}

View File

@@ -3,7 +3,7 @@
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Models.Business.Provider;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Entities;
using Bit.Core.Models.Business;
@@ -11,8 +11,7 @@ namespace Bit.Core.AdminConsole.Services;
public interface IProviderService
{
Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo,
TokenizedPaymentSource tokenizedPaymentSource = null);
Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TokenizedPaymentMethod paymentMethod, BillingAddress billingAddress);
Task UpdateAsync(Provider provider, bool updateBilling = false);
Task<List<ProviderUser>> InviteUserAsync(ProviderUserInvite<string> invite);

View File

@@ -5,7 +5,7 @@ public interface ISlackService
Task<string> GetChannelIdAsync(string token, string channelName);
Task<List<string>> GetChannelIdsAsync(string token, List<string> channelNames);
Task<string> GetDmChannelByEmailAsync(string token, string email);
string GetRedirectUrl(string redirectUrl);
string GetRedirectUrl(string callbackUrl, string state);
Task<string> ObtainTokenViaOAuth(string code, string redirectUrl);
Task SendSlackMessageByChannelIdAsync(string token, string message, string channelId);
}

View File

@@ -14,13 +14,14 @@ public class AzureServiceBusEventListenerService<TConfiguration> : EventLoggingL
TConfiguration configuration,
IEventMessageHandler handler,
IAzureServiceBusService serviceBusService,
ServiceBusProcessorOptions serviceBusOptions,
ILoggerFactory loggerFactory)
: base(handler, CreateLogger(loggerFactory, configuration))
{
_processor = serviceBusService.CreateProcessor(
topicName: configuration.EventTopicName,
subscriptionName: configuration.EventSubscriptionName,
new ServiceBusProcessorOptions());
options: serviceBusOptions);
}
protected override async Task ExecuteAsync(CancellationToken cancellationToken)

View File

@@ -18,6 +18,7 @@ public class AzureServiceBusIntegrationListenerService<TConfiguration> : Backgro
TConfiguration configuration,
IIntegrationHandler handler,
IAzureServiceBusService serviceBusService,
ServiceBusProcessorOptions serviceBusOptions,
ILoggerFactory loggerFactory)
{
_handler = handler;
@@ -29,7 +30,7 @@ public class AzureServiceBusIntegrationListenerService<TConfiguration> : Backgro
_processor = _serviceBusService.CreateProcessor(
topicName: configuration.IntegrationTopicName,
subscriptionName: configuration.IntegrationSubscriptionName,
options: new ServiceBusProcessorOptions());
options: serviceBusOptions);
}
protected override async Task ExecuteAsync(CancellationToken cancellationToken)

View File

@@ -19,6 +19,7 @@ public class SlackService(
private readonly string _slackApiBaseUrl = globalSettings.Slack.ApiBaseUrl;
public const string HttpClientName = "SlackServiceHttpClient";
private const string _slackOAuthBaseUri = "https://slack.com/oauth/v2/authorize";
public async Task<string> GetChannelIdAsync(string token, string channelName)
{
@@ -73,9 +74,18 @@ public class SlackService(
return await OpenDmChannel(token, userId);
}
public string GetRedirectUrl(string redirectUrl)
public string GetRedirectUrl(string callbackUrl, string state)
{
return $"https://slack.com/oauth/v2/authorize?client_id={_clientId}&scope={_scopes}&redirect_uri={redirectUrl}";
var builder = new UriBuilder(_slackOAuthBaseUri);
var query = HttpUtility.ParseQueryString(builder.Query);
query["client_id"] = _clientId;
query["scope"] = _scopes;
query["redirect_uri"] = callbackUrl;
query["state"] = state;
builder.Query = query.ToString();
return builder.ToString();
}
public async Task<string> ObtainTokenViaOAuth(string code, string redirectUrl)

View File

@@ -6,6 +6,7 @@ using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Interfaces;
using Bit.Core.AdminConsole.Models.Data.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.Identity;
using Bit.Core.Context;
using Bit.Core.Entities;
using Bit.Core.Enums;
@@ -516,6 +517,135 @@ public class EventService : IEventService
await _eventWriteService.CreateManyAsync(eventMessages);
}
public async Task LogServiceAccountPeopleEventAsync(Guid userId, UserServiceAccountAccessPolicy policy, EventType type, IdentityClientType identityClientType, DateTime? date = null)
{
var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync();
var eventMessages = new List<IEvent>();
var orgUser = await _organizationUserRepository.GetByIdAsync((Guid)policy.OrganizationUserId);
if (!CanUseEvents(orgAbilities, orgUser.OrganizationId))
{
return;
}
var (actingUserId, serviceAccountId) = MapIdentityClientType(userId, identityClientType);
if (actingUserId is null && serviceAccountId is null)
{
return;
}
if (policy.OrganizationUserId != null)
{
var e = new EventMessage(_currentContext)
{
OrganizationId = orgUser.OrganizationId,
Type = type,
GrantedServiceAccountId = policy.GrantedServiceAccountId,
ServiceAccountId = serviceAccountId,
UserId = policy.OrganizationUserId,
ActingUserId = actingUserId,
Date = date.GetValueOrDefault(DateTime.UtcNow)
};
eventMessages.Add(e);
await _eventWriteService.CreateManyAsync(eventMessages);
}
}
public async Task LogServiceAccountGroupEventAsync(Guid userId, GroupServiceAccountAccessPolicy policy, EventType type, IdentityClientType identityClientType, DateTime? date = null)
{
var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync();
var eventMessages = new List<IEvent>();
if (!CanUseEvents(orgAbilities, policy.Group.OrganizationId))
{
return;
}
var (actingUserId, serviceAccountId) = MapIdentityClientType(userId, identityClientType);
if (actingUserId is null && serviceAccountId is null)
{
return;
}
if (policy.GroupId != null)
{
var e = new EventMessage(_currentContext)
{
OrganizationId = policy.Group.OrganizationId,
Type = type,
GrantedServiceAccountId = policy.GrantedServiceAccountId,
ServiceAccountId = serviceAccountId,
GroupId = policy.GroupId,
ActingUserId = actingUserId,
Date = date.GetValueOrDefault(DateTime.UtcNow)
};
eventMessages.Add(e);
await _eventWriteService.CreateManyAsync(eventMessages);
}
}
public async Task LogServiceAccountEventAsync(Guid userId, List<ServiceAccount> serviceAccounts, EventType type, IdentityClientType identityClientType, DateTime? date = null)
{
var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync();
var eventMessages = new List<IEvent>();
foreach (var serviceAccount in serviceAccounts)
{
if (!CanUseEvents(orgAbilities, serviceAccount.OrganizationId))
{
continue;
}
var (actingUserId, serviceAccountId) = MapIdentityClientType(userId, identityClientType);
if (actingUserId is null && serviceAccountId is null)
{
continue;
}
if (serviceAccount != null)
{
var e = new EventMessage(_currentContext)
{
OrganizationId = serviceAccount.OrganizationId,
Type = type,
GrantedServiceAccountId = serviceAccount.Id,
ServiceAccountId = serviceAccountId,
ActingUserId = actingUserId,
Date = date.GetValueOrDefault(DateTime.UtcNow)
};
eventMessages.Add(e);
}
}
if (eventMessages.Any())
{
await _eventWriteService.CreateManyAsync(eventMessages);
}
}
private (Guid? actingUserId, Guid? serviceAccountId) MapIdentityClientType(
Guid userId, IdentityClientType identityClientType)
{
if (identityClientType == IdentityClientType.Organization)
{
return (null, null);
}
return identityClientType switch
{
IdentityClientType.User => (userId, null),
IdentityClientType.ServiceAccount => (null, userId),
_ => throw new InvalidOperationException("Unknown identity client type.")
};
}
private async Task<Guid?> GetProviderIdAsync(Guid? orgId)
{
if (_currentContext == null || !orgId.HasValue)

View File

@@ -1,6 +1,7 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Interfaces;
using Bit.Core.Auth.Identity;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.SecretsManager.Entities;
@@ -139,4 +140,19 @@ public class NoopEventService : IEventService
{
return Task.FromResult(0);
}
public Task LogServiceAccountPeopleEventAsync(Guid userId, UserServiceAccountAccessPolicy policy, EventType type, IdentityClientType identityClientType, DateTime? date = null)
{
return Task.FromResult(0);
}
public Task LogServiceAccountGroupEventAsync(Guid userId, GroupServiceAccountAccessPolicy policy, EventType type, IdentityClientType identityClientType, DateTime? date = null)
{
return Task.FromResult(0);
}
public Task LogServiceAccountEventAsync(Guid userId, List<ServiceAccount> serviceAccount, EventType type, IdentityClientType identityClientType, DateTime? date = null)
{
return Task.FromResult(0);
}
}

View File

@@ -3,7 +3,7 @@
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Models.Business.Provider;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Entities;
using Bit.Core.Models.Business;
@@ -11,7 +11,7 @@ namespace Bit.Core.AdminConsole.Services.NoopImplementations;
public class NoopProviderService : IProviderService
{
public Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo, TokenizedPaymentSource tokenizedPaymentSource = null) => throw new NotImplementedException();
public Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TokenizedPaymentMethod paymentMethod, BillingAddress billingAddress) => throw new NotImplementedException();
public Task UpdateAsync(Provider provider, bool updateBilling = false) => throw new NotImplementedException();

View File

@@ -19,7 +19,7 @@ public class NoopSlackService : ISlackService
return Task.FromResult(string.Empty);
}
public string GetRedirectUrl(string redirectUrl)
public string GetRedirectUrl(string callbackUrl, string state)
{
return string.Empty;
}

View File

@@ -1,5 +1,4 @@
#nullable enable
using OneOf;
using OneOf;
namespace Bit.Core.Billing.Commands;
@@ -20,18 +19,38 @@ public record Unhandled(Exception? Exception = null, string Response = "Somethin
/// </remarks>
/// </summary>
/// <typeparam name="T">The successful result type of the operation.</typeparam>
public class BillingCommandResult<T> : OneOfBase<T, BadRequest, Conflict, Unhandled>
public class BillingCommandResult<T>(OneOf<T, BadRequest, Conflict, Unhandled> input)
: OneOfBase<T, BadRequest, Conflict, Unhandled>(input)
{
private BillingCommandResult(OneOf<T, BadRequest, Conflict, Unhandled> input) : base(input) { }
public static implicit operator BillingCommandResult<T>(T output) => new(output);
public static implicit operator BillingCommandResult<T>(BadRequest badRequest) => new(badRequest);
public static implicit operator BillingCommandResult<T>(Conflict conflict) => new(conflict);
public static implicit operator BillingCommandResult<T>(Unhandled unhandled) => new(unhandled);
public BillingCommandResult<TResult> Map<TResult>(Func<T, TResult> f)
=> Match(
value => new BillingCommandResult<TResult>(f(value)),
badRequest => new BillingCommandResult<TResult>(badRequest),
conflict => new BillingCommandResult<TResult>(conflict),
unhandled => new BillingCommandResult<TResult>(unhandled));
public Task TapAsync(Func<T, Task> f) => Match(
f,
_ => Task.CompletedTask,
_ => Task.CompletedTask,
_ => Task.CompletedTask);
}
public static class BillingCommandResultExtensions
{
public static async Task<BillingCommandResult<TResult>> AndThenAsync<T, TResult>(
this Task<BillingCommandResult<T>> task, Func<T, Task<BillingCommandResult<TResult>>> binder)
{
var result = await task;
return await result.Match(
binder,
badRequest => Task.FromResult(new BillingCommandResult<TResult>(badRequest)),
conflict => Task.FromResult(new BillingCommandResult<TResult>(conflict)),
unhandled => Task.FromResult(new BillingCommandResult<TResult>(unhandled)));
}
}

View File

@@ -0,0 +1,7 @@
namespace Bit.Core.Billing.Enums;
public enum PlanCadenceType
{
Annually,
Monthly
}

View File

@@ -9,7 +9,7 @@ using Bit.Core.Billing.Premium.Commands;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Implementations;
using Bit.Core.Billing.Tax.Commands;
using Bit.Core.Billing.Subscriptions.Commands;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Billing.Tax.Services.Implementations;
@@ -28,11 +28,12 @@ public static class ServiceCollectionExtensions
services.AddTransient<ISubscriberService, SubscriberService>();
services.AddLicenseServices();
services.AddPricingClient();
services.AddTransient<IPreviewTaxAmountCommand, PreviewTaxAmountCommand>();
services.AddPaymentOperations();
services.AddOrganizationLicenseCommandsQueries();
services.AddPremiumCommands();
services.AddTransient<IGetOrganizationWarningsQuery, GetOrganizationWarningsQuery>();
services.AddTransient<IRestartSubscriptionCommand, RestartSubscriptionCommand>();
services.AddTransient<IPreviewOrganizationTaxCommand, PreviewOrganizationTaxCommand>();
}
private static void AddOrganizationLicenseCommandsQueries(this IServiceCollection services)
@@ -46,5 +47,6 @@ public static class ServiceCollectionExtensions
{
services.AddScoped<ICreatePremiumCloudHostedSubscriptionCommand, CreatePremiumCloudHostedSubscriptionCommand>();
services.AddScoped<ICreatePremiumSelfHostedSubscriptionCommand, CreatePremiumSelfHostedSubscriptionCommand>();
services.AddTransient<IPreviewPremiumTaxCommand, PreviewPremiumTaxCommand>();
}
}

View File

@@ -0,0 +1,383 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Organizations.Models;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Pricing;
using Bit.Core.Enums;
using Bit.Core.Services;
using Bit.Core.Utilities;
using Microsoft.Extensions.Logging;
using OneOf;
using Stripe;
namespace Bit.Core.Billing.Organizations.Commands;
using static Core.Constants;
using static StripeConstants;
public interface IPreviewOrganizationTaxCommand
{
Task<BillingCommandResult<(decimal Tax, decimal Total)>> Run(
OrganizationSubscriptionPurchase purchase,
BillingAddress billingAddress);
Task<BillingCommandResult<(decimal Tax, decimal Total)>> Run(
Organization organization,
OrganizationSubscriptionPlanChange planChange,
BillingAddress billingAddress);
Task<BillingCommandResult<(decimal Tax, decimal Total)>> Run(
Organization organization,
OrganizationSubscriptionUpdate update);
}
public class PreviewOrganizationTaxCommand(
ILogger<PreviewOrganizationTaxCommand> logger,
IPricingClient pricingClient,
IStripeAdapter stripeAdapter)
: BaseBillingCommand<PreviewOrganizationTaxCommand>(logger), IPreviewOrganizationTaxCommand
{
public Task<BillingCommandResult<(decimal Tax, decimal Total)>> Run(
OrganizationSubscriptionPurchase purchase,
BillingAddress billingAddress)
=> HandleAsync<(decimal, decimal)>(async () =>
{
var plan = await pricingClient.GetPlanOrThrow(purchase.PlanType);
var options = GetBaseOptions(billingAddress, purchase.Tier != ProductTierType.Families);
var items = new List<InvoiceSubscriptionDetailsItemOptions>();
switch (purchase)
{
case { PasswordManager.Sponsored: true }:
var sponsoredPlan = StaticStore.GetSponsoredPlan(PlanSponsorshipType.FamiliesForEnterprise);
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = sponsoredPlan.StripePlanId,
Quantity = 1
});
break;
case { SecretsManager.Standalone: true }:
items.AddRange([
new InvoiceSubscriptionDetailsItemOptions
{
Price = plan.PasswordManager.StripeSeatPlanId,
Quantity = purchase.PasswordManager.Seats
},
new InvoiceSubscriptionDetailsItemOptions
{
Price = plan.SecretsManager.StripeSeatPlanId,
Quantity = purchase.SecretsManager.Seats
}
]);
options.Coupon = CouponIDs.SecretsManagerStandalone;
break;
default:
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = plan.HasNonSeatBasedPasswordManagerPlan()
? plan.PasswordManager.StripePlanId
: plan.PasswordManager.StripeSeatPlanId,
Quantity = purchase.PasswordManager.Seats
});
if (purchase.PasswordManager.AdditionalStorage > 0)
{
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = plan.PasswordManager.StripeStoragePlanId,
Quantity = purchase.PasswordManager.AdditionalStorage
});
}
if (purchase.SecretsManager is { Seats: > 0 })
{
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = plan.SecretsManager.StripeSeatPlanId,
Quantity = purchase.SecretsManager.Seats
});
if (purchase.SecretsManager.AdditionalServiceAccounts > 0)
{
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = plan.SecretsManager.StripeServiceAccountPlanId,
Quantity = purchase.SecretsManager.AdditionalServiceAccounts
});
}
}
break;
}
options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items };
var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options);
return GetAmounts(invoice);
});
public Task<BillingCommandResult<(decimal Tax, decimal Total)>> Run(
Organization organization,
OrganizationSubscriptionPlanChange planChange,
BillingAddress billingAddress)
=> HandleAsync<(decimal, decimal)>(async () =>
{
if (organization.PlanType.GetProductTier() == ProductTierType.Free)
{
var options = GetBaseOptions(billingAddress, planChange.Tier != ProductTierType.Families);
var newPlan = await pricingClient.GetPlanOrThrow(planChange.PlanType);
var items = new List<InvoiceSubscriptionDetailsItemOptions>
{
new ()
{
Price = newPlan.HasNonSeatBasedPasswordManagerPlan()
? newPlan.PasswordManager.StripePlanId
: newPlan.PasswordManager.StripeSeatPlanId,
Quantity = 2
}
};
if (organization.UseSecretsManager && planChange.Tier != ProductTierType.Families)
{
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = newPlan.SecretsManager.StripeSeatPlanId,
Quantity = 2
});
}
options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items };
var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options);
return GetAmounts(invoice);
}
else
{
if (organization is not
{
GatewayCustomerId: not null,
GatewaySubscriptionId: not null
})
{
return new BadRequest("Organization does not have a subscription.");
}
var options = GetBaseOptions(billingAddress, planChange.Tier != ProductTierType.Families);
var subscription = await stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId,
new SubscriptionGetOptions { Expand = ["customer"] });
if (subscription.Customer.Discount != null)
{
options.Coupon = subscription.Customer.Discount.Coupon.Id;
}
var currentPlan = await pricingClient.GetPlanOrThrow(organization.PlanType);
var newPlan = await pricingClient.GetPlanOrThrow(planChange.PlanType);
var subscriptionItemsByPriceId =
subscription.Items.ToDictionary(subscriptionItem => subscriptionItem.Price.Id);
var items = new List<InvoiceSubscriptionDetailsItemOptions>();
var passwordManagerSeats = subscriptionItemsByPriceId[
currentPlan.HasNonSeatBasedPasswordManagerPlan()
? currentPlan.PasswordManager.StripePlanId
: currentPlan.PasswordManager.StripeSeatPlanId];
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = newPlan.HasNonSeatBasedPasswordManagerPlan()
? newPlan.PasswordManager.StripePlanId
: newPlan.PasswordManager.StripeSeatPlanId,
Quantity = passwordManagerSeats.Quantity
});
var hasStorage =
subscriptionItemsByPriceId.TryGetValue(newPlan.PasswordManager.StripeStoragePlanId,
out var storage);
if (hasStorage && storage != null)
{
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = newPlan.PasswordManager.StripeStoragePlanId,
Quantity = storage.Quantity
});
}
var hasSecretsManagerSeats = subscriptionItemsByPriceId.TryGetValue(
newPlan.SecretsManager.StripeSeatPlanId,
out var secretsManagerSeats);
if (hasSecretsManagerSeats && secretsManagerSeats != null)
{
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = newPlan.SecretsManager.StripeSeatPlanId,
Quantity = secretsManagerSeats.Quantity
});
var hasServiceAccounts =
subscriptionItemsByPriceId.TryGetValue(newPlan.SecretsManager.StripeServiceAccountPlanId,
out var serviceAccounts);
if (hasServiceAccounts && serviceAccounts != null)
{
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = newPlan.SecretsManager.StripeServiceAccountPlanId,
Quantity = serviceAccounts.Quantity
});
}
}
options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items };
var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options);
return GetAmounts(invoice);
}
});
public Task<BillingCommandResult<(decimal Tax, decimal Total)>> Run(
Organization organization,
OrganizationSubscriptionUpdate update)
=> HandleAsync<(decimal, decimal)>(async () =>
{
if (organization is not
{
GatewayCustomerId: not null,
GatewaySubscriptionId: not null
})
{
return new BadRequest("Organization does not have a subscription.");
}
var subscription = await stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId,
new SubscriptionGetOptions { Expand = ["customer.tax_ids"] });
var options = GetBaseOptions(subscription.Customer,
organization.GetProductUsageType() == ProductUsageType.Business);
if (subscription.Customer.Discount != null)
{
options.Coupon = subscription.Customer.Discount.Coupon.Id;
}
var currentPlan = await pricingClient.GetPlanOrThrow(organization.PlanType);
var items = new List<InvoiceSubscriptionDetailsItemOptions>();
if (update.PasswordManager?.Seats != null)
{
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = currentPlan.HasNonSeatBasedPasswordManagerPlan()
? currentPlan.PasswordManager.StripePlanId
: currentPlan.PasswordManager.StripeSeatPlanId,
Quantity = update.PasswordManager.Seats
});
}
if (update.PasswordManager?.AdditionalStorage is > 0)
{
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = currentPlan.PasswordManager.StripeStoragePlanId,
Quantity = update.PasswordManager.AdditionalStorage
});
}
if (update.SecretsManager?.Seats is > 0)
{
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = currentPlan.SecretsManager.StripeSeatPlanId,
Quantity = update.SecretsManager.Seats
});
if (update.SecretsManager.AdditionalServiceAccounts is > 0)
{
items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = currentPlan.SecretsManager.StripeServiceAccountPlanId,
Quantity = update.SecretsManager.AdditionalServiceAccounts
});
}
}
options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items };
var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options);
return GetAmounts(invoice);
});
private static (decimal, decimal) GetAmounts(Invoice invoice) => (
Convert.ToDecimal(invoice.Tax) / 100,
Convert.ToDecimal(invoice.Total) / 100);
private static InvoiceCreatePreviewOptions GetBaseOptions(
OneOf<Customer, BillingAddress> addressChoice,
bool businessUse)
{
var country = addressChoice.Match(
customer => customer.Address.Country,
billingAddress => billingAddress.Country
);
var postalCode = addressChoice.Match(
customer => customer.Address.PostalCode,
billingAddress => billingAddress.PostalCode);
var options = new InvoiceCreatePreviewOptions
{
AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true },
Currency = "usd",
CustomerDetails = new InvoiceCustomerDetailsOptions
{
Address = new AddressOptions { Country = country, PostalCode = postalCode },
TaxExempt = businessUse && country != CountryAbbreviations.UnitedStates
? TaxExempt.Reverse
: TaxExempt.None
}
};
var taxId = addressChoice.Match(
customer =>
{
var taxId = customer.TaxIds?.FirstOrDefault();
return taxId != null ? new TaxID(taxId.Type, taxId.Value) : null;
},
billingAddress => billingAddress.TaxId);
if (taxId == null)
{
return options;
}
options.CustomerDetails.TaxIds =
[
new InvoiceCustomerDetailsTaxIdOptions { Type = taxId.Code, Value = taxId.Value }
];
if (taxId.Code == TaxIdType.SpanishNIF)
{
options.CustomerDetails.TaxIds.Add(new InvoiceCustomerDetailsTaxIdOptions
{
Type = TaxIdType.EUVAT,
Value = $"ES{taxId.Value}"
});
}
return options;
}
}

View File

@@ -0,0 +1,23 @@
using Bit.Core.Billing.Enums;
namespace Bit.Core.Billing.Organizations.Models;
public record OrganizationSubscriptionPlanChange
{
public ProductTierType Tier { get; init; }
public PlanCadenceType Cadence { get; init; }
public PlanType PlanType =>
// ReSharper disable once SwitchExpressionHandlesSomeKnownEnumValuesWithExceptionInDefault
Tier switch
{
ProductTierType.Families => PlanType.FamiliesAnnually,
ProductTierType.Teams => Cadence == PlanCadenceType.Monthly
? PlanType.TeamsMonthly
: PlanType.TeamsAnnually,
ProductTierType.Enterprise => Cadence == PlanCadenceType.Monthly
? PlanType.EnterpriseMonthly
: PlanType.EnterpriseAnnually,
_ => throw new InvalidOperationException("Cannot change an Organization subscription to a tier that isn't Families, Teams or Enterprise.")
};
}

View File

@@ -0,0 +1,39 @@
using Bit.Core.Billing.Enums;
namespace Bit.Core.Billing.Organizations.Models;
public record OrganizationSubscriptionPurchase
{
public ProductTierType Tier { get; init; }
public PlanCadenceType Cadence { get; init; }
public required PasswordManagerSelections PasswordManager { get; init; }
public SecretsManagerSelections? SecretsManager { get; init; }
public PlanType PlanType =>
// ReSharper disable once SwitchExpressionHandlesSomeKnownEnumValuesWithExceptionInDefault
Tier switch
{
ProductTierType.Families => PlanType.FamiliesAnnually,
ProductTierType.Teams => Cadence == PlanCadenceType.Monthly
? PlanType.TeamsMonthly
: PlanType.TeamsAnnually,
ProductTierType.Enterprise => Cadence == PlanCadenceType.Monthly
? PlanType.EnterpriseMonthly
: PlanType.EnterpriseAnnually,
_ => throw new InvalidOperationException("Cannot purchase an Organization subscription that isn't Families, Teams or Enterprise.")
};
public record PasswordManagerSelections
{
public int Seats { get; init; }
public int AdditionalStorage { get; init; }
public bool Sponsored { get; init; }
}
public record SecretsManagerSelections
{
public int Seats { get; init; }
public int AdditionalServiceAccounts { get; init; }
public bool Standalone { get; init; }
}
}

View File

@@ -0,0 +1,19 @@
namespace Bit.Core.Billing.Organizations.Models;
public record OrganizationSubscriptionUpdate
{
public PasswordManagerSelections? PasswordManager { get; init; }
public SecretsManagerSelections? SecretsManager { get; init; }
public record PasswordManagerSelections
{
public int? Seats { get; init; }
public int? AdditionalStorage { get; init; }
}
public record SecretsManagerSelections
{
public int? Seats { get; init; }
public int? AdditionalServiceAccounts { get; init; }
}
}

View File

@@ -0,0 +1,65 @@
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Services;
using Microsoft.Extensions.Logging;
using Stripe;
namespace Bit.Core.Billing.Premium.Commands;
using static StripeConstants;
public interface IPreviewPremiumTaxCommand
{
Task<BillingCommandResult<(decimal Tax, decimal Total)>> Run(
int additionalStorage,
BillingAddress billingAddress);
}
public class PreviewPremiumTaxCommand(
ILogger<PreviewPremiumTaxCommand> logger,
IStripeAdapter stripeAdapter) : BaseBillingCommand<PreviewPremiumTaxCommand>(logger), IPreviewPremiumTaxCommand
{
public Task<BillingCommandResult<(decimal Tax, decimal Total)>> Run(
int additionalStorage,
BillingAddress billingAddress)
=> HandleAsync<(decimal, decimal)>(async () =>
{
var options = new InvoiceCreatePreviewOptions
{
AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true },
CustomerDetails = new InvoiceCustomerDetailsOptions
{
Address = new AddressOptions
{
Country = billingAddress.Country,
PostalCode = billingAddress.PostalCode
}
},
Currency = "usd",
SubscriptionDetails = new InvoiceSubscriptionDetailsOptions
{
Items =
[
new InvoiceSubscriptionDetailsItemOptions { Price = Prices.PremiumAnnually, Quantity = 1 }
]
}
};
if (additionalStorage > 0)
{
options.SubscriptionDetails.Items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = Prices.StoragePlanPersonal,
Quantity = additionalStorage
});
}
var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options);
return GetAmounts(invoice);
});
private static (decimal, decimal) GetAmounts(Invoice invoice) => (
Convert.ToDecimal(invoice.Tax) / 100,
Convert.ToDecimal(invoice.Total) / 100);
}

View File

@@ -258,7 +258,7 @@ public class ProviderMigrator(
// Create dummy payment source for legacy migration - this migrator is deprecated and will be removed
var dummyPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "migration_dummy_token");
var customer = await providerBillingService.SetupCustomer(provider, taxInfo, dummyPaymentSource);
var customer = await providerBillingService.SetupCustomer(provider, null, null);
await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions
{

View File

@@ -5,10 +5,10 @@ using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Providers.Entities;
using Bit.Core.Billing.Providers.Models;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Models.Business;
using Stripe;
namespace Bit.Core.Billing.Providers.Services;
@@ -79,16 +79,16 @@ public interface IProviderBillingService
int seatAdjustment);
/// <summary>
/// For use during the provider setup process, this method creates a Stripe <see cref="Stripe.Customer"/> for the specified <paramref name="provider"/> utilizing the provided <paramref name="taxInfo"/>.
/// For use during the provider setup process, this method creates a Stripe <see cref="Stripe.Customer"/> for the specified <paramref name="provider"/> utilizing the provided <paramref name="paymentMethod"/> and <paramref name="billingAddress"/>.
/// </summary>
/// <param name="provider">The <see cref="Provider"/> to create a Stripe customer for.</param>
/// <param name="taxInfo">The <see cref="TaxInfo"/> to use for calculating the customer's automatic tax.</param>
/// <param name="tokenizedPaymentSource">The <see cref="TokenizedPaymentSource"/> (ex. Credit Card) to attach to the customer.</param>
/// <param name="paymentMethod">The <see cref="TokenizedPaymentMethod"/> (e.g., Credit Card, Bank Account, or PayPal) to attach to the customer.</param>
/// <param name="billingAddress">The <see cref="BillingAddress"/> containing the customer's billing information including address and tax ID details.</param>
/// <returns>The newly created <see cref="Stripe.Customer"/> for the <paramref name="provider"/>.</returns>
Task<Customer> SetupCustomer(
Provider provider,
TaxInfo taxInfo,
TokenizedPaymentSource tokenizedPaymentSource);
TokenizedPaymentMethod paymentMethod,
BillingAddress billingAddress);
/// <summary>
/// For use during the provider setup process, this method starts a Stripe <see cref="Stripe.Subscription"/> for the given <paramref name="provider"/>.

View File

@@ -0,0 +1,92 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
using Bit.Core.Repositories;
using Bit.Core.Services;
using OneOf.Types;
using Stripe;
namespace Bit.Core.Billing.Subscriptions.Commands;
using static StripeConstants;
public interface IRestartSubscriptionCommand
{
Task<BillingCommandResult<None>> Run(
ISubscriber subscriber);
}
public class RestartSubscriptionCommand(
IOrganizationRepository organizationRepository,
IProviderRepository providerRepository,
IStripeAdapter stripeAdapter,
ISubscriberService subscriberService,
IUserRepository userRepository) : IRestartSubscriptionCommand
{
public async Task<BillingCommandResult<None>> Run(
ISubscriber subscriber)
{
var existingSubscription = await subscriberService.GetSubscription(subscriber);
if (existingSubscription is not { Status: SubscriptionStatus.Canceled })
{
return new BadRequest("Cannot restart a subscription that is not canceled.");
}
var options = new SubscriptionCreateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true },
CollectionMethod = CollectionMethod.ChargeAutomatically,
Customer = existingSubscription.CustomerId,
Items = existingSubscription.Items.Select(subscriptionItem => new SubscriptionItemOptions
{
Price = subscriptionItem.Price.Id,
Quantity = subscriptionItem.Quantity
}).ToList(),
Metadata = existingSubscription.Metadata,
OffSession = true,
TrialPeriodDays = 0
};
var subscription = await stripeAdapter.SubscriptionCreateAsync(options);
await EnableAsync(subscriber, subscription);
return new None();
}
private async Task EnableAsync(ISubscriber subscriber, Subscription subscription)
{
switch (subscriber)
{
case Organization organization:
{
organization.GatewaySubscriptionId = subscription.Id;
organization.Enabled = true;
organization.ExpirationDate = subscription.CurrentPeriodEnd;
organization.RevisionDate = DateTime.UtcNow;
await organizationRepository.ReplaceAsync(organization);
break;
}
case Provider provider:
{
provider.GatewaySubscriptionId = subscription.Id;
provider.Enabled = true;
provider.RevisionDate = DateTime.UtcNow;
await providerRepository.ReplaceAsync(provider);
break;
}
case User user:
{
user.GatewaySubscriptionId = subscription.Id;
user.Premium = true;
user.PremiumExpirationDate = subscription.CurrentPeriodEnd;
user.RevisionDate = DateTime.UtcNow;
await userRepository.ReplaceAsync(user);
break;
}
}
}
}

View File

@@ -1,136 +0,0 @@
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Services;
using Microsoft.Extensions.Logging;
using Stripe;
namespace Bit.Core.Billing.Tax.Commands;
public interface IPreviewTaxAmountCommand
{
Task<BillingCommandResult<decimal>> Run(OrganizationTrialParameters parameters);
}
public class PreviewTaxAmountCommand(
ILogger<PreviewTaxAmountCommand> logger,
IPricingClient pricingClient,
IStripeAdapter stripeAdapter,
ITaxService taxService) : BaseBillingCommand<PreviewTaxAmountCommand>(logger), IPreviewTaxAmountCommand
{
protected override Conflict DefaultConflict
=> new("We had a problem calculating your tax obligation. Please contact support for assistance.");
public Task<BillingCommandResult<decimal>> Run(OrganizationTrialParameters parameters)
=> HandleAsync<decimal>(async () =>
{
var (planType, productType, taxInformation) = parameters;
var plan = await pricingClient.GetPlanOrThrow(planType);
var options = new InvoiceCreatePreviewOptions
{
Currency = "usd",
CustomerDetails = new InvoiceCustomerDetailsOptions
{
Address = new AddressOptions
{
Country = taxInformation.Country,
PostalCode = taxInformation.PostalCode
}
},
SubscriptionDetails = new InvoiceSubscriptionDetailsOptions
{
Items =
[
new InvoiceSubscriptionDetailsItemOptions
{
Price = plan.HasNonSeatBasedPasswordManagerPlan()
? plan.PasswordManager.StripePlanId
: plan.PasswordManager.StripeSeatPlanId,
Quantity = 1
}
]
}
};
if (productType == ProductType.SecretsManager)
{
options.SubscriptionDetails.Items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = plan.SecretsManager.StripeSeatPlanId,
Quantity = 1
});
options.Coupon = StripeConstants.CouponIDs.SecretsManagerStandalone;
}
if (!string.IsNullOrEmpty(taxInformation.TaxId))
{
var taxIdType = taxService.GetStripeTaxCode(
taxInformation.Country,
taxInformation.TaxId);
if (string.IsNullOrEmpty(taxIdType))
{
return new BadRequest(
"We couldn't find a corresponding tax ID type for the tax ID you provided. Please try again or contact support for assistance.");
}
options.CustomerDetails.TaxIds =
[
new InvoiceCustomerDetailsTaxIdOptions { Type = taxIdType, Value = taxInformation.TaxId }
];
if (taxIdType == StripeConstants.TaxIdType.SpanishNIF)
{
options.CustomerDetails.TaxIds.Add(new InvoiceCustomerDetailsTaxIdOptions
{
Type = StripeConstants.TaxIdType.EUVAT,
Value = $"ES{parameters.TaxInformation.TaxId}"
});
}
}
options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true };
if (parameters.PlanType.IsBusinessProductTierType() &&
parameters.TaxInformation.Country != Core.Constants.CountryAbbreviations.UnitedStates)
{
options.CustomerDetails.TaxExempt = StripeConstants.TaxExempt.Reverse;
}
var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options);
return Convert.ToDecimal(invoice.Tax) / 100;
});
}
#region Command Parameters
public record OrganizationTrialParameters
{
public required PlanType PlanType { get; set; }
public required ProductType ProductType { get; set; }
public required TaxInformationDTO TaxInformation { get; set; }
public void Deconstruct(
out PlanType planType,
out ProductType productType,
out TaxInformationDTO taxInformation)
{
planType = PlanType;
productType = ProductType;
taxInformation = TaxInformation;
}
public record TaxInformationDTO
{
public required string Country { get; set; }
public required string PostalCode { get; set; }
public string? TaxId { get; set; }
}
}
#endregion

View File

@@ -70,6 +70,17 @@ public static class Constants
/// </summary>
public const string UnitedStates = "US";
}
/// <summary>
/// Constants for our browser extensions IDs
/// </summary>
public static class BrowserExtensions
{
public const string ChromeId = "chrome-extension://nngceckbapebfimnlniiiahkandclblb/";
public const string EdgeId = "chrome-extension://jbkfoedolllekgbhcbcoahefnbanhhlh/";
public const string OperaId = "chrome-extension://ccnckbpmaceehanjmeomladnmlffdjgn/";
}
}
public static class AuthConstants
@@ -124,8 +135,6 @@ public static class AuthenticationSchemes
public static class FeatureFlagKeys
{
/* Admin Console Team */
public const string VerifiedSsoDomainEndpoint = "pm-12337-refactor-sso-details-endpoint";
public const string LimitItemDeletion = "pm-15493-restrict-item-deletion-to-can-manage-permission";
public const string PolicyRequirements = "pm-14439-policy-requirements";
public const string ScimInviteUserOptimization = "pm-16811-optimize-invite-user-flow-to-fail-fast";
public const string EventBasedOrganizationIntegrations = "event-based-organization-integrations";
@@ -169,10 +178,11 @@ public static class FeatureFlagKeys
public const string PM17772_AdminInitiatedSponsorships = "pm-17772-admin-initiated-sponsorships";
public const string UsePricingService = "use-pricing-service";
public const string PM19422_AllowAutomaticTaxUpdates = "pm-19422-allow-automatic-tax-updates";
public const string PM21881_ManagePaymentDetailsOutsideCheckout = "pm-21881-manage-payment-details-outside-checkout";
public const string PM21821_ProviderPortalTakeover = "pm-21821-provider-portal-takeover";
public const string PM22415_TaxIDWarnings = "pm-22415-tax-id-warnings";
public const string PM23385_UseNewPremiumFlow = "pm-23385-use-new-premium-flow";
public const string PM24996ImplementUpgradeFromFreeDialog = "pm-24996-implement-upgrade-from-free-dialog";
public const string PM24032_NewNavigationPremiumUpgradeButton = "pm-24032-new-navigation-premium-upgrade-button";
public const string PM23713_PremiumBadgeOpensNewPremiumUpgradeDialog = "pm-23713-premium-badge-opens-new-premium-upgrade-dialog";
/* Key Management Team */
public const string ReturnErrorOnExistingKeypair = "return-error-on-existing-keypair";
@@ -222,7 +232,6 @@ public static class FeatureFlagKeys
/* Vault Team */
public const string PM8851_BrowserOnboardingNudge = "pm-8851-browser-onboarding-nudge";
public const string PM9111ExtensionPersistAddEditForm = "pm-9111-extension-persist-add-edit-form";
public const string SecurityTasks = "security-tasks";
public const string CipherKeyEncryption = "cipher-key-encryption";
public const string DesktopCipherForms = "pm-18520-desktop-cipher-forms";
public const string PM19941MigrateCipherDomainToSdk = "pm-19941-migrate-cipher-domain-to-sdk";

View File

@@ -21,8 +21,8 @@
<ItemGroup>
<PackageReference Include="AspNetCoreRateLimit.Redis" Version="2.0.0" />
<PackageReference Include="AWSSDK.SimpleEmail" Version="4.0.0.20" />
<PackageReference Include="AWSSDK.SQS" Version="4.0.1.1" />
<PackageReference Include="AWSSDK.SimpleEmail" Version="4.0.1.3" />
<PackageReference Include="AWSSDK.SQS" Version="4.0.1.5" />
<PackageReference Include="Azure.Data.Tables" Version="12.9.0" />
<PackageReference Include="Azure.Extensions.AspNetCore.DataProtection.Blobs" Version="1.3.4" />
<PackageReference Include="Microsoft.AspNetCore.DataProtection" Version="8.0.10" />
@@ -34,7 +34,7 @@
<PackageReference Include="DnsClient" Version="1.8.0" />
<PackageReference Include="Fido2.AspNet" Version="3.0.1" />
<PackageReference Include="Handlebars.Net" Version="2.1.6" />
<PackageReference Include="MailKit" Version="4.13.0" />
<PackageReference Include="MailKit" Version="4.14.0" />
<PackageReference Include="Microsoft.AspNetCore.Authentication.JwtBearer" Version="8.0.10" />
<PackageReference Include="Microsoft.Azure.Cosmos" Version="3.52.0" />
<PackageReference Include="Microsoft.Azure.NotificationHubs" Version="4.2.0" />

View File

@@ -315,7 +315,7 @@ public class UpdateSecretsManagerSubscriptionCommand : IUpdateSecretsManagerSubs
throw new BadRequestException($"Cannot set max Secrets Manager seat autoscaling below current Secrets Manager seat count.");
}
if (plan.SecretsManager.MaxSeats.HasValue && update.MaxAutoscaleSmSeats.Value > plan.SecretsManager.MaxSeats)
if (plan.SecretsManager.MaxSeats.HasValue && plan.SecretsManager.MaxSeats.Value > 0 && update.MaxAutoscaleSmSeats.Value > plan.SecretsManager.MaxSeats)
{
throw new BadRequestException(string.Concat(
$"Your plan has a Secrets Manager seat limit of {plan.SecretsManager.MaxSeats}, ",

View File

@@ -389,7 +389,7 @@
<value>If SAML Binding Type is set to artifact, identity provider resolution service URL is required.</value>
</data>
<data name="IdpSingleSignOnServiceUrlValidationError" xml:space="preserve">
<value>If Identity Provider Entity ID is not a URL, single sign on service URL is required.</value>
<value>Single sign on service URL is required.</value>
</data>
<data name="InvalidSchemeConfigurationError" xml:space="preserve">
<value>The configured authentication scheme is not valid: "{0}"</value>

View File

@@ -26,6 +26,7 @@ using Bit.Core.Vault.Models.Data;
using Core.Auth.Enums;
using HandlebarsDotNet;
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Logging;
namespace Bit.Core.Services;
@@ -39,6 +40,7 @@ public class HandlebarsMailService : IMailService
private readonly IMailDeliveryService _mailDeliveryService;
private readonly IMailEnqueuingService _mailEnqueuingService;
private readonly IDistributedCache _distributedCache;
private readonly ILogger<HandlebarsMailService> _logger;
private readonly Dictionary<string, HandlebarsTemplate<object, object>> _templateCache = new();
private bool _registeredHelpersAndPartials = false;
@@ -47,12 +49,14 @@ public class HandlebarsMailService : IMailService
GlobalSettings globalSettings,
IMailDeliveryService mailDeliveryService,
IMailEnqueuingService mailEnqueuingService,
IDistributedCache distributedCache)
IDistributedCache distributedCache,
ILogger<HandlebarsMailService> logger)
{
_globalSettings = globalSettings;
_mailDeliveryService = mailDeliveryService;
_mailEnqueuingService = mailEnqueuingService;
_distributedCache = distributedCache;
_logger = logger;
}
public async Task SendVerifyEmailEmailAsync(string email, Guid userId, string token)
@@ -708,6 +712,12 @@ public class HandlebarsMailService : IMailService
private async Task<string?> ReadSourceAsync(string templateName)
{
var diskSource = await ReadSourceFromDiskAsync(templateName);
if (!string.IsNullOrWhiteSpace(diskSource))
{
return diskSource;
}
var assembly = typeof(HandlebarsMailService).GetTypeInfo().Assembly;
var fullTemplateName = $"{Namespace}.{templateName}.hbs";
if (!assembly.GetManifestResourceNames().Any(f => f == fullTemplateName))
@@ -721,6 +731,42 @@ public class HandlebarsMailService : IMailService
}
}
private async Task<string?> ReadSourceFromDiskAsync(string templateName)
{
if (!_globalSettings.SelfHosted)
{
return null;
}
try
{
var templateFileSuffix = ".html";
if (templateName.EndsWith(".txt"))
{
templateFileSuffix = ".txt";
}
else if (!templateName.EndsWith(".html"))
{
// unexpected suffix
return null;
}
var suffixPosition = templateName.LastIndexOf(templateFileSuffix);
var templateNameNoSuffix = templateName.Substring(0, suffixPosition);
var templatePathNoSuffix = templateNameNoSuffix.Replace(".", "/");
var diskPath = $"{_globalSettings.MailTemplateDirectory}/{templatePathNoSuffix}{templateFileSuffix}.hbs";
var directory = Path.GetDirectoryName(diskPath);
if (Directory.Exists(directory) && File.Exists(diskPath))
{
var fileContents = await File.ReadAllTextAsync(diskPath);
return fileContents;
}
}
catch (Exception e)
{
_logger.LogError(e, "Failed to read mail template from disk.");
}
return null;
}
private async Task RegisterHelpersAndPartialsAsync()
{
if (_registeredHelpersAndPartials)

View File

@@ -3,6 +3,7 @@
using Bit.Core.Models.BitStripe;
using Stripe;
using Stripe.Tax;
namespace Bit.Core.Services;
@@ -23,6 +24,7 @@ public class StripeAdapter : IStripeAdapter
private readonly Stripe.TestHelpers.TestClockService _testClockService;
private readonly CustomerBalanceTransactionService _customerBalanceTransactionService;
private readonly Stripe.Tax.RegistrationService _taxRegistrationService;
private readonly CalculationService _calculationService;
public StripeAdapter()
{
@@ -41,6 +43,7 @@ public class StripeAdapter : IStripeAdapter
_testClockService = new Stripe.TestHelpers.TestClockService();
_customerBalanceTransactionService = new CustomerBalanceTransactionService();
_taxRegistrationService = new Stripe.Tax.RegistrationService();
_calculationService = new CalculationService();
}
public Task<Stripe.Customer> CustomerCreateAsync(Stripe.CustomerCreateOptions options)

View File

@@ -8,6 +8,7 @@ namespace Bit.Core.Settings;
public class GlobalSettings : IGlobalSettings
{
private string _mailTemplateDirectory;
private string _logDirectory;
private string _licenseDirectory;
@@ -37,6 +38,11 @@ public class GlobalSettings : IGlobalSettings
get => BuildDirectory(_licenseDirectory, "/core/licenses");
set => _licenseDirectory = value;
}
public virtual string MailTemplateDirectory
{
get => BuildDirectory(_mailTemplateDirectory, "/mail-templates");
set => _mailTemplateDirectory = value;
}
public string LicenseCertificatePassword { get; set; }
public virtual string PushRelayBaseUri { get; set; }
public virtual string InternalIdentityKey { get; set; }
@@ -97,6 +103,7 @@ public class GlobalSettings : IGlobalSettings
/// </summary>
public virtual string SendDefaultHashKey { get; set; }
public virtual string PricingUri { get; set; }
public virtual Fido2Settings Fido2 { get; set; } = new Fido2Settings();
public string BuildExternalUri(string explicitValue, string name)
{
@@ -301,6 +308,9 @@ public class GlobalSettings : IGlobalSettings
private string _eventTopicName;
private string _integrationTopicName;
public virtual int DefaultMaxConcurrentCalls { get; set; } = 1;
public virtual int DefaultPrefetchCount { get; set; } = 0;
public virtual string EventRepositorySubscriptionName { get; set; } = "events-write-subscription";
public virtual string SlackEventSubscriptionName { get; set; } = "events-slack-subscription";
public virtual string SlackIntegrationSubscriptionName { get; set; } = "integration-slack-subscription";
@@ -763,4 +773,9 @@ public class GlobalSettings : IGlobalSettings
{
public string VapidPublicKey { get; set; }
}
public class Fido2Settings
{
public HashSet<string> Origins { get; set; }
}
}

View File

@@ -230,6 +230,8 @@ public class EventRepository : Repository<Event, Guid>, IEventRepository
eventsTable.Columns.Add(serviceAccountIdColumn);
var projectIdColumn = new DataColumn(nameof(e.ProjectId), typeof(Guid));
eventsTable.Columns.Add(projectIdColumn);
var grantedServiceAccountIdColumn = new DataColumn(nameof(e.GrantedServiceAccountId), typeof(Guid));
eventsTable.Columns.Add(grantedServiceAccountIdColumn);
foreach (DataColumn col in eventsTable.Columns)
{
@@ -263,6 +265,7 @@ public class EventRepository : Repository<Event, Guid>, IEventRepository
row[secretIdColumn] = ev.SecretId.HasValue ? ev.SecretId.Value : DBNull.Value;
row[serviceAccountIdColumn] = ev.ServiceAccountId.HasValue ? ev.ServiceAccountId.Value : DBNull.Value;
row[projectIdColumn] = ev.ProjectId.HasValue ? ev.ProjectId.Value : DBNull.Value;
row[grantedServiceAccountIdColumn] = ev.GrantedServiceAccountId.HasValue ? ev.GrantedServiceAccountId.Value : DBNull.Value;
eventsTable.Rows.Add(row);
}

View File

@@ -12,9 +12,16 @@ public class EventEntityTypeConfiguration : IEntityTypeConfiguration<Event>
.Property(e => e.Id)
.ValueGeneratedNever();
builder
.HasIndex(e => new { e.Date, e.OrganizationId, e.ActingUserId, e.CipherId })
.IsClustered(false);
builder.HasKey(e => e.Id)
.IsClustered();
var index = builder.HasIndex(e => new { e.Date, e.OrganizationId, e.ActingUserId, e.CipherId })
.IsClustered(false)
.HasDatabaseName("IX_Event_DateOrganizationIdUserId");
SqlServerIndexBuilderExtensions.IncludeProperties(
index,
e => new { e.ServiceAccountId, e.GrantedServiceAccountId });
builder.ToTable(nameof(Event));
}

View File

@@ -30,7 +30,7 @@ public class EventReadPageByOrganizationIdServiceAccountIdQuery : IQuery<Event>
(_beforeDate != null || e.Date <= _endDate) &&
(_beforeDate == null || e.Date < _beforeDate.Value) &&
e.OrganizationId == _organizationId &&
e.ServiceAccountId == _serviceAccountId
(e.ServiceAccountId == _serviceAccountId || e.GrantedServiceAccountId == _serviceAccountId)
orderby e.Date descending
select e;
return q.Skip(0).Take(_pageOptions.PageSize);

View File

@@ -0,0 +1,48 @@
using Bit.Core.Models.Data;
using Bit.Core.SecretsManager.Entities;
using Event = Bit.Infrastructure.EntityFramework.Models.Event;
namespace Bit.Infrastructure.EntityFramework.Repositories.Queries;
public class EventReadPageByServiceAccountQuery : IQuery<Event>
{
private readonly ServiceAccount _serviceAccount;
private readonly DateTime _startDate;
private readonly DateTime _endDate;
private readonly DateTime? _beforeDate;
private readonly PageOptions _pageOptions;
public EventReadPageByServiceAccountQuery(ServiceAccount serviceAccount, DateTime startDate, DateTime endDate, PageOptions pageOptions)
{
_serviceAccount = serviceAccount;
_startDate = startDate;
_endDate = endDate;
_beforeDate = null;
_pageOptions = pageOptions;
}
public EventReadPageByServiceAccountQuery(ServiceAccount serviceAccount, DateTime startDate, DateTime endDate, DateTime? beforeDate, PageOptions pageOptions)
{
_serviceAccount = serviceAccount;
_startDate = startDate;
_endDate = endDate;
_beforeDate = beforeDate;
_pageOptions = pageOptions;
}
public IQueryable<Event> Run(DatabaseContext dbContext)
{
var q = from e in dbContext.Events
where e.Date >= _startDate &&
(_beforeDate == null || e.Date < _beforeDate.Value) &&
(
(_serviceAccount.OrganizationId == Guid.Empty && !e.OrganizationId.HasValue) ||
(_serviceAccount.OrganizationId != Guid.Empty && e.OrganizationId == _serviceAccount.OrganizationId)
) &&
e.GrantedServiceAccountId == _serviceAccount.Id
orderby e.Date descending
select e;
return q.Take(_pageOptions.PageSize);
}
}

View File

@@ -283,6 +283,9 @@ public class UserRepository : Repository<Core.Entities.User, User, Guid>, IUserR
var transaction = await dbContext.Database.BeginTransactionAsync();
MigrateDefaultUserCollectionsToShared(dbContext, [user.Id]);
await dbContext.SaveChangesAsync();
dbContext.WebAuthnCredentials.RemoveRange(dbContext.WebAuthnCredentials.Where(w => w.UserId == user.Id));
dbContext.Ciphers.RemoveRange(dbContext.Ciphers.Where(c => c.UserId == user.Id));
dbContext.Folders.RemoveRange(dbContext.Folders.Where(f => f.UserId == user.Id));
@@ -314,8 +317,8 @@ public class UserRepository : Repository<Core.Entities.User, User, Guid>, IUserR
var mappedUser = Mapper.Map<User>(user);
dbContext.Users.Remove(mappedUser);
await transaction.CommitAsync();
await dbContext.SaveChangesAsync();
await transaction.CommitAsync();
}
}
@@ -329,21 +332,30 @@ public class UserRepository : Repository<Core.Entities.User, User, Guid>, IUserR
var targetIds = users.Select(u => u.Id).ToList();
MigrateDefaultUserCollectionsToShared(dbContext, targetIds);
await dbContext.SaveChangesAsync();
await dbContext.WebAuthnCredentials.Where(wa => targetIds.Contains(wa.UserId)).ExecuteDeleteAsync();
await dbContext.Ciphers.Where(c => targetIds.Contains(c.UserId ?? default)).ExecuteDeleteAsync();
await dbContext.Folders.Where(f => targetIds.Contains(f.UserId)).ExecuteDeleteAsync();
await dbContext.AuthRequests.Where(a => targetIds.Contains(a.UserId)).ExecuteDeleteAsync();
await dbContext.Devices.Where(d => targetIds.Contains(d.UserId)).ExecuteDeleteAsync();
var collectionUsers = from cu in dbContext.CollectionUsers
join ou in dbContext.OrganizationUsers on cu.OrganizationUserId equals ou.Id
where targetIds.Contains(ou.UserId ?? default)
select cu;
dbContext.CollectionUsers.RemoveRange(collectionUsers);
var groupUsers = from gu in dbContext.GroupUsers
join ou in dbContext.OrganizationUsers on gu.OrganizationUserId equals ou.Id
where targetIds.Contains(ou.UserId ?? default)
select gu;
dbContext.GroupUsers.RemoveRange(groupUsers);
await dbContext.CollectionUsers
.Join(dbContext.OrganizationUsers,
cu => cu.OrganizationUserId,
ou => ou.Id,
(cu, ou) => new { CollectionUser = cu, OrganizationUser = ou })
.Where((joined) => targetIds.Contains(joined.OrganizationUser.UserId ?? default))
.Select(joined => joined.CollectionUser)
.ExecuteDeleteAsync();
await dbContext.GroupUsers
.Join(dbContext.OrganizationUsers,
gu => gu.OrganizationUserId,
ou => ou.Id,
(gu, ou) => new { GroupUser = gu, OrganizationUser = ou })
.Where(joined => targetIds.Contains(joined.OrganizationUser.UserId ?? default))
.Select(joined => joined.GroupUser)
.ExecuteDeleteAsync();
await dbContext.UserProjectAccessPolicy.Where(ap => targetIds.Contains(ap.OrganizationUser.UserId ?? default)).ExecuteDeleteAsync();
await dbContext.UserServiceAccountAccessPolicy.Where(ap => targetIds.Contains(ap.OrganizationUser.UserId ?? default)).ExecuteDeleteAsync();
await dbContext.OrganizationUsers.Where(ou => targetIds.Contains(ou.UserId ?? default)).ExecuteDeleteAsync();
@@ -354,15 +366,29 @@ public class UserRepository : Repository<Core.Entities.User, User, Guid>, IUserR
await dbContext.NotificationStatuses.Where(ns => targetIds.Contains(ns.UserId)).ExecuteDeleteAsync();
await dbContext.Notifications.Where(n => targetIds.Contains(n.UserId ?? default)).ExecuteDeleteAsync();
foreach (var u in users)
{
var mappedUser = Mapper.Map<User>(u);
dbContext.Users.Remove(mappedUser);
}
await dbContext.Users.Where(u => targetIds.Contains(u.Id)).ExecuteDeleteAsync();
await transaction.CommitAsync();
await dbContext.SaveChangesAsync();
await transaction.CommitAsync();
}
}
private static void MigrateDefaultUserCollectionsToShared(DatabaseContext dbContext, IEnumerable<Guid> userIds)
{
var defaultCollections = (from c in dbContext.Collections
join cu in dbContext.CollectionUsers on c.Id equals cu.CollectionId
join ou in dbContext.OrganizationUsers on cu.OrganizationUserId equals ou.Id
join u in dbContext.Users on ou.UserId equals u.Id
where userIds.Contains(ou.UserId!.Value)
&& c.Type == Core.Enums.CollectionType.DefaultUserCollection
select new { Collection = c, UserEmail = u.Email })
.ToList();
foreach (var item in defaultCollections)
{
item.Collection.Type = Core.Enums.CollectionType.SharedCollection;
item.Collection.DefaultUserCollectionEmail = item.Collection.DefaultUserCollectionEmail ?? item.UserEmail;
item.Collection.RevisionDate = DateTime.UtcNow;
}
}
}

View File

@@ -6,6 +6,8 @@ using System.Reflection;
using System.Security.Claims;
using System.Security.Cryptography.X509Certificates;
using AspNetCoreRateLimit;
using Azure.Messaging.ServiceBus;
using Bit.Core;
using Bit.Core.AdminConsole.AbilitiesCache;
using Bit.Core.AdminConsole.Models.Business.Tokenables;
using Bit.Core.AdminConsole.Models.Data.EventIntegrations;
@@ -694,8 +696,23 @@ public static class ServiceCollectionExtensions
{
options.ServerDomain = new Uri(globalSettings.BaseServiceUri.Vault).Host;
options.ServerName = "Bitwarden";
options.Origins = new HashSet<string> { globalSettings.BaseServiceUri.Vault, };
options.TimestampDriftTolerance = 300000;
if (globalSettings.Fido2?.Origins?.Any() == true)
{
options.Origins = new HashSet<string>(globalSettings.Fido2.Origins);
}
else
{
// Default to allowing the vault domain and chromium browser extension IDs
options.Origins = new HashSet<string> {
globalSettings.BaseServiceUri.Vault,
Constants.BrowserExtensions.ChromeId,
Constants.BrowserExtensions.EdgeId,
Constants.BrowserExtensions.OperaId
};
}
});
}
@@ -855,6 +872,11 @@ public static class ServiceCollectionExtensions
configuration: listenerConfiguration,
handler: provider.GetRequiredKeyedService<IEventMessageHandler>(serviceKey: listenerConfiguration.RoutingKey),
serviceBusService: provider.GetRequiredService<IAzureServiceBusService>(),
serviceBusOptions: new ServiceBusProcessorOptions()
{
PrefetchCount = listenerConfiguration.EventPrefetchCount,
MaxConcurrentCalls = listenerConfiguration.EventMaxConcurrentCalls
},
loggerFactory: provider.GetRequiredService<ILoggerFactory>()
)
)
@@ -865,6 +887,11 @@ public static class ServiceCollectionExtensions
configuration: listenerConfiguration,
handler: provider.GetRequiredService<IIntegrationHandler<TConfig>>(),
serviceBusService: provider.GetRequiredService<IAzureServiceBusService>(),
serviceBusOptions: new ServiceBusProcessorOptions()
{
PrefetchCount = listenerConfiguration.IntegrationPrefetchCount,
MaxConcurrentCalls = listenerConfiguration.IntegrationMaxConcurrentCalls
},
loggerFactory: provider.GetRequiredService<ILoggerFactory>()
)
)
@@ -927,6 +954,11 @@ public static class ServiceCollectionExtensions
configuration: repositoryConfiguration,
handler: provider.GetRequiredService<AzureTableStorageEventHandler>(),
serviceBusService: provider.GetRequiredService<IAzureServiceBusService>(),
serviceBusOptions: new ServiceBusProcessorOptions()
{
PrefetchCount = repositoryConfiguration.EventPrefetchCount,
MaxConcurrentCalls = repositoryConfiguration.EventMaxConcurrentCalls
},
loggerFactory: provider.GetRequiredService<ILoggerFactory>()
)
)

View File

@@ -18,7 +18,7 @@ BEGIN
AND (@BeforeDate IS NOT NULL OR [Date] <= @EndDate)
AND (@BeforeDate IS NULL OR [Date] < @BeforeDate)
AND [OrganizationId] = @OrganizationId
AND [ServiceAccountId] = @ServiceAccountId
AND ([ServiceAccountId] = @ServiceAccountId OR [GrantedServiceAccountId] = @ServiceAccountId)
ORDER BY [Date] DESC
OFFSET 0 ROWS
FETCH NEXT @PageSize ROWS ONLY

View File

@@ -0,0 +1,45 @@
CREATE PROCEDURE [dbo].[Event_ReadPageByServiceAccountId]
@GrantedServiceAccountId UNIQUEIDENTIFIER,
@StartDate DATETIME2(7),
@EndDate DATETIME2(7),
@BeforeDate DATETIME2(7),
@PageSize INT
AS
BEGIN
SET NOCOUNT ON
SELECT
e.Id,
e.Date,
e.Type,
e.UserId,
e.OrganizationId,
e.InstallationId,
e.ProviderId,
e.CipherId,
e.CollectionId,
e.PolicyId,
e.GroupId,
e.OrganizationUserId,
e.ProviderUserId,
e.ProviderOrganizationId,
e.DeviceType,
e.IpAddress,
e.ActingUserId,
e.SystemUser,
e.DomainName,
e.SecretId,
e.ServiceAccountId,
e.ProjectId,
e.GrantedServiceAccountId
FROM
[dbo].[EventView] e
WHERE
[Date] >= @StartDate
AND (@BeforeDate IS NOT NULL OR [Date] <= @EndDate)
AND (@BeforeDate IS NULL OR [Date] < @BeforeDate)
AND [GrantedServiceAccountId] = @GrantedServiceAccountId
ORDER BY [Date] DESC
OFFSET 0 ROWS
FETCH NEXT @PageSize ROWS ONLY
END

View File

@@ -20,7 +20,8 @@
@DomainName VARCHAR(256),
@SecretId UNIQUEIDENTIFIER = null,
@ServiceAccountId UNIQUEIDENTIFIER = null,
@ProjectId UNIQUEIDENTIFIER = null
@ProjectId UNIQUEIDENTIFIER = null,
@GrantedServiceAccountId UNIQUEIDENTIFIER = null
AS
BEGIN
SET NOCOUNT ON
@@ -48,7 +49,8 @@ BEGIN
[DomainName],
[SecretId],
[ServiceAccountId],
[ProjectId]
[ProjectId],
[GrantedServiceAccountId]
)
VALUES
(
@@ -73,6 +75,7 @@ BEGIN
@DomainName,
@SecretId,
@ServiceAccountId,
@ProjectId
@ProjectId,
@GrantedServiceAccountId
)
END

View File

@@ -52,6 +52,16 @@ BEGIN
WHERE
[UserId] = @Id
-- Migrate DefaultUserCollection to SharedCollection before deleting CollectionUser records
DECLARE @OrgUserIds [dbo].[GuidIdArray]
INSERT INTO @OrgUserIds (Id)
SELECT [Id] FROM [dbo].[OrganizationUser] WHERE [UserId] = @Id
IF EXISTS (SELECT 1 FROM @OrgUserIds)
BEGIN
EXEC [dbo].[OrganizationUser_MigrateDefaultCollection] @OrgUserIds
END
-- Delete collection users
DELETE
CU

View File

@@ -66,6 +66,16 @@ BEGIN
WHERE
[UserId] IN (SELECT * FROM @ParsedIds)
-- Migrate DefaultUserCollection to SharedCollection before deleting CollectionUser records
DECLARE @OrgUserIds [dbo].[GuidIdArray]
INSERT INTO @OrgUserIds (Id)
SELECT [Id] FROM [dbo].[OrganizationUser] WHERE [UserId] IN (SELECT * FROM @ParsedIds)
IF EXISTS (SELECT 1 FROM @OrgUserIds)
BEGIN
EXEC [dbo].[OrganizationUser_MigrateDefaultCollection] @OrgUserIds
END
-- Delete collection users
DELETE
CU

View File

@@ -21,11 +21,12 @@
[SecretId] UNIQUEIDENTIFIER NULL,
[ServiceAccountId] UNIQUEIDENTIFIER NULL,
[ProjectId] UNIQUEIDENTIFIER NULL,
[GrantedServiceAccountId] UNIQUEIDENTIFIER NULL,
CONSTRAINT [PK_Event] PRIMARY KEY CLUSTERED ([Id] ASC)
);
GO
CREATE NONCLUSTERED INDEX [IX_Event_DateOrganizationIdUserId]
ON [dbo].[Event]([Date] DESC, [OrganizationId] ASC, [ActingUserId] ASC, [CipherId] ASC);
ON [dbo].[Event]([Date] DESC, [OrganizationId] ASC, [ActingUserId] ASC, [CipherId] ASC) INCLUDE ([ServiceAccountId], [GrantedServiceAccountId]);

View File

@@ -1,12 +1,18 @@
using Bit.Api.AdminConsole.Controllers;
#nullable enable
using Bit.Api.AdminConsole.Controllers;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Models.Data.EventIntegrations;
using Bit.Core.Context;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.Mvc.Routing;
using Microsoft.Extensions.Time.Testing;
using NSubstitute;
using Xunit;
@@ -16,98 +22,312 @@ namespace Bit.Api.Test.AdminConsole.Controllers;
[SutProviderCustomize]
public class SlackIntegrationControllerTests
{
private const string _slackToken = "xoxb-test-token";
private const string _validSlackCode = "A_test_code";
[Theory, BitAutoData]
public async Task CreateAsync_AllParamsProvided_Succeeds(SutProvider<SlackIntegrationController> sutProvider, Guid organizationId)
public async Task CreateAsync_AllParamsProvided_Succeeds(
SutProvider<SlackIntegrationController> sutProvider,
OrganizationIntegration integration)
{
var token = "xoxb-test-token";
integration.Type = IntegrationType.Slack;
integration.Configuration = null;
sutProvider.Sut.Url = Substitute.For<IUrlHelper>();
sutProvider.GetDependency<ICurrentContext>()
.OrganizationOwner(organizationId)
.Returns(true);
sutProvider.Sut.Url
.RouteUrl(Arg.Is<UrlRouteContext>(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync)))
.Returns("https://localhost");
sutProvider.GetDependency<ISlackService>()
.ObtainTokenViaOAuth(Arg.Any<string>(), Arg.Any<string>())
.Returns(token);
.ObtainTokenViaOAuth(_validSlackCode, Arg.Any<string>())
.Returns(_slackToken);
sutProvider.GetDependency<IOrganizationIntegrationRepository>()
.CreateAsync(Arg.Any<OrganizationIntegration>())
.Returns(callInfo => callInfo.Arg<OrganizationIntegration>());
var requestAction = await sutProvider.Sut.CreateAsync(organizationId, "A_test_code");
.GetByIdAsync(integration.Id)
.Returns(integration);
var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency<TimeProvider>());
var requestAction = await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString());
await sutProvider.GetDependency<IOrganizationIntegrationRepository>().Received(1)
.CreateAsync(Arg.Any<OrganizationIntegration>());
.UpsertAsync(Arg.Any<OrganizationIntegration>());
Assert.IsType<CreatedResult>(requestAction);
}
[Theory, BitAutoData]
public async Task CreateAsync_CodeIsEmpty_ThrowsBadRequest(SutProvider<SlackIntegrationController> sutProvider, Guid organizationId)
public async Task CreateAsync_CodeIsEmpty_ThrowsBadRequest(
SutProvider<SlackIntegrationController> sutProvider,
OrganizationIntegration integration)
{
integration.Type = IntegrationType.Slack;
integration.Configuration = null;
sutProvider.Sut.Url = Substitute.For<IUrlHelper>();
sutProvider.GetDependency<ICurrentContext>()
.OrganizationOwner(organizationId)
.Returns(true);
sutProvider.Sut.Url
.RouteUrl(Arg.Is<UrlRouteContext>(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync)))
.Returns("https://localhost");
sutProvider.GetDependency<IOrganizationIntegrationRepository>()
.GetByIdAsync(integration.Id)
.Returns(integration);
var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency<TimeProvider>());
await Assert.ThrowsAsync<BadRequestException>(async () => await sutProvider.Sut.CreateAsync(organizationId, string.Empty));
await Assert.ThrowsAsync<BadRequestException>(async () =>
await sutProvider.Sut.CreateAsync(string.Empty, state.ToString()));
}
[Theory, BitAutoData]
public async Task CreateAsync_SlackServiceReturnsEmpty_ThrowsBadRequest(SutProvider<SlackIntegrationController> sutProvider, Guid organizationId)
public async Task CreateAsync_SlackServiceReturnsEmpty_ThrowsBadRequest(
SutProvider<SlackIntegrationController> sutProvider,
OrganizationIntegration integration)
{
integration.Type = IntegrationType.Slack;
integration.Configuration = null;
sutProvider.Sut.Url = Substitute.For<IUrlHelper>();
sutProvider.GetDependency<ICurrentContext>()
.OrganizationOwner(organizationId)
.Returns(true);
sutProvider.Sut.Url
.RouteUrl(Arg.Is<UrlRouteContext>(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync)))
.Returns("https://localhost");
sutProvider.GetDependency<IOrganizationIntegrationRepository>()
.GetByIdAsync(integration.Id)
.Returns(integration);
sutProvider.GetDependency<ISlackService>()
.ObtainTokenViaOAuth(Arg.Any<string>(), Arg.Any<string>())
.ObtainTokenViaOAuth(_validSlackCode, Arg.Any<string>())
.Returns(string.Empty);
var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency<TimeProvider>());
await Assert.ThrowsAsync<BadRequestException>(async () => await sutProvider.Sut.CreateAsync(organizationId, "A_test_code"));
await Assert.ThrowsAsync<BadRequestException>(async () => await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString()));
}
[Theory, BitAutoData]
public async Task CreateAsync_UserIsNotOrganizationAdmin_ThrowsNotFound(SutProvider<SlackIntegrationController> sutProvider, Guid organizationId)
public async Task CreateAsync_StateEmpty_ThrowsNotFound(
SutProvider<SlackIntegrationController> sutProvider)
{
var token = "xoxb-test-token";
sutProvider.Sut.Url = Substitute.For<IUrlHelper>();
sutProvider.GetDependency<ICurrentContext>()
.OrganizationOwner(organizationId)
.Returns(false);
sutProvider.Sut.Url
.RouteUrl(Arg.Is<UrlRouteContext>(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync)))
.Returns("https://localhost");
sutProvider.GetDependency<ISlackService>()
.ObtainTokenViaOAuth(Arg.Any<string>(), Arg.Any<string>())
.Returns(token);
.ObtainTokenViaOAuth(_validSlackCode, Arg.Any<string>())
.Returns(_slackToken);
await Assert.ThrowsAsync<NotFoundException>(async () => await sutProvider.Sut.CreateAsync(organizationId, "A_test_code"));
await Assert.ThrowsAsync<NotFoundException>(async () => await sutProvider.Sut.CreateAsync(_validSlackCode, String.Empty));
}
[Theory, BitAutoData]
public async Task RedirectAsync_Success(SutProvider<SlackIntegrationController> sutProvider, Guid organizationId)
public async Task CreateAsync_StateExpired_ThrowsNotFound(
SutProvider<SlackIntegrationController> sutProvider,
OrganizationIntegration integration)
{
var expectedUrl = $"https://localhost/{organizationId}";
var timeProvider = new FakeTimeProvider(new DateTime(2024, 4, 3, 2, 1, 0, DateTimeKind.Utc));
sutProvider.Sut.Url = Substitute.For<IUrlHelper>();
sutProvider.Sut.Url
.RouteUrl(Arg.Is<UrlRouteContext>(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync)))
.Returns("https://localhost");
sutProvider.GetDependency<ISlackService>()
.ObtainTokenViaOAuth(_validSlackCode, Arg.Any<string>())
.Returns(_slackToken);
var state = IntegrationOAuthState.FromIntegration(integration, timeProvider);
timeProvider.Advance(TimeSpan.FromMinutes(30));
sutProvider.SetDependency<TimeProvider>(timeProvider);
await Assert.ThrowsAsync<NotFoundException>(async () => await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString()));
}
[Theory, BitAutoData]
public async Task CreateAsync_StateHasNonexistentIntegration_ThrowsNotFound(
SutProvider<SlackIntegrationController> sutProvider,
OrganizationIntegration integration)
{
sutProvider.Sut.Url = Substitute.For<IUrlHelper>();
sutProvider.Sut.Url
.RouteUrl(Arg.Is<UrlRouteContext>(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync)))
.Returns("https://localhost");
sutProvider.GetDependency<ISlackService>()
.ObtainTokenViaOAuth(_validSlackCode, Arg.Any<string>())
.Returns(_slackToken);
var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency<TimeProvider>());
await Assert.ThrowsAsync<NotFoundException>(async () => await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString()));
}
[Theory, BitAutoData]
public async Task CreateAsync_StateHasWrongOgranizationHash_ThrowsNotFound(
SutProvider<SlackIntegrationController> sutProvider,
OrganizationIntegration integration,
OrganizationIntegration wrongOrgIntegration)
{
wrongOrgIntegration.Id = integration.Id;
sutProvider.Sut.Url = Substitute.For<IUrlHelper>();
sutProvider.GetDependency<ISlackService>().GetRedirectUrl(Arg.Any<string>()).Returns(expectedUrl);
sutProvider.Sut.Url
.RouteUrl(Arg.Is<UrlRouteContext>(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync)))
.Returns("https://localhost");
sutProvider.GetDependency<ISlackService>()
.ObtainTokenViaOAuth(_validSlackCode, Arg.Any<string>())
.Returns(_slackToken);
sutProvider.GetDependency<IOrganizationIntegrationRepository>()
.GetByIdAsync(integration.Id)
.Returns(wrongOrgIntegration);
var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency<TimeProvider>());
await Assert.ThrowsAsync<NotFoundException>(async () => await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString()));
}
[Theory, BitAutoData]
public async Task CreateAsync_StateHasNonEmptyIntegration_ThrowsNotFound(
SutProvider<SlackIntegrationController> sutProvider,
OrganizationIntegration integration)
{
integration.Type = IntegrationType.Slack;
integration.Configuration = "{}";
sutProvider.Sut.Url = Substitute.For<IUrlHelper>();
sutProvider.Sut.Url
.RouteUrl(Arg.Is<UrlRouteContext>(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync)))
.Returns("https://localhost");
sutProvider.GetDependency<ISlackService>()
.ObtainTokenViaOAuth(_validSlackCode, Arg.Any<string>())
.Returns(_slackToken);
sutProvider.GetDependency<IOrganizationIntegrationRepository>()
.GetByIdAsync(integration.Id)
.Returns(integration);
var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency<TimeProvider>());
await Assert.ThrowsAsync<NotFoundException>(async () => await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString()));
}
[Theory, BitAutoData]
public async Task CreateAsync_StateHasNonSlackIntegration_ThrowsNotFound(
SutProvider<SlackIntegrationController> sutProvider,
OrganizationIntegration integration)
{
integration.Type = IntegrationType.Hec;
integration.Configuration = null;
sutProvider.Sut.Url = Substitute.For<IUrlHelper>();
sutProvider.Sut.Url
.RouteUrl(Arg.Is<UrlRouteContext>(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync)))
.Returns("https://localhost");
sutProvider.GetDependency<ISlackService>()
.ObtainTokenViaOAuth(_validSlackCode, Arg.Any<string>())
.Returns(_slackToken);
sutProvider.GetDependency<IOrganizationIntegrationRepository>()
.GetByIdAsync(integration.Id)
.Returns(integration);
var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency<TimeProvider>());
await Assert.ThrowsAsync<NotFoundException>(async () => await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString()));
}
[Theory, BitAutoData]
public async Task RedirectAsync_Success(
SutProvider<SlackIntegrationController> sutProvider,
OrganizationIntegration integration)
{
integration.Configuration = null;
var expectedUrl = "https://localhost/";
sutProvider.Sut.Url = Substitute.For<IUrlHelper>();
sutProvider.Sut.Url
.RouteUrl(Arg.Is<UrlRouteContext>(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync)))
.Returns(expectedUrl);
sutProvider.GetDependency<ICurrentContext>()
.OrganizationOwner(integration.OrganizationId)
.Returns(true);
sutProvider.GetDependency<IOrganizationIntegrationRepository>()
.GetManyByOrganizationAsync(integration.OrganizationId)
.Returns([]);
sutProvider.GetDependency<IOrganizationIntegrationRepository>()
.CreateAsync(Arg.Any<OrganizationIntegration>())
.Returns(integration);
sutProvider.GetDependency<ISlackService>().GetRedirectUrl(Arg.Any<string>(), Arg.Any<string>()).Returns(expectedUrl);
var expectedState = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency<TimeProvider>());
var requestAction = await sutProvider.Sut.RedirectAsync(integration.OrganizationId);
Assert.IsType<RedirectResult>(requestAction);
await sutProvider.GetDependency<IOrganizationIntegrationRepository>().Received(1)
.CreateAsync(Arg.Any<OrganizationIntegration>());
sutProvider.GetDependency<ISlackService>().Received(1).GetRedirectUrl(Arg.Any<string>(), expectedState.ToString());
}
[Theory, BitAutoData]
public async Task RedirectAsync_IntegrationAlreadyExistsWithNullConfig_Success(
SutProvider<SlackIntegrationController> sutProvider,
Guid organizationId,
OrganizationIntegration integration)
{
integration.OrganizationId = organizationId;
integration.Configuration = null;
integration.Type = IntegrationType.Slack;
var expectedUrl = "https://localhost/";
sutProvider.Sut.Url = Substitute.For<IUrlHelper>();
sutProvider.Sut.Url
.RouteUrl(Arg.Is<UrlRouteContext>(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync)))
.Returns(expectedUrl);
sutProvider.GetDependency<ICurrentContext>()
.OrganizationOwner(organizationId)
.Returns(true);
sutProvider.GetDependency<ICurrentContext>()
.HttpContext.Request.Scheme
.Returns("https");
sutProvider.GetDependency<IOrganizationIntegrationRepository>()
.GetManyByOrganizationAsync(organizationId)
.Returns([integration]);
sutProvider.GetDependency<ISlackService>().GetRedirectUrl(Arg.Any<string>(), Arg.Any<string>()).Returns(expectedUrl);
var requestAction = await sutProvider.Sut.RedirectAsync(organizationId);
var redirectResult = Assert.IsType<RedirectResult>(requestAction);
Assert.Equal(expectedUrl, redirectResult.Url);
var expectedState = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency<TimeProvider>());
Assert.IsType<RedirectResult>(requestAction);
sutProvider.GetDependency<ISlackService>().Received(1).GetRedirectUrl(Arg.Any<string>(), expectedState.ToString());
}
[Theory, BitAutoData]
public async Task RedirectAsync_SlackServiceReturnsEmpty_ThrowsNotFound(SutProvider<SlackIntegrationController> sutProvider, Guid organizationId)
public async Task RedirectAsync_IntegrationAlreadyExistsWithConfig_ThrowsBadRequest(
SutProvider<SlackIntegrationController> sutProvider,
Guid organizationId,
OrganizationIntegration integration)
{
integration.OrganizationId = organizationId;
integration.Configuration = "{}";
integration.Type = IntegrationType.Slack;
var expectedUrl = "https://localhost/";
sutProvider.Sut.Url = Substitute.For<IUrlHelper>();
sutProvider.GetDependency<ISlackService>().GetRedirectUrl(Arg.Any<string>()).Returns(string.Empty);
sutProvider.Sut.Url
.RouteUrl(Arg.Is<UrlRouteContext>(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync)))
.Returns(expectedUrl);
sutProvider.GetDependency<ICurrentContext>()
.OrganizationOwner(organizationId)
.Returns(true);
sutProvider.GetDependency<IOrganizationIntegrationRepository>()
.GetManyByOrganizationAsync(organizationId)
.Returns([integration]);
sutProvider.GetDependency<ISlackService>().GetRedirectUrl(Arg.Any<string>(), Arg.Any<string>()).Returns(expectedUrl);
await Assert.ThrowsAsync<BadRequestException>(async () => await sutProvider.Sut.RedirectAsync(organizationId));
}
[Theory, BitAutoData]
public async Task RedirectAsync_SlackServiceReturnsEmpty_ThrowsNotFound(
SutProvider<SlackIntegrationController> sutProvider,
Guid organizationId,
OrganizationIntegration integration)
{
integration.OrganizationId = organizationId;
integration.Configuration = null;
var expectedUrl = "https://localhost/";
sutProvider.Sut.Url = Substitute.For<IUrlHelper>();
sutProvider.Sut.Url
.RouteUrl(Arg.Is<UrlRouteContext>(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync)))
.Returns(expectedUrl);
sutProvider.GetDependency<ICurrentContext>()
.HttpContext.Request.Scheme
.Returns("https");
.OrganizationOwner(organizationId)
.Returns(true);
sutProvider.GetDependency<IOrganizationIntegrationRepository>()
.GetManyByOrganizationAsync(organizationId)
.Returns([]);
sutProvider.GetDependency<IOrganizationIntegrationRepository>()
.CreateAsync(Arg.Any<OrganizationIntegration>())
.Returns(integration);
sutProvider.GetDependency<ISlackService>().GetRedirectUrl(Arg.Any<string>(), Arg.Any<string>()).Returns(string.Empty);
await Assert.ThrowsAsync<NotFoundException>(async () => await sutProvider.Sut.RedirectAsync(organizationId));
}
@@ -116,14 +336,9 @@ public class SlackIntegrationControllerTests
public async Task RedirectAsync_UserIsNotOrganizationAdmin_ThrowsNotFound(SutProvider<SlackIntegrationController> sutProvider,
Guid organizationId)
{
sutProvider.Sut.Url = Substitute.For<IUrlHelper>();
sutProvider.GetDependency<ISlackService>().GetRedirectUrl(Arg.Any<string>()).Returns(string.Empty);
sutProvider.GetDependency<ICurrentContext>()
.OrganizationOwner(organizationId)
.Returns(false);
sutProvider.GetDependency<ICurrentContext>()
.HttpContext.Request.Scheme
.Returns("https");
await Assert.ThrowsAsync<NotFoundException>(async () => await sutProvider.Sut.RedirectAsync(organizationId));
}

View File

@@ -0,0 +1,117 @@
#nullable enable
using Bit.Api.AdminConsole.Models.Response.Organizations;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Enums;
using Bit.Test.Common.AutoFixture.Attributes;
using Xunit;
namespace Bit.Api.Test.AdminConsole.Models.Response.Organizations;
public class OrganizationIntegrationResponseModelTests
{
[Theory, BitAutoData]
public void Status_CloudBillingSync_AlwaysNotApplicable(OrganizationIntegration oi)
{
oi.Type = IntegrationType.CloudBillingSync;
oi.Configuration = null;
var model = new OrganizationIntegrationResponseModel(oi);
Assert.Equal(OrganizationIntegrationStatus.NotApplicable, model.Status);
model.Configuration = "{}";
Assert.Equal(OrganizationIntegrationStatus.NotApplicable, model.Status);
}
[Theory, BitAutoData]
public void Status_Scim_AlwaysNotApplicable(OrganizationIntegration oi)
{
oi.Type = IntegrationType.Scim;
oi.Configuration = null;
var model = new OrganizationIntegrationResponseModel(oi);
Assert.Equal(OrganizationIntegrationStatus.NotApplicable, model.Status);
model.Configuration = "{}";
Assert.Equal(OrganizationIntegrationStatus.NotApplicable, model.Status);
}
[Theory, BitAutoData]
public void Status_Slack_NullConfig_ReturnsInitiated(OrganizationIntegration oi)
{
oi.Type = IntegrationType.Slack;
oi.Configuration = null;
var model = new OrganizationIntegrationResponseModel(oi);
Assert.Equal(OrganizationIntegrationStatus.Initiated, model.Status);
}
[Theory, BitAutoData]
public void Status_Slack_WithConfig_ReturnsCompleted(OrganizationIntegration oi)
{
oi.Type = IntegrationType.Slack;
oi.Configuration = "{}";
var model = new OrganizationIntegrationResponseModel(oi);
Assert.Equal(OrganizationIntegrationStatus.Completed, model.Status);
}
[Theory, BitAutoData]
public void Status_Webhook_AlwaysCompleted(OrganizationIntegration oi)
{
oi.Type = IntegrationType.Webhook;
oi.Configuration = null;
var model = new OrganizationIntegrationResponseModel(oi);
Assert.Equal(OrganizationIntegrationStatus.Completed, model.Status);
model.Configuration = "{}";
Assert.Equal(OrganizationIntegrationStatus.Completed, model.Status);
}
[Theory, BitAutoData]
public void Status_Hec_NullConfig_ReturnsInvalid(OrganizationIntegration oi)
{
oi.Type = IntegrationType.Hec;
oi.Configuration = null;
var model = new OrganizationIntegrationResponseModel(oi);
Assert.Equal(OrganizationIntegrationStatus.Invalid, model.Status);
}
[Theory, BitAutoData]
public void Status_Hec_WithConfig_ReturnsCompleted(OrganizationIntegration oi)
{
oi.Type = IntegrationType.Hec;
oi.Configuration = "{}";
var model = new OrganizationIntegrationResponseModel(oi);
Assert.Equal(OrganizationIntegrationStatus.Completed, model.Status);
}
[Theory, BitAutoData]
public void Status_Datadog_NullConfig_ReturnsInvalid(OrganizationIntegration oi)
{
oi.Type = IntegrationType.Datadog;
oi.Configuration = null;
var model = new OrganizationIntegrationResponseModel(oi);
Assert.Equal(OrganizationIntegrationStatus.Invalid, model.Status);
}
[Theory, BitAutoData]
public void Status_Datadog_WithConfig_ReturnsCompleted(OrganizationIntegration oi)
{
oi.Type = IntegrationType.Datadog;
oi.Configuration = "{}";
var model = new OrganizationIntegrationResponseModel(oi);
Assert.Equal(OrganizationIntegrationStatus.Completed, model.Status);
}
}

View File

@@ -0,0 +1,313 @@
using System.ComponentModel.DataAnnotations;
using Bit.Api.Auth.Models.Request.Organizations;
using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Enums;
using Bit.Core.Services;
using Bit.Core.Sso;
using Microsoft.Extensions.Localization;
using NSubstitute;
using Xunit;
namespace Bit.Api.Test.Auth.Models.Request;
public class OrganizationSsoRequestModelTests
{
[Fact]
public void ToSsoConfig_WithOrganizationId_CreatesNewSsoConfig()
{
// Arrange
var organizationId = Guid.NewGuid();
var model = new OrganizationSsoRequestModel
{
Enabled = true,
Identifier = "test-identifier",
Data = new SsoConfigurationDataRequest
{
ConfigType = SsoType.OpenIdConnect,
Authority = "https://example.com",
ClientId = "test-client",
ClientSecret = "test-secret"
}
};
// Act
var result = model.ToSsoConfig(organizationId);
// Assert
Assert.NotNull(result);
Assert.Equal(organizationId, result.OrganizationId);
Assert.True(result.Enabled);
}
[Fact]
public void ToSsoConfig_WithExistingConfig_UpdatesExistingConfig()
{
// Arrange
var organizationId = Guid.NewGuid();
var existingConfig = new SsoConfig
{
Id = 1,
OrganizationId = organizationId,
Enabled = false
};
var model = new OrganizationSsoRequestModel
{
Enabled = true,
Identifier = "updated-identifier",
Data = new SsoConfigurationDataRequest
{
ConfigType = SsoType.Saml2,
IdpEntityId = "test-entity",
IdpSingleSignOnServiceUrl = "https://sso.example.com"
}
};
// Act
var result = model.ToSsoConfig(existingConfig);
// Assert
Assert.Same(existingConfig, result);
Assert.Equal(organizationId, result.OrganizationId);
Assert.True(result.Enabled);
}
}
public class SsoConfigurationDataRequestTests
{
private readonly TestI18nService _i18nService;
private readonly ValidationContext _validationContext;
public SsoConfigurationDataRequestTests()
{
_i18nService = new TestI18nService();
var serviceProvider = Substitute.For<IServiceProvider>();
serviceProvider.GetService(typeof(II18nService)).Returns(_i18nService);
_validationContext = new ValidationContext(new object(), serviceProvider, null);
}
[Fact]
public void ToConfigurationData_MapsProperties()
{
// Arrange
var model = new SsoConfigurationDataRequest
{
ConfigType = SsoType.OpenIdConnect,
MemberDecryptionType = MemberDecryptionType.KeyConnector,
Authority = "https://authority.example.com",
ClientId = "test-client-id",
ClientSecret = "test-client-secret",
IdpX509PublicCert = "-----BEGIN CERTIFICATE-----\nMIIC...test\n-----END CERTIFICATE-----",
SpOutboundSigningAlgorithm = null // Test default
};
// Act
var result = model.ToConfigurationData();
// Assert
Assert.Equal(SsoType.OpenIdConnect, result.ConfigType);
Assert.Equal(MemberDecryptionType.KeyConnector, result.MemberDecryptionType);
Assert.Equal("https://authority.example.com", result.Authority);
Assert.Equal("test-client-id", result.ClientId);
Assert.Equal("test-client-secret", result.ClientSecret);
Assert.Equal("MIIC...test", result.IdpX509PublicCert); // PEM headers stripped
Assert.Equal(SamlSigningAlgorithms.Sha256, result.SpOutboundSigningAlgorithm); // Default applied
Assert.Null(result.IdpArtifactResolutionServiceUrl); // Always null
}
[Fact]
public void KeyConnectorEnabled_Setter_UpdatesMemberDecryptionType()
{
// Arrange
var model = new SsoConfigurationDataRequest();
// Act & Assert
#pragma warning disable CS0618 // Type or member is obsolete
model.KeyConnectorEnabled = true;
Assert.Equal(MemberDecryptionType.KeyConnector, model.MemberDecryptionType);
model.KeyConnectorEnabled = false;
Assert.Equal(MemberDecryptionType.MasterPassword, model.MemberDecryptionType);
#pragma warning restore CS0618 // Type or member is obsolete
}
// Validation Tests
[Fact]
public void Validate_OpenIdConnect_ValidData_NoErrors()
{
// Arrange
var model = new SsoConfigurationDataRequest
{
ConfigType = SsoType.OpenIdConnect,
Authority = "https://example.com",
ClientId = "test-client",
ClientSecret = "test-secret"
};
// Act
var results = model.Validate(_validationContext).ToList();
// Assert
Assert.Empty(results);
}
[Theory]
[InlineData("", "test-client", "test-secret", "AuthorityValidationError")]
[InlineData("https://example.com", "", "test-secret", "ClientIdValidationError")]
[InlineData("https://example.com", "test-client", "", "ClientSecretValidationError")]
public void Validate_OpenIdConnect_MissingRequiredFields_ReturnsErrors(string authority, string clientId, string clientSecret, string expectedError)
{
// Arrange
var model = new SsoConfigurationDataRequest
{
ConfigType = SsoType.OpenIdConnect,
Authority = authority,
ClientId = clientId,
ClientSecret = clientSecret
};
// Act
var results = model.Validate(_validationContext).ToList();
// Assert
Assert.Single(results);
Assert.Equal(expectedError, results[0].ErrorMessage);
}
[Fact]
public void Validate_Saml2_ValidData_NoErrors()
{
// Arrange
var model = new SsoConfigurationDataRequest
{
ConfigType = SsoType.Saml2,
IdpEntityId = "https://idp.example.com",
IdpSingleSignOnServiceUrl = "https://sso.example.com",
IdpSingleLogoutServiceUrl = "https://logout.example.com"
};
// Act
var results = model.Validate(_validationContext).ToList();
// Assert
Assert.Empty(results);
}
[Theory]
[InlineData("", "https://sso.example.com", "IdpEntityIdValidationError")]
[InlineData("not-a-valid-uri", "", "IdpSingleSignOnServiceUrlValidationError")]
public void Validate_Saml2_MissingRequiredFields_ReturnsErrors(string entityId, string signOnUrl, string expectedError)
{
// Arrange
var model = new SsoConfigurationDataRequest
{
ConfigType = SsoType.Saml2,
IdpEntityId = entityId,
IdpSingleSignOnServiceUrl = signOnUrl
};
// Act
var results = model.Validate(_validationContext).ToList();
// Assert
Assert.Contains(results, r => r.ErrorMessage == expectedError);
}
[Theory]
[InlineData("not-a-url")]
[InlineData("ftp://example.com")]
[InlineData("https://example.com<script>")]
[InlineData("https://example.com\"onclick")]
public void Validate_Saml2_InvalidUrls_ReturnsErrors(string invalidUrl)
{
// Arrange
var model = new SsoConfigurationDataRequest
{
ConfigType = SsoType.Saml2,
IdpEntityId = "https://idp.example.com",
IdpSingleSignOnServiceUrl = invalidUrl,
IdpSingleLogoutServiceUrl = invalidUrl
};
// Act
var results = model.Validate(_validationContext).ToList();
// Assert
Assert.Contains(results, r => r.ErrorMessage == "IdpSingleSignOnServiceUrlInvalid");
Assert.Contains(results, r => r.ErrorMessage == "IdpSingleLogoutServiceUrlInvalid");
}
[Fact]
public void Validate_Saml2_MissingSignOnUrl_AlwaysReturnsError()
{
// Arrange - SignOnUrl is always required for SAML2, regardless of EntityId format
var model = new SsoConfigurationDataRequest
{
ConfigType = SsoType.Saml2,
IdpEntityId = "https://idp.example.com", // Valid URI
IdpSingleSignOnServiceUrl = "" // Missing - always causes error
};
// Act
var results = model.Validate(_validationContext).ToList();
// Assert - Should always fail validation when SignOnUrl is missing
Assert.Contains(results, r => r.ErrorMessage == "IdpSingleSignOnServiceUrlValidationError");
}
[Fact]
public void Validate_Saml2_InvalidCertificate_ReturnsError()
{
// Arrange
var model = new SsoConfigurationDataRequest
{
ConfigType = SsoType.Saml2,
IdpEntityId = "https://idp.example.com",
IdpSingleSignOnServiceUrl = "https://sso.example.com",
IdpX509PublicCert = "invalid-certificate-data"
};
// Act
var results = model.Validate(_validationContext).ToList();
// Assert
Assert.Contains(results, r => r.ErrorMessage.Contains("IdpX509PublicCert") && r.ErrorMessage.Contains("ValidationError"));
}
// TODO: On server, make public certificate required for SAML2 SSO: https://bitwarden.atlassian.net/browse/PM-26028
[Fact]
public void Validate_Saml2_EmptyCertificate_PassesValidation()
{
// Arrange
var model = new SsoConfigurationDataRequest
{
ConfigType = SsoType.Saml2,
IdpEntityId = "https://idp.example.com",
IdpSingleSignOnServiceUrl = "https://sso.example.com",
IdpX509PublicCert = ""
};
// Act
var results = model.Validate(_validationContext).ToList();
// Assert
Assert.DoesNotContain(results, r => r.MemberNames.Contains("IdpX509PublicCert"));
}
private class TestI18nService : I18nService
{
public TestI18nService() : base(CreateMockLocalizerFactory()) { }
private static IStringLocalizerFactory CreateMockLocalizerFactory()
{
var factory = Substitute.For<IStringLocalizerFactory>();
var localizer = Substitute.For<IStringLocalizer>();
localizer[Arg.Any<string>()].Returns(callInfo => new LocalizedString(callInfo.Arg<string>(), callInfo.Arg<string>()));
localizer[Arg.Any<string>(), Arg.Any<object[]>()].Returns(callInfo => new LocalizedString(callInfo.Arg<string>(), callInfo.Arg<string>()));
factory.Create(Arg.Any<string>(), Arg.Any<string>()).Returns(localizer);
return factory;
}
}
}

View File

@@ -361,7 +361,7 @@ public class ServiceAccountsControllerTests
[Theory]
[BitAutoData]
public async Task BulkDelete_ReturnsAccessDeniedForProjectsWithoutAccess_Success(SutProvider<ServiceAccountsController> sutProvider, List<ServiceAccount> data)
public async Task BulkDelete_ReturnsAccessDeniedForProjectsWithoutAccess_Success(SutProvider<ServiceAccountsController> sutProvider, List<ServiceAccount> data, Guid userId)
{
var ids = data.Select(sa => sa.Id).ToList();
var organizationId = data.First().OrganizationId;
@@ -377,6 +377,7 @@ public class ServiceAccountsControllerTests
Arg.Any<IEnumerable<IAuthorizationRequirement>>()).Returns(AuthorizationResult.Failed());
sutProvider.GetDependency<ICurrentContext>().AccessSecretsManager(Arg.Is(organizationId)).ReturnsForAnyArgs(true);
sutProvider.GetDependency<IServiceAccountRepository>().GetManyByIds(Arg.Is(ids)).ReturnsForAnyArgs(data);
sutProvider.GetDependency<IUserService>().GetProperUserId(default).ReturnsForAnyArgs(userId);
var results = await sutProvider.Sut.BulkDeleteAsync(ids);
@@ -390,7 +391,7 @@ public class ServiceAccountsControllerTests
[Theory]
[BitAutoData]
public async Task BulkDelete_Success(SutProvider<ServiceAccountsController> sutProvider, List<ServiceAccount> data)
public async Task BulkDelete_Success(SutProvider<ServiceAccountsController> sutProvider, List<ServiceAccount> data, Guid userId)
{
var ids = data.Select(sa => sa.Id).ToList();
var organizationId = data.First().OrganizationId;
@@ -404,6 +405,7 @@ public class ServiceAccountsControllerTests
sutProvider.GetDependency<ICurrentContext>().AccessSecretsManager(Arg.Is(organizationId)).ReturnsForAnyArgs(true);
sutProvider.GetDependency<IServiceAccountRepository>().GetManyByIds(Arg.Is(ids)).ReturnsForAnyArgs(data);
sutProvider.GetDependency<IUserService>().GetProperUserId(default).ReturnsForAnyArgs(userId);
var results = await sutProvider.Sut.BulkDeleteAsync(ids);

View File

@@ -0,0 +1,91 @@
#nullable enable
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Models.Data.EventIntegrations;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.Extensions.Time.Testing;
using Xunit;
namespace Bit.Core.Test.AdminConsole.Models.Data.EventIntegrations;
public class IntegrationOAuthStateTests
{
private readonly FakeTimeProvider _fakeTimeProvider = new(
new DateTime(2014, 3, 2, 1, 0, 0, DateTimeKind.Utc)
);
[Theory, BitAutoData]
public void FromIntegration_ToString_RoundTripsCorrectly(OrganizationIntegration integration)
{
var state = IntegrationOAuthState.FromIntegration(integration, _fakeTimeProvider);
var parsed = IntegrationOAuthState.FromString(state.ToString(), _fakeTimeProvider);
Assert.NotNull(parsed);
Assert.Equal(state.IntegrationId, parsed.IntegrationId);
Assert.True(parsed.ValidateOrg(integration.OrganizationId));
}
[Theory]
[InlineData("")]
[InlineData(" ")]
[InlineData("not-a-valid-state")]
public void FromString_InvalidString_ReturnsNull(string state)
{
var parsed = IntegrationOAuthState.FromString(state, _fakeTimeProvider);
Assert.Null(parsed);
}
[Fact]
public void FromString_InvalidGuid_ReturnsNull()
{
var badState = $"not-a-guid.ABCD1234.1706313600";
var parsed = IntegrationOAuthState.FromString(badState, _fakeTimeProvider);
Assert.Null(parsed);
}
[Theory, BitAutoData]
public void FromString_ExpiredState_ReturnsNull(OrganizationIntegration integration)
{
var state = IntegrationOAuthState.FromIntegration(integration, _fakeTimeProvider);
// Advance time 30 minutes to exceed the 20-minute max age
_fakeTimeProvider.Advance(TimeSpan.FromMinutes(30));
var parsed = IntegrationOAuthState.FromString(state.ToString(), _fakeTimeProvider);
Assert.Null(parsed);
}
[Theory, BitAutoData]
public void ValidateOrg_WithCorrectOrgId_ReturnsTrue(OrganizationIntegration integration)
{
var state = IntegrationOAuthState.FromIntegration(integration, _fakeTimeProvider);
Assert.True(state.ValidateOrg(integration.OrganizationId));
}
[Theory, BitAutoData]
public void ValidateOrg_WithWrongOrgId_ReturnsFalse(OrganizationIntegration integration)
{
var state = IntegrationOAuthState.FromIntegration(integration, _fakeTimeProvider);
Assert.False(state.ValidateOrg(Guid.NewGuid()));
}
[Theory, BitAutoData]
public void ValidateOrg_ModifiedTimestamp_ReturnsFalse(OrganizationIntegration integration)
{
var state = IntegrationOAuthState.FromIntegration(integration, _fakeTimeProvider);
var parts = state.ToString().Split('.');
parts[2] = $"{_fakeTimeProvider.GetUtcNow().ToUnixTimeSeconds() - 1}";
var modifiedState = IntegrationOAuthState.FromString(string.Join(".", parts), _fakeTimeProvider);
Assert.True(state.ValidateOrg(integration.OrganizationId));
Assert.NotNull(modifiedState);
Assert.False(modifiedState.ValidateOrg(integration.OrganizationId));
}
}

View File

@@ -13,4 +13,8 @@ public class TestListenerConfiguration : IIntegrationListenerConfiguration
public string IntegrationSubscriptionName => "integration_subscription";
public string IntegrationTopicName => "integration_topic";
public int MaxRetries => 3;
public int EventMaxConcurrentCalls => 1;
public int EventPrefetchCount => 0;
public int IntegrationMaxConcurrentCalls => 1;
public int IntegrationPrefetchCount => 0;
}

View File

@@ -2,6 +2,7 @@
using System.Net;
using System.Text.Json;
using System.Web;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
@@ -261,10 +262,19 @@ public class SlackServiceTests
var sutProvider = GetSutProvider();
var clientId = sutProvider.GetDependency<GlobalSettings>().Slack.ClientId;
var scopes = sutProvider.GetDependency<GlobalSettings>().Slack.Scopes;
var redirectUrl = "https://example.com/callback";
var expectedUrl = $"https://slack.com/oauth/v2/authorize?client_id={clientId}&scope={scopes}&redirect_uri={redirectUrl}";
var result = sutProvider.Sut.GetRedirectUrl(redirectUrl);
Assert.Equal(expectedUrl, result);
var callbackUrl = "https://example.com/callback";
var state = Guid.NewGuid().ToString();
var result = sutProvider.Sut.GetRedirectUrl(callbackUrl, state);
var uri = new Uri(result);
var query = HttpUtility.ParseQueryString(uri.Query);
Assert.Equal(clientId, query["client_id"]);
Assert.Equal(scopes, query["scope"]);
Assert.Equal(callbackUrl, query["redirect_uri"]);
Assert.Equal(state, query["state"]);
Assert.Equal("slack.com", uri.Host);
Assert.Equal("/oauth/v2/authorize", uri.AbsolutePath);
}
[Fact]

Some files were not shown because too many files have changed in this diff Show More