1
0
mirror of https://github.com/bitwarden/browser synced 2025-12-11 05:43:41 +00:00

[PM-17408] Create new method on sdk service to allow explicit addition of a new client instance (#13309)

* feat: allow the user client to be overriden by an external provider

* feat: add ability to unset client

* feat: add `setClient` to interface (and add some docs)

* fix: re-add undefined

* fix: strict typing issues
This commit is contained in:
Andreas Coroiu
2025-02-24 11:29:47 +01:00
committed by GitHub
parent 5f390e6151
commit a9862d2a19
4 changed files with 265 additions and 74 deletions

View File

@@ -1,13 +1,33 @@
// FIXME: Update this file to be type safe and remove this and next line
// @ts-strict-ignore
import { firstValueFrom, Observable, Subject, Subscription, throwError, timeout } from "rxjs";
import {
filter,
firstValueFrom,
lastValueFrom,
Observable,
Subject,
Subscription,
throwError,
timeout,
} from "rxjs";
/** Test class to enable async awaiting of observable emissions */
export class ObservableTracker<T> {
private subscription: Subscription;
private emissionReceived = new Subject<T>();
emissions: T[] = [];
constructor(observable: Observable<T>) {
/**
* Creates a new ObservableTracker and instantly subscribes to the given observable.
* @param observable The observable to track
* @param clone Whether to clone tracked emissions or not, defaults to true.
* Cloning can be necessary if the observable emits objects that are mutated after emission. Cloning makes it
* harder to compare the original and the tracked emission using reference equality (e.g. `expect().toBe()`).
*/
constructor(
observable: Observable<T>,
private clone = true,
) {
this.emissions = this.trackEmissions(observable);
}
@@ -33,6 +53,19 @@ export class ObservableTracker<T> {
);
}
async expectCompletion(msTimeout = 50): Promise<void> {
return await lastValueFrom(
this.emissionReceived.pipe(
filter(() => false),
timeout({
first: msTimeout,
with: () => throwError(() => new Error("Timeout exceeded waiting for completion.")),
}),
),
{ defaultValue: undefined },
);
}
/** Awaits until the total number of emissions observed by this tracker equals or exceeds {@link count}
* @param count The number of emissions to wait for
*/
@@ -48,7 +81,8 @@ export class ObservableTracker<T> {
this.emissionReceived.subscribe((value) => {
emissions.push(value);
});
this.subscription = observable.subscribe((value) => {
this.subscription = observable.subscribe({
next: (value) => {
if (value == null) {
this.emissionReceived.next(null);
return;
@@ -65,9 +99,13 @@ export class ObservableTracker<T> {
this.emissionReceived.next(value as T);
break;
default: {
this.emissionReceived.next(clone(value));
this.emissionReceived.next(this.clone ? clone(value) : value);
}
}
},
complete: () => {
this.emissionReceived.complete();
},
});
return emissions;

View File

@@ -5,6 +5,12 @@ import { BitwardenClient } from "@bitwarden/sdk-internal";
import { UserId } from "../../../types/guid";
import { Rc } from "../../misc/reference-counting/rc";
export class UserNotLoggedInError extends Error {
constructor(userId: UserId) {
super(`User (${userId}) is not logged in`);
}
}
export abstract class SdkService {
/**
* Retrieve the version of the SDK.
@@ -26,7 +32,20 @@ export abstract class SdkService {
* The client will be destroyed when the observable is no longer subscribed to.
* Please let platform know if you need a client that is not destroyed when the observable is no longer subscribed to.
*
* @param userId
* @param userId The user id for which to retrieve the client
*
* @throws {UserNotLoggedInError} If the user is not logged in
*/
abstract userClient$(userId: UserId): Observable<Rc<BitwardenClient> | undefined>;
/**
* This method is used during/after an authentication procedure to set a new client for a specific user.
* It can also be used to unset the client when a user logs out, this will result in:
* - The client being disposed of
* - All subscriptions to the client being completed
* - Any new subscribers receiving an error
* @param userId The user id for which to set the client
* @param client The client to set for the user. If undefined, the client will be unset.
*/
abstract setClient(userId: UserId, client: BitwardenClient | undefined): void;
}

View File

@@ -4,12 +4,14 @@ import { BehaviorSubject, firstValueFrom, of } from "rxjs";
import { KdfConfigService, KeyService, PBKDF2KdfConfig } from "@bitwarden/key-management";
import { BitwardenClient } from "@bitwarden/sdk-internal";
import { ObservableTracker } from "../../../../spec";
import { AccountInfo, AccountService } from "../../../auth/abstractions/account.service";
import { UserId } from "../../../types/guid";
import { UserKey } from "../../../types/key";
import { Environment, EnvironmentService } from "../../abstractions/environment.service";
import { PlatformUtilsService } from "../../abstractions/platform-utils.service";
import { SdkClientFactory } from "../../abstractions/sdk/sdk-client-factory";
import { UserNotLoggedInError } from "../../abstractions/sdk/sdk.service";
import { Rc } from "../../misc/reference-counting/rc";
import { EncryptedString } from "../../models/domain/enc-string";
import { SymmetricCryptoKey } from "../../models/domain/symmetric-crypto-key";
@@ -26,8 +28,6 @@ describe("DefaultSdkService", () => {
let keyService!: MockProxy<KeyService>;
let service!: DefaultSdkService;
let mockClient!: MockProxy<BitwardenClient>;
beforeEach(() => {
sdkClientFactory = mock<SdkClientFactory>();
environmentService = mock<EnvironmentService>();
@@ -47,15 +47,10 @@ describe("DefaultSdkService", () => {
kdfConfigService,
keyService,
);
mockClient = mock<BitwardenClient>();
mockClient.crypto.mockReturnValue(mock());
sdkClientFactory.createSdkClient.mockResolvedValue(mockClient);
});
describe("given the user is logged in", () => {
const userId = "user-id" as UserId;
beforeEach(() => {
environmentService.getEnvironment$
.calledWith(userId)
@@ -75,7 +70,15 @@ describe("DefaultSdkService", () => {
keyService.encryptedOrgKeys$.calledWith(userId).mockReturnValue(of({}));
});
it("creates an SDK client when called the first time", async () => {
describe("given no client override has been set for the user", () => {
let mockClient!: MockProxy<BitwardenClient>;
beforeEach(() => {
mockClient = createMockClient();
sdkClientFactory.createSdkClient.mockResolvedValue(mockClient);
});
it("creates an internal SDK client when called the first time", async () => {
await firstValueFrom(service.userClient$(userId));
expect(sdkClientFactory.createSdkClient).toHaveBeenCalled();
@@ -97,7 +100,7 @@ describe("DefaultSdkService", () => {
expect(sdkClientFactory.createSdkClient).toHaveBeenCalledTimes(1);
});
it("destroys the SDK client when all subscriptions are closed", async () => {
it("destroys the internal SDK client when all subscriptions are closed", async () => {
const subject_1 = new BehaviorSubject<Rc<BitwardenClient> | undefined>(undefined);
const subject_2 = new BehaviorSubject<Rc<BitwardenClient> | undefined>(undefined);
const subscription_1 = service.userClient$(userId).subscribe(subject_1);
@@ -111,8 +114,10 @@ describe("DefaultSdkService", () => {
expect(mockClient.free).toHaveBeenCalledTimes(1);
});
it("destroys the SDK client when the userKey is unset (i.e. lock or logout)", async () => {
const userKey$ = new BehaviorSubject(new SymmetricCryptoKey(new Uint8Array(64)) as UserKey);
it("destroys the internal SDK client when the userKey is unset (i.e. lock or logout)", async () => {
const userKey$ = new BehaviorSubject(
new SymmetricCryptoKey(new Uint8Array(64)) as UserKey,
);
keyService.userKey$.calledWith(userId).mockReturnValue(userKey$);
const subject = new BehaviorSubject<Rc<BitwardenClient> | undefined>(undefined);
@@ -126,5 +131,83 @@ describe("DefaultSdkService", () => {
expect(subject.value).toBe(undefined);
});
});
describe("given overrides are used", () => {
it("does not create a new client and emits the override client when a client override has already been set ", async () => {
const mockClient = mock<BitwardenClient>();
service.setClient(userId, mockClient);
const userClientTracker = new ObservableTracker(service.userClient$(userId), false);
await userClientTracker.pauseUntilReceived(1);
expect(sdkClientFactory.createSdkClient).not.toHaveBeenCalled();
expect(userClientTracker.emissions[0].take().value).toBe(mockClient);
});
it("emits the internal client then switches to override when an override is set", async () => {
const mockInternalClient = createMockClient();
const mockOverrideClient = createMockClient();
sdkClientFactory.createSdkClient.mockResolvedValue(mockInternalClient);
const userClientTracker = new ObservableTracker(service.userClient$(userId), false);
await userClientTracker.pauseUntilReceived(1);
expect(userClientTracker.emissions[0].take().value).toBe(mockInternalClient);
service.setClient(userId, mockOverrideClient);
await userClientTracker.pauseUntilReceived(2);
expect(userClientTracker.emissions[1].take().value).toBe(mockOverrideClient);
});
it("throws error when the client has explicitly been set as undefined", async () => {
service.setClient(userId, undefined);
const result = () => firstValueFrom(service.userClient$(userId));
await expect(result).rejects.toThrow(UserNotLoggedInError);
});
it("completes the subscription when the override is set to undefined after having been defined", async () => {
const mockOverrideClient = createMockClient();
service.setClient(userId, mockOverrideClient);
const userClientTracker = new ObservableTracker(service.userClient$(userId), false);
await userClientTracker.pauseUntilReceived(1);
service.setClient(userId, undefined);
await userClientTracker.expectCompletion();
});
it("destroys the internal client when an override is set", async () => {
const mockInternalClient = createMockClient();
const mockOverrideClient = createMockClient();
sdkClientFactory.createSdkClient.mockResolvedValue(mockInternalClient);
const userClientTracker = new ObservableTracker(service.userClient$(userId), false);
await userClientTracker.pauseUntilReceived(1);
service.setClient(userId, mockOverrideClient);
await userClientTracker.pauseUntilReceived(2);
expect(mockInternalClient.free).toHaveBeenCalled();
});
it("destroys the override client when explicitly setting the client to undefined", async () => {
const mockOverrideClient = createMockClient();
service.setClient(userId, mockOverrideClient);
const userClientTracker = new ObservableTracker(service.userClient$(userId), false);
await userClientTracker.pauseUntilReceived(1);
service.setClient(userId, undefined);
await userClientTracker.expectCompletion();
expect(mockOverrideClient.free).toHaveBeenCalled();
});
});
});
});
});
function createMockClient(): MockProxy<BitwardenClient> {
const client = mock<BitwardenClient>();
client.crypto.mockReturnValue(mock());
return client;
}

View File

@@ -1,5 +1,3 @@
// FIXME: Update this file to be type safe and remove this and next line
// @ts-strict-ignore
import {
combineLatest,
concatMap,
@@ -10,6 +8,10 @@ import {
tap,
switchMap,
catchError,
BehaviorSubject,
of,
takeWhile,
throwIfEmpty,
} from "rxjs";
import { KeyService, KdfConfigService, KdfConfig, KdfType } from "@bitwarden/key-management";
@@ -28,12 +30,19 @@ import { UserKey } from "../../../types/key";
import { Environment, EnvironmentService } from "../../abstractions/environment.service";
import { PlatformUtilsService } from "../../abstractions/platform-utils.service";
import { SdkClientFactory } from "../../abstractions/sdk/sdk-client-factory";
import { SdkService } from "../../abstractions/sdk/sdk.service";
import { SdkService, UserNotLoggedInError } from "../../abstractions/sdk/sdk.service";
import { compareValues } from "../../misc/compare-values";
import { Rc } from "../../misc/reference-counting/rc";
import { EncryptedString } from "../../models/domain/enc-string";
// A symbol that represents an overriden client that is explicitly set to undefined,
// blocking the creation of an internal client for that user.
const UnsetClient = Symbol("UnsetClient");
export class DefaultSdkService implements SdkService {
private sdkClientOverrides = new BehaviorSubject<{
[userId: UserId]: Rc<BitwardenClient> | typeof UnsetClient;
}>({});
private sdkClientCache = new Map<UserId, Observable<Rc<BitwardenClient>>>();
client$ = this.environmentService.environment$.pipe(
@@ -56,13 +65,54 @@ export class DefaultSdkService implements SdkService {
private accountService: AccountService,
private kdfConfigService: KdfConfigService,
private keyService: KeyService,
private userAgent: string = null,
private userAgent: string | null = null,
) {}
userClient$(userId: UserId): Observable<Rc<BitwardenClient> | undefined> {
// TODO: Figure out what happens when the user logs out
if (this.sdkClientCache.has(userId)) {
return this.sdkClientCache.get(userId);
return this.sdkClientOverrides.pipe(
takeWhile((clients) => clients[userId] !== UnsetClient, false),
map((clients) => {
if (clients[userId] === UnsetClient) {
throw new Error("Encountered UnsetClient even though it should have been filtered out");
}
return clients[userId] as Rc<BitwardenClient>;
}),
distinctUntilChanged(),
switchMap((clientOverride) => {
if (clientOverride) {
return of(clientOverride);
}
return this.internalClient$(userId);
}),
throwIfEmpty(() => new UserNotLoggedInError(userId)),
);
}
setClient(userId: UserId, client: BitwardenClient | undefined) {
const previousValue = this.sdkClientOverrides.value[userId];
this.sdkClientOverrides.next({
...this.sdkClientOverrides.value,
[userId]: client ? new Rc(client) : UnsetClient,
});
if (previousValue !== UnsetClient && previousValue !== undefined) {
previousValue.markForDisposal();
}
}
/**
* This method is used to create a client for a specific user by using the existing state of the application.
* This methods is a fallback for when no client has been provided by Auth. As Auth starts implementing the
* client creation, this method will be deprecated.
* @param userId The user id for which to create the client
* @returns An observable that emits the client for the user
*/
private internalClient$(userId: UserId): Observable<Rc<BitwardenClient> | undefined> {
const cached = this.sdkClientCache.get(userId);
if (cached !== undefined) {
return cached;
}
const account$ = this.accountService.accounts$.pipe(
@@ -91,7 +141,7 @@ export class DefaultSdkService implements SdkService {
// Create our own observable to be able to implement clean-up logic
return new Observable<Rc<BitwardenClient>>((subscriber) => {
const createAndInitializeClient = async () => {
if (privateKey == null || userKey == null) {
if (env == null || kdfParams == null || privateKey == null || userKey == null) {
return undefined;
}
@@ -103,10 +153,11 @@ export class DefaultSdkService implements SdkService {
return client;
};
let client: Rc<BitwardenClient>;
let client: Rc<BitwardenClient> | undefined;
createAndInitializeClient()
.then((c) => {
client = c === undefined ? undefined : new Rc(c);
subscriber.next(client);
})
.catch((e) => {