1
0
mirror of https://github.com/bitwarden/browser synced 2025-12-11 13:53:34 +00:00

[PM-17408] Create new method on sdk service to allow explicit addition of a new client instance (#13309)

* feat: allow the user client to be overriden by an external provider

* feat: add ability to unset client

* feat: add `setClient` to interface (and add some docs)

* fix: re-add undefined

* fix: strict typing issues
This commit is contained in:
Andreas Coroiu
2025-02-24 11:29:47 +01:00
committed by GitHub
parent 5f390e6151
commit a9862d2a19
4 changed files with 265 additions and 74 deletions

View File

@@ -1,13 +1,33 @@
// FIXME: Update this file to be type safe and remove this and next line // FIXME: Update this file to be type safe and remove this and next line
// @ts-strict-ignore // @ts-strict-ignore
import { firstValueFrom, Observable, Subject, Subscription, throwError, timeout } from "rxjs"; import {
filter,
firstValueFrom,
lastValueFrom,
Observable,
Subject,
Subscription,
throwError,
timeout,
} from "rxjs";
/** Test class to enable async awaiting of observable emissions */ /** Test class to enable async awaiting of observable emissions */
export class ObservableTracker<T> { export class ObservableTracker<T> {
private subscription: Subscription; private subscription: Subscription;
private emissionReceived = new Subject<T>(); private emissionReceived = new Subject<T>();
emissions: T[] = []; emissions: T[] = [];
constructor(observable: Observable<T>) {
/**
* Creates a new ObservableTracker and instantly subscribes to the given observable.
* @param observable The observable to track
* @param clone Whether to clone tracked emissions or not, defaults to true.
* Cloning can be necessary if the observable emits objects that are mutated after emission. Cloning makes it
* harder to compare the original and the tracked emission using reference equality (e.g. `expect().toBe()`).
*/
constructor(
observable: Observable<T>,
private clone = true,
) {
this.emissions = this.trackEmissions(observable); this.emissions = this.trackEmissions(observable);
} }
@@ -33,6 +53,19 @@ export class ObservableTracker<T> {
); );
} }
async expectCompletion(msTimeout = 50): Promise<void> {
return await lastValueFrom(
this.emissionReceived.pipe(
filter(() => false),
timeout({
first: msTimeout,
with: () => throwError(() => new Error("Timeout exceeded waiting for completion.")),
}),
),
{ defaultValue: undefined },
);
}
/** Awaits until the total number of emissions observed by this tracker equals or exceeds {@link count} /** Awaits until the total number of emissions observed by this tracker equals or exceeds {@link count}
* @param count The number of emissions to wait for * @param count The number of emissions to wait for
*/ */
@@ -48,26 +81,31 @@ export class ObservableTracker<T> {
this.emissionReceived.subscribe((value) => { this.emissionReceived.subscribe((value) => {
emissions.push(value); emissions.push(value);
}); });
this.subscription = observable.subscribe((value) => { this.subscription = observable.subscribe({
if (value == null) { next: (value) => {
this.emissionReceived.next(null); if (value == null) {
return; this.emissionReceived.next(null);
} return;
switch (typeof value) {
case "string":
case "number":
case "boolean":
this.emissionReceived.next(value);
break;
case "symbol":
// Cheating types to make symbols work at all
this.emissionReceived.next(value as T);
break;
default: {
this.emissionReceived.next(clone(value));
} }
}
switch (typeof value) {
case "string":
case "number":
case "boolean":
this.emissionReceived.next(value);
break;
case "symbol":
// Cheating types to make symbols work at all
this.emissionReceived.next(value as T);
break;
default: {
this.emissionReceived.next(this.clone ? clone(value) : value);
}
}
},
complete: () => {
this.emissionReceived.complete();
},
}); });
return emissions; return emissions;

View File

@@ -5,6 +5,12 @@ import { BitwardenClient } from "@bitwarden/sdk-internal";
import { UserId } from "../../../types/guid"; import { UserId } from "../../../types/guid";
import { Rc } from "../../misc/reference-counting/rc"; import { Rc } from "../../misc/reference-counting/rc";
export class UserNotLoggedInError extends Error {
constructor(userId: UserId) {
super(`User (${userId}) is not logged in`);
}
}
export abstract class SdkService { export abstract class SdkService {
/** /**
* Retrieve the version of the SDK. * Retrieve the version of the SDK.
@@ -26,7 +32,20 @@ export abstract class SdkService {
* The client will be destroyed when the observable is no longer subscribed to. * The client will be destroyed when the observable is no longer subscribed to.
* Please let platform know if you need a client that is not destroyed when the observable is no longer subscribed to. * Please let platform know if you need a client that is not destroyed when the observable is no longer subscribed to.
* *
* @param userId * @param userId The user id for which to retrieve the client
*
* @throws {UserNotLoggedInError} If the user is not logged in
*/ */
abstract userClient$(userId: UserId): Observable<Rc<BitwardenClient> | undefined>; abstract userClient$(userId: UserId): Observable<Rc<BitwardenClient> | undefined>;
/**
* This method is used during/after an authentication procedure to set a new client for a specific user.
* It can also be used to unset the client when a user logs out, this will result in:
* - The client being disposed of
* - All subscriptions to the client being completed
* - Any new subscribers receiving an error
* @param userId The user id for which to set the client
* @param client The client to set for the user. If undefined, the client will be unset.
*/
abstract setClient(userId: UserId, client: BitwardenClient | undefined): void;
} }

View File

@@ -4,12 +4,14 @@ import { BehaviorSubject, firstValueFrom, of } from "rxjs";
import { KdfConfigService, KeyService, PBKDF2KdfConfig } from "@bitwarden/key-management"; import { KdfConfigService, KeyService, PBKDF2KdfConfig } from "@bitwarden/key-management";
import { BitwardenClient } from "@bitwarden/sdk-internal"; import { BitwardenClient } from "@bitwarden/sdk-internal";
import { ObservableTracker } from "../../../../spec";
import { AccountInfo, AccountService } from "../../../auth/abstractions/account.service"; import { AccountInfo, AccountService } from "../../../auth/abstractions/account.service";
import { UserId } from "../../../types/guid"; import { UserId } from "../../../types/guid";
import { UserKey } from "../../../types/key"; import { UserKey } from "../../../types/key";
import { Environment, EnvironmentService } from "../../abstractions/environment.service"; import { Environment, EnvironmentService } from "../../abstractions/environment.service";
import { PlatformUtilsService } from "../../abstractions/platform-utils.service"; import { PlatformUtilsService } from "../../abstractions/platform-utils.service";
import { SdkClientFactory } from "../../abstractions/sdk/sdk-client-factory"; import { SdkClientFactory } from "../../abstractions/sdk/sdk-client-factory";
import { UserNotLoggedInError } from "../../abstractions/sdk/sdk.service";
import { Rc } from "../../misc/reference-counting/rc"; import { Rc } from "../../misc/reference-counting/rc";
import { EncryptedString } from "../../models/domain/enc-string"; import { EncryptedString } from "../../models/domain/enc-string";
import { SymmetricCryptoKey } from "../../models/domain/symmetric-crypto-key"; import { SymmetricCryptoKey } from "../../models/domain/symmetric-crypto-key";
@@ -26,8 +28,6 @@ describe("DefaultSdkService", () => {
let keyService!: MockProxy<KeyService>; let keyService!: MockProxy<KeyService>;
let service!: DefaultSdkService; let service!: DefaultSdkService;
let mockClient!: MockProxy<BitwardenClient>;
beforeEach(() => { beforeEach(() => {
sdkClientFactory = mock<SdkClientFactory>(); sdkClientFactory = mock<SdkClientFactory>();
environmentService = mock<EnvironmentService>(); environmentService = mock<EnvironmentService>();
@@ -47,15 +47,10 @@ describe("DefaultSdkService", () => {
kdfConfigService, kdfConfigService,
keyService, keyService,
); );
mockClient = mock<BitwardenClient>();
mockClient.crypto.mockReturnValue(mock());
sdkClientFactory.createSdkClient.mockResolvedValue(mockClient);
}); });
describe("given the user is logged in", () => { describe("given the user is logged in", () => {
const userId = "user-id" as UserId; const userId = "user-id" as UserId;
beforeEach(() => { beforeEach(() => {
environmentService.getEnvironment$ environmentService.getEnvironment$
.calledWith(userId) .calledWith(userId)
@@ -75,56 +70,144 @@ describe("DefaultSdkService", () => {
keyService.encryptedOrgKeys$.calledWith(userId).mockReturnValue(of({})); keyService.encryptedOrgKeys$.calledWith(userId).mockReturnValue(of({}));
}); });
it("creates an SDK client when called the first time", async () => { describe("given no client override has been set for the user", () => {
await firstValueFrom(service.userClient$(userId)); let mockClient!: MockProxy<BitwardenClient>;
expect(sdkClientFactory.createSdkClient).toHaveBeenCalled(); beforeEach(() => {
mockClient = createMockClient();
sdkClientFactory.createSdkClient.mockResolvedValue(mockClient);
});
it("creates an internal SDK client when called the first time", async () => {
await firstValueFrom(service.userClient$(userId));
expect(sdkClientFactory.createSdkClient).toHaveBeenCalled();
});
it("does not create an SDK client when called the second time with same userId", async () => {
const subject_1 = new BehaviorSubject<Rc<BitwardenClient> | undefined>(undefined);
const subject_2 = new BehaviorSubject<Rc<BitwardenClient> | undefined>(undefined);
// Use subjects to ensure the subscription is kept alive
service.userClient$(userId).subscribe(subject_1);
service.userClient$(userId).subscribe(subject_2);
// Wait for the next tick to ensure all async operations are done
await new Promise(process.nextTick);
expect(subject_1.value.take().value).toBe(mockClient);
expect(subject_2.value.take().value).toBe(mockClient);
expect(sdkClientFactory.createSdkClient).toHaveBeenCalledTimes(1);
});
it("destroys the internal SDK client when all subscriptions are closed", async () => {
const subject_1 = new BehaviorSubject<Rc<BitwardenClient> | undefined>(undefined);
const subject_2 = new BehaviorSubject<Rc<BitwardenClient> | undefined>(undefined);
const subscription_1 = service.userClient$(userId).subscribe(subject_1);
const subscription_2 = service.userClient$(userId).subscribe(subject_2);
await new Promise(process.nextTick);
subscription_1.unsubscribe();
subscription_2.unsubscribe();
await new Promise(process.nextTick);
expect(mockClient.free).toHaveBeenCalledTimes(1);
});
it("destroys the internal SDK client when the userKey is unset (i.e. lock or logout)", async () => {
const userKey$ = new BehaviorSubject(
new SymmetricCryptoKey(new Uint8Array(64)) as UserKey,
);
keyService.userKey$.calledWith(userId).mockReturnValue(userKey$);
const subject = new BehaviorSubject<Rc<BitwardenClient> | undefined>(undefined);
service.userClient$(userId).subscribe(subject);
await new Promise(process.nextTick);
userKey$.next(undefined);
await new Promise(process.nextTick);
expect(mockClient.free).toHaveBeenCalledTimes(1);
expect(subject.value).toBe(undefined);
});
}); });
it("does not create an SDK client when called the second time with same userId", async () => { describe("given overrides are used", () => {
const subject_1 = new BehaviorSubject<Rc<BitwardenClient> | undefined>(undefined); it("does not create a new client and emits the override client when a client override has already been set ", async () => {
const subject_2 = new BehaviorSubject<Rc<BitwardenClient> | undefined>(undefined); const mockClient = mock<BitwardenClient>();
service.setClient(userId, mockClient);
const userClientTracker = new ObservableTracker(service.userClient$(userId), false);
await userClientTracker.pauseUntilReceived(1);
// Use subjects to ensure the subscription is kept alive expect(sdkClientFactory.createSdkClient).not.toHaveBeenCalled();
service.userClient$(userId).subscribe(subject_1); expect(userClientTracker.emissions[0].take().value).toBe(mockClient);
service.userClient$(userId).subscribe(subject_2); });
// Wait for the next tick to ensure all async operations are done it("emits the internal client then switches to override when an override is set", async () => {
await new Promise(process.nextTick); const mockInternalClient = createMockClient();
const mockOverrideClient = createMockClient();
sdkClientFactory.createSdkClient.mockResolvedValue(mockInternalClient);
const userClientTracker = new ObservableTracker(service.userClient$(userId), false);
expect(subject_1.value.take().value).toBe(mockClient); await userClientTracker.pauseUntilReceived(1);
expect(subject_2.value.take().value).toBe(mockClient); expect(userClientTracker.emissions[0].take().value).toBe(mockInternalClient);
expect(sdkClientFactory.createSdkClient).toHaveBeenCalledTimes(1);
});
it("destroys the SDK client when all subscriptions are closed", async () => { service.setClient(userId, mockOverrideClient);
const subject_1 = new BehaviorSubject<Rc<BitwardenClient> | undefined>(undefined);
const subject_2 = new BehaviorSubject<Rc<BitwardenClient> | undefined>(undefined);
const subscription_1 = service.userClient$(userId).subscribe(subject_1);
const subscription_2 = service.userClient$(userId).subscribe(subject_2);
await new Promise(process.nextTick);
subscription_1.unsubscribe(); await userClientTracker.pauseUntilReceived(2);
subscription_2.unsubscribe(); expect(userClientTracker.emissions[1].take().value).toBe(mockOverrideClient);
});
await new Promise(process.nextTick); it("throws error when the client has explicitly been set as undefined", async () => {
expect(mockClient.free).toHaveBeenCalledTimes(1); service.setClient(userId, undefined);
});
it("destroys the SDK client when the userKey is unset (i.e. lock or logout)", async () => { const result = () => firstValueFrom(service.userClient$(userId));
const userKey$ = new BehaviorSubject(new SymmetricCryptoKey(new Uint8Array(64)) as UserKey);
keyService.userKey$.calledWith(userId).mockReturnValue(userKey$);
const subject = new BehaviorSubject<Rc<BitwardenClient> | undefined>(undefined); await expect(result).rejects.toThrow(UserNotLoggedInError);
service.userClient$(userId).subscribe(subject); });
await new Promise(process.nextTick);
userKey$.next(undefined); it("completes the subscription when the override is set to undefined after having been defined", async () => {
await new Promise(process.nextTick); const mockOverrideClient = createMockClient();
service.setClient(userId, mockOverrideClient);
const userClientTracker = new ObservableTracker(service.userClient$(userId), false);
await userClientTracker.pauseUntilReceived(1);
expect(mockClient.free).toHaveBeenCalledTimes(1); service.setClient(userId, undefined);
expect(subject.value).toBe(undefined);
await userClientTracker.expectCompletion();
});
it("destroys the internal client when an override is set", async () => {
const mockInternalClient = createMockClient();
const mockOverrideClient = createMockClient();
sdkClientFactory.createSdkClient.mockResolvedValue(mockInternalClient);
const userClientTracker = new ObservableTracker(service.userClient$(userId), false);
await userClientTracker.pauseUntilReceived(1);
service.setClient(userId, mockOverrideClient);
await userClientTracker.pauseUntilReceived(2);
expect(mockInternalClient.free).toHaveBeenCalled();
});
it("destroys the override client when explicitly setting the client to undefined", async () => {
const mockOverrideClient = createMockClient();
service.setClient(userId, mockOverrideClient);
const userClientTracker = new ObservableTracker(service.userClient$(userId), false);
await userClientTracker.pauseUntilReceived(1);
service.setClient(userId, undefined);
await userClientTracker.expectCompletion();
expect(mockOverrideClient.free).toHaveBeenCalled();
});
}); });
}); });
}); });
}); });
function createMockClient(): MockProxy<BitwardenClient> {
const client = mock<BitwardenClient>();
client.crypto.mockReturnValue(mock());
return client;
}

