1
0
mirror of https://github.com/bitwarden/browser synced 2025-12-12 14:23:32 +00:00

feat(SSO): (Auth/[PM-22110] Remove Alternate Login Options when SSO Required (#16340)

If a user is part of an org that has the `RequireSso` policy, when that user successfully logs in we add their email to a local `ssoRequiredCache` on their device. The next time this user goes to the `/login` screen on this device, we will use that cache to determine that for this email we should only show the "Use single sign-on" button and disable the alternate login buttons.

These changes are behind the flag: `PM22110_DisableAlternateLoginMethods`
This commit is contained in:
rr-bw
2025-09-22 08:32:20 -07:00
committed by GitHub
parent b455cb5986
commit 3bbc6c564c
15 changed files with 539 additions and 19 deletions

View File

@@ -900,7 +900,11 @@ export default class MainBackground {
this.restrictedItemTypesService,
);
this.ssoLoginService = new SsoLoginService(this.stateProvider, this.logService);
this.ssoLoginService = new SsoLoginService(
this.stateProvider,
this.logService,
this.policyService,
);
this.userVerificationApiService = new UserVerificationApiService(this.apiService);

View File

@@ -1,5 +1,13 @@
import { Component, Inject, OnDestroy, OnInit } from "@angular/core";
import { combineLatest, map, Observable, Subject, switchMap, takeUntil } from "rxjs";
import {
combineLatest,
firstValueFrom,
map,
Observable,
Subject,
switchMap,
takeUntil,
} from "rxjs";
import {
OrganizationUserApiService,
@@ -14,8 +22,10 @@ import { PolicyType } from "@bitwarden/common/admin-console/enums";
import { Organization } from "@bitwarden/common/admin-console/models/domain/organization";
import { Policy } from "@bitwarden/common/admin-console/models/domain/policy";
import { AccountService } from "@bitwarden/common/auth/abstractions/account.service";
import { SsoLoginServiceAbstraction } from "@bitwarden/common/auth/abstractions/sso-login.service.abstraction";
import { UserVerificationService } from "@bitwarden/common/auth/abstractions/user-verification/user-verification.service.abstraction";
import { getUserId } from "@bitwarden/common/auth/services/account.service";
import { FeatureFlag } from "@bitwarden/common/enums/feature-flag.enum";
import { ConfigService } from "@bitwarden/common/platform/abstractions/config/config.service";
import { I18nService } from "@bitwarden/common/platform/abstractions/i18n.service";
import { LogService } from "@bitwarden/common/platform/abstractions/log.service";
@@ -65,6 +75,7 @@ export class OrganizationOptionsComponent implements OnInit, OnDestroy {
private keyService: KeyService,
private accountService: AccountService,
private linkSsoService: LinkSsoService,
private ssoLoginService: SsoLoginServiceAbstraction,
) {}
async ngOnInit() {
@@ -167,6 +178,14 @@ export class OrganizationOptionsComponent implements OnInit, OnDestroy {
title: "",
message: this.i18nService.t("unlinkedSso"),
});
const disableAlternateLoginMethodsFlagEnabled = await this.configService.getFeatureFlag(
FeatureFlag.PM22110_DisableAlternateLoginMethods,
);
if (disableAlternateLoginMethodsFlagEnabled) {
await this.removeEmailFromSsoRequiredCacheIfPresent();
}
} catch (e) {
this.logService.error(e);
}
@@ -186,16 +205,36 @@ export class OrganizationOptionsComponent implements OnInit, OnDestroy {
try {
this.actionPromise = this.organizationApiService.leave(org.id);
await this.actionPromise;
this.toastService.showToast({
variant: "success",
title: "",
message: this.i18nService.t("leftOrganization"),
});
const disableAlternateLoginMethodsFlagEnabled = await this.configService.getFeatureFlag(
FeatureFlag.PM22110_DisableAlternateLoginMethods,
);
if (disableAlternateLoginMethodsFlagEnabled) {
await this.removeEmailFromSsoRequiredCacheIfPresent();
}
} catch (e) {
this.logService.error(e);
}
}
private async removeEmailFromSsoRequiredCacheIfPresent() {
const activeAccount = await firstValueFrom(this.accountService.activeAccount$);
if (!activeAccount) {
this.logService.error("Active account not found.");
return;
}
await this.ssoLoginService.removeFromSsoRequiredCacheIfPresent(activeAccount.email);
}
async toggleResetPasswordEnrollment(org: Organization) {
if (!this.organization.resetPasswordEnrolled) {
await EnrollMasterPasswordReset.open(

View File

@@ -6725,6 +6725,15 @@
"disabledSso": {
"message": "SSO turned on"
},
"emailMustLoginWithSso": {
"message": "$EMAIL$ must login with Single Sign-on",
"placeholders": {
"email": {
"content": "$1",
"example": "name@example.com"
}
}
},
"enabledKeyConnector": {
"message": "Key Connector activated"
},

View File

@@ -877,7 +877,7 @@ const safeProviders: SafeProvider[] = [
safeProvider({
provide: SsoLoginServiceAbstraction,
useClass: SsoLoginService,
deps: [StateProvider, LogService],
deps: [StateProvider, LogService, PolicyServiceAbstraction],
}),
safeProvider({
provide: StateServiceAbstraction,
@@ -1561,7 +1561,14 @@ const safeProviders: SafeProvider[] = [
safeProvider({
provide: LoginSuccessHandlerService,
useClass: DefaultLoginSuccessHandlerService,
deps: [SyncService, UserAsymmetricKeysRegenerationService, LoginEmailService],
deps: [
ConfigService,
LoginEmailService,
SsoLoginServiceAbstraction,
SyncService,
UserAsymmetricKeysRegenerationService,
LogService,
],
}),
safeProvider({
provide: TaskService,

View File

@@ -38,7 +38,14 @@
<div class="tw-grid tw-gap-3">
<!-- Continue button -->
<button type="button" bitButton block buttonType="primary" (click)="continuePressed()">
<button
type="button"
bitButton
block
buttonType="primary"
(click)="continuePressed()"
[disabled]="ssoRequired"
>
{{ "continue" | i18n }}
</button>
@@ -52,6 +59,7 @@
block
buttonType="secondary"
(click)="handleLoginWithPasskeyClick()"
[disabled]="ssoRequired"
>
<i class="bwi bwi-passkey tw-mr-1" aria-hidden="true"></i>
{{ "logInWithPasskey" | i18n }}

View File

@@ -1,5 +1,14 @@
import { CommonModule } from "@angular/common";
import { Component, ElementRef, NgZone, OnDestroy, OnInit, ViewChild } from "@angular/core";
import {
Component,
DestroyRef,
ElementRef,
NgZone,
OnDestroy,
OnInit,
ViewChild,
} from "@angular/core";
import { takeUntilDestroyed } from "@angular/core/rxjs-interop";
import { FormBuilder, FormControl, ReactiveFormsModule, Validators } from "@angular/forms";
import { ActivatedRoute, Router, RouterModule } from "@angular/router";
import { firstValueFrom, Subject, take, takeUntil } from "rxjs";
@@ -17,9 +26,10 @@ import { PolicyData } from "@bitwarden/common/admin-console/models/data/policy.d
import { MasterPasswordPolicyOptions } from "@bitwarden/common/admin-console/models/domain/master-password-policy-options";
import { Policy } from "@bitwarden/common/admin-console/models/domain/policy";
import { DevicesApiServiceAbstraction } from "@bitwarden/common/auth/abstractions/devices-api.service.abstraction";
import { SsoLoginServiceAbstraction } from "@bitwarden/common/auth/abstractions/sso-login.service.abstraction";
import { AuthResult } from "@bitwarden/common/auth/models/domain/auth-result";
import { ClientType, HttpStatusCode } from "@bitwarden/common/enums";
import { MasterPasswordServiceAbstraction } from "@bitwarden/common/key-management/master-password/abstractions/master-password.service.abstraction";
import { FeatureFlag } from "@bitwarden/common/enums/feature-flag.enum";
import { ErrorResponse } from "@bitwarden/common/models/response/error.response";
import { AppIdService } from "@bitwarden/common/platform/abstractions/app-id.service";
import { BroadcasterService } from "@bitwarden/common/platform/abstractions/broadcaster.service";
@@ -83,6 +93,7 @@ export class LoginComponent implements OnInit, OnDestroy {
LoginUiState = LoginUiState;
isKnownDevice = false;
loginUiState: LoginUiState = LoginUiState.EMAIL_ENTRY;
ssoRequired = false;
formGroup = this.formBuilder.group(
{
@@ -108,6 +119,7 @@ export class LoginComponent implements OnInit, OnDestroy {
private anonLayoutWrapperDataService: AnonLayoutWrapperDataService,
private appIdService: AppIdService,
private broadcasterService: BroadcasterService,
private destroyRef: DestroyRef,
private devicesApiService: DevicesApiServiceAbstraction,
private formBuilder: FormBuilder,
private i18nService: I18nService,
@@ -124,8 +136,8 @@ export class LoginComponent implements OnInit, OnDestroy {
private logService: LogService,
private validationService: ValidationService,
private loginSuccessHandlerService: LoginSuccessHandlerService,
private masterPasswordService: MasterPasswordServiceAbstraction,
private configService: ConfigService,
private ssoLoginService: SsoLoginServiceAbstraction,
) {
this.clientType = this.platformUtilsService.getClientType();
}
@@ -184,6 +196,15 @@ export class LoginComponent implements OnInit, OnDestroy {
if (!this.activatedRoute) {
await this.loadRememberedEmail();
}
const disableAlternateLoginMethodsFlagEnabled = await this.configService.getFeatureFlag(
FeatureFlag.PM22110_DisableAlternateLoginMethods,
);
if (disableAlternateLoginMethodsFlagEnabled) {
// This SSO required check should come after email has had a chance to be pre-filled (if it
// was found in query params or was the remembered email)
await this.determineIfSsoRequired();
}
}
private async desktopOnInit(): Promise<void> {
@@ -210,6 +231,40 @@ export class LoginComponent implements OnInit, OnDestroy {
this.messagingService.send("getWindowIsFocused");
}
private async determineIfSsoRequired() {
const ssoRequiredCache = await firstValueFrom(this.ssoLoginService.ssoRequiredCache$);
// Only perform initial update and setup a subscription if there is actually a populated ssoRequiredCache
if (ssoRequiredCache != null && ssoRequiredCache.size > 0) {
// If the pre-filled/remembered email field value exists in the cache, set to true
if (
this.emailFormControl.value &&
ssoRequiredCache.has(this.emailFormControl.value.toLowerCase())
) {
this.ssoRequired = true;
}
this.listenForEmailChanges(ssoRequiredCache);
}
}
private listenForEmailChanges(ssoRequiredCache: Set<string>) {
// On subsequent email field value changes, check and set again. This allows alternate login buttons
// to dynamically enable/disable depending on whether or not the entered email is in the ssoRequiredCache
this.formGroup.controls.email.valueChanges
.pipe(takeUntilDestroyed(this.destroyRef))
.subscribe(() => {
if (
this.emailFormControl.value &&
ssoRequiredCache.has(this.emailFormControl.value.toLowerCase())
) {
this.ssoRequired = true;
} else {
this.ssoRequired = false;
}
});
}
submit = async (): Promise<void> => {
if (this.clientType === ClientType.Desktop) {
if (this.loginUiState !== LoginUiState.MASTER_PASSWORD_ENTRY) {

View File

@@ -8,7 +8,6 @@
<bit-label>{{ "ssoIdentifier" | i18n }}</bit-label>
<input bitInput type="text" formControlName="identifier" appAutofocus />
</bit-form-field>
<hr />
<div class="tw-flex tw-gap-2">
<button type="submit" bitButton bitFormButton buttonType="primary" [block]="true">
{{ "continue" | i18n }}

View File

@@ -290,6 +290,7 @@ export class SsoComponent implements OnInit {
this.identifier = this.identifierFormControl.value ?? "";
await this.ssoLoginService.setOrganizationSsoIdentifier(this.identifier);
this.ssoComponentService.setDocumentCookies?.();
try {
await this.submitSso();
} catch (error) {

View File

@@ -0,0 +1,120 @@
import { MockProxy, mock } from "jest-mock-extended";
import { SsoLoginServiceAbstraction } from "@bitwarden/common/auth/abstractions/sso-login.service.abstraction";
import { FeatureFlag } from "@bitwarden/common/enums/feature-flag.enum";
import { ConfigService } from "@bitwarden/common/platform/abstractions/config/config.service";
import { SyncService } from "@bitwarden/common/platform/sync";
import { UserId } from "@bitwarden/common/types/guid";
import { UserAsymmetricKeysRegenerationService } from "@bitwarden/key-management";
import { LogService } from "@bitwarden/logging";
import { LoginEmailService } from "../login-email/login-email.service";
import { DefaultLoginSuccessHandlerService } from "./default-login-success-handler.service";
describe("DefaultLoginSuccessHandlerService", () => {
let service: DefaultLoginSuccessHandlerService;
let configService: MockProxy<ConfigService>;
let loginEmailService: MockProxy<LoginEmailService>;
let ssoLoginService: MockProxy<SsoLoginServiceAbstraction>;
let syncService: MockProxy<SyncService>;
let userAsymmetricKeysRegenerationService: MockProxy<UserAsymmetricKeysRegenerationService>;
let logService: MockProxy<LogService>;
const userId = "USER_ID" as UserId;
const testEmail = "test@bitwarden.com";
beforeEach(() => {
configService = mock<ConfigService>();
loginEmailService = mock<LoginEmailService>();
ssoLoginService = mock<SsoLoginServiceAbstraction>();
syncService = mock<SyncService>();
userAsymmetricKeysRegenerationService = mock<UserAsymmetricKeysRegenerationService>();
logService = mock<LogService>();
service = new DefaultLoginSuccessHandlerService(
configService,
loginEmailService,
ssoLoginService,
syncService,
userAsymmetricKeysRegenerationService,
logService,
);
syncService.fullSync.mockResolvedValue(true);
});
afterEach(() => {
jest.clearAllMocks();
});
describe("run", () => {
it("should call required services on successful login", async () => {
await service.run(userId);
expect(syncService.fullSync).toHaveBeenCalledWith(true, { skipTokenRefresh: true });
expect(userAsymmetricKeysRegenerationService.regenerateIfNeeded).toHaveBeenCalledWith(userId);
expect(loginEmailService.clearLoginEmail).toHaveBeenCalled();
});
describe("when PM22110_DisableAlternateLoginMethods flag is disabled", () => {
beforeEach(() => {
configService.getFeatureFlag.mockResolvedValue(false);
});
it("should not check SSO requirements", async () => {
await service.run(userId);
expect(ssoLoginService.getSsoEmail).not.toHaveBeenCalled();
expect(ssoLoginService.updateSsoRequiredCache).not.toHaveBeenCalled();
});
});
describe("given PM22110_DisableAlternateLoginMethods flag is enabled", () => {
beforeEach(() => {
configService.getFeatureFlag.mockResolvedValue(true);
});
it("should check feature flag", async () => {
await service.run(userId);
expect(configService.getFeatureFlag).toHaveBeenCalledWith(
FeatureFlag.PM22110_DisableAlternateLoginMethods,
);
});
it("should get SSO email", async () => {
await service.run(userId);
expect(ssoLoginService.getSsoEmail).toHaveBeenCalled();
});
describe("given SSO email is not found", () => {
beforeEach(() => {
ssoLoginService.getSsoEmail.mockResolvedValue(null);
});
it("should log error and return early", async () => {
await service.run(userId);
expect(logService.error).toHaveBeenCalledWith("SSO login email not found.");
expect(ssoLoginService.updateSsoRequiredCache).not.toHaveBeenCalled();
});
});
describe("given SSO email is found", () => {
beforeEach(() => {
ssoLoginService.getSsoEmail.mockResolvedValue(testEmail);
});
it("should call updateSsoRequiredCache() and clearSsoEmail()", async () => {
await service.run(userId);
expect(ssoLoginService.updateSsoRequiredCache).toHaveBeenCalledWith(testEmail, userId);
expect(ssoLoginService.clearSsoEmail).toHaveBeenCalled();
});
});
});
});
});

View File

@@ -1,19 +1,42 @@
import { SsoLoginServiceAbstraction } from "@bitwarden/common/auth/abstractions/sso-login.service.abstraction";
import { FeatureFlag } from "@bitwarden/common/enums/feature-flag.enum";
import { ConfigService } from "@bitwarden/common/platform/abstractions/config/config.service";
import { SyncService } from "@bitwarden/common/platform/sync";
import { UserId } from "@bitwarden/common/types/guid";
import { UserAsymmetricKeysRegenerationService } from "@bitwarden/key-management";
import { LogService } from "@bitwarden/logging";
import { LoginSuccessHandlerService } from "../../abstractions/login-success-handler.service";
import { LoginEmailService } from "../login-email/login-email.service";
export class DefaultLoginSuccessHandlerService implements LoginSuccessHandlerService {
constructor(
private configService: ConfigService,
private loginEmailService: LoginEmailService,
private ssoLoginService: SsoLoginServiceAbstraction,
private syncService: SyncService,
private userAsymmetricKeysRegenerationService: UserAsymmetricKeysRegenerationService,
private loginEmailService: LoginEmailService,
private logService: LogService,
) {}
async run(userId: UserId): Promise<void> {
await this.syncService.fullSync(true, { skipTokenRefresh: true });
await this.userAsymmetricKeysRegenerationService.regenerateIfNeeded(userId);
await this.loginEmailService.clearLoginEmail();
const disableAlternateLoginMethodsFlagEnabled = await this.configService.getFeatureFlag(
FeatureFlag.PM22110_DisableAlternateLoginMethods,
);
if (disableAlternateLoginMethodsFlagEnabled) {
const ssoLoginEmail = await this.ssoLoginService.getSsoEmail();
if (!ssoLoginEmail) {
this.logService.error("SSO login email not found.");
return;
}
await this.ssoLoginService.updateSsoRequiredCache(ssoLoginEmail, userId);
await this.ssoLoginService.clearSsoEmail();
}
}
}

View File

@@ -1,3 +1,5 @@
import { Observable } from "rxjs";
import { UserId } from "@bitwarden/common/types/guid";
export abstract class SsoLoginServiceAbstraction {
@@ -70,6 +72,10 @@ export abstract class SsoLoginServiceAbstraction {
*
*/
abstract setSsoEmail: (email: string) => Promise<void>;
/**
* Clear the SSO email
*/
abstract clearSsoEmail: () => Promise<void>;
/**
* Gets the value of the active user's organization sso identifier.
*
@@ -86,4 +92,24 @@ export abstract class SsoLoginServiceAbstraction {
organizationIdentifier: string,
userId: UserId | undefined,
) => Promise<void>;
/**
* A cache list of user emails for whom the `PolicyType.RequireSso` policy is applied (that is, a list
* of users who are required to authenticate via SSO only). The cache lives on the current device only.
*/
abstract ssoRequiredCache$: Observable<Set<string> | null>;
/**
* Remove an email from the cached list of emails that must authenticate via SSO.
*/
abstract removeFromSsoRequiredCacheIfPresent: (email: string) => Promise<void>;
/**
* Check if the user is required to authenticate via SSO. If so, add their email to a cache list.
* We'll use this cache list to display ONLY the "Use single sign-on" button to the
* user the next time they are on the /login page.
*
* If the user is not required to authenticate via SSO, remove their email from the cache list if it is present.
*/
abstract updateSsoRequiredCache: (ssoLoginEmail: string, userId: UserId) => Promise<void>;
}

View File

@@ -1,9 +1,13 @@
import { mock, MockProxy } from "jest-mock-extended";
import { of } from "rxjs";
import { PolicyService } from "@bitwarden/common/admin-console/abstractions/policy/policy.service.abstraction";
import { PolicyType } from "@bitwarden/common/admin-console/enums";
import {
CODE_VERIFIER,
GLOBAL_ORGANIZATION_SSO_IDENTIFIER,
SSO_EMAIL,
SSO_REQUIRED_CACHE,
SSO_STATE,
SsoLoginService,
USER_ORGANIZATION_SSO_IDENTIFIER,
@@ -18,8 +22,9 @@ describe("SSOLoginService ", () => {
let sut: SsoLoginService;
let accountService: FakeAccountService;
let mockSingleUserStateProvider: FakeStateProvider;
let mockStateProvider: FakeStateProvider;
let mockLogService: MockProxy<LogService>;
let mockPolicyService: MockProxy<PolicyService>;
let userId: UserId;
beforeEach(() => {
@@ -27,10 +32,11 @@ describe("SSOLoginService ", () => {
userId = Utils.newGuid() as UserId;
accountService = mockAccountServiceWith(userId);
mockSingleUserStateProvider = new FakeStateProvider(accountService);
mockStateProvider = new FakeStateProvider(accountService);
mockLogService = mock<LogService>();
mockPolicyService = mock<PolicyService>();
sut = new SsoLoginService(mockSingleUserStateProvider, mockLogService);
sut = new SsoLoginService(mockStateProvider, mockLogService, mockPolicyService);
});
it("instantiates", () => {
@@ -40,7 +46,7 @@ describe("SSOLoginService ", () => {
it("gets and sets code verifier", async () => {
const codeVerifier = "test-code-verifier";
await sut.setCodeVerifier(codeVerifier);
mockSingleUserStateProvider.getGlobal(CODE_VERIFIER);
mockStateProvider.getGlobal(CODE_VERIFIER);
const result = await sut.getCodeVerifier();
expect(result).toBe(codeVerifier);
@@ -49,7 +55,7 @@ describe("SSOLoginService ", () => {
it("gets and sets SSO state", async () => {
const ssoState = "test-sso-state";
await sut.setSsoState(ssoState);
mockSingleUserStateProvider.getGlobal(SSO_STATE);
mockStateProvider.getGlobal(SSO_STATE);
const result = await sut.getSsoState();
expect(result).toBe(ssoState);
@@ -58,7 +64,7 @@ describe("SSOLoginService ", () => {
it("gets and sets organization SSO identifier", async () => {
const orgIdentifier = "test-org-identifier";
await sut.setOrganizationSsoIdentifier(orgIdentifier);
mockSingleUserStateProvider.getGlobal(GLOBAL_ORGANIZATION_SSO_IDENTIFIER);
mockStateProvider.getGlobal(GLOBAL_ORGANIZATION_SSO_IDENTIFIER);
const result = await sut.getOrganizationSsoIdentifier();
expect(result).toBe(orgIdentifier);
@@ -67,7 +73,7 @@ describe("SSOLoginService ", () => {
it("gets and sets SSO email", async () => {
const email = "test@example.com";
await sut.setSsoEmail(email);
mockSingleUserStateProvider.getGlobal(SSO_EMAIL);
mockStateProvider.getGlobal(SSO_EMAIL);
const result = await sut.getSsoEmail();
expect(result).toBe(email);
@@ -77,7 +83,7 @@ describe("SSOLoginService ", () => {
const userId = Utils.newGuid() as UserId;
const orgIdentifier = "test-active-org-identifier";
await sut.setActiveUserOrganizationSsoIdentifier(orgIdentifier, userId);
mockSingleUserStateProvider.getUser(userId, USER_ORGANIZATION_SSO_IDENTIFIER);
mockStateProvider.getUser(userId, USER_ORGANIZATION_SSO_IDENTIFIER);
const result = await sut.getActiveUserOrganizationSsoIdentifier(userId);
expect(result).toBe(orgIdentifier);
@@ -91,4 +97,153 @@ describe("SSOLoginService ", () => {
"Tried to set a user organization sso identifier with an undefined user id.",
);
});
describe("updateSsoRequiredCache()", () => {
it("should add email to cache when SSO is required", async () => {
const email = "test@example.com";
mockStateProvider.global.getFake(SSO_REQUIRED_CACHE).stateSubject.next([]);
mockStateProvider.global.getFake(SSO_EMAIL).stateSubject.next(email);
mockPolicyService.policyAppliesToUser$.mockReturnValue(of(true));
await sut.updateSsoRequiredCache(email, userId);
const cacheState = mockStateProvider.global.getFake(SSO_REQUIRED_CACHE);
expect(cacheState.nextMock).toHaveBeenCalledWith([email.toLowerCase()]);
});
it("should add email to existing cache when SSO is required and email is not already present", async () => {
const existingEmail = "existing@example.com";
const newEmail = "new@example.com";
mockStateProvider.global.getFake(SSO_REQUIRED_CACHE).stateSubject.next([existingEmail]);
mockStateProvider.global.getFake(SSO_EMAIL).stateSubject.next(newEmail);
mockPolicyService.policyAppliesToUser$.mockReturnValue(of(true));
await sut.updateSsoRequiredCache(newEmail, userId);
const cacheState = mockStateProvider.global.getFake(SSO_REQUIRED_CACHE);
expect(cacheState.nextMock).toHaveBeenCalledWith([existingEmail, newEmail.toLowerCase()]);
});
it("should not add duplicate email to cache when SSO is required", async () => {
const duplicateEmail = "duplicate@example.com";
mockStateProvider.global.getFake(SSO_REQUIRED_CACHE).stateSubject.next([duplicateEmail]);
mockStateProvider.global.getFake(SSO_EMAIL).stateSubject.next(duplicateEmail);
mockPolicyService.policyAppliesToUser$.mockReturnValue(of(true));
await sut.updateSsoRequiredCache(duplicateEmail, userId);
const cacheState = mockStateProvider.global.getFake(SSO_REQUIRED_CACHE);
expect(cacheState.nextMock).not.toHaveBeenCalled();
});
it("should initialize new cache with email when SSO is required and no cache exists", async () => {
const email = "test@example.com";
mockStateProvider.global.getFake(SSO_REQUIRED_CACHE).stateSubject.next(null);
mockStateProvider.global.getFake(SSO_EMAIL).stateSubject.next(email);
mockPolicyService.policyAppliesToUser$.mockReturnValue(of(true));
await sut.updateSsoRequiredCache(email, userId);
const cacheState = mockStateProvider.global.getFake(SSO_REQUIRED_CACHE);
expect(cacheState.nextMock).toHaveBeenCalledWith([email.toLowerCase()]);
});
it("should remove email from cache when SSO is not required", async () => {
const emailToRemove = "remove@example.com";
const remainingEmail = "keep@example.com";
mockStateProvider.global
.getFake(SSO_REQUIRED_CACHE)
.stateSubject.next([emailToRemove, remainingEmail]);
mockStateProvider.global.getFake(SSO_EMAIL).stateSubject.next(emailToRemove);
mockPolicyService.policyAppliesToUser$.mockReturnValue(of(false));
await sut.updateSsoRequiredCache(emailToRemove, userId);
const cacheState = mockStateProvider.global.getFake(SSO_REQUIRED_CACHE);
expect(cacheState.nextMock).toHaveBeenCalledWith([remainingEmail]);
});
it("should not update cache when SSO is not required and email is not present", async () => {
const existingEmail = "existing@example.com";
const nonExistentEmail = "nonexistent@example.com";
mockStateProvider.global.getFake(SSO_REQUIRED_CACHE).stateSubject.next([existingEmail]);
mockStateProvider.global.getFake(SSO_EMAIL).stateSubject.next(nonExistentEmail);
mockPolicyService.policyAppliesToUser$.mockReturnValue(of(false));
await sut.updateSsoRequiredCache(nonExistentEmail, userId);
const cacheState = mockStateProvider.global.getFake(SSO_REQUIRED_CACHE);
expect(cacheState.nextMock).not.toHaveBeenCalled();
});
it("should check policy for correct PolicyType and userId", async () => {
const email = "test@example.com";
mockStateProvider.global.getFake(SSO_REQUIRED_CACHE).stateSubject.next([]);
mockPolicyService.policyAppliesToUser$.mockReturnValue(of(true));
await sut.updateSsoRequiredCache(email, userId);
expect(mockPolicyService.policyAppliesToUser$).toHaveBeenCalledWith(
PolicyType.RequireSso,
userId,
);
});
});
describe("removeFromSsoRequiredCacheIfPresent()", () => {
it("should remove email from cache when present", async () => {
const emailToRemove = "remove@example.com";
const remainingEmail = "keep@example.com";
mockStateProvider.global
.getFake(SSO_REQUIRED_CACHE)
.stateSubject.next([emailToRemove, remainingEmail]);
await sut.removeFromSsoRequiredCacheIfPresent(emailToRemove);
const cacheState = mockStateProvider.global.getFake(SSO_REQUIRED_CACHE);
expect(cacheState.nextMock).toHaveBeenCalledWith([remainingEmail]);
});
it("should not update cache when email is not present", async () => {
const existingEmail = "existing@example.com";
const nonExistentEmail = "nonexistent@example.com";
mockStateProvider.global.getFake(SSO_REQUIRED_CACHE).stateSubject.next([existingEmail]);
await sut.removeFromSsoRequiredCacheIfPresent(nonExistentEmail);
const cacheState = mockStateProvider.global.getFake(SSO_REQUIRED_CACHE);
expect(cacheState.nextMock).not.toHaveBeenCalled();
});
it("should not update cache when cache is already null", async () => {
const email = "test@example.com";
mockStateProvider.global.getFake(SSO_REQUIRED_CACHE).stateSubject.next(null);
await sut.removeFromSsoRequiredCacheIfPresent(email);
const cacheState = mockStateProvider.global.getFake(SSO_REQUIRED_CACHE);
expect(cacheState.nextMock).not.toHaveBeenCalled();
});
it("should result in an empty array when removing last email", async () => {
const email = "test@example.com";
mockStateProvider.global.getFake(SSO_REQUIRED_CACHE).stateSubject.next([email]);
await sut.removeFromSsoRequiredCacheIfPresent(email);
const cacheState = mockStateProvider.global.getFake(SSO_REQUIRED_CACHE);
expect(cacheState.nextMock).toHaveBeenCalledWith([]);
});
});
});

View File

@@ -1,5 +1,7 @@
import { firstValueFrom } from "rxjs";
import { firstValueFrom, map, Observable } from "rxjs";
import { PolicyService } from "@bitwarden/common/admin-console/abstractions/policy/policy.service.abstraction";
import { PolicyType } from "@bitwarden/common/admin-console/enums";
import { LogService } from "@bitwarden/common/platform/abstractions/log.service";
import { UserId } from "@bitwarden/common/types/guid";
@@ -8,6 +10,7 @@ import {
KeyDefinition,
SingleUserState,
SSO_DISK,
SSO_DISK_LOCAL,
StateProvider,
UserKeyDefinition,
} from "../../platform/state";
@@ -57,20 +60,35 @@ export const SSO_EMAIL = new KeyDefinition<string>(SSO_DISK, "ssoEmail", {
deserializer: (state) => state,
});
/**
* A cache list of user emails for whom the `PolicyType.RequireSso` policy is applied (that is, a list
* of users who are required to authenticate via SSO only). The cache lives on the current device only.
*/
export const SSO_REQUIRED_CACHE = new KeyDefinition<string[]>(SSO_DISK_LOCAL, "ssoRequiredCache", {
deserializer: (ssoRequiredCache) => ssoRequiredCache,
});
export class SsoLoginService implements SsoLoginServiceAbstraction {
private codeVerifierState: GlobalState<string>;
private ssoState: GlobalState<string>;
private orgSsoIdentifierState: GlobalState<string>;
private ssoEmailState: GlobalState<string>;
private ssoRequiredCacheState: GlobalState<string[]>;
ssoRequiredCache$: Observable<Set<string> | null>;
constructor(
private stateProvider: StateProvider,
private logService: LogService,
private policyService: PolicyService,
) {
this.codeVerifierState = this.stateProvider.getGlobal(CODE_VERIFIER);
this.ssoState = this.stateProvider.getGlobal(SSO_STATE);
this.orgSsoIdentifierState = this.stateProvider.getGlobal(GLOBAL_ORGANIZATION_SSO_IDENTIFIER);
this.ssoEmailState = this.stateProvider.getGlobal(SSO_EMAIL);
this.ssoRequiredCacheState = this.stateProvider.getGlobal(SSO_REQUIRED_CACHE);
this.ssoRequiredCache$ = this.ssoRequiredCacheState.state$.pipe(map((cache) => new Set(cache)));
}
getCodeVerifier(): Promise<string | null> {
@@ -105,6 +123,10 @@ export class SsoLoginService implements SsoLoginServiceAbstraction {
await this.ssoEmailState.update((_) => email);
}
async clearSsoEmail(): Promise<void> {
await this.ssoEmailState.update((_) => null);
}
getActiveUserOrganizationSsoIdentifier(userId: UserId): Promise<string | null> {
return firstValueFrom(this.userOrgSsoIdentifierState(userId).state$);
}
@@ -125,4 +147,53 @@ export class SsoLoginService implements SsoLoginServiceAbstraction {
private userOrgSsoIdentifierState(userId: UserId): SingleUserState<string> {
return this.stateProvider.getUser(userId, USER_ORGANIZATION_SSO_IDENTIFIER);
}
/**
* Add an email to the cached list of emails that must authenticate via SSO.
*/
private async addToSsoRequiredCache(email: string): Promise<void> {
await this.ssoRequiredCacheState.update(
(cache) => (cache == null ? [email] : [...cache, email]),
{
shouldUpdate: (cache) => {
if (cache == null) {
return true;
}
return !cache.includes(email);
},
},
);
}
async removeFromSsoRequiredCacheIfPresent(email: string): Promise<void> {
await this.ssoRequiredCacheState.update(
(cache) => cache?.filter((cachedEmail) => cachedEmail !== email) ?? cache,
{
shouldUpdate: (cache) => {
if (cache == null) {
return false;
}
return cache.includes(email);
},
},
);
}
async updateSsoRequiredCache(ssoLoginEmail: string, userId: UserId): Promise<void> {
const ssoRequired = await firstValueFrom(
this.policyService.policyAppliesToUser$(PolicyType.RequireSso, userId),
);
if (ssoRequired) {
await this.addToSsoRequiredCache(ssoLoginEmail.toLowerCase());
} else {
/**
* If user is not required to authenticate via SSO, remove email from the cache
* list (if it was on the list). This is necessary because the user may have been
* required to authenticate via SSO at some point in the past, but now their org
* no longer requires SSO authenticaiton.
*/
await this.removeFromSsoRequiredCacheIfPresent(ssoLoginEmail.toLowerCase());
}
}
}

View File

@@ -16,6 +16,7 @@ export enum FeatureFlag {
/* Auth */
PM14938_BrowserExtensionLoginApproval = "pm-14938-browser-extension-login-approvals",
PM22110_DisableAlternateLoginMethods = "pm-22110-disable-alternate-login-methods",
/* Autofill */
MacOsNativeCredentialSync = "macos-native-credential-sync",
@@ -98,6 +99,7 @@ export const DefaultFeatureFlagValue = {
/* Auth */
[FeatureFlag.PM14938_BrowserExtensionLoginApproval]: FALSE,
[FeatureFlag.PM22110_DisableAlternateLoginMethods]: FALSE,
/* Billing */
[FeatureFlag.TrialPaymentOptional]: FALSE,

View File

@@ -66,6 +66,7 @@ export const PIN_DISK = new StateDefinition("pinUnlock", "disk");
export const PIN_MEMORY = new StateDefinition("pinUnlock", "memory");
export const ROUTER_DISK = new StateDefinition("router", "disk");
export const SSO_DISK = new StateDefinition("ssoLogin", "disk");
export const SSO_DISK_LOCAL = new StateDefinition("ssoLoginLocal", "disk", { web: "disk-local" });
export const TOKEN_DISK = new StateDefinition("token", "disk");
export const TOKEN_DISK_LOCAL = new StateDefinition("tokenDiskLocal", "disk", {
web: "disk-local",