1
0
mirror of https://github.com/bitwarden/browser synced 2025-12-16 08:13:42 +00:00

[PM-11764] Implement account switching and sdk initialization (#11472)

* feat: update sdk service abstraction with documentation and new `userClient$` function

* feat: add uninitialized user client with cache

* feat: initialize user crypto

* feat: initialize org keys

* fix: org crypto not initializing properly

* feat: avoid creating clients unnecessarily

* chore: remove dev print/subscription

* fix: clean up cache

* chore: update sdk version

* feat: implement clean-up logic (#11504)

* chore: bump sdk version to fix build issues

* chore: bump sdk version to fix build issues

* fix: missing constructor parameters

* refactor: simplify free() and delete() calls

* refactor: use a named function for client creation

* fix: client never freeing after refactor

* fix: broken impl and race condition in tests
This commit is contained in:
Andreas Coroiu
2024-10-18 16:15:10 +02:00
committed by GitHub
parent cdd5bd4387
commit c787ecd22c
12 changed files with 355 additions and 21 deletions

View File

@@ -0,0 +1,132 @@
import { mock, MockProxy } from "jest-mock-extended";
import { BehaviorSubject, firstValueFrom, of } from "rxjs";
import { BitwardenClient } from "@bitwarden/sdk-internal";
import { ApiService } from "../../../abstractions/api.service";
import { AccountInfo, AccountService } from "../../../auth/abstractions/account.service";
import { KdfConfigService } from "../../../auth/abstractions/kdf-config.service";
import { PBKDF2KdfConfig } from "../../../auth/models/domain/kdf-config";
import { UserId } from "../../../types/guid";
import { UserKey } from "../../../types/key";
import { CryptoService } from "../../abstractions/crypto.service";
import { Environment, EnvironmentService } from "../../abstractions/environment.service";
import { PlatformUtilsService } from "../../abstractions/platform-utils.service";
import { SdkClientFactory } from "../../abstractions/sdk/sdk-client-factory";
import { EncryptedString } from "../../models/domain/enc-string";
import { SymmetricCryptoKey } from "../../models/domain/symmetric-crypto-key";
import { DefaultSdkService } from "./default-sdk.service";
describe("DefaultSdkService", () => {
describe("userClient$", () => {
let sdkClientFactory!: MockProxy<SdkClientFactory>;
let environmentService!: MockProxy<EnvironmentService>;
let platformUtilsService!: MockProxy<PlatformUtilsService>;
let accountService!: MockProxy<AccountService>;
let kdfConfigService!: MockProxy<KdfConfigService>;
let cryptoService!: MockProxy<CryptoService>;
let apiService!: MockProxy<ApiService>;
let service!: DefaultSdkService;
let mockClient!: MockProxy<BitwardenClient>;
beforeEach(() => {
sdkClientFactory = mock<SdkClientFactory>();
environmentService = mock<EnvironmentService>();
platformUtilsService = mock<PlatformUtilsService>();
accountService = mock<AccountService>();
kdfConfigService = mock<KdfConfigService>();
cryptoService = mock<CryptoService>();
apiService = mock<ApiService>();
// Can't use `of(mock<Environment>())` for some reason
environmentService.environment$ = new BehaviorSubject(mock<Environment>());
service = new DefaultSdkService(
sdkClientFactory,
environmentService,
platformUtilsService,
accountService,
kdfConfigService,
cryptoService,
apiService,
);
mockClient = mock<BitwardenClient>();
mockClient.crypto.mockReturnValue(mock());
sdkClientFactory.createSdkClient.mockResolvedValue(mockClient);
});
describe("given the user is logged in", () => {
const userId = "user-id" as UserId;
beforeEach(() => {
accountService.accounts$ = of({
[userId]: { email: "email", emailVerified: true, name: "name" } as AccountInfo,
});
kdfConfigService.getKdfConfig$
.calledWith(userId)
.mockReturnValue(of(new PBKDF2KdfConfig()));
cryptoService.userKey$
.calledWith(userId)
.mockReturnValue(of(new SymmetricCryptoKey(new Uint8Array(64)) as UserKey));
cryptoService.userEncryptedPrivateKey$
.calledWith(userId)
.mockReturnValue(of("private-key" as EncryptedString));
cryptoService.encryptedOrgKeys$.calledWith(userId).mockReturnValue(of({}));
});
it("creates an SDK client when called the first time", async () => {
const result = await firstValueFrom(service.userClient$(userId));
expect(result).toBe(mockClient);
expect(sdkClientFactory.createSdkClient).toHaveBeenCalled();
});
it("does not create an SDK client when called the second time with same userId", async () => {
const subject_1 = new BehaviorSubject(undefined);
const subject_2 = new BehaviorSubject(undefined);
// Use subjects to ensure the subscription is kept alive
service.userClient$(userId).subscribe(subject_1);
service.userClient$(userId).subscribe(subject_2);
// Wait for the next tick to ensure all async operations are done
await new Promise(process.nextTick);
expect(subject_1.value).toBe(mockClient);
expect(subject_2.value).toBe(mockClient);
expect(sdkClientFactory.createSdkClient).toHaveBeenCalledTimes(1);
});
it("destroys the SDK client when all subscriptions are closed", async () => {
const subject_1 = new BehaviorSubject(undefined);
const subject_2 = new BehaviorSubject(undefined);
const subscription_1 = service.userClient$(userId).subscribe(subject_1);
const subscription_2 = service.userClient$(userId).subscribe(subject_2);
await new Promise(process.nextTick);
subscription_1.unsubscribe();
subscription_2.unsubscribe();
expect(mockClient.free).toHaveBeenCalledTimes(1);
});
it("destroys the SDK client when the userKey is unset (i.e. lock or logout)", async () => {
const userKey$ = new BehaviorSubject(new SymmetricCryptoKey(new Uint8Array(64)) as UserKey);
cryptoService.userKey$.calledWith(userId).mockReturnValue(userKey$);
const subject = new BehaviorSubject(undefined);
service.userClient$(userId).subscribe(subject);
await new Promise(process.nextTick);
userKey$.next(undefined);
await new Promise(process.nextTick);
expect(mockClient.free).toHaveBeenCalledTimes(1);
expect(subject.value).toBe(undefined);
});
});
});
});

View File

@@ -1,24 +1,45 @@
import { concatMap, firstValueFrom, shareReplay } from "rxjs";
import {
combineLatest,
concatMap,
firstValueFrom,
Observable,
shareReplay,
map,
distinctUntilChanged,
tap,
switchMap,
} from "rxjs";
import { LogLevel, DeviceType as SdkDeviceType } from "@bitwarden/sdk-internal";
import {
BitwardenClient,
ClientSettings,
LogLevel,
DeviceType as SdkDeviceType,
} from "@bitwarden/sdk-internal";
import { ApiService } from "../../../abstractions/api.service";
import { EncryptedOrganizationKeyData } from "../../../admin-console/models/data/encrypted-organization-key.data";
import { AccountInfo, AccountService } from "../../../auth/abstractions/account.service";
import { KdfConfigService } from "../../../auth/abstractions/kdf-config.service";
import { KdfConfig } from "../../../auth/models/domain/kdf-config";
import { DeviceType } from "../../../enums/device-type.enum";
import { EnvironmentService } from "../../abstractions/environment.service";
import { OrganizationId, UserId } from "../../../types/guid";
import { UserKey } from "../../../types/key";
import { CryptoService } from "../../abstractions/crypto.service";
import { Environment, EnvironmentService } from "../../abstractions/environment.service";
import { PlatformUtilsService } from "../../abstractions/platform-utils.service";
import { SdkClientFactory } from "../../abstractions/sdk/sdk-client-factory";
import { SdkService } from "../../abstractions/sdk/sdk.service";
import { KdfType } from "../../enums";
import { compareValues } from "../../misc/compare-values";
import { EncryptedString } from "../../models/domain/enc-string";
export class DefaultSdkService implements SdkService {
private sdkClientCache = new Map<UserId, Observable<BitwardenClient>>();
client$ = this.environmentService.environment$.pipe(
concatMap(async (env) => {
const settings = {
apiUrl: env.getApiUrl(),
identityUrl: env.getIdentityUrl(),
deviceType: this.toDevice(this.platformUtilsService.getDevice()),
userAgent: this.userAgent ?? navigator.userAgent,
};
const settings = this.toSettings(env);
return await this.sdkClientFactory.createSdkClient(settings, LogLevel.Info);
}),
shareReplay({ refCount: true, bufferSize: 1 }),
@@ -34,10 +55,81 @@ export class DefaultSdkService implements SdkService {
private sdkClientFactory: SdkClientFactory,
private environmentService: EnvironmentService,
private platformUtilsService: PlatformUtilsService,
private accountService: AccountService,
private kdfConfigService: KdfConfigService,
private cryptoService: CryptoService,
private apiService: ApiService, // Yes we shouldn't import ApiService, but it's temporary
private userAgent: string = null,
) {}
userClient$(userId: UserId): Observable<BitwardenClient | undefined> {
// TODO: Figure out what happens when the user logs out
if (this.sdkClientCache.has(userId)) {
return this.sdkClientCache.get(userId);
}
const account$ = this.accountService.accounts$.pipe(
map((accounts) => accounts[userId]),
distinctUntilChanged(),
);
const kdfParams$ = this.kdfConfigService.getKdfConfig$(userId).pipe(distinctUntilChanged());
const privateKey$ = this.cryptoService
.userEncryptedPrivateKey$(userId)
.pipe(distinctUntilChanged());
const userKey$ = this.cryptoService.userKey$(userId).pipe(distinctUntilChanged());
const orgKeys$ = this.cryptoService.encryptedOrgKeys$(userId).pipe(
distinctUntilChanged(compareValues), // The upstream observable emits different objects with the same values
);
const client$ = combineLatest([
this.environmentService.environment$,
account$,
kdfParams$,
privateKey$,
userKey$,
orgKeys$,
]).pipe(
// switchMap is required to allow the clean-up logic to be executed when `combineLatest` emits a new value.
switchMap(([env, account, kdfParams, privateKey, userKey, orgKeys]) => {
// Create our own observable to be able to implement clean-up logic
return new Observable<BitwardenClient>((subscriber) => {
let client: BitwardenClient;
const createAndInitializeClient = async () => {
if (privateKey == null || userKey == null || orgKeys == null) {
return undefined;
}
const settings = this.toSettings(env);
client = await this.sdkClientFactory.createSdkClient(settings, LogLevel.Info);
await this.initializeClient(client, account, kdfParams, privateKey, userKey, orgKeys);
return client;
};
createAndInitializeClient()
.then((c) => {
client = c;
subscriber.next(c);
})
.catch((e) => {
subscriber.error(e);
});
return () => client?.free();
});
}),
tap({
finalize: () => this.sdkClientCache.delete(userId),
}),
shareReplay({ refCount: true, bufferSize: 1 }),
);
this.sdkClientCache.set(userId, client$);
return client$;
}
async failedToInitialize(): Promise<void> {
// Only log on cloud instances
if (
@@ -52,6 +144,49 @@ export class DefaultSdkService implements SdkService {
});
}
private async initializeClient(
client: BitwardenClient,
account: AccountInfo,
kdfParams: KdfConfig,
privateKey: EncryptedString,
userKey: UserKey,
orgKeys: Record<OrganizationId, EncryptedOrganizationKeyData>,
) {
await client.crypto().initialize_user_crypto({
email: account.email,
method: { decryptedKey: { decrypted_user_key: userKey.keyB64 } },
kdfParams:
kdfParams.kdfType === KdfType.PBKDF2_SHA256
? {
pBKDF2: { iterations: kdfParams.iterations },
}
: {
argon2id: {
iterations: kdfParams.iterations,
memory: kdfParams.memory,
parallelism: kdfParams.parallelism,
},
},
privateKey,
});
await client.crypto().initialize_org_crypto({
organizationKeys: new Map(
Object.entries(orgKeys)
.filter(([_, v]) => v.type === "organization")
.map(([k, v]) => [k, v.key]),
),
});
}
private toSettings(env: Environment): ClientSettings {
return {
apiUrl: env.getApiUrl(),
identityUrl: env.getIdentityUrl(),
deviceType: this.toDevice(this.platformUtilsService.getDevice()),
userAgent: this.userAgent ?? navigator.userAgent,
};
}
private toDevice(device: DeviceType): SdkDeviceType {
switch (device) {
case DeviceType.Android: