diff --git a/libs/common/spec/observable-tracker.ts b/libs/common/spec/observable-tracker.ts index 0ec4aa812f8..adefabb7472 100644 --- a/libs/common/spec/observable-tracker.ts +++ b/libs/common/spec/observable-tracker.ts @@ -1,13 +1,33 @@ // FIXME: Update this file to be type safe and remove this and next line // @ts-strict-ignore -import { firstValueFrom, Observable, Subject, Subscription, throwError, timeout } from "rxjs"; +import { + filter, + firstValueFrom, + lastValueFrom, + Observable, + Subject, + Subscription, + throwError, + timeout, +} from "rxjs"; /** Test class to enable async awaiting of observable emissions */ export class ObservableTracker { private subscription: Subscription; private emissionReceived = new Subject(); emissions: T[] = []; - constructor(observable: Observable) { + + /** + * Creates a new ObservableTracker and instantly subscribes to the given observable. + * @param observable The observable to track + * @param clone Whether to clone tracked emissions or not, defaults to true. + * Cloning can be necessary if the observable emits objects that are mutated after emission. Cloning makes it + * harder to compare the original and the tracked emission using reference equality (e.g. `expect().toBe()`). + */ + constructor( + observable: Observable, + private clone = true, + ) { this.emissions = this.trackEmissions(observable); } @@ -33,6 +53,19 @@ export class ObservableTracker { ); } + async expectCompletion(msTimeout = 50): Promise { + return await lastValueFrom( + this.emissionReceived.pipe( + filter(() => false), + timeout({ + first: msTimeout, + with: () => throwError(() => new Error("Timeout exceeded waiting for completion.")), + }), + ), + { defaultValue: undefined }, + ); + } + /** Awaits until the total number of emissions observed by this tracker equals or exceeds {@link count} * @param count The number of emissions to wait for */ @@ -48,26 +81,31 @@ export class ObservableTracker { this.emissionReceived.subscribe((value) => { emissions.push(value); }); - this.subscription = observable.subscribe((value) => { - if (value == null) { - this.emissionReceived.next(null); - return; - } - - switch (typeof value) { - case "string": - case "number": - case "boolean": - this.emissionReceived.next(value); - break; - case "symbol": - // Cheating types to make symbols work at all - this.emissionReceived.next(value as T); - break; - default: { - this.emissionReceived.next(clone(value)); + this.subscription = observable.subscribe({ + next: (value) => { + if (value == null) { + this.emissionReceived.next(null); + return; } - } + + switch (typeof value) { + case "string": + case "number": + case "boolean": + this.emissionReceived.next(value); + break; + case "symbol": + // Cheating types to make symbols work at all + this.emissionReceived.next(value as T); + break; + default: { + this.emissionReceived.next(this.clone ? clone(value) : value); + } + } + }, + complete: () => { + this.emissionReceived.complete(); + }, }); return emissions; diff --git a/libs/common/src/platform/abstractions/sdk/sdk.service.ts b/libs/common/src/platform/abstractions/sdk/sdk.service.ts index 22ad2b44ff9..3adf3291bbf 100644 --- a/libs/common/src/platform/abstractions/sdk/sdk.service.ts +++ b/libs/common/src/platform/abstractions/sdk/sdk.service.ts @@ -5,6 +5,12 @@ import { BitwardenClient } from "@bitwarden/sdk-internal"; import { UserId } from "../../../types/guid"; import { Rc } from "../../misc/reference-counting/rc"; +export class UserNotLoggedInError extends Error { + constructor(userId: UserId) { + super(`User (${userId}) is not logged in`); + } +} + export abstract class SdkService { /** * Retrieve the version of the SDK. @@ -26,7 +32,20 @@ export abstract class SdkService { * The client will be destroyed when the observable is no longer subscribed to. * Please let platform know if you need a client that is not destroyed when the observable is no longer subscribed to. * - * @param userId + * @param userId The user id for which to retrieve the client + * + * @throws {UserNotLoggedInError} If the user is not logged in */ abstract userClient$(userId: UserId): Observable | undefined>; + + /** + * This method is used during/after an authentication procedure to set a new client for a specific user. + * It can also be used to unset the client when a user logs out, this will result in: + * - The client being disposed of + * - All subscriptions to the client being completed + * - Any new subscribers receiving an error + * @param userId The user id for which to set the client + * @param client The client to set for the user. If undefined, the client will be unset. + */ + abstract setClient(userId: UserId, client: BitwardenClient | undefined): void; } diff --git a/libs/common/src/platform/services/sdk/default-sdk.service.spec.ts b/libs/common/src/platform/services/sdk/default-sdk.service.spec.ts index e8dfde863ec..fed4746acd3 100644 --- a/libs/common/src/platform/services/sdk/default-sdk.service.spec.ts +++ b/libs/common/src/platform/services/sdk/default-sdk.service.spec.ts @@ -4,12 +4,14 @@ import { BehaviorSubject, firstValueFrom, of } from "rxjs"; import { KdfConfigService, KeyService, PBKDF2KdfConfig } from "@bitwarden/key-management"; import { BitwardenClient } from "@bitwarden/sdk-internal"; +import { ObservableTracker } from "../../../../spec"; import { AccountInfo, AccountService } from "../../../auth/abstractions/account.service"; import { UserId } from "../../../types/guid"; import { UserKey } from "../../../types/key"; import { Environment, EnvironmentService } from "../../abstractions/environment.service"; import { PlatformUtilsService } from "../../abstractions/platform-utils.service"; import { SdkClientFactory } from "../../abstractions/sdk/sdk-client-factory"; +import { UserNotLoggedInError } from "../../abstractions/sdk/sdk.service"; import { Rc } from "../../misc/reference-counting/rc"; import { EncryptedString } from "../../models/domain/enc-string"; import { SymmetricCryptoKey } from "../../models/domain/symmetric-crypto-key"; @@ -26,8 +28,6 @@ describe("DefaultSdkService", () => { let keyService!: MockProxy; let service!: DefaultSdkService; - let mockClient!: MockProxy; - beforeEach(() => { sdkClientFactory = mock(); environmentService = mock(); @@ -47,15 +47,10 @@ describe("DefaultSdkService", () => { kdfConfigService, keyService, ); - - mockClient = mock(); - mockClient.crypto.mockReturnValue(mock()); - sdkClientFactory.createSdkClient.mockResolvedValue(mockClient); }); describe("given the user is logged in", () => { const userId = "user-id" as UserId; - beforeEach(() => { environmentService.getEnvironment$ .calledWith(userId) @@ -75,56 +70,144 @@ describe("DefaultSdkService", () => { keyService.encryptedOrgKeys$.calledWith(userId).mockReturnValue(of({})); }); - it("creates an SDK client when called the first time", async () => { - await firstValueFrom(service.userClient$(userId)); + describe("given no client override has been set for the user", () => { + let mockClient!: MockProxy; - expect(sdkClientFactory.createSdkClient).toHaveBeenCalled(); + beforeEach(() => { + mockClient = createMockClient(); + sdkClientFactory.createSdkClient.mockResolvedValue(mockClient); + }); + + it("creates an internal SDK client when called the first time", async () => { + await firstValueFrom(service.userClient$(userId)); + + expect(sdkClientFactory.createSdkClient).toHaveBeenCalled(); + }); + + it("does not create an SDK client when called the second time with same userId", async () => { + const subject_1 = new BehaviorSubject | undefined>(undefined); + const subject_2 = new BehaviorSubject | undefined>(undefined); + + // Use subjects to ensure the subscription is kept alive + service.userClient$(userId).subscribe(subject_1); + service.userClient$(userId).subscribe(subject_2); + + // Wait for the next tick to ensure all async operations are done + await new Promise(process.nextTick); + + expect(subject_1.value.take().value).toBe(mockClient); + expect(subject_2.value.take().value).toBe(mockClient); + expect(sdkClientFactory.createSdkClient).toHaveBeenCalledTimes(1); + }); + + it("destroys the internal SDK client when all subscriptions are closed", async () => { + const subject_1 = new BehaviorSubject | undefined>(undefined); + const subject_2 = new BehaviorSubject | undefined>(undefined); + const subscription_1 = service.userClient$(userId).subscribe(subject_1); + const subscription_2 = service.userClient$(userId).subscribe(subject_2); + await new Promise(process.nextTick); + + subscription_1.unsubscribe(); + subscription_2.unsubscribe(); + + await new Promise(process.nextTick); + expect(mockClient.free).toHaveBeenCalledTimes(1); + }); + + it("destroys the internal SDK client when the userKey is unset (i.e. lock or logout)", async () => { + const userKey$ = new BehaviorSubject( + new SymmetricCryptoKey(new Uint8Array(64)) as UserKey, + ); + keyService.userKey$.calledWith(userId).mockReturnValue(userKey$); + + const subject = new BehaviorSubject | undefined>(undefined); + service.userClient$(userId).subscribe(subject); + await new Promise(process.nextTick); + + userKey$.next(undefined); + await new Promise(process.nextTick); + + expect(mockClient.free).toHaveBeenCalledTimes(1); + expect(subject.value).toBe(undefined); + }); }); - it("does not create an SDK client when called the second time with same userId", async () => { - const subject_1 = new BehaviorSubject | undefined>(undefined); - const subject_2 = new BehaviorSubject | undefined>(undefined); + describe("given overrides are used", () => { + it("does not create a new client and emits the override client when a client override has already been set ", async () => { + const mockClient = mock(); + service.setClient(userId, mockClient); + const userClientTracker = new ObservableTracker(service.userClient$(userId), false); + await userClientTracker.pauseUntilReceived(1); - // Use subjects to ensure the subscription is kept alive - service.userClient$(userId).subscribe(subject_1); - service.userClient$(userId).subscribe(subject_2); + expect(sdkClientFactory.createSdkClient).not.toHaveBeenCalled(); + expect(userClientTracker.emissions[0].take().value).toBe(mockClient); + }); - // Wait for the next tick to ensure all async operations are done - await new Promise(process.nextTick); + it("emits the internal client then switches to override when an override is set", async () => { + const mockInternalClient = createMockClient(); + const mockOverrideClient = createMockClient(); + sdkClientFactory.createSdkClient.mockResolvedValue(mockInternalClient); + const userClientTracker = new ObservableTracker(service.userClient$(userId), false); - expect(subject_1.value.take().value).toBe(mockClient); - expect(subject_2.value.take().value).toBe(mockClient); - expect(sdkClientFactory.createSdkClient).toHaveBeenCalledTimes(1); - }); + await userClientTracker.pauseUntilReceived(1); + expect(userClientTracker.emissions[0].take().value).toBe(mockInternalClient); - it("destroys the SDK client when all subscriptions are closed", async () => { - const subject_1 = new BehaviorSubject | undefined>(undefined); - const subject_2 = new BehaviorSubject | undefined>(undefined); - const subscription_1 = service.userClient$(userId).subscribe(subject_1); - const subscription_2 = service.userClient$(userId).subscribe(subject_2); - await new Promise(process.nextTick); + service.setClient(userId, mockOverrideClient); - subscription_1.unsubscribe(); - subscription_2.unsubscribe(); + await userClientTracker.pauseUntilReceived(2); + expect(userClientTracker.emissions[1].take().value).toBe(mockOverrideClient); + }); - await new Promise(process.nextTick); - expect(mockClient.free).toHaveBeenCalledTimes(1); - }); + it("throws error when the client has explicitly been set as undefined", async () => { + service.setClient(userId, undefined); - it("destroys the SDK client when the userKey is unset (i.e. lock or logout)", async () => { - const userKey$ = new BehaviorSubject(new SymmetricCryptoKey(new Uint8Array(64)) as UserKey); - keyService.userKey$.calledWith(userId).mockReturnValue(userKey$); + const result = () => firstValueFrom(service.userClient$(userId)); - const subject = new BehaviorSubject | undefined>(undefined); - service.userClient$(userId).subscribe(subject); - await new Promise(process.nextTick); + await expect(result).rejects.toThrow(UserNotLoggedInError); + }); - userKey$.next(undefined); - await new Promise(process.nextTick); + it("completes the subscription when the override is set to undefined after having been defined", async () => { + const mockOverrideClient = createMockClient(); + service.setClient(userId, mockOverrideClient); + const userClientTracker = new ObservableTracker(service.userClient$(userId), false); + await userClientTracker.pauseUntilReceived(1); - expect(mockClient.free).toHaveBeenCalledTimes(1); - expect(subject.value).toBe(undefined); + service.setClient(userId, undefined); + + await userClientTracker.expectCompletion(); + }); + + it("destroys the internal client when an override is set", async () => { + const mockInternalClient = createMockClient(); + const mockOverrideClient = createMockClient(); + sdkClientFactory.createSdkClient.mockResolvedValue(mockInternalClient); + const userClientTracker = new ObservableTracker(service.userClient$(userId), false); + + await userClientTracker.pauseUntilReceived(1); + service.setClient(userId, mockOverrideClient); + await userClientTracker.pauseUntilReceived(2); + + expect(mockInternalClient.free).toHaveBeenCalled(); + }); + + it("destroys the override client when explicitly setting the client to undefined", async () => { + const mockOverrideClient = createMockClient(); + service.setClient(userId, mockOverrideClient); + const userClientTracker = new ObservableTracker(service.userClient$(userId), false); + await userClientTracker.pauseUntilReceived(1); + + service.setClient(userId, undefined); + await userClientTracker.expectCompletion(); + + expect(mockOverrideClient.free).toHaveBeenCalled(); + }); }); }); }); }); + +function createMockClient(): MockProxy { + const client = mock(); + client.crypto.mockReturnValue(mock()); + return client; +} diff --git a/libs/common/src/platform/services/sdk/default-sdk.service.ts b/libs/common/src/platform/services/sdk/default-sdk.service.ts index 516334c7fb4..7e373faed48 100644 --- a/libs/common/src/platform/services/sdk/default-sdk.service.ts +++ b/libs/common/src/platform/services/sdk/default-sdk.service.ts @@ -1,5 +1,3 @@ -// FIXME: Update this file to be type safe and remove this and next line -// @ts-strict-ignore import { combineLatest, concatMap, @@ -10,6 +8,10 @@ import { tap, switchMap, catchError, + BehaviorSubject, + of, + takeWhile, + throwIfEmpty, } from "rxjs"; import { KeyService, KdfConfigService, KdfConfig, KdfType } from "@bitwarden/key-management"; @@ -28,12 +30,19 @@ import { UserKey } from "../../../types/key"; import { Environment, EnvironmentService } from "../../abstractions/environment.service"; import { PlatformUtilsService } from "../../abstractions/platform-utils.service"; import { SdkClientFactory } from "../../abstractions/sdk/sdk-client-factory"; -import { SdkService } from "../../abstractions/sdk/sdk.service"; +import { SdkService, UserNotLoggedInError } from "../../abstractions/sdk/sdk.service"; import { compareValues } from "../../misc/compare-values"; import { Rc } from "../../misc/reference-counting/rc"; import { EncryptedString } from "../../models/domain/enc-string"; +// A symbol that represents an overriden client that is explicitly set to undefined, +// blocking the creation of an internal client for that user. +const UnsetClient = Symbol("UnsetClient"); + export class DefaultSdkService implements SdkService { + private sdkClientOverrides = new BehaviorSubject<{ + [userId: UserId]: Rc | typeof UnsetClient; + }>({}); private sdkClientCache = new Map>>(); client$ = this.environmentService.environment$.pipe( @@ -56,13 +65,54 @@ export class DefaultSdkService implements SdkService { private accountService: AccountService, private kdfConfigService: KdfConfigService, private keyService: KeyService, - private userAgent: string = null, + private userAgent: string | null = null, ) {} userClient$(userId: UserId): Observable | undefined> { - // TODO: Figure out what happens when the user logs out - if (this.sdkClientCache.has(userId)) { - return this.sdkClientCache.get(userId); + return this.sdkClientOverrides.pipe( + takeWhile((clients) => clients[userId] !== UnsetClient, false), + map((clients) => { + if (clients[userId] === UnsetClient) { + throw new Error("Encountered UnsetClient even though it should have been filtered out"); + } + return clients[userId] as Rc; + }), + distinctUntilChanged(), + switchMap((clientOverride) => { + if (clientOverride) { + return of(clientOverride); + } + + return this.internalClient$(userId); + }), + throwIfEmpty(() => new UserNotLoggedInError(userId)), + ); + } + + setClient(userId: UserId, client: BitwardenClient | undefined) { + const previousValue = this.sdkClientOverrides.value[userId]; + + this.sdkClientOverrides.next({ + ...this.sdkClientOverrides.value, + [userId]: client ? new Rc(client) : UnsetClient, + }); + + if (previousValue !== UnsetClient && previousValue !== undefined) { + previousValue.markForDisposal(); + } + } + + /** + * This method is used to create a client for a specific user by using the existing state of the application. + * This methods is a fallback for when no client has been provided by Auth. As Auth starts implementing the + * client creation, this method will be deprecated. + * @param userId The user id for which to create the client + * @returns An observable that emits the client for the user + */ + private internalClient$(userId: UserId): Observable | undefined> { + const cached = this.sdkClientCache.get(userId); + if (cached !== undefined) { + return cached; } const account$ = this.accountService.accounts$.pipe( @@ -91,7 +141,7 @@ export class DefaultSdkService implements SdkService { // Create our own observable to be able to implement clean-up logic return new Observable>((subscriber) => { const createAndInitializeClient = async () => { - if (privateKey == null || userKey == null) { + if (env == null || kdfParams == null || privateKey == null || userKey == null) { return undefined; } @@ -103,10 +153,11 @@ export class DefaultSdkService implements SdkService { return client; }; - let client: Rc; + let client: Rc | undefined; createAndInitializeClient() .then((c) => { client = c === undefined ? undefined : new Rc(c); + subscriber.next(client); }) .catch((e) => {