View File

@@ -1,5 +1,3 @@
// FIXME: Update this file to be type safe and remove this and next line
// @ts-strict-ignore
import { import {
combineLatest, combineLatest,
concatMap, concatMap,
@@ -10,6 +8,10 @@ import {
tap, tap,
switchMap, switchMap,
catchError, catchError,
BehaviorSubject,
of,
takeWhile,
throwIfEmpty,
} from "rxjs"; } from "rxjs";
import { KeyService, KdfConfigService, KdfConfig, KdfType } from "@bitwarden/key-management"; import { KeyService, KdfConfigService, KdfConfig, KdfType } from "@bitwarden/key-management";
@@ -28,12 +30,19 @@ import { UserKey } from "../../../types/key";
import { Environment, EnvironmentService } from "../../abstractions/environment.service"; import { Environment, EnvironmentService } from "../../abstractions/environment.service";
import { PlatformUtilsService } from "../../abstractions/platform-utils.service"; import { PlatformUtilsService } from "../../abstractions/platform-utils.service";
import { SdkClientFactory } from "../../abstractions/sdk/sdk-client-factory"; import { SdkClientFactory } from "../../abstractions/sdk/sdk-client-factory";
import { SdkService } from "../../abstractions/sdk/sdk.service"; import { SdkService, UserNotLoggedInError } from "../../abstractions/sdk/sdk.service";
import { compareValues } from "../../misc/compare-values"; import { compareValues } from "../../misc/compare-values";
import { Rc } from "../../misc/reference-counting/rc"; import { Rc } from "../../misc/reference-counting/rc";
import { EncryptedString } from "../../models/domain/enc-string"; import { EncryptedString } from "../../models/domain/enc-string";
// A symbol that represents an overriden client that is explicitly set to undefined,
// blocking the creation of an internal client for that user.
const UnsetClient = Symbol("UnsetClient");
export class DefaultSdkService implements SdkService { export class DefaultSdkService implements SdkService {
private sdkClientOverrides = new BehaviorSubject<{
[userId: UserId]: Rc<BitwardenClient> | typeof UnsetClient;
}>({});
private sdkClientCache = new Map<UserId, Observable<Rc<BitwardenClient>>>(); private sdkClientCache = new Map<UserId, Observable<Rc<BitwardenClient>>>();
client$ = this.environmentService.environment$.pipe( client$ = this.environmentService.environment$.pipe(
@@ -56,13 +65,54 @@ export class DefaultSdkService implements SdkService {
private accountService: AccountService, private accountService: AccountService,
private kdfConfigService: KdfConfigService, private kdfConfigService: KdfConfigService,
private keyService: KeyService, private keyService: KeyService,
private userAgent: string = null, private userAgent: string | null = null,
) {} ) {}
userClient$(userId: UserId): Observable<Rc<BitwardenClient> | undefined> { userClient$(userId: UserId): Observable<Rc<BitwardenClient> | undefined> {
// TODO: Figure out what happens when the user logs out return this.sdkClientOverrides.pipe(
if (this.sdkClientCache.has(userId)) { takeWhile((clients) => clients[userId] !== UnsetClient, false),
return this.sdkClientCache.get(userId); map((clients) => {
if (clients[userId] === UnsetClient) {
throw new Error("Encountered UnsetClient even though it should have been filtered out");
}
return clients[userId] as Rc<BitwardenClient>;
}),
distinctUntilChanged(),
switchMap((clientOverride) => {
if (clientOverride) {
return of(clientOverride);
}
return this.internalClient$(userId);
}),
throwIfEmpty(() => new UserNotLoggedInError(userId)),
);
}
setClient(userId: UserId, client: BitwardenClient | undefined) {
const previousValue = this.sdkClientOverrides.value[userId];
this.sdkClientOverrides.next({
...this.sdkClientOverrides.value,
[userId]: client ? new Rc(client) : UnsetClient,
});
if (previousValue !== UnsetClient && previousValue !== undefined) {
previousValue.markForDisposal();
}
}
/**
* This method is used to create a client for a specific user by using the existing state of the application.
* This methods is a fallback for when no client has been provided by Auth. As Auth starts implementing the
* client creation, this method will be deprecated.
* @param userId The user id for which to create the client
* @returns An observable that emits the client for the user
*/
private internalClient$(userId: UserId): Observable<Rc<BitwardenClient> | undefined> {
const cached = this.sdkClientCache.get(userId);
if (cached !== undefined) {
return cached;
} }
const account$ = this.accountService.accounts$.pipe( const account$ = this.accountService.accounts$.pipe(
@@ -91,7 +141,7 @@ export class DefaultSdkService implements SdkService {
// Create our own observable to be able to implement clean-up logic // Create our own observable to be able to implement clean-up logic
return new Observable<Rc<BitwardenClient>>((subscriber) => { return new Observable<Rc<BitwardenClient>>((subscriber) => {
const createAndInitializeClient = async () => { const createAndInitializeClient = async () => {
if (privateKey == null || userKey == null) { if (env == null || kdfParams == null || privateKey == null || userKey == null) {
return undefined; return undefined;
} }
@@ -103,10 +153,11 @@ export class DefaultSdkService implements SdkService {
return client; return client;
}; };
let client: Rc<BitwardenClient>; let client: Rc<BitwardenClient> | undefined;
createAndInitializeClient() createAndInitializeClient()
.then((c) => { .then((c) => {
client = c === undefined ? undefined : new Rc(c); client = c === undefined ? undefined : new Rc(c);
subscriber.next(client); subscriber.next(client);
}) })
.catch((e) => { .catch((e) => {