From 6126368705a2d949e929b7f776e51d2c5aef90d4 Mon Sep 17 00:00:00 2001 From: Bernd Schoolmann Date: Wed, 26 Nov 2025 15:47:20 +0100 Subject: [PATCH] [PM-27835] Implement register SDK service (#17632) * Implement register SDK service * Relative import * Relative import * Rename to registerClient * Update libs/common/src/platform/abstractions/sdk/register-sdk.service.ts Co-authored-by: Derek Nance * Rename --------- Co-authored-by: Derek Nance --- .../src/services/jslib-services.module.ts | 15 ++ .../abstractions/sdk/register-sdk.service.ts | 56 +++++ .../platform/abstractions/sdk/sdk.service.ts | 60 +++++- .../services/sdk/default-sdk.service.ts | 68 +----- .../services/sdk/register-sdk.service.spec.ts | 170 +++++++++++++++ .../services/sdk/register-sdk.service.ts | 196 ++++++++++++++++++ 6 files changed, 503 insertions(+), 62 deletions(-) create mode 100644 libs/common/src/platform/abstractions/sdk/register-sdk.service.ts create mode 100644 libs/common/src/platform/services/sdk/register-sdk.service.spec.ts create mode 100644 libs/common/src/platform/services/sdk/register-sdk.service.ts diff --git a/libs/angular/src/services/jslib-services.module.ts b/libs/angular/src/services/jslib-services.module.ts index bcb601a993c..c8a70cf5af6 100644 --- a/libs/angular/src/services/jslib-services.module.ts +++ b/libs/angular/src/services/jslib-services.module.ts @@ -223,6 +223,7 @@ import { I18nService as I18nServiceAbstraction } from "@bitwarden/common/platfor import { LogService } from "@bitwarden/common/platform/abstractions/log.service"; import { MessagingService as MessagingServiceAbstraction } from "@bitwarden/common/platform/abstractions/messaging.service"; import { PlatformUtilsService as PlatformUtilsServiceAbstraction } from "@bitwarden/common/platform/abstractions/platform-utils.service"; +import { RegisterSdkService } from "@bitwarden/common/platform/abstractions/sdk/register-sdk.service"; import { SdkClientFactory } from "@bitwarden/common/platform/abstractions/sdk/sdk-client-factory"; import { SdkService } from "@bitwarden/common/platform/abstractions/sdk/sdk.service"; import { StateService as StateServiceAbstraction } from "@bitwarden/common/platform/abstractions/state.service"; @@ -261,6 +262,7 @@ import { FileUploadService } from "@bitwarden/common/platform/services/file-uplo import { MigrationBuilderService } from "@bitwarden/common/platform/services/migration-builder.service"; import { MigrationRunner } from "@bitwarden/common/platform/services/migration-runner"; import { DefaultSdkService } from "@bitwarden/common/platform/services/sdk/default-sdk.service"; +import { DefaultRegisterSdkService } from "@bitwarden/common/platform/services/sdk/register-sdk.service"; import { StorageServiceProvider } from "@bitwarden/common/platform/services/storage-service.provider"; import { UserAutoUnlockKeyService } from "@bitwarden/common/platform/services/user-auto-unlock-key.service"; import { ValidationService } from "@bitwarden/common/platform/services/validation.service"; @@ -1586,6 +1588,19 @@ const safeProviders: SafeProvider[] = [ SsoLoginServiceAbstraction, ], }), + safeProvider({ + provide: RegisterSdkService, + useClass: DefaultRegisterSdkService, + deps: [ + SdkClientFactory, + EnvironmentService, + PlatformUtilsServiceAbstraction, + AccountServiceAbstraction, + ApiServiceAbstraction, + StateProvider, + ConfigService, + ], + }), safeProvider({ provide: SdkService, useClass: DefaultSdkService, diff --git a/libs/common/src/platform/abstractions/sdk/register-sdk.service.ts b/libs/common/src/platform/abstractions/sdk/register-sdk.service.ts new file mode 100644 index 00000000000..b340dd95ebe --- /dev/null +++ b/libs/common/src/platform/abstractions/sdk/register-sdk.service.ts @@ -0,0 +1,56 @@ +import { Observable } from "rxjs"; + +import { BitwardenClient, Uuid } from "@bitwarden/sdk-internal"; + +import { UserId } from "../../../types/guid"; +import { Rc } from "../../misc/reference-counting/rc"; +import { Utils } from "../../misc/utils"; + +export class UserNotLoggedInError extends Error { + constructor(userId: UserId) { + super(`User (${userId}) is not logged in`); + } +} + +export class InvalidUuid extends Error { + constructor(uuid: string) { + super(`Invalid UUID: ${uuid}`); + } +} + +/** + * Converts a string to UUID. Will throw an error if the UUID is non valid. + */ +export function asUuid(uuid: string): T { + if (Utils.isGuid(uuid)) { + return uuid as T; + } + + throw new InvalidUuid(uuid); +} + +/** + * Converts a UUID to the string representation. + */ +export function uuidAsString(uuid: T): string { + return uuid as unknown as string; +} + +export abstract class RegisterSdkService { + /** + * Retrieve a client with tokens for a specific user. + * This client is meant exclusively for registrations that require tokens, such as TDE and key-connector. + * + * - If the user is not logged when the subscription is created, the observable will complete + * immediately with {@link UserNotLoggedInError}. + * - If the user is logged in, the observable will emit the client and complete without an error + * when the user logs out. + * + * **WARNING:** Do not use `firstValueFrom(userClient$)`! Any operations on the client must be done within the observable. + * The client will be destroyed when the observable is no longer subscribed to. + * Please let platform know if you need a client that is not destroyed when the observable is no longer subscribed to. + * + * @param userId The user id for which to retrieve the client + */ + abstract registerClient$(userId: UserId): Observable>; +} diff --git a/libs/common/src/platform/abstractions/sdk/sdk.service.ts b/libs/common/src/platform/abstractions/sdk/sdk.service.ts index 9b7f32a8a0e..f34bb8fb612 100644 --- a/libs/common/src/platform/abstractions/sdk/sdk.service.ts +++ b/libs/common/src/platform/abstractions/sdk/sdk.service.ts @@ -1,7 +1,8 @@ import { Observable } from "rxjs"; -import { PasswordManagerClient, Uuid } from "@bitwarden/sdk-internal"; +import { PasswordManagerClient, Uuid, DeviceType as SdkDeviceType } from "@bitwarden/sdk-internal"; +import { DeviceType } from "../../../enums"; import { UserId } from "../../../types/guid"; import { Rc } from "../../misc/reference-counting/rc"; import { Utils } from "../../misc/utils"; @@ -18,6 +19,63 @@ export class InvalidUuid extends Error { } } +export function toSdkDevice(device: DeviceType): SdkDeviceType { + switch (device) { + case DeviceType.Android: + return "Android"; + case DeviceType.iOS: + return "iOS"; + case DeviceType.ChromeExtension: + return "ChromeExtension"; + case DeviceType.FirefoxExtension: + return "FirefoxExtension"; + case DeviceType.OperaExtension: + return "OperaExtension"; + case DeviceType.EdgeExtension: + return "EdgeExtension"; + case DeviceType.WindowsDesktop: + return "WindowsDesktop"; + case DeviceType.MacOsDesktop: + return "MacOsDesktop"; + case DeviceType.LinuxDesktop: + return "LinuxDesktop"; + case DeviceType.ChromeBrowser: + return "ChromeBrowser"; + case DeviceType.FirefoxBrowser: + return "FirefoxBrowser"; + case DeviceType.OperaBrowser: + return "OperaBrowser"; + case DeviceType.EdgeBrowser: + return "EdgeBrowser"; + case DeviceType.IEBrowser: + return "IEBrowser"; + case DeviceType.UnknownBrowser: + return "UnknownBrowser"; + case DeviceType.AndroidAmazon: + return "AndroidAmazon"; + case DeviceType.UWP: + return "UWP"; + case DeviceType.SafariBrowser: + return "SafariBrowser"; + case DeviceType.VivaldiBrowser: + return "VivaldiBrowser"; + case DeviceType.VivaldiExtension: + return "VivaldiExtension"; + case DeviceType.SafariExtension: + return "SafariExtension"; + case DeviceType.Server: + return "Server"; + case DeviceType.WindowsCLI: + return "WindowsCLI"; + case DeviceType.MacOsCLI: + return "MacOsCLI"; + case DeviceType.LinuxCLI: + return "LinuxCLI"; + default: + return "SDK"; + } +} + /** * Converts a string to UUID. Will throw an error if the UUID is non valid. */ diff --git a/libs/common/src/platform/services/sdk/default-sdk.service.ts b/libs/common/src/platform/services/sdk/default-sdk.service.ts index eb663c6f928..6e7bcbb197d 100644 --- a/libs/common/src/platform/services/sdk/default-sdk.service.ts +++ b/libs/common/src/platform/services/sdk/default-sdk.service.ts @@ -22,14 +22,12 @@ import { KeyService, KdfConfigService, KdfConfig, KdfType } from "@bitwarden/key import { PasswordManagerClient, ClientSettings, - DeviceType as SdkDeviceType, TokenProvider, UnsignedSharedKey, } from "@bitwarden/sdk-internal"; import { ApiService } from "../../../abstractions/api.service"; import { AccountInfo, AccountService } from "../../../auth/abstractions/account.service"; -import { DeviceType } from "../../../enums/device-type.enum"; import { EncryptedString, EncString } from "../../../key-management/crypto/models/enc-string"; import { SecurityStateService } from "../../../key-management/security-state/abstractions/security-state.service"; import { SignedSecurityState, WrappedSigningKey } from "../../../key-management/types"; @@ -39,7 +37,12 @@ import { Environment, EnvironmentService } from "../../abstractions/environment. import { PlatformUtilsService } from "../../abstractions/platform-utils.service"; import { SdkClientFactory } from "../../abstractions/sdk/sdk-client-factory"; import { SdkLoadService } from "../../abstractions/sdk/sdk-load.service"; -import { asUuid, SdkService, UserNotLoggedInError } from "../../abstractions/sdk/sdk.service"; +import { + asUuid, + SdkService, + toSdkDevice, + UserNotLoggedInError, +} from "../../abstractions/sdk/sdk.service"; import { compareValues } from "../../misc/compare-values"; import { Rc } from "../../misc/reference-counting/rc"; import { StateProvider } from "../../state"; @@ -297,65 +300,8 @@ export class DefaultSdkService implements SdkService { return { apiUrl: env.getApiUrl(), identityUrl: env.getIdentityUrl(), - deviceType: this.toDevice(this.platformUtilsService.getDevice()), + deviceType: toSdkDevice(this.platformUtilsService.getDevice()), userAgent: this.userAgent ?? navigator.userAgent, }; } - - private toDevice(device: DeviceType): SdkDeviceType { - switch (device) { - case DeviceType.Android: - return "Android"; - case DeviceType.iOS: - return "iOS"; - case DeviceType.ChromeExtension: - return "ChromeExtension"; - case DeviceType.FirefoxExtension: - return "FirefoxExtension"; - case DeviceType.OperaExtension: - return "OperaExtension"; - case DeviceType.EdgeExtension: - return "EdgeExtension"; - case DeviceType.WindowsDesktop: - return "WindowsDesktop"; - case DeviceType.MacOsDesktop: - return "MacOsDesktop"; - case DeviceType.LinuxDesktop: - return "LinuxDesktop"; - case DeviceType.ChromeBrowser: - return "ChromeBrowser"; - case DeviceType.FirefoxBrowser: - return "FirefoxBrowser"; - case DeviceType.OperaBrowser: - return "OperaBrowser"; - case DeviceType.EdgeBrowser: - return "EdgeBrowser"; - case DeviceType.IEBrowser: - return "IEBrowser"; - case DeviceType.UnknownBrowser: - return "UnknownBrowser"; - case DeviceType.AndroidAmazon: - return "AndroidAmazon"; - case DeviceType.UWP: - return "UWP"; - case DeviceType.SafariBrowser: - return "SafariBrowser"; - case DeviceType.VivaldiBrowser: - return "VivaldiBrowser"; - case DeviceType.VivaldiExtension: - return "VivaldiExtension"; - case DeviceType.SafariExtension: - return "SafariExtension"; - case DeviceType.Server: - return "Server"; - case DeviceType.WindowsCLI: - return "WindowsCLI"; - case DeviceType.MacOsCLI: - return "MacOsCLI"; - case DeviceType.LinuxCLI: - return "LinuxCLI"; - default: - return "SDK"; - } - } } diff --git a/libs/common/src/platform/services/sdk/register-sdk.service.spec.ts b/libs/common/src/platform/services/sdk/register-sdk.service.spec.ts new file mode 100644 index 00000000000..0a05ac8dbf4 --- /dev/null +++ b/libs/common/src/platform/services/sdk/register-sdk.service.spec.ts @@ -0,0 +1,170 @@ +import { mock, MockProxy } from "jest-mock-extended"; +import { BehaviorSubject, firstValueFrom, of } from "rxjs"; + +import { BitwardenClient } from "@bitwarden/sdk-internal"; + +import { + ObservableTracker, + FakeAccountService, + FakeStateProvider, + mockAccountServiceWith, +} from "../../../../spec"; +import { ApiService } from "../../../abstractions/api.service"; +import { AccountInfo } from "../../../auth/abstractions/account.service"; +import { UserId } from "../../../types/guid"; +import { ConfigService } from "../../abstractions/config/config.service"; +import { Environment, EnvironmentService } from "../../abstractions/environment.service"; +import { PlatformUtilsService } from "../../abstractions/platform-utils.service"; +import { SdkClientFactory } from "../../abstractions/sdk/sdk-client-factory"; +import { SdkLoadService } from "../../abstractions/sdk/sdk-load.service"; +import { UserNotLoggedInError } from "../../abstractions/sdk/sdk.service"; +import { Rc } from "../../misc/reference-counting/rc"; +import { Utils } from "../../misc/utils"; + +import { DefaultRegisterSdkService } from "./register-sdk.service"; + +class TestSdkLoadService extends SdkLoadService { + protected override load(): Promise { + // Simulate successful WASM load + return Promise.resolve(); + } +} + +describe("DefaultRegisterSdkService", () => { + describe("userClient$", () => { + let sdkClientFactory!: MockProxy; + let environmentService!: MockProxy; + let platformUtilsService!: MockProxy; + let configService!: MockProxy; + let service!: DefaultRegisterSdkService; + let accountService!: FakeAccountService; + let fakeStateProvider!: FakeStateProvider; + let apiService!: MockProxy; + + beforeEach(async () => { + await new TestSdkLoadService().loadAndInit(); + + sdkClientFactory = mock(); + environmentService = mock(); + platformUtilsService = mock(); + apiService = mock(); + const mockUserId = Utils.newGuid() as UserId; + accountService = mockAccountServiceWith(mockUserId); + fakeStateProvider = new FakeStateProvider(accountService); + configService = mock(); + + configService.serverConfig$ = new BehaviorSubject(null); + + // Can't use `of(mock())` for some reason + environmentService.environment$ = new BehaviorSubject(mock()); + + service = new DefaultRegisterSdkService( + sdkClientFactory, + environmentService, + platformUtilsService, + accountService, + apiService, + fakeStateProvider, + configService, + ); + }); + + describe("given the user is logged in", () => { + const userId = "0da62ebd-98bb-4f42-a846-64e8555087d7" as UserId; + beforeEach(() => { + environmentService.getEnvironment$ + .calledWith(userId) + .mockReturnValue(new BehaviorSubject(mock())); + accountService.accounts$ = of({ + [userId]: { email: "email", emailVerified: true, name: "name" } as AccountInfo, + }); + }); + + let mockClient!: MockProxy; + + beforeEach(() => { + mockClient = createMockClient(); + sdkClientFactory.createSdkClient.mockResolvedValue(mockClient); + }); + + it("creates an internal SDK client when called the first time", async () => { + await firstValueFrom(service.registerClient$(userId)); + + expect(sdkClientFactory.createSdkClient).toHaveBeenCalled(); + }); + + it("does not create an SDK client when called the second time with same userId", async () => { + const subject_1 = new BehaviorSubject | undefined>(undefined); + const subject_2 = new BehaviorSubject | undefined>(undefined); + + // Use subjects to ensure the subscription is kept alive + service.registerClient$(userId).subscribe(subject_1); + service.registerClient$(userId).subscribe(subject_2); + + // Wait for the next tick to ensure all async operations are done + await new Promise(process.nextTick); + + expect(subject_1.value.take().value).toBe(mockClient); + expect(subject_2.value.take().value).toBe(mockClient); + expect(sdkClientFactory.createSdkClient).toHaveBeenCalledTimes(1); + }); + + it("destroys the internal SDK client when all subscriptions are closed", async () => { + const subject_1 = new BehaviorSubject | undefined>(undefined); + const subject_2 = new BehaviorSubject | undefined>(undefined); + const subscription_1 = service.registerClient$(userId).subscribe(subject_1); + const subscription_2 = service.registerClient$(userId).subscribe(subject_2); + await new Promise(process.nextTick); + + subscription_1.unsubscribe(); + subscription_2.unsubscribe(); + + await new Promise(process.nextTick); + expect(mockClient.free).toHaveBeenCalledTimes(1); + }); + + it("destroys the internal SDK client when the account is removed (logout)", async () => { + const accounts$ = new BehaviorSubject({ + [userId]: { email: "email", emailVerified: true, name: "name" } as AccountInfo, + }); + accountService.accounts$ = accounts$; + + const userClientTracker = new ObservableTracker(service.registerClient$(userId), false); + await userClientTracker.pauseUntilReceived(1); + + accounts$.next({}); + await userClientTracker.expectCompletion(); + + expect(mockClient.free).toHaveBeenCalledTimes(1); + }); + }); + + describe("given the user is not logged in", () => { + const userId = "0da62ebd-98bb-4f42-a846-64e8555087d7" as UserId; + + beforeEach(() => { + environmentService.getEnvironment$ + .calledWith(userId) + .mockReturnValue(new BehaviorSubject(mock())); + accountService.accounts$ = of({}); + }); + + it("throws UserNotLoggedInError when user has no account", async () => { + const result = () => firstValueFrom(service.registerClient$(userId)); + + await expect(result).rejects.toThrow(UserNotLoggedInError); + }); + }); + }); +}); + +function createMockClient(): MockProxy { + const client = mock(); + client.platform.mockReturnValue({ + state: jest.fn().mockReturnValue(mock()), + load_flags: jest.fn().mockReturnValue(mock()), + free: mock(), + [Symbol.dispose]: jest.fn(), + }); + return client; +} diff --git a/libs/common/src/platform/services/sdk/register-sdk.service.ts b/libs/common/src/platform/services/sdk/register-sdk.service.ts new file mode 100644 index 00000000000..a222807640f --- /dev/null +++ b/libs/common/src/platform/services/sdk/register-sdk.service.ts @@ -0,0 +1,196 @@ +import { + combineLatest, + concatMap, + Observable, + shareReplay, + map, + distinctUntilChanged, + tap, + switchMap, + BehaviorSubject, + of, + takeWhile, + throwIfEmpty, + firstValueFrom, +} from "rxjs"; + +import { PasswordManagerClient, ClientSettings, TokenProvider } from "@bitwarden/sdk-internal"; + +import { ApiService } from "../../../abstractions/api.service"; +import { AccountService } from "../../../auth/abstractions/account.service"; +import { ConfigService } from "../../../platform/abstractions/config/config.service"; +import { UserId } from "../../../types/guid"; +import { Environment, EnvironmentService } from "../../abstractions/environment.service"; +import { PlatformUtilsService } from "../../abstractions/platform-utils.service"; +import { RegisterSdkService } from "../../abstractions/sdk/register-sdk.service"; +import { SdkClientFactory } from "../../abstractions/sdk/sdk-client-factory"; +import { SdkLoadService } from "../../abstractions/sdk/sdk-load.service"; +import { toSdkDevice, UserNotLoggedInError } from "../../abstractions/sdk/sdk.service"; +import { Rc } from "../../misc/reference-counting/rc"; +import { StateProvider } from "../../state"; + +import { initializeState } from "./client-managed-state"; + +// A symbol that represents an overridden client that is explicitly set to undefined, +// blocking the creation of an internal client for that user. +const UnsetClient = Symbol("UnsetClient"); + +/** + * A token provider that exposes the access token to the SDK. + */ +class JsTokenProvider implements TokenProvider { + constructor( + private apiService: ApiService, + private userId?: UserId, + ) {} + + async get_access_token(): Promise { + if (this.userId == null) { + return undefined; + } + + return await this.apiService.getActiveBearerToken(this.userId); + } +} + +export class DefaultRegisterSdkService implements RegisterSdkService { + private sdkClientOverrides = new BehaviorSubject<{ + [userId: UserId]: Rc | typeof UnsetClient; + }>({}); + private sdkClientCache = new Map>>(); + + client$ = this.environmentService.environment$.pipe( + concatMap(async (env) => { + await SdkLoadService.Ready; + const settings = this.toSettings(env); + const client = await this.sdkClientFactory.createSdkClient( + new JsTokenProvider(this.apiService), + settings, + ); + await this.loadFeatureFlags(client); + return client; + }), + shareReplay({ refCount: true, bufferSize: 1 }), + ); + + constructor( + private sdkClientFactory: SdkClientFactory, + private environmentService: EnvironmentService, + private platformUtilsService: PlatformUtilsService, + private accountService: AccountService, + private apiService: ApiService, + private stateProvider: StateProvider, + private configService: ConfigService, + private userAgent: string | null = null, + ) {} + + registerClient$(userId: UserId): Observable> { + return this.sdkClientOverrides.pipe( + takeWhile((clients) => clients[userId] !== UnsetClient, false), + map((clients) => { + if (clients[userId] === UnsetClient) { + throw new Error("Encountered UnsetClient even though it should have been filtered out"); + } + return clients[userId] as Rc; + }), + distinctUntilChanged(), + switchMap((clientOverride) => { + if (clientOverride) { + return of(clientOverride); + } + + return this.internalClient$(userId); + }), + takeWhile((client) => client !== undefined, false), + throwIfEmpty(() => new UserNotLoggedInError(userId)), + ); + } + + /** + * This method is used to create a client for a specific user by using the existing state of the application. + * This client is token-only and does not initialize any encryption keys. + * @param userId The user id for which to create the client + * @returns An observable that emits the client for the user + */ + private internalClient$(userId: UserId): Observable> { + const cached = this.sdkClientCache.get(userId); + if (cached !== undefined) { + return cached; + } + + const account$ = this.accountService.accounts$.pipe( + map((accounts) => accounts[userId]), + distinctUntilChanged(), + ); + + const client$ = combineLatest([ + this.environmentService.getEnvironment$(userId), + account$, + SdkLoadService.Ready, // Makes sure we wait (once) for the SDK to be loaded + ]).pipe( + // switchMap is required to allow the clean-up logic to be executed when `combineLatest` emits a new value. + switchMap(([env, account]) => { + // Create our own observable to be able to implement clean-up logic + return new Observable>((subscriber) => { + const createAndInitializeClient = async () => { + if (env == null || account == null) { + return undefined; + } + + const settings = this.toSettings(env); + const client = await this.sdkClientFactory.createSdkClient( + new JsTokenProvider(this.apiService, userId), + settings, + ); + + // Initialize the SDK managed database and the client managed repositories. + await initializeState(userId, client.platform().state(), this.stateProvider); + + await this.loadFeatureFlags(client); + + return client; + }; + + let client: Rc | undefined; + createAndInitializeClient() + .then((c) => { + client = c === undefined ? undefined : new Rc(c); + + subscriber.next(client); + }) + .catch((e) => { + subscriber.error(e); + }); + + return () => client?.markForDisposal(); + }); + }), + tap({ finalize: () => this.sdkClientCache.delete(userId) }), + shareReplay({ refCount: true, bufferSize: 1 }), + ); + + this.sdkClientCache.set(userId, client$); + return client$; + } + + private async loadFeatureFlags(client: PasswordManagerClient) { + const serverConfig = await firstValueFrom(this.configService.serverConfig$); + + const featureFlagMap = new Map( + Object.entries(serverConfig?.featureStates ?? {}) + .filter(([, value]) => typeof value === "boolean") // The SDK only supports boolean feature flags at this time + .map(([key, value]) => [key, value] as [string, boolean]), + ); + + client.platform().load_flags(featureFlagMap); + } + + private toSettings(env: Environment): ClientSettings { + return { + apiUrl: env.getApiUrl(), + identityUrl: env.getIdentityUrl(), + deviceType: toSdkDevice(this.platformUtilsService.getDevice()), + userAgent: this.userAgent ?? navigator.userAgent, + }; + } +}