1
0
mirror of https://github.com/bitwarden/browser synced 2026-02-07 12:13:45 +00:00

Merge remote-tracking branch 'origin' into auth/pm-19877/notification-processing

This commit is contained in:
Patrick Pimentel
2025-08-04 15:46:17 -04:00
885 changed files with 17448 additions and 14057 deletions

View File

@@ -89,8 +89,7 @@ export class DefaultPolicyService implements PolicyService {
const policies$ = policies ? of(policies) : this.policies$(userId);
return policies$.pipe(
map((obsPolicies) => {
// TODO: replace with this.combinePoliciesIntoMasterPasswordPolicyOptions(obsPolicies)) once
// FeatureFlag.PM16117_ChangeExistingPasswordRefactor is removed.
// TODO ([PM-23777]): replace with this.combinePoliciesIntoMasterPasswordPolicyOptions(obsPolicies))
let enforcedOptions: MasterPasswordPolicyOptions | undefined = undefined;
const filteredPolicies =
obsPolicies.filter((p) => p.type === PolicyType.MasterPassword) ?? [];

View File

@@ -0,0 +1,19 @@
import { map, Observable } from "rxjs";
import { UserId } from "@bitwarden/user-core";
import { ActiveUserAccessor } from "../../platform/state";
import { AccountService } from "../abstractions/account.service";
/**
* Implementation for Platform so they can avoid a direct dependency on AccountService. Not for general consumption.
*/
export class DefaultActiveUserAccessor implements ActiveUserAccessor {
constructor(private readonly accountService: AccountService) {
this.activeUserId$ = this.accountService.activeAccount$.pipe(
map((a) => (a != null ? a.id : null)),
);
}
activeUserId$: Observable<UserId | null>;
}

View File

@@ -4,8 +4,6 @@ import { of } from "rxjs";
// This import has been flagged as unallowed for this class. It may be involved in a circular dependency loop.
// eslint-disable-next-line no-restricted-imports
import {
PinLockType,
PinServiceAbstraction,
UserDecryptionOptions,
UserDecryptionOptionsServiceAbstraction,
} from "@bitwarden/auth/common";
@@ -21,6 +19,8 @@ import {
import { FakeAccountService, mockAccountServiceWith } from "../../../../spec";
import { InternalMasterPasswordServiceAbstraction } from "../../../key-management/master-password/abstractions/master-password.service.abstraction";
import { PinServiceAbstraction } from "../../../key-management/pin/pin.service.abstraction";
import { PinLockType } from "../../../key-management/pin/pin.service.implementation";
import { VaultTimeoutSettingsService } from "../../../key-management/vault-timeout";
import { I18nService } from "../../../platform/abstractions/i18n.service";
import { HashPurpose } from "../../../platform/enums";

View File

@@ -14,10 +14,8 @@ import {
KeyService,
} from "@bitwarden/key-management";
// FIXME: remove `src` and fix import
// eslint-disable-next-line no-restricted-imports
import { PinServiceAbstraction } from "../../../../../auth/src/common/abstractions/pin.service.abstraction";
import { InternalMasterPasswordServiceAbstraction } from "../../../key-management/master-password/abstractions/master-password.service.abstraction";
import { PinServiceAbstraction } from "../../../key-management/pin/pin.service.abstraction";
import { I18nService } from "../../../platform/abstractions/i18n.service";
import { HashPurpose } from "../../../platform/enums";
import { UserId } from "../../../types/guid";

View File

@@ -3,7 +3,6 @@ import { PreviewIndividualInvoiceRequest } from "../models/request/preview-indiv
import { PreviewOrganizationInvoiceRequest } from "../models/request/preview-organization-invoice.request";
import { PreviewTaxAmountForOrganizationTrialRequest } from "../models/request/tax";
import { PreviewInvoiceResponse } from "../models/response/preview-invoice.response";
import { PreviewTaxAmountResponse } from "../models/response/tax";
export abstract class TaxServiceAbstraction {
abstract getCountries(): CountryListItem[];
@@ -20,5 +19,5 @@ export abstract class TaxServiceAbstraction {
abstract previewTaxAmountForOrganizationTrial: (
request: PreviewTaxAmountForOrganizationTrialRequest,
) => Promise<PreviewTaxAmountResponse>;
) => Promise<number>;
}

View File

@@ -1,5 +1,4 @@
import { PreviewTaxAmountForOrganizationTrialRequest } from "@bitwarden/common/billing/models/request/tax";
import { PreviewTaxAmountResponse } from "@bitwarden/common/billing/models/response/tax";
import { ApiService } from "../../abstractions/api.service";
import { TaxServiceAbstraction } from "../abstractions/tax.service.abstraction";
@@ -306,13 +305,14 @@ export class TaxService implements TaxServiceAbstraction {
async previewTaxAmountForOrganizationTrial(
request: PreviewTaxAmountForOrganizationTrialRequest,
): Promise<PreviewTaxAmountResponse> {
return await this.apiService.send(
): Promise<number> {
const response = await this.apiService.send(
"POST",
"/tax/preview-amount/organization-trial",
request,
true,
true,
);
return response as number;
}
}

View File

@@ -1,10 +1 @@
// FIXME: update to use a const object instead of a typescript enum
// eslint-disable-next-line @bitwarden/platform/no-enums
export enum ClientType {
Web = "web",
Browser = "browser",
Desktop = "desktop",
// Mobile = "mobile",
Cli = "cli",
// DirectoryConnector = "connector",
}
export { ClientType } from "@bitwarden/client-type";

View File

@@ -14,8 +14,6 @@ export enum FeatureFlag {
CreateDefaultLocation = "pm-19467-create-default-location",
/* Auth */
PM16117_SetInitialPasswordRefactor = "pm-16117-set-initial-password-refactor",
PM16117_ChangeExistingPasswordRefactor = "pm-16117-change-existing-password-refactor",
PM14938_BrowserExtensionLoginApproval = "pm-14938-browser-extension-login-approvals",
/* Autofill */
@@ -34,17 +32,19 @@ export enum FeatureFlag {
UseOrganizationWarningsService = "use-organization-warnings-service",
AllowTrialLengthZero = "pm-20322-allow-trial-length-0",
PM21881_ManagePaymentDetailsOutsideCheckout = "pm-21881-manage-payment-details-outside-checkout",
PM21821_ProviderPortalTakeover = "pm-21821-provider-portal-takeover",
/* Key Management */
PrivateKeyRegeneration = "pm-12241-private-key-regeneration",
PM4154_BulkEncryptionService = "PM-4154-bulk-encryption-service",
UseSDKForDecryption = "use-sdk-for-decryption",
PM17987_BlockType0 = "pm-17987-block-type-0",
EnrollAeadOnKeyRotation = "enroll-aead-on-key-rotation",
ForceUpdateKDFSettings = "pm-18021-force-update-kdf-settings",
/* Tools */
DesktopSendUIRefresh = "desktop-send-ui-refresh",
UseSdkPasswordGenerators = "pm-19976-use-sdk-password-generators",
/* DIRT */
EventBasedOrganizationIntegrations = "event-based-organization-integrations",
/* Vault */
PM8851_BrowserOnboardingNudge = "pm-8851-browser-onboarding-nudge",
@@ -88,6 +88,10 @@ export const DefaultFeatureFlagValue = {
/* Tools */
[FeatureFlag.DesktopSendUIRefresh]: FALSE,
[FeatureFlag.UseSdkPasswordGenerators]: FALSE,
/* DIRT */
[FeatureFlag.EventBasedOrganizationIntegrations]: FALSE,
/* Vault */
[FeatureFlag.PM8851_BrowserOnboardingNudge]: FALSE,
@@ -101,8 +105,6 @@ export const DefaultFeatureFlagValue = {
[FeatureFlag.PM22136_SdkCipherEncryption]: FALSE,
/* Auth */
[FeatureFlag.PM16117_SetInitialPasswordRefactor]: FALSE,
[FeatureFlag.PM16117_ChangeExistingPasswordRefactor]: FALSE,
[FeatureFlag.PM14938_BrowserExtensionLoginApproval]: FALSE,
/* Billing */
@@ -113,12 +115,10 @@ export const DefaultFeatureFlagValue = {
[FeatureFlag.UseOrganizationWarningsService]: FALSE,
[FeatureFlag.AllowTrialLengthZero]: FALSE,
[FeatureFlag.PM21881_ManagePaymentDetailsOutsideCheckout]: FALSE,
[FeatureFlag.PM21821_ProviderPortalTakeover]: FALSE,
/* Key Management */
[FeatureFlag.PrivateKeyRegeneration]: FALSE,
[FeatureFlag.PM4154_BulkEncryptionService]: FALSE,
[FeatureFlag.UseSDKForDecryption]: FALSE,
[FeatureFlag.PM17987_BlockType0]: FALSE,
[FeatureFlag.EnrollAeadOnKeyRotation]: FALSE,
[FeatureFlag.ForceUpdateKDFSettings]: FALSE,

View File

@@ -1,13 +0,0 @@
import { ServerConfig } from "../../../platform/abstractions/config/server-config";
import { Decryptable } from "../../../platform/interfaces/decryptable.interface";
import { InitializerMetadata } from "../../../platform/interfaces/initializer-metadata.interface";
import { SymmetricCryptoKey } from "../../../platform/models/domain/symmetric-crypto-key";
export abstract class BulkEncryptService {
abstract decryptItems<T extends InitializerMetadata>(
items: Decryptable<T>[],
key: SymmetricCryptoKey,
): Promise<T[]>;
abstract onServerConfigChange(newConfig: ServerConfig): void;
}

View File

@@ -6,12 +6,20 @@ import { SymmetricCryptoKey } from "../../../platform/models/domain/symmetric-cr
import { CsprngArray } from "../../../types/csprng";
export abstract class CryptoFunctionService {
/**
* @deprecated HAZMAT WARNING: DO NOT USE THIS FOR NEW CODE. Implement low-level crypto operations
* in the SDK instead. Further, you should probably never find yourself using this low-level crypto function.
*/
abstract pbkdf2(
password: string | Uint8Array,
salt: string | Uint8Array,
algorithm: "sha256" | "sha512",
iterations: number,
): Promise<Uint8Array>;
/**
* @deprecated HAZMAT WARNING: DO NOT USE THIS FOR NEW CODE. Implement low-level crypto operations
* in the SDK instead. Further, you should probably never find yourself using this low-level crypto function.
*/
abstract hkdf(
ikm: Uint8Array,
salt: string | Uint8Array,
@@ -19,51 +27,76 @@ export abstract class CryptoFunctionService {
outputByteSize: number,
algorithm: "sha256" | "sha512",
): Promise<Uint8Array>;
/**
* @deprecated HAZMAT WARNING: DO NOT USE THIS FOR NEW CODE. Implement low-level crypto operations
* in the SDK instead. Further, you should probably never find yourself using this low-level crypto function.
*/
abstract hkdfExpand(
prk: Uint8Array,
info: string | Uint8Array,
outputByteSize: number,
algorithm: "sha256" | "sha512",
): Promise<Uint8Array>;
/**
* @deprecated HAZMAT WARNING: DO NOT USE THIS FOR NEW CODE. Implement low-level crypto operations
* in the SDK instead. Further, you should probably never find yourself using this low-level crypto function.
*/
abstract hash(
value: string | Uint8Array,
algorithm: "sha1" | "sha256" | "sha512" | "md5",
): Promise<Uint8Array>;
abstract hmac(
value: Uint8Array,
key: Uint8Array,
algorithm: "sha1" | "sha256" | "sha512",
): Promise<Uint8Array>;
abstract compare(a: Uint8Array, b: Uint8Array): Promise<boolean>;
/**
* @deprecated HAZMAT WARNING: DO NOT USE THIS FOR NEW CODE. Implement low-level crypto operations
* in the SDK instead. Further, you should probably never find yourself using this low-level crypto function.
*/
abstract hmacFast(
value: Uint8Array | string,
key: Uint8Array | string,
algorithm: "sha1" | "sha256" | "sha512",
): Promise<Uint8Array | string>;
abstract compareFast(a: Uint8Array | string, b: Uint8Array | string): Promise<boolean>;
/**
* @deprecated HAZMAT WARNING: DO NOT USE THIS FOR NEW CODE. Implement low-level crypto operations
* in the SDK instead. Further, you should probably never find yourself using this low-level crypto function.
*/
abstract aesDecryptFastParameters(
data: string,
iv: string,
mac: string,
key: SymmetricCryptoKey,
): CbcDecryptParameters<Uint8Array | string>;
/**
* @deprecated HAZMAT WARNING: DO NOT USE THIS FOR NEW CODE. Implement low-level crypto operations
* in the SDK instead. Further, you should probably never find yourself using this low-level crypto function.
*/
abstract aesDecryptFast({
mode,
parameters,
}:
| { mode: "cbc"; parameters: CbcDecryptParameters<Uint8Array | string> }
| { mode: "ecb"; parameters: EcbDecryptParameters<Uint8Array | string> }): Promise<string>;
/**
* @deprecated HAZMAT WARNING: DO NOT USE THIS FOR NEW CODE. Only used by DDG integration until DDG uses PKCS#7 padding, and by lastpass importer.
*/
abstract aesDecrypt(
data: Uint8Array,
iv: Uint8Array,
key: Uint8Array,
mode: "cbc" | "ecb",
): Promise<Uint8Array>;
/**
* @deprecated HAZMAT WARNING: DO NOT USE THIS FOR NEW CODE. Implement low-level crypto operations
* in the SDK instead. Further, you should probably never find yourself using this low-level crypto function.
*/
abstract rsaEncrypt(
data: Uint8Array,
publicKey: Uint8Array,
algorithm: "sha1" | "sha256",
): Promise<Uint8Array>;
/**
* @deprecated HAZMAT WARNING: DO NOT USE THIS FOR NEW CODE. Implement low-level crypto operations
* in the SDK instead. Further, you should probably never find yourself using this low-level crypto function.
*/
abstract rsaDecrypt(
data: Uint8Array,
privateKey: Uint8Array,
@@ -77,7 +110,6 @@ export abstract class CryptoFunctionService {
abstract aesGenerateKey(bitLength: 128 | 192 | 256 | 512): Promise<CsprngArray>;
/**
* Generates a random array of bytes of the given length. Uses a cryptographically secure random number generator.
*
* Do not use this for generating encryption keys. Use aesGenerateKey or rsaGenerateKeyPair instead.
*/
abstract randomBytes(length: number): Promise<CsprngArray>;

View File

@@ -1,51 +1,8 @@
import { ServerConfig } from "../../../platform/abstractions/config/server-config";
import { Decryptable } from "../../../platform/interfaces/decryptable.interface";
import { Encrypted } from "../../../platform/interfaces/encrypted";
import { InitializerMetadata } from "../../../platform/interfaces/initializer-metadata.interface";
import { EncArrayBuffer } from "../../../platform/models/domain/enc-array-buffer";
import { SymmetricCryptoKey } from "../../../platform/models/domain/symmetric-crypto-key";
import { EncString } from "../models/enc-string";
export abstract class EncryptService {
/**
* @deprecated
* Decrypts an EncString to a string
* @param encString - The EncString to decrypt
* @param key - The key to decrypt the EncString with
* @param decryptTrace - A string to identify the context of the object being decrypted. This can include: field name, encryption type, cipher id, key type, but should not include
* sensitive information like encryption keys or data. This is used for logging when decryption errors occur in order to identify what failed to decrypt
* @returns The decrypted string
*/
abstract decryptToUtf8(
encString: EncString,
key: SymmetricCryptoKey,
decryptTrace?: string,
): Promise<string>;
/**
* @deprecated
* Decrypts an Encrypted object to a Uint8Array
* @param encThing - The Encrypted object to decrypt
* @param key - The key to decrypt the Encrypted object with
* @param decryptTrace - A string to identify the context of the object being decrypted. This can include: field name, encryption type, cipher id, key type, but should not include
* sensitive information like encryption keys or data. This is used for logging when decryption errors occur in order to identify what failed to decrypt
* @returns The decrypted Uint8Array
*/
abstract decryptToBytes(
encThing: Encrypted,
key: SymmetricCryptoKey,
decryptTrace?: string,
): Promise<Uint8Array | null>;
/**
* @deprecated Replaced by BulkEncryptService, remove once the feature is tested and the featureflag PM-4154-multi-worker-encryption-service is removed
* @param items The items to decrypt
* @param key The key to decrypt the items with
*/
abstract decryptItems<T extends InitializerMetadata>(
items: Decryptable<T>[],
key: SymmetricCryptoKey,
): Promise<T[]>;
/**
* Encrypts a string to an EncString
* @param plainValue - The value to encrypt
@@ -188,12 +145,6 @@ export abstract class EncryptService {
decapsulationKey: Uint8Array,
): Promise<SymmetricCryptoKey>;
/**
* @deprecated Use @see {@link encapsulateKeyUnsigned} instead
* @param data - The data to encrypt
* @param publicKey - The public key to encrypt with
*/
abstract rsaEncrypt(data: Uint8Array, publicKey: Uint8Array): Promise<EncString>;
/**
* @deprecated Use @see {@link decapsulateKeyUnsigned} instead
* @param data - The ciphertext to decrypt
@@ -210,6 +161,4 @@ export abstract class EncryptService {
value: string | Uint8Array,
algorithm: "sha1" | "sha256" | "sha512",
): Promise<string>;
abstract onServerConfigChange(newConfig: ServerConfig): void;
}

View File

@@ -4,7 +4,7 @@ import { mock, MockProxy } from "jest-mock-extended";
// eslint-disable-next-line no-restricted-imports
import { KeyService } from "@bitwarden/key-management";
import { makeEncString, makeStaticByteArray } from "../../../../spec";
import { makeStaticByteArray } from "../../../../spec";
import { EncryptionType } from "../../../platform/enums";
import { SymmetricCryptoKey } from "../../../platform/models/domain/symmetric-crypto-key";
import { ContainerService } from "../../../platform/services/container.service";
@@ -83,7 +83,7 @@ describe("EncString", () => {
const keyService = mock<KeyService>();
keyService.hasUserKey.mockResolvedValue(true);
keyService.getUserKeyWithLegacySupport.mockResolvedValue(
keyService.getUserKey.mockResolvedValue(
new SymmetricCryptoKey(makeStaticByteArray(32)) as UserKey,
);
@@ -114,67 +114,6 @@ describe("EncString", () => {
});
});
describe("decryptWithKey", () => {
const encString = new EncString(EncryptionType.Rsa2048_OaepSha256_B64, "data");
const keyService = mock<KeyService>();
const encryptService = mock<EncryptService>();
encryptService.decryptString
.calledWith(encString, expect.anything())
.mockResolvedValue("decrypted");
function setupEncryption() {
encryptService.encryptString.mockImplementation(async (data, key) => {
return makeEncString(data);
});
encryptService.decryptString.mockImplementation(async (encString, key) => {
return encString.data;
});
}
beforeEach(() => {
(window as any).bitwardenContainerService = new ContainerService(keyService, encryptService);
});
it("decrypts using the provided key and encryptService", async () => {
setupEncryption();
const key = new SymmetricCryptoKey(makeStaticByteArray(32));
await encString.decryptWithKey(key, encryptService);
expect(encryptService.decryptString).toHaveBeenCalledWith(encString, key);
});
it("fails to decrypt when key is null", async () => {
const decrypted = await encString.decryptWithKey(null, encryptService);
expect(decrypted).toBe("[error: cannot decrypt]");
expect(encString.decryptedValue).toBe("[error: cannot decrypt]");
});
it("fails to decrypt when encryptService is null", async () => {
const decrypted = await encString.decryptWithKey(
new SymmetricCryptoKey(makeStaticByteArray(32)),
null,
);
expect(decrypted).toBe("[error: cannot decrypt]");
expect(encString.decryptedValue).toBe("[error: cannot decrypt]");
});
it("fails to decrypt when encryptService throws", async () => {
encryptService.decryptString.mockRejectedValue("error");
const decrypted = await encString.decryptWithKey(
new SymmetricCryptoKey(makeStaticByteArray(32)),
encryptService,
);
expect(decrypted).toBe("[error: cannot decrypt]");
expect(encString.decryptedValue).toBe("[error: cannot decrypt]");
});
});
describe("AesCbc256_B64", () => {
it("constructor", () => {
const encString = new EncString(EncryptionType.AesCbc256_B64, "data", "iv");
@@ -343,7 +282,7 @@ describe("EncString", () => {
await encString.decrypt(null, key);
expect(keyService.getUserKeyWithLegacySupport).not.toHaveBeenCalled();
expect(keyService.getUserKey).not.toHaveBeenCalled();
expect(encryptService.decryptString).toHaveBeenCalledWith(encString, key);
});
@@ -361,11 +300,11 @@ describe("EncString", () => {
it("gets the user's decryption key if required", async () => {
const userKey = mock<UserKey>();
keyService.getUserKeyWithLegacySupport.mockResolvedValue(userKey);
keyService.getUserKey.mockResolvedValue(userKey);
await encString.decrypt(null, null);
expect(keyService.getUserKeyWithLegacySupport).toHaveBeenCalledWith();
expect(keyService.getUserKey).toHaveBeenCalledWith();
expect(encryptService.decryptString).toHaveBeenCalledWith(encString, userKey);
});
});

View File

@@ -1,17 +1,18 @@
// FIXME: Update this file to be type safe and remove this and next line
// @ts-strict-ignore
import { Jsonify, Opaque } from "type-fest";
import { Jsonify } from "type-fest";
import { EncString as SdkEncString } from "@bitwarden/sdk-internal";
import { EncryptionType, EXPECTED_NUM_PARTS_BY_ENCRYPTION_TYPE } from "../../../platform/enums";
import { Encrypted } from "../../../platform/interfaces/encrypted";
import { Utils } from "../../../platform/misc/utils";
import { SymmetricCryptoKey } from "../../../platform/models/domain/symmetric-crypto-key";
import { EncryptService } from "../abstractions/encrypt.service";
export const DECRYPT_ERROR = "[error: cannot decrypt]";
export class EncString implements Encrypted {
encryptedString?: EncryptedString;
encryptedString?: SdkEncString;
encryptionType?: EncryptionType;
decryptedValue?: string;
data?: string;
@@ -43,7 +44,11 @@ export class EncString implements Encrypted {
return this.data == null ? null : Utils.fromB64ToArray(this.data);
}
toJSON() {
toSdk(): SdkEncString {
return this.encryptedString;
}
toJSON(): string {
return this.encryptedString as string;
}
@@ -57,14 +62,14 @@ export class EncString implements Encrypted {
private initFromData(encType: EncryptionType, data: string, iv: string, mac: string) {
if (iv != null) {
this.encryptedString = (encType + "." + iv + "|" + data) as EncryptedString;
this.encryptedString = (encType + "." + iv + "|" + data) as SdkEncString;
} else {
this.encryptedString = (encType + "." + data) as EncryptedString;
this.encryptedString = (encType + "." + data) as SdkEncString;
}
// mac
if (mac != null) {
this.encryptedString = (this.encryptedString + "|" + mac) as EncryptedString;
this.encryptedString = (this.encryptedString + "|" + mac) as SdkEncString;
}
this.encryptionType = encType;
@@ -74,7 +79,7 @@ export class EncString implements Encrypted {
}
private initFromEncryptedString(encryptedString: string) {
this.encryptedString = encryptedString as EncryptedString;
this.encryptedString = encryptedString as SdkEncString;
if (!this.encryptedString) {
return;
}
@@ -184,31 +189,14 @@ export class EncString implements Encrypted {
return this.decryptedValue;
}
async decryptWithKey(
key: SymmetricCryptoKey,
encryptService: EncryptService,
decryptTrace: string = "domain-withkey",
): Promise<string> {
try {
if (key == null) {
throw new Error("No key to decrypt EncString");
}
this.decryptedValue = await encryptService.decryptString(this, key);
// FIXME: Remove when updating file. Eslint update
// eslint-disable-next-line @typescript-eslint/no-unused-vars
} catch (e) {
this.decryptedValue = DECRYPT_ERROR;
}
return this.decryptedValue;
}
private async getKeyForDecryption(orgId: string) {
const keyService = Utils.getContainerService().getKeyService();
return orgId != null
? await keyService.getOrgKey(orgId)
: await keyService.getUserKeyWithLegacySupport();
return orgId != null ? await keyService.getOrgKey(orgId) : await keyService.getUserKey();
}
}
export type EncryptedString = Opaque<string, "EncString">;
/**
* Temporary type mapping until consumers are moved over.
* @deprecated - Use SdkEncString directly
*/
export type EncryptedString = SdkEncString;

View File

@@ -1,46 +0,0 @@
import { BulkEncryptService } from "@bitwarden/common/key-management/crypto/abstractions/bulk-encrypt.service";
import { CryptoFunctionService } from "@bitwarden/common/key-management/crypto/abstractions/crypto-function.service";
import { LogService } from "@bitwarden/common/platform/abstractions/log.service";
import { Decryptable } from "@bitwarden/common/platform/interfaces/decryptable.interface";
import { InitializerMetadata } from "@bitwarden/common/platform/interfaces/initializer-metadata.interface";
import { SymmetricCryptoKey } from "@bitwarden/common/platform/models/domain/symmetric-crypto-key";
import { DefaultFeatureFlagValue, FeatureFlag } from "../../../enums/feature-flag.enum";
import { ServerConfig } from "../../../platform/abstractions/config/server-config";
/**
* @deprecated Will be deleted in an immediate subsequent PR
*/
export class BulkEncryptServiceImplementation implements BulkEncryptService {
protected useSDKForDecryption: boolean = DefaultFeatureFlagValue[FeatureFlag.UseSDKForDecryption];
constructor(
protected cryptoFunctionService: CryptoFunctionService,
protected logService: LogService,
) {}
/**
* Decrypts items using a web worker if the environment supports it.
* Will fall back to the main thread if the window object is not available.
*/
async decryptItems<T extends InitializerMetadata>(
items: Decryptable<T>[],
key: SymmetricCryptoKey,
): Promise<T[]> {
if (key == null) {
throw new Error("No encryption key provided.");
}
if (items == null || items.length < 1) {
return [];
}
const results = [];
for (let i = 0; i < items.length; i++) {
results.push(await items[i].decrypt(key));
}
return results;
}
onServerConfigChange(newConfig: ServerConfig): void {}
}

View File

@@ -5,15 +5,11 @@ import { EncString } from "@bitwarden/common/key-management/crypto/models/enc-st
import { LogService } from "@bitwarden/common/platform/abstractions/log.service";
import { SdkLoadService } from "@bitwarden/common/platform/abstractions/sdk/sdk-load.service";
import { EncryptionType } from "@bitwarden/common/platform/enums";
import { Decryptable } from "@bitwarden/common/platform/interfaces/decryptable.interface";
import { Encrypted } from "@bitwarden/common/platform/interfaces/encrypted";
import { InitializerMetadata } from "@bitwarden/common/platform/interfaces/initializer-metadata.interface";
import { Utils } from "@bitwarden/common/platform/misc/utils";
import { EncArrayBuffer } from "@bitwarden/common/platform/models/domain/enc-array-buffer";
import { SymmetricCryptoKey } from "@bitwarden/common/platform/models/domain/symmetric-crypto-key";
import { PureCrypto } from "@bitwarden/sdk-internal";
import { ServerConfig } from "../../../platform/abstractions/config/server-config";
import { EncryptService } from "../abstractions/encrypt.service";
export class EncryptServiceImplementation implements EncryptService {
@@ -23,7 +19,6 @@ export class EncryptServiceImplementation implements EncryptService {
protected logMacFailures: boolean,
) {}
// Proxy functions; Their implementation are temporary before moving at this level to the SDK
async encryptString(plainValue: string, key: SymmetricCryptoKey): Promise<EncString> {
if (plainValue == null) {
this.logService.warning(
@@ -171,36 +166,6 @@ export class EncryptServiceImplementation implements EncryptService {
return Utils.fromBufferToB64(hashArray);
}
// Handle updating private properties to turn on/off feature flags.
onServerConfigChange(newConfig: ServerConfig): void {}
async decryptToUtf8(
encString: EncString,
key: SymmetricCryptoKey,
_decryptContext: string = "no context",
): Promise<string> {
await SdkLoadService.Ready;
return PureCrypto.symmetric_decrypt(encString.encryptedString, key.toEncoded());
}
async decryptToBytes(
encThing: Encrypted,
key: SymmetricCryptoKey,
_decryptContext: string = "no context",
): Promise<Uint8Array | null> {
if (encThing.encryptionType == null || encThing.ivBytes == null || encThing.dataBytes == null) {
throw new Error("Cannot decrypt, missing type, IV, or data bytes.");
}
const buffer = EncArrayBuffer.fromParts(
encThing.encryptionType,
encThing.ivBytes,
encThing.dataBytes,
encThing.macBytes,
).buffer;
await SdkLoadService.Ready;
return PureCrypto.symmetric_decrypt_array_buffer(buffer, key.toEncoded());
}
async encapsulateKeyUnsigned(
sharedKey: SymmetricCryptoKey,
encapsulationKey: Uint8Array,
@@ -228,45 +193,14 @@ export class EncryptServiceImplementation implements EncryptService {
throw new Error("No decapsulationKey provided for decapsulation");
}
await SdkLoadService.Ready;
const keyBytes = PureCrypto.decapsulate_key_unsigned(
encryptedSharedKey.encryptedString,
decapsulationKey,
);
await SdkLoadService.Ready;
return new SymmetricCryptoKey(keyBytes);
}
/**
* @deprecated Replaced by BulkEncryptService (PM-4154)
*/
async decryptItems<T extends InitializerMetadata>(
items: Decryptable<T>[],
key: SymmetricCryptoKey,
): Promise<T[]> {
if (items == null || items.length < 1) {
return [];
}
// don't use promise.all because this task is not io bound
const results = [];
for (let i = 0; i < items.length; i++) {
results.push(await items[i].decrypt(key));
}
return results;
}
async rsaEncrypt(data: Uint8Array, publicKey: Uint8Array): Promise<EncString> {
if (data == null) {
throw new Error("No data provided for encryption.");
}
if (publicKey == null) {
throw new Error("No public key provided for encryption.");
}
const encrypted = await this.cryptoFunctionService.rsaEncrypt(data, publicKey, "sha1");
return new EncString(EncryptionType.Rsa2048_OaepSha1_B64, Utils.fromBufferToB64(encrypted));
}
async rsaDecrypt(data: EncString, privateKey: Uint8Array): Promise<Uint8Array> {
if (data == null) {
throw new Error("[Encrypt service] rsaDecrypt: No data provided for decryption.");

View File

@@ -303,12 +303,6 @@ describe("EncryptService", () => {
const actual = await encryptService.encapsulateKeyUnsigned(testKey, publicKey);
expect(actual).toEqual(new EncString("encapsulated_key_unsigned"));
});
it("throws if no data was provided", () => {
return expect(encryptService.rsaEncrypt(null, new Uint8Array(32))).rejects.toThrow(
"No data provided for encryption",
);
});
});
describe("decapsulateKeyUnsigned", () => {
@@ -338,23 +332,4 @@ describe("EncryptService", () => {
expect(cryptoFunctionService.hash).toHaveBeenCalledWith("test", "sha256");
});
});
describe("decryptItems", () => {
it("returns empty array if no items are provided", async () => {
const key = mock<SymmetricCryptoKey>();
const actual = await encryptService.decryptItems(null, key);
expect(actual).toEqual([]);
});
it("returns items decrypted with provided key", async () => {
const key = mock<SymmetricCryptoKey>();
const decryptable = {
decrypt: jest.fn().mockResolvedValue("decrypted"),
};
const items = [decryptable];
const actual = await encryptService.decryptItems(items as any, key);
expect(actual).toEqual(["decrypted"]);
expect(decryptable.decrypt).toHaveBeenCalledWith(key);
});
});
});

View File

@@ -1,81 +0,0 @@
// FIXME: Update this file to be type safe and remove this and next line
// @ts-strict-ignore
import { Jsonify } from "type-fest";
import { ServerConfig } from "../../../platform/abstractions/config/server-config";
import { LogService } from "../../../platform/abstractions/log.service";
import { Decryptable } from "../../../platform/interfaces/decryptable.interface";
import { SymmetricCryptoKey } from "../../../platform/models/domain/symmetric-crypto-key";
import { ConsoleLogService } from "../../../platform/services/console-log.service";
import { ContainerService } from "../../../platform/services/container.service";
import { getClassInitializer } from "../../../platform/services/cryptography/get-class-initializer";
import {
DECRYPT_COMMAND,
SET_CONFIG_COMMAND,
ParsedDecryptCommandData,
} from "../types/worker-command.type";
import { EncryptServiceImplementation } from "./encrypt.service.implementation";
import { WebCryptoFunctionService } from "./web-crypto-function.service";
const workerApi: Worker = self as any;
let inited = false;
let encryptService: EncryptServiceImplementation;
let logService: LogService;
/**
* Bootstrap the worker environment with services required for decryption
*/
export function init() {
const cryptoFunctionService = new WebCryptoFunctionService(self);
logService = new ConsoleLogService(false);
encryptService = new EncryptServiceImplementation(cryptoFunctionService, logService, true);
const bitwardenContainerService = new ContainerService(null, encryptService);
bitwardenContainerService.attachToGlobal(self);
inited = true;
}
/**
* Listen for messages and decrypt their contents
*/
workerApi.addEventListener("message", async (event: { data: string }) => {
if (!inited) {
init();
}
const request: {
command: string;
} = JSON.parse(event.data);
switch (request.command) {
case DECRYPT_COMMAND:
return await handleDecrypt(request as unknown as ParsedDecryptCommandData);
case SET_CONFIG_COMMAND: {
const newConfig = (request as unknown as { newConfig: Jsonify<ServerConfig> }).newConfig;
return await handleSetConfig(newConfig);
}
default:
logService.error(`[EncryptWorker] unknown worker command`, request.command, request);
}
});
async function handleDecrypt(request: ParsedDecryptCommandData) {
const key = SymmetricCryptoKey.fromJSON(request.key);
const items = request.items.map((jsonItem) => {
const initializer = getClassInitializer<Decryptable<any>>(jsonItem.initializerKey);
return initializer(jsonItem);
});
const result = await encryptService.decryptItems(items, key);
workerApi.postMessage({
id: request.id,
items: JSON.stringify(result),
});
}
async function handleSetConfig(newConfig: Jsonify<ServerConfig>) {
encryptService.onServerConfigChange(ServerConfig.fromJSON(newConfig));
}

View File

@@ -1,34 +0,0 @@
// FIXME: Update this file to be type safe and remove this and next line
// @ts-strict-ignore
import { BulkEncryptService } from "@bitwarden/common/key-management/crypto/abstractions/bulk-encrypt.service";
import { Decryptable } from "@bitwarden/common/platform/interfaces/decryptable.interface";
import { InitializerMetadata } from "@bitwarden/common/platform/interfaces/initializer-metadata.interface";
import { SymmetricCryptoKey } from "@bitwarden/common/platform/models/domain/symmetric-crypto-key";
import { ServerConfig } from "../../../platform/abstractions/config/server-config";
import { EncryptService } from "../abstractions/encrypt.service";
/**
* @deprecated Will be deleted in an immediate subsequent PR
*/
export class FallbackBulkEncryptService implements BulkEncryptService {
private featureFlagEncryptService: BulkEncryptService;
private currentServerConfig: ServerConfig | undefined = undefined;
constructor(protected encryptService: EncryptService) {}
/**
* Decrypts items using a web worker if the environment supports it.
* Will fall back to the main thread if the window object is not available.
*/
async decryptItems<T extends InitializerMetadata>(
items: Decryptable<T>[],
key: SymmetricCryptoKey,
): Promise<T[]> {
return await this.encryptService.decryptItems(items, key);
}
async setFeatureFlagEncryptService(featureFlagEncryptService: BulkEncryptService) {}
onServerConfigChange(newConfig: ServerConfig): void {}
}

View File

@@ -1,27 +0,0 @@
import { Decryptable } from "@bitwarden/common/platform/interfaces/decryptable.interface";
import { InitializerMetadata } from "@bitwarden/common/platform/interfaces/initializer-metadata.interface";
import { SymmetricCryptoKey } from "@bitwarden/common/platform/models/domain/symmetric-crypto-key";
import { ServerConfig } from "../../../platform/abstractions/config/server-config";
import { EncryptServiceImplementation } from "./encrypt.service.implementation";
/**
* @deprecated Will be deleted in an immediate subsequent PR
*/
export class MultithreadEncryptServiceImplementation extends EncryptServiceImplementation {
protected useSDKForDecryption: boolean = true;
/**
* Sends items to a web worker to decrypt them.
* This utilises multithreading to decrypt items faster without interrupting other operations (e.g. updating UI).
*/
async decryptItems<T extends InitializerMetadata>(
items: Decryptable<T>[],
key: SymmetricCryptoKey,
): Promise<T[]> {
return await super.decryptItems(items, key);
}
override onServerConfigChange(newConfig: ServerConfig): void {}
}

View File

@@ -154,46 +154,6 @@ describe("WebCrypto Function Service", () => {
testHmac("sha512", Sha512Mac);
});
describe("compare", () => {
it("should successfully compare two of the same values", async () => {
const cryptoFunctionService = getWebCryptoFunctionService();
const a = new Uint8Array(2);
a[0] = 1;
a[1] = 2;
const equal = await cryptoFunctionService.compare(a, a);
expect(equal).toBe(true);
});
it("should successfully compare two different values of the same length", async () => {
const cryptoFunctionService = getWebCryptoFunctionService();
const a = new Uint8Array(2);
a[0] = 1;
a[1] = 2;
const b = new Uint8Array(2);
b[0] = 3;
b[1] = 4;
const equal = await cryptoFunctionService.compare(a, b);
expect(equal).toBe(false);
});
it("should successfully compare two different values of different lengths", async () => {
const cryptoFunctionService = getWebCryptoFunctionService();
const a = new Uint8Array(2);
a[0] = 1;
a[1] = 2;
const b = new Uint8Array(2);
b[0] = 3;
const equal = await cryptoFunctionService.compare(a, b);
expect(equal).toBe(false);
});
});
describe("hmacFast", () => {
testHmacFast("sha1", Sha1Mac);
testHmacFast("sha256", Sha256Mac);
testHmacFast("sha512", Sha512Mac);
});
describe("compareFast", () => {
it("should successfully compare two of the same values", async () => {
const cryptoFunctionService = getWebCryptoFunctionService();
@@ -523,20 +483,6 @@ function testHmac(algorithm: "sha1" | "sha256" | "sha512", mac: string) {
});
}
function testHmacFast(algorithm: "sha1" | "sha256" | "sha512", mac: string) {
it("should create valid " + algorithm + " hmac", async () => {
const cryptoFunctionService = getWebCryptoFunctionService();
const keyByteString = Utils.fromBufferToByteString(Utils.fromUtf8ToArray("secretkey"));
const dataByteString = Utils.fromBufferToByteString(Utils.fromUtf8ToArray("SignMe!!"));
const computedMac = await cryptoFunctionService.hmacFast(
dataByteString,
keyByteString,
algorithm,
);
expect(Utils.fromBufferToHex(Utils.fromByteStringToArray(computedMac))).toBe(mac);
});
}
function testRsaGenerateKeyPair(length: 1024 | 2048 | 4096) {
it(
"should successfully generate a " + length + " bit key pair",

View File

@@ -146,34 +146,6 @@ export class WebCryptoFunctionService implements CryptoFunctionService {
return new Uint8Array(buffer);
}
// Safely compare two values in a way that protects against timing attacks (Double HMAC Verification).
// ref: https://www.nccgroup.trust/us/about-us/newsroom-and-events/blog/2011/february/double-hmac-verification/
// ref: https://paragonie.com/blog/2015/11/preventing-timing-attacks-on-string-comparison-with-double-hmac-strategy
async compare(a: Uint8Array, b: Uint8Array): Promise<boolean> {
const macKey = await this.randomBytes(32);
const signingAlgorithm = {
name: "HMAC",
hash: { name: "SHA-256" },
};
const impKey = await this.subtle.importKey("raw", macKey, signingAlgorithm, false, ["sign"]);
const mac1 = await this.subtle.sign(signingAlgorithm, impKey, a);
const mac2 = await this.subtle.sign(signingAlgorithm, impKey, b);
if (mac1.byteLength !== mac2.byteLength) {
return false;
}
const arr1 = new Uint8Array(mac1);
const arr2 = new Uint8Array(mac2);
for (let i = 0; i < arr2.length; i++) {
if (arr1[i] !== arr2[i]) {
return false;
}
}
return true;
}
hmacFast(value: string, key: string, algorithm: "sha1" | "sha256" | "sha512"): Promise<string> {
const hmac = forge.hmac.create();
hmac.start(algorithm, key);
@@ -182,6 +154,9 @@ export class WebCryptoFunctionService implements CryptoFunctionService {
return Promise.resolve(bytes);
}
// Safely compare two values in a way that protects against timing attacks (Double HMAC Verification).
// ref: https://www.nccgroup.trust/us/about-us/newsroom-and-events/blog/2011/february/double-hmac-verification/
// ref: https://paragonie.com/blog/2015/11/preventing-timing-attacks-on-string-comparison-with-double-hmac-strategy
async compareFast(a: string, b: string): Promise<boolean> {
const rand = await this.randomBytes(32);
const bytes = new Uint32Array(rand);

View File

@@ -1,67 +0,0 @@
import { mock } from "jest-mock-extended";
import { makeStaticByteArray } from "../../../../spec";
import { ServerConfig } from "../../../platform/abstractions/config/server-config";
import { Decryptable } from "../../../platform/interfaces/decryptable.interface";
import { SymmetricCryptoKey } from "../../../platform/models/domain/symmetric-crypto-key";
import {
DECRYPT_COMMAND,
DecryptCommandData,
SET_CONFIG_COMMAND,
buildDecryptMessage,
buildSetConfigMessage,
} from "./worker-command.type";
describe("Worker command types", () => {
describe("buildDecryptMessage", () => {
it("builds a message with the correct command", () => {
const commandData = createDecryptCommandData();
const result = buildDecryptMessage(commandData);
const parsedResult = JSON.parse(result);
expect(parsedResult.command).toBe(DECRYPT_COMMAND);
});
it("includes the provided data in the message", () => {
const mockItems = [{ encrypted: "test-encrypted" } as unknown as Decryptable<any>];
const commandData = createDecryptCommandData(mockItems);
const result = buildDecryptMessage(commandData);
const parsedResult = JSON.parse(result);
expect(parsedResult.command).toBe(DECRYPT_COMMAND);
expect(parsedResult.id).toBe("test-id");
expect(parsedResult.items).toEqual(mockItems);
expect(SymmetricCryptoKey.fromJSON(parsedResult.key)).toEqual(commandData.key);
});
});
describe("buildSetConfigMessage", () => {
it("builds a message with the correct command", () => {
const result = buildSetConfigMessage({ newConfig: mock<ServerConfig>() });
const parsedResult = JSON.parse(result);
expect(parsedResult.command).toBe(SET_CONFIG_COMMAND);
});
it("includes the provided data in the message", () => {
const serverConfig = { version: "test-version" } as unknown as ServerConfig;
const result = buildSetConfigMessage({ newConfig: serverConfig });
const parsedResult = JSON.parse(result);
expect(parsedResult.command).toBe(SET_CONFIG_COMMAND);
expect(ServerConfig.fromJSON(parsedResult.newConfig).version).toEqual(serverConfig.version);
});
});
});
function createDecryptCommandData(items?: Decryptable<any>[]): DecryptCommandData {
return {
id: "test-id",
items: items ?? [],
key: new SymmetricCryptoKey(makeStaticByteArray(64)),
};
}

View File

@@ -1,36 +0,0 @@
import { Jsonify } from "type-fest";
import { ServerConfig } from "../../../platform/abstractions/config/server-config";
import { Decryptable } from "../../../platform/interfaces/decryptable.interface";
import { SymmetricCryptoKey } from "../../../platform/models/domain/symmetric-crypto-key";
export const DECRYPT_COMMAND = "decrypt";
export const SET_CONFIG_COMMAND = "updateConfig";
export type DecryptCommandData = {
id: string;
items: Decryptable<any>[];
key: SymmetricCryptoKey;
};
export type ParsedDecryptCommandData = {
id: string;
items: Jsonify<Decryptable<any>>[];
key: Jsonify<SymmetricCryptoKey>;
};
type SetConfigCommandData = { newConfig: ServerConfig };
export function buildDecryptMessage(data: DecryptCommandData): string {
return JSON.stringify({
command: DECRYPT_COMMAND,
...data,
});
}
export function buildSetConfigMessage(data: SetConfigCommandData): string {
return JSON.stringify({
command: SET_CONFIG_COMMAND,
...data,
});
}

View File

@@ -1,9 +1,17 @@
import { Observable } from "rxjs";
// eslint-disable-next-line no-restricted-imports
import { KdfConfig } from "@bitwarden/key-management";
import { ForceSetPasswordReason } from "../../../auth/models/domain/force-set-password-reason";
import { UserId } from "../../../types/guid";
import { MasterKey, UserKey } from "../../../types/key";
import { EncString } from "../../crypto/models/enc-string";
import {
MasterPasswordAuthenticationData,
MasterPasswordSalt,
MasterPasswordUnlockData,
} from "../types/master-password.types";
export abstract class MasterPasswordServiceAbstraction {
/**
@@ -12,14 +20,23 @@ export abstract class MasterPasswordServiceAbstraction {
* @throws If the user ID is missing.
*/
abstract forceSetPasswordReason$: (userId: UserId) => Observable<ForceSetPasswordReason>;
/**
* An observable that emits the master password salt for the user.
* @param userId The user ID.
* @throws If the user ID is missing.
* @throws If the user ID is provided, but the user is not found.
*/
abstract saltForUser$: (userId: UserId) => Observable<MasterPasswordSalt>;
/**
* An observable that emits the master key for the user.
* @deprecated Interacting with the master-key directly is deprecated. Please use {@link makeMasterPasswordUnlockData}, {@link makeMasterPasswordAuthenticationData} or {@link unwrapUserKeyFromMasterPasswordUnlockData} instead.
* @param userId The user ID.
* @throws If the user ID is missing.
*/
abstract masterKey$: (userId: UserId) => Observable<MasterKey>;
/**
* An observable that emits the master key hash for the user.
* @deprecated Interacting with the master-key directly is deprecated. Please use {@link makeMasterPasswordAuthenticationData}.
* @param userId The user ID.
* @throws If the user ID is missing.
*/
@@ -32,6 +49,7 @@ export abstract class MasterPasswordServiceAbstraction {
abstract getMasterKeyEncryptedUserKey: (userId: UserId) => Promise<EncString>;
/**
* Decrypts the user key with the provided master key
* @deprecated Interacting with the master-key directly is deprecated. Please use {@link unwrapUserKeyFromMasterPasswordUnlockData} instead.
* @param masterKey The user's master key
* * @param userId The desired user
* @param userKey The user's encrypted symmetric key
@@ -44,12 +62,52 @@ export abstract class MasterPasswordServiceAbstraction {
userId: string,
userKey?: EncString,
) => Promise<UserKey | null>;
/**
* Makes the authentication hash for authenticating to the server with the master password.
* @param password The master password.
* @param kdf The KDF configuration.
* @param salt The master password salt to use. See {@link saltForUser$} for current salt.
* @throws If password, KDF or salt are null or undefined.
*/
abstract makeMasterPasswordAuthenticationData: (
password: string,
kdf: KdfConfig,
salt: MasterPasswordSalt,
) => Promise<MasterPasswordAuthenticationData>;
/**
* Creates a MasterPasswordUnlockData bundle that encrypts the user-key with a key derived from the password. The
* bundle also contains the KDF settings and salt used to derive the key, which are required to decrypt the user-key later.
* @param password The master password.
* @param kdf The KDF configuration.
* @param salt The master password salt to use. See {@link saltForUser$} for current salt.
* @param userKey The user's userKey to encrypt.
* @throws If password, KDF, salt, or userKey are null or undefined.
*/
abstract makeMasterPasswordUnlockData: (
password: string,
kdf: KdfConfig,
salt: MasterPasswordSalt,
userKey: UserKey,
) => Promise<MasterPasswordUnlockData>;
/**
* Unwraps a user-key that was wrapped with a password provided KDF settings. The same KDF settings and salt must be provided to unwrap the user-key, otherwise it will fail to decrypt.
* @throws If the encryption type is not supported.
* @throws If the password, KDF, or salt don't match the original wrapping parameters.
*/
abstract unwrapUserKeyFromMasterPasswordUnlockData: (
password: string,
masterPasswordUnlockData: MasterPasswordUnlockData,
) => Promise<UserKey>;
}
export abstract class InternalMasterPasswordServiceAbstraction extends MasterPasswordServiceAbstraction {
/**
* Set the master key for the user.
* Note: Use {@link clearMasterKey} to clear the master key.
* @deprecated Interacting with the master-key directly is deprecated.
* @param masterKey The master key.
* @param userId The user ID.
* @throws If the user ID or master key is missing.
@@ -57,6 +115,7 @@ export abstract class InternalMasterPasswordServiceAbstraction extends MasterPas
abstract setMasterKey: (masterKey: MasterKey, userId: UserId) => Promise<void>;
/**
* Clear the master key for the user.
* @deprecated Interacting with the master-key directly is deprecated.
* @param userId The user ID.
* @throws If the user ID is missing.
*/
@@ -64,6 +123,7 @@ export abstract class InternalMasterPasswordServiceAbstraction extends MasterPas
/**
* Set the master key hash for the user.
* Note: Use {@link clearMasterKeyHash} to clear the master key hash.
* @deprecated Interacting with the master-key directly is deprecated.
* @param masterKeyHash The master key hash.
* @param userId The user ID.
* @throws If the user ID or master key hash is missing.
@@ -71,6 +131,7 @@ export abstract class InternalMasterPasswordServiceAbstraction extends MasterPas
abstract setMasterKeyHash: (masterKeyHash: string, userId: UserId) => Promise<void>;
/**
* Clear the master key hash for the user.
* @deprecated Interacting with the master-key directly is deprecated.
* @param userId The user ID.
* @throws If the user ID is missing.
*/

View File

@@ -3,11 +3,20 @@
import { mock } from "jest-mock-extended";
import { ReplaySubject, Observable } from "rxjs";
// FIXME: Update this file to be type safe and remove this and next line
// eslint-disable-next-line no-restricted-imports
import { KdfConfig } from "@bitwarden/key-management";
import { ForceSetPasswordReason } from "../../../auth/models/domain/force-set-password-reason";
import { UserId } from "../../../types/guid";
import { MasterKey, UserKey } from "../../../types/key";
import { EncString } from "../../crypto/models/enc-string";
import { InternalMasterPasswordServiceAbstraction } from "../abstractions/master-password.service.abstraction";
import {
MasterPasswordAuthenticationData,
MasterPasswordSalt,
MasterPasswordUnlockData,
} from "../types/master-password.types";
export class FakeMasterPasswordService implements InternalMasterPasswordServiceAbstraction {
mock = mock<InternalMasterPasswordServiceAbstraction>();
@@ -24,6 +33,10 @@ export class FakeMasterPasswordService implements InternalMasterPasswordServiceA
this.masterKeyHashSubject.next(initialMasterKeyHash);
}
saltForUser$(userId: UserId): Observable<MasterPasswordSalt> {
return this.mock.saltForUser$(userId);
}
masterKey$(userId: UserId): Observable<MasterKey> {
return this.masterKeySubject.asObservable();
}
@@ -71,4 +84,28 @@ export class FakeMasterPasswordService implements InternalMasterPasswordServiceA
): Promise<UserKey> {
return this.mock.decryptUserKeyWithMasterKey(masterKey, userId, userKey);
}
makeMasterPasswordAuthenticationData(
password: string,
kdf: KdfConfig,
salt: MasterPasswordSalt,
): Promise<MasterPasswordAuthenticationData> {
return this.mock.makeMasterPasswordAuthenticationData(password, kdf, salt);
}
makeMasterPasswordUnlockData(
password: string,
kdf: KdfConfig,
salt: MasterPasswordSalt,
userKey: UserKey,
): Promise<MasterPasswordUnlockData> {
return this.mock.makeMasterPasswordUnlockData(password, kdf, salt, userKey);
}
unwrapUserKeyFromMasterPasswordUnlockData(
password: string,
masterPasswordUnlockData: MasterPasswordUnlockData,
): Promise<UserKey> {
return this.mock.unwrapUserKeyFromMasterPasswordUnlockData(password, masterPasswordUnlockData);
}
}

View File

@@ -1,8 +1,17 @@
import { mock, MockProxy } from "jest-mock-extended";
import { of } from "rxjs";
import { firstValueFrom, of } from "rxjs";
import * as rxjs from "rxjs";
import { makeSymmetricCryptoKey } from "../../../../spec";
import { SdkLoadService } from "@bitwarden/common/platform/abstractions/sdk/sdk-load.service";
import { Utils } from "@bitwarden/common/platform/misc/utils";
// eslint-disable-next-line no-restricted-imports
import { KdfConfig, PBKDF2KdfConfig } from "@bitwarden/key-management";
import {
FakeAccountService,
makeSymmetricCryptoKey,
mockAccountServiceWith,
} from "../../../../spec";
import { ForceSetPasswordReason } from "../../../auth/models/domain/force-set-password-reason";
import { KeyGenerationService } from "../../../platform/abstractions/key-generation.service";
import { LogService } from "../../../platform/abstractions/log.service";
@@ -10,9 +19,11 @@ import { StateService } from "../../../platform/abstractions/state.service";
import { SymmetricCryptoKey } from "../../../platform/models/domain/symmetric-crypto-key";
import { StateProvider } from "../../../platform/state";
import { UserId } from "../../../types/guid";
import { MasterKey } from "../../../types/key";
import { MasterKey, UserKey } from "../../../types/key";
import { CryptoFunctionService } from "../../crypto/abstractions/crypto-function.service";
import { EncryptService } from "../../crypto/abstractions/encrypt.service";
import { EncString } from "../../crypto/models/enc-string";
import { MasterPasswordSalt } from "../types/master-password.types";
import { MasterPasswordService } from "./master-password.service";
@@ -24,8 +35,10 @@ describe("MasterPasswordService", () => {
let keyGenerationService: MockProxy<KeyGenerationService>;
let encryptService: MockProxy<EncryptService>;
let logService: MockProxy<LogService>;
let cryptoFunctionService: MockProxy<CryptoFunctionService>;
let accountService: FakeAccountService;
const userId = "user-id" as UserId;
const userId = "00000000-0000-0000-0000-000000000000" as UserId;
const mockUserState = {
state$: of(null),
update: jest.fn().mockResolvedValue(null),
@@ -45,6 +58,8 @@ describe("MasterPasswordService", () => {
keyGenerationService = mock<KeyGenerationService>();
encryptService = mock<EncryptService>();
logService = mock<LogService>();
cryptoFunctionService = mock<CryptoFunctionService>();
accountService = mockAccountServiceWith(userId);
stateProvider.getUser.mockReturnValue(mockUserState as any);
@@ -56,10 +71,33 @@ describe("MasterPasswordService", () => {
keyGenerationService,
encryptService,
logService,
cryptoFunctionService,
accountService,
);
encryptService.unwrapSymmetricKey.mockResolvedValue(makeSymmetricCryptoKey(64, 1));
keyGenerationService.stretchKey.mockResolvedValue(makeSymmetricCryptoKey(64, 3));
Object.defineProperty(SdkLoadService, "Ready", {
value: Promise.resolve(),
configurable: true,
});
});
describe("saltForUser$", () => {
it("throws when userid not present", async () => {
expect(() => {
sut.saltForUser$(null as unknown as UserId);
}).toThrow("userId is null or undefined.");
});
it("throws when userid present but not in account service", async () => {
await expect(
firstValueFrom(sut.saltForUser$("00000000-0000-0000-0000-000000000001" as UserId)),
).rejects.toThrow("Cannot read properties of undefined (reading 'email')");
});
it("returns salt", async () => {
const salt = await firstValueFrom(sut.saltForUser$(userId));
expect(salt).toBeDefined();
});
});
describe("setForceSetPasswordReason", () => {
@@ -190,4 +228,97 @@ describe("MasterPasswordService", () => {
expect(updateFn(null)).toEqual(encryptedKey.toJSON());
});
});
describe("makeMasterPasswordAuthenticationData", () => {
const password = "test-password";
const kdf: KdfConfig = new PBKDF2KdfConfig(600_000);
const salt = "test@bitwarden.com" as MasterPasswordSalt;
const masterKey = makeSymmetricCryptoKey(32, 2);
const masterKeyHash = makeSymmetricCryptoKey(32, 3).toEncoded();
beforeEach(() => {
keyGenerationService.deriveKeyFromPassword.mockResolvedValue(masterKey);
cryptoFunctionService.pbkdf2.mockResolvedValue(masterKeyHash);
});
it("derives master key and creates authentication hash", async () => {
const result = await sut.makeMasterPasswordAuthenticationData(password, kdf, salt);
expect(keyGenerationService.deriveKeyFromPassword).toHaveBeenCalledWith(password, salt, kdf);
expect(cryptoFunctionService.pbkdf2).toHaveBeenCalledWith(
masterKey.toEncoded(),
password,
"sha256",
1,
);
expect(result).toEqual({
kdf,
salt,
masterPasswordAuthenticationHash: Utils.fromBufferToB64(masterKeyHash),
});
});
it("throws if password is null", async () => {
await expect(
sut.makeMasterPasswordAuthenticationData(null as unknown as string, kdf, salt),
).rejects.toThrow();
});
it("throws if kdf is null", async () => {
await expect(
sut.makeMasterPasswordAuthenticationData(password, null as unknown as KdfConfig, salt),
).rejects.toThrow();
});
it("throws if salt is null", async () => {
await expect(
sut.makeMasterPasswordAuthenticationData(
password,
kdf,
null as unknown as MasterPasswordSalt,
),
).rejects.toThrow();
});
});
describe("wrapUnwrapUserKeyWithPassword", () => {
const password = "test-password";
const kdf: KdfConfig = new PBKDF2KdfConfig(600_000);
const salt = "test@bitwarden.com" as MasterPasswordSalt;
const userKey = makeSymmetricCryptoKey(64, 2) as UserKey;
it("wraps and unwraps user key with password", async () => {
const unlockData = await sut.makeMasterPasswordUnlockData(password, kdf, salt, userKey);
const unwrappedUserkey = await sut.unwrapUserKeyFromMasterPasswordUnlockData(
password,
unlockData,
);
expect(unwrappedUserkey).toEqual(userKey);
});
it("throws if password is null", async () => {
await expect(
sut.makeMasterPasswordUnlockData(null as unknown as string, kdf, salt, userKey),
).rejects.toThrow();
});
it("throws if kdf is null", async () => {
await expect(
sut.makeMasterPasswordUnlockData(password, null as unknown as KdfConfig, salt, userKey),
).rejects.toThrow();
});
it("throws if salt is null", async () => {
await expect(
sut.makeMasterPasswordUnlockData(
password,
kdf,
null as unknown as MasterPasswordSalt,
userKey,
),
).rejects.toThrow();
});
it("throws if userKey is null", async () => {
await expect(
sut.makeMasterPasswordUnlockData(password, kdf, salt, null as unknown as UserKey),
).rejects.toThrow();
});
});
});

View File

@@ -2,6 +2,14 @@
// @ts-strict-ignore
import { firstValueFrom, map, Observable } from "rxjs";
import { AccountService } from "@bitwarden/common/auth/abstractions/account.service";
import { assertNonNullish } from "@bitwarden/common/auth/utils";
import { SdkLoadService } from "@bitwarden/common/platform/abstractions/sdk/sdk-load.service";
import { Utils } from "@bitwarden/common/platform/misc/utils";
// eslint-disable-next-line no-restricted-imports
import { KdfConfig } from "@bitwarden/key-management";
import { PureCrypto } from "@bitwarden/sdk-internal";
import { ForceSetPasswordReason } from "../../../auth/models/domain/force-set-password-reason";
import { KeyGenerationService } from "../../../platform/abstractions/key-generation.service";
import { LogService } from "../../../platform/abstractions/log.service";
@@ -16,9 +24,17 @@ import {
} from "../../../platform/state";
import { UserId } from "../../../types/guid";
import { MasterKey, UserKey } from "../../../types/key";
import { CryptoFunctionService } from "../../crypto/abstractions/crypto-function.service";
import { EncryptService } from "../../crypto/abstractions/encrypt.service";
import { EncryptedString, EncString } from "../../crypto/models/enc-string";
import { InternalMasterPasswordServiceAbstraction } from "../abstractions/master-password.service.abstraction";
import {
MasterKeyWrappedUserKey,
MasterPasswordAuthenticationData,
MasterPasswordAuthenticationHash,
MasterPasswordSalt,
MasterPasswordUnlockData,
} from "../types/master-password.types";
/** Memory since master key shouldn't be available on lock */
const MASTER_KEY = new UserKeyDefinition<MasterKey>(MASTER_PASSWORD_MEMORY, "masterKey", {
@@ -59,8 +75,18 @@ export class MasterPasswordService implements InternalMasterPasswordServiceAbstr
private keyGenerationService: KeyGenerationService,
private encryptService: EncryptService,
private logService: LogService,
private cryptoFunctionService: CryptoFunctionService,
private accountService: AccountService,
) {}
saltForUser$(userId: UserId): Observable<MasterPasswordSalt> {
assertNonNullish(userId, "userId");
return this.accountService.accounts$.pipe(
map((accounts) => accounts[userId].email),
map((email) => this.emailToSalt(email)),
);
}
masterKey$(userId: UserId): Observable<MasterKey> {
if (userId == null) {
throw new Error("User ID is required.");
@@ -95,6 +121,10 @@ export class MasterPasswordService implements InternalMasterPasswordServiceAbstr
return EncString.fromJSON(key);
}
private emailToSalt(email: string): MasterPasswordSalt {
return email.toLowerCase().trim() as MasterPasswordSalt;
}
async setMasterKey(masterKey: MasterKey, userId: UserId): Promise<void> {
if (masterKey == null) {
throw new Error("Master key is required.");
@@ -202,4 +232,89 @@ export class MasterPasswordService implements InternalMasterPasswordServiceAbstr
return decUserKey as UserKey;
}
async makeMasterPasswordAuthenticationData(
password: string,
kdf: KdfConfig,
salt: MasterPasswordSalt,
): Promise<MasterPasswordAuthenticationData> {
assertNonNullish(password, "password");
assertNonNullish(kdf, "kdf");
assertNonNullish(salt, "salt");
// We don't trust callers to use masterpasswordsalt correctly. They may type assert incorrectly.
salt = salt.toLowerCase().trim() as MasterPasswordSalt;
const SERVER_AUTHENTICATION_HASH_ITERATIONS = 1;
const masterKey = (await this.keyGenerationService.deriveKeyFromPassword(
password,
salt,
kdf,
)) as MasterKey;
const masterPasswordAuthenticationHash = Utils.fromBufferToB64(
await this.cryptoFunctionService.pbkdf2(
masterKey.toEncoded(),
password,
"sha256",
SERVER_AUTHENTICATION_HASH_ITERATIONS,
),
) as MasterPasswordAuthenticationHash;
return {
salt,
kdf,
masterPasswordAuthenticationHash,
} as MasterPasswordAuthenticationData;
}
async makeMasterPasswordUnlockData(
password: string,
kdf: KdfConfig,
salt: MasterPasswordSalt,
userKey: UserKey,
): Promise<MasterPasswordUnlockData> {
assertNonNullish(password, "password");
assertNonNullish(kdf, "kdf");
assertNonNullish(salt, "salt");
assertNonNullish(userKey, "userKey");
// We don't trust callers to use masterpasswordsalt correctly. They may type assert incorrectly.
salt = salt.toLowerCase().trim() as MasterPasswordSalt;
await SdkLoadService.Ready;
const masterKeyWrappedUserKey = new EncString(
PureCrypto.encrypt_user_key_with_master_password(
userKey.toEncoded(),
password,
salt,
kdf.toSdkConfig(),
),
) as MasterKeyWrappedUserKey;
return {
salt,
kdf,
masterKeyWrappedUserKey,
};
}
async unwrapUserKeyFromMasterPasswordUnlockData(
password: string,
masterPasswordUnlockData: MasterPasswordUnlockData,
): Promise<UserKey> {
assertNonNullish(password, "password");
assertNonNullish(masterPasswordUnlockData, "masterPasswordUnlockData");
await SdkLoadService.Ready;
const userKey = new SymmetricCryptoKey(
PureCrypto.decrypt_user_key_with_master_password(
masterPasswordUnlockData.masterKeyWrappedUserKey.encryptedString,
password,
masterPasswordUnlockData.salt,
masterPasswordUnlockData.kdf.toSdkConfig(),
),
);
return userKey as UserKey;
}
}

View File

@@ -0,0 +1,34 @@
import { Opaque } from "type-fest";
// eslint-disable-next-line no-restricted-imports
import { KdfConfig } from "@bitwarden/key-management";
import { EncString } from "../../crypto/models/enc-string";
/**
* The Base64-encoded master password authentication hash, that is sent to the server for authentication.
*/
export type MasterPasswordAuthenticationHash = Opaque<string, "MasterPasswordAuthenticationHash">;
/**
* You MUST obtain this through the emailToSalt function in MasterPasswordService
*/
export type MasterPasswordSalt = Opaque<string, "MasterPasswordSalt">;
export type MasterKeyWrappedUserKey = Opaque<EncString, "MasterPasswordSalt">;
/**
* The data required to unlock with the master password.
*/
export type MasterPasswordUnlockData = {
salt: MasterPasswordSalt;
kdf: KdfConfig;
masterKeyWrappedUserKey: MasterKeyWrappedUserKey;
};
/**
* The data required to authenticate with the master password.
*/
export type MasterPasswordAuthenticationData = {
salt: MasterPasswordSalt;
kdf: KdfConfig;
masterPasswordAuthenticationHash: MasterPasswordAuthenticationHash;
};

View File

@@ -0,0 +1,128 @@
// eslint-disable-next-line no-restricted-imports
import { KdfConfig } from "@bitwarden/key-management";
import { EncString } from "../../key-management/crypto/models/enc-string";
import { UserId } from "../../types/guid";
import { PinKey, UserKey } from "../../types/key";
import { PinLockType } from "./pin.service.implementation";
/**
* The PinService is used for PIN-based unlocks. Below is a very basic overview of the PIN flow:
*
* -- Setting the PIN via {@link SetPinComponent} --
*
* When the user submits the setPinForm:
* 1. We encrypt the PIN with the UserKey and store it on disk as `userKeyEncryptedPin`.
*
* 2. We create a PinKey from the PIN, and then use that PinKey to encrypt the UserKey, resulting in
* a `pinKeyEncryptedUserKey`, which can be stored in one of two ways depending on what the user selects
* for the `requireMasterPasswordOnClientReset` checkbox.
*
* If `requireMasterPasswordOnClientReset` is:
* - TRUE, store in memory as `pinKeyEncryptedUserKeyEphemeral` (does NOT persist through a client reset)
* - FALSE, store on disk as `pinKeyEncryptedUserKeyPersistent` (persists through a client reset)
*
* -- Unlocking with the PIN via {@link LockComponent} --
*
* When the user enters their PIN, we decrypt their UserKey with the PIN and set that UserKey to state.
*/
export abstract class PinServiceAbstraction {
/**
* Gets the persistent (stored on disk) version of the UserKey, encrypted by the PinKey.
*/
abstract getPinKeyEncryptedUserKeyPersistent: (userId: UserId) => Promise<EncString | null>;
/**
* Clears the persistent (stored on disk) version of the UserKey, encrypted by the PinKey.
*/
abstract clearPinKeyEncryptedUserKeyPersistent(userId: UserId): Promise<void>;
/**
* Gets the ephemeral (stored in memory) version of the UserKey, encrypted by the PinKey.
*/
abstract getPinKeyEncryptedUserKeyEphemeral: (userId: UserId) => Promise<EncString | null>;
/**
* Clears the ephemeral (stored in memory) version of the UserKey, encrypted by the PinKey.
*/
abstract clearPinKeyEncryptedUserKeyEphemeral(userId: UserId): Promise<void>;
/**
* Creates a pinKeyEncryptedUserKey from the provided PIN and UserKey.
*/
abstract createPinKeyEncryptedUserKey: (
pin: string,
userKey: UserKey,
userId: UserId,
) => Promise<EncString>;
/**
* Stores the UserKey, encrypted by the PinKey.
* @param storeEphemeralVersion If true, stores an ephemeral version via the private {@link setPinKeyEncryptedUserKeyEphemeral} method.
* If false, stores a persistent version via the private {@link setPinKeyEncryptedUserKeyPersistent} method.
*/
abstract storePinKeyEncryptedUserKey: (
pinKeyEncryptedUserKey: EncString,
storeEphemeralVersion: boolean,
userId: UserId,
) => Promise<void>;
/**
* Gets the user's PIN, encrypted by the UserKey.
*/
abstract getUserKeyEncryptedPin: (userId: UserId) => Promise<EncString | null>;
/**
* Sets the user's PIN, encrypted by the UserKey.
*/
abstract setUserKeyEncryptedPin: (
userKeyEncryptedPin: EncString,
userId: UserId,
) => Promise<void>;
/**
* Creates a PIN, encrypted by the UserKey.
*/
abstract createUserKeyEncryptedPin: (pin: string, userKey: UserKey) => Promise<EncString>;
/**
* Clears the user's PIN, encrypted by the UserKey.
*/
abstract clearUserKeyEncryptedPin(userId: UserId): Promise<void>;
/**
* Makes a PinKey from the provided PIN.
*/
abstract makePinKey: (pin: string, salt: string, kdfConfig: KdfConfig) => Promise<PinKey>;
/**
* Gets the user's PinLockType {@link PinLockType}.
*/
abstract getPinLockType: (userId: UserId) => Promise<PinLockType>;
/**
* Declares whether or not the user has a PIN set (either persistent or ephemeral).
* Note: for ephemeral, this does not check if we actual have an ephemeral PIN-encrypted UserKey stored in memory.
* Decryption might not be possible even if this returns true. Use {@link isPinDecryptionAvailable} if decryption is required.
*/
abstract isPinSet: (userId: UserId) => Promise<boolean>;
/**
* Checks if PIN-encrypted keys are stored for the user.
* Used for unlock / user verification scenarios where we will need to decrypt the UserKey with the PIN.
*/
abstract isPinDecryptionAvailable: (userId: UserId) => Promise<boolean>;
/**
* Decrypts the UserKey with the provided PIN.
*
* @remarks - If the user has an old pinKeyEncryptedMasterKey (formerly called `pinProtected`), the UserKey
* will be obtained via the private {@link decryptAndMigrateOldPinKeyEncryptedMasterKey} method.
* - If the user does not have an old pinKeyEncryptedMasterKey, the UserKey will be obtained via the
* private {@link decryptUserKey} method.
* @returns UserKey
*/
abstract decryptUserKeyWithPin: (pin: string, userId: UserId) => Promise<UserKey | null>;
}

View File

@@ -0,0 +1,390 @@
// FIXME: Update this file to be type safe and remove this and next line
// @ts-strict-ignore
import { firstValueFrom, map } from "rxjs";
// eslint-disable-next-line no-restricted-imports
import { KdfConfig, KdfConfigService } from "@bitwarden/key-management";
import { AccountService } from "../../auth/abstractions/account.service";
import { CryptoFunctionService } from "../../key-management/crypto/abstractions/crypto-function.service";
import { EncryptService } from "../../key-management/crypto/abstractions/encrypt.service";
import { EncString, EncryptedString } from "../../key-management/crypto/models/enc-string";
import { KeyGenerationService } from "../../platform/abstractions/key-generation.service";
import { LogService } from "../../platform/abstractions/log.service";
import { PIN_DISK, PIN_MEMORY, StateProvider, UserKeyDefinition } from "../../platform/state";
import { UserId } from "../../types/guid";
import { PinKey, UserKey } from "../../types/key";
import { PinServiceAbstraction } from "./pin.service.abstraction";
/**
* - DISABLED : No PIN set.
* - PERSISTENT : PIN is set and persists through client reset.
* - EPHEMERAL : PIN is set, but does NOT persist through client reset. This means that
* after client reset the master password is required to unlock.
*/
export type PinLockType = "DISABLED" | "PERSISTENT" | "EPHEMERAL";
/**
* The persistent (stored on disk) version of the UserKey, encrypted by the PinKey.
*
* @remarks Persists through a client reset. Used when `requireMasterPasswordOnClientRestart` is disabled.
* @see SetPinComponent.setPinForm.requireMasterPasswordOnClientRestart
*/
export const PIN_KEY_ENCRYPTED_USER_KEY_PERSISTENT = new UserKeyDefinition<EncryptedString>(
PIN_DISK,
"pinKeyEncryptedUserKeyPersistent",
{
deserializer: (jsonValue) => jsonValue,
clearOn: ["logout"],
},
);
/**
* The ephemeral (stored in memory) version of the UserKey, encrypted by the PinKey.
*
* @remarks Does NOT persist through a client reset. Used when `requireMasterPasswordOnClientRestart` is enabled.
* @see SetPinComponent.setPinForm.requireMasterPasswordOnClientRestart
*/
export const PIN_KEY_ENCRYPTED_USER_KEY_EPHEMERAL = new UserKeyDefinition<EncryptedString>(
PIN_MEMORY,
"pinKeyEncryptedUserKeyEphemeral",
{
deserializer: (jsonValue) => jsonValue,
clearOn: ["logout"],
},
);
/**
* The PIN, encrypted by the UserKey.
*/
export const USER_KEY_ENCRYPTED_PIN = new UserKeyDefinition<EncryptedString>(
PIN_DISK,
"userKeyEncryptedPin",
{
deserializer: (jsonValue) => jsonValue,
clearOn: ["logout"],
},
);
export class PinService implements PinServiceAbstraction {
constructor(
private accountService: AccountService,
private cryptoFunctionService: CryptoFunctionService,
private encryptService: EncryptService,
private kdfConfigService: KdfConfigService,
private keyGenerationService: KeyGenerationService,
private logService: LogService,
private stateProvider: StateProvider,
) {}
async getPinKeyEncryptedUserKeyPersistent(userId: UserId): Promise<EncString | null> {
this.validateUserId(userId, "Cannot get pinKeyEncryptedUserKeyPersistent.");
return EncString.fromJSON(
await firstValueFrom(
this.stateProvider.getUserState$(PIN_KEY_ENCRYPTED_USER_KEY_PERSISTENT, userId),
),
);
}
/**
* Sets the persistent (stored on disk) version of the UserKey, encrypted by the PinKey.
*/
private async setPinKeyEncryptedUserKeyPersistent(
pinKeyEncryptedUserKey: EncString,
userId: UserId,
): Promise<void> {
this.validateUserId(userId, "Cannot set pinKeyEncryptedUserKeyPersistent.");
if (pinKeyEncryptedUserKey == null) {
throw new Error(
"No pinKeyEncryptedUserKey provided. Cannot set pinKeyEncryptedUserKeyPersistent.",
);
}
await this.stateProvider.setUserState(
PIN_KEY_ENCRYPTED_USER_KEY_PERSISTENT,
pinKeyEncryptedUserKey?.encryptedString,
userId,
);
}
async clearPinKeyEncryptedUserKeyPersistent(userId: UserId): Promise<void> {
this.validateUserId(userId, "Cannot clear pinKeyEncryptedUserKeyPersistent.");
await this.stateProvider.setUserState(PIN_KEY_ENCRYPTED_USER_KEY_PERSISTENT, null, userId);
}
async getPinKeyEncryptedUserKeyEphemeral(userId: UserId): Promise<EncString | null> {
this.validateUserId(userId, "Cannot get pinKeyEncryptedUserKeyEphemeral.");
return EncString.fromJSON(
await firstValueFrom(
this.stateProvider.getUserState$(PIN_KEY_ENCRYPTED_USER_KEY_EPHEMERAL, userId),
),
);
}
/**
* Sets the ephemeral (stored in memory) version of the UserKey, encrypted by the PinKey.
*/
private async setPinKeyEncryptedUserKeyEphemeral(
pinKeyEncryptedUserKey: EncString,
userId: UserId,
): Promise<void> {
this.validateUserId(userId, "Cannot set pinKeyEncryptedUserKeyEphemeral.");
if (pinKeyEncryptedUserKey == null) {
throw new Error(
"No pinKeyEncryptedUserKey provided. Cannot set pinKeyEncryptedUserKeyEphemeral.",
);
}
await this.stateProvider.setUserState(
PIN_KEY_ENCRYPTED_USER_KEY_EPHEMERAL,
pinKeyEncryptedUserKey?.encryptedString,
userId,
);
}
async clearPinKeyEncryptedUserKeyEphemeral(userId: UserId): Promise<void> {
this.validateUserId(userId, "Cannot clear pinKeyEncryptedUserKeyEphemeral.");
await this.stateProvider.setUserState(PIN_KEY_ENCRYPTED_USER_KEY_EPHEMERAL, null, userId);
}
async createPinKeyEncryptedUserKey(
pin: string,
userKey: UserKey,
userId: UserId,
): Promise<EncString> {
this.validateUserId(userId, "Cannot create pinKeyEncryptedUserKey.");
if (!userKey) {
throw new Error("No UserKey provided. Cannot create pinKeyEncryptedUserKey.");
}
const email = await firstValueFrom(
this.accountService.accounts$.pipe(map((accounts) => accounts[userId].email)),
);
const kdfConfig = await this.kdfConfigService.getKdfConfig(userId);
const pinKey = await this.makePinKey(pin, email, kdfConfig);
return await this.encryptService.wrapSymmetricKey(userKey, pinKey);
}
async storePinKeyEncryptedUserKey(
pinKeyEncryptedUserKey: EncString,
storeAsEphemeral: boolean,
userId: UserId,
): Promise<void> {
this.validateUserId(userId, "Cannot store pinKeyEncryptedUserKey.");
if (storeAsEphemeral) {
await this.setPinKeyEncryptedUserKeyEphemeral(pinKeyEncryptedUserKey, userId);
} else {
await this.setPinKeyEncryptedUserKeyPersistent(pinKeyEncryptedUserKey, userId);
}
}
async getUserKeyEncryptedPin(userId: UserId): Promise<EncString | null> {
this.validateUserId(userId, "Cannot get userKeyEncryptedPin.");
return EncString.fromJSON(
await firstValueFrom(this.stateProvider.getUserState$(USER_KEY_ENCRYPTED_PIN, userId)),
);
}
async setUserKeyEncryptedPin(userKeyEncryptedPin: EncString, userId: UserId): Promise<void> {
this.validateUserId(userId, "Cannot set userKeyEncryptedPin.");
await this.stateProvider.setUserState(
USER_KEY_ENCRYPTED_PIN,
userKeyEncryptedPin?.encryptedString,
userId,
);
}
async clearUserKeyEncryptedPin(userId: UserId): Promise<void> {
this.validateUserId(userId, "Cannot clear userKeyEncryptedPin.");
await this.stateProvider.setUserState(USER_KEY_ENCRYPTED_PIN, null, userId);
}
async createUserKeyEncryptedPin(pin: string, userKey: UserKey): Promise<EncString> {
if (!userKey) {
throw new Error("No UserKey provided. Cannot create userKeyEncryptedPin.");
}
return await this.encryptService.encryptString(pin, userKey);
}
async makePinKey(pin: string, salt: string, kdfConfig: KdfConfig): Promise<PinKey> {
const start = Date.now();
const pinKey = await this.keyGenerationService.deriveKeyFromPassword(pin, salt, kdfConfig);
this.logService.info(`[Pin Service] deriving pin key took ${Date.now() - start}ms`);
return (await this.keyGenerationService.stretchKey(pinKey)) as PinKey;
}
async getPinLockType(userId: UserId): Promise<PinLockType> {
this.validateUserId(userId, "Cannot get PinLockType.");
const aUserKeyEncryptedPinIsSet = !!(await this.getUserKeyEncryptedPin(userId));
const aPinKeyEncryptedUserKeyPersistentIsSet =
!!(await this.getPinKeyEncryptedUserKeyPersistent(userId));
if (aPinKeyEncryptedUserKeyPersistentIsSet) {
return "PERSISTENT";
} else if (aUserKeyEncryptedPinIsSet && !aPinKeyEncryptedUserKeyPersistentIsSet) {
return "EPHEMERAL";
} else {
return "DISABLED";
}
}
async isPinSet(userId: UserId): Promise<boolean> {
this.validateUserId(userId, "Cannot determine if PIN is set.");
return (await this.getPinLockType(userId)) !== "DISABLED";
}
async isPinDecryptionAvailable(userId: UserId): Promise<boolean> {
this.validateUserId(userId, "Cannot determine if decryption of user key via PIN is available.");
const pinLockType = await this.getPinLockType(userId);
switch (pinLockType) {
case "DISABLED":
return false;
case "PERSISTENT":
// The above getPinLockType call ensures that we have either a PinKeyEncryptedUserKey set.
return true;
case "EPHEMERAL": {
// The above getPinLockType call ensures that we have a UserKeyEncryptedPin set.
// However, we must additively check to ensure that we have a set PinKeyEncryptedUserKeyEphemeral b/c otherwise
// we cannot take a PIN, derive a PIN key, and decrypt the ephemeral UserKey.
const pinKeyEncryptedUserKeyEphemeral =
await this.getPinKeyEncryptedUserKeyEphemeral(userId);
return Boolean(pinKeyEncryptedUserKeyEphemeral);
}
default: {
// Compile-time check for exhaustive switch
const _exhaustiveCheck: never = pinLockType;
throw new Error(`Unexpected pinLockType: ${_exhaustiveCheck}`);
}
}
}
async decryptUserKeyWithPin(pin: string, userId: UserId): Promise<UserKey | null> {
this.validateUserId(userId, "Cannot decrypt user key with PIN.");
try {
const pinLockType = await this.getPinLockType(userId);
const pinKeyEncryptedUserKey = await this.getPinKeyEncryptedKeys(pinLockType, userId);
const email = await firstValueFrom(
this.accountService.accounts$.pipe(map((accounts) => accounts[userId].email)),
);
const kdfConfig = await this.kdfConfigService.getKdfConfig(userId);
const userKey: UserKey = await this.decryptUserKey(
userId,
pin,
email,
kdfConfig,
pinKeyEncryptedUserKey,
);
if (!userKey) {
this.logService.warning(`User key null after pin key decryption.`);
return null;
}
if (!(await this.validatePin(userKey, pin, userId))) {
this.logService.warning(`Pin key decryption successful but pin validation failed.`);
return null;
}
return userKey;
} catch (error) {
this.logService.error(`Error decrypting user key with pin: ${error}`);
return null;
}
}
/**
* Decrypts the UserKey with the provided PIN.
*/
private async decryptUserKey(
userId: UserId,
pin: string,
salt: string,
kdfConfig: KdfConfig,
pinKeyEncryptedUserKey?: EncString,
): Promise<UserKey> {
this.validateUserId(userId, "Cannot decrypt user key.");
pinKeyEncryptedUserKey ||= await this.getPinKeyEncryptedUserKeyPersistent(userId);
pinKeyEncryptedUserKey ||= await this.getPinKeyEncryptedUserKeyEphemeral(userId);
if (!pinKeyEncryptedUserKey) {
throw new Error("No pinKeyEncryptedUserKey found.");
}
const pinKey = await this.makePinKey(pin, salt, kdfConfig);
const userKey = await this.encryptService.unwrapSymmetricKey(pinKeyEncryptedUserKey, pinKey);
return userKey as UserKey;
}
/**
* Gets the user's `pinKeyEncryptedUserKey` (persistent or ephemeral)
* (if one exists) based on the user's PinLockType.
*
* @throws If PinLockType is 'DISABLED' or if userId is not provided
*/
private async getPinKeyEncryptedKeys(
pinLockType: PinLockType,
userId: UserId,
): Promise<EncString> {
this.validateUserId(userId, "Cannot get PinKey encrypted keys.");
switch (pinLockType) {
case "PERSISTENT": {
return await this.getPinKeyEncryptedUserKeyPersistent(userId);
}
case "EPHEMERAL": {
return await this.getPinKeyEncryptedUserKeyEphemeral(userId);
}
case "DISABLED":
throw new Error("Pin is disabled");
default: {
// Compile-time check for exhaustive switch
const _exhaustiveCheck: never = pinLockType;
return _exhaustiveCheck;
}
}
}
private async validatePin(userKey: UserKey, pin: string, userId: UserId): Promise<boolean> {
this.validateUserId(userId, "Cannot validate PIN.");
const userKeyEncryptedPin = await this.getUserKeyEncryptedPin(userId);
const decryptedPin = await this.encryptService.decryptString(userKeyEncryptedPin, userKey);
const isPinValid = this.cryptoFunctionService.compareFast(decryptedPin, pin);
return isPinValid;
}
/**
* Throws a custom error message if user ID is not provided.
*/
private validateUserId(userId: UserId, errorMessage: string = "") {
if (!userId) {
throw new Error(`User ID is required. ${errorMessage}`);
}
}
}

View File

@@ -0,0 +1,518 @@
import { mock } from "jest-mock-extended";
// eslint-disable-next-line no-restricted-imports
import { DEFAULT_KDF_CONFIG, KdfConfigService } from "@bitwarden/key-management";
import { FakeAccountService, FakeStateProvider, mockAccountServiceWith } from "../../../spec";
import { KeyGenerationService } from "../../platform/abstractions/key-generation.service";
import { LogService } from "../../platform/abstractions/log.service";
import { Utils } from "../../platform/misc/utils";
import { SymmetricCryptoKey } from "../../platform/models/domain/symmetric-crypto-key";
import { UserId } from "../../types/guid";
import { PinKey, UserKey } from "../../types/key";
import { CryptoFunctionService } from "../crypto/abstractions/crypto-function.service";
import { EncryptService } from "../crypto/abstractions/encrypt.service";
import { EncString } from "../crypto/models/enc-string";
import {
PinService,
PIN_KEY_ENCRYPTED_USER_KEY_PERSISTENT,
PIN_KEY_ENCRYPTED_USER_KEY_EPHEMERAL,
USER_KEY_ENCRYPTED_PIN,
PinLockType,
} from "./pin.service.implementation";
describe("PinService", () => {
let sut: PinService;
let accountService: FakeAccountService;
let stateProvider: FakeStateProvider;
const cryptoFunctionService = mock<CryptoFunctionService>();
const encryptService = mock<EncryptService>();
const kdfConfigService = mock<KdfConfigService>();
const keyGenerationService = mock<KeyGenerationService>();
const logService = mock<LogService>();
const mockUserId = Utils.newGuid() as UserId;
const mockUserKey = new SymmetricCryptoKey(randomBytes(64)) as UserKey;
const mockPinKey = new SymmetricCryptoKey(randomBytes(32)) as PinKey;
const mockUserEmail = "user@example.com";
const mockPin = "1234";
const mockUserKeyEncryptedPin = new EncString("userKeyEncryptedPin");
// Note: both pinKeyEncryptedUserKeys use encryptionType: 2 (AesCbc256_HmacSha256_B64)
const pinKeyEncryptedUserKeyEphemeral = new EncString(
"2.gbauOANURUHqvhLTDnva1A==|nSW+fPumiuTaDB/s12+JO88uemV6rhwRSR+YR1ZzGr5j6Ei3/h+XEli2Unpz652NlZ9NTuRpHxeOqkYYJtp7J+lPMoclgteXuAzUu9kqlRc=|DeUFkhIwgkGdZA08bDnDqMMNmZk21D+H5g8IostPKAY=",
);
const pinKeyEncryptedUserKeyPersistant = new EncString(
"2.fb5kOEZvh9zPABbP8WRmSQ==|Yi6ZAJY+UtqCKMUSqp1ahY9Kf8QuneKXs6BMkpNsakLVOzTYkHHlilyGABMF7GzUO8QHyZi7V/Ovjjg+Naf3Sm8qNhxtDhibITv4k8rDnM0=|TFkq3h2VNTT1z5BFbebm37WYuxyEHXuRo0DZJI7TQnw=",
);
beforeEach(() => {
jest.clearAllMocks();
accountService = mockAccountServiceWith(mockUserId, { email: mockUserEmail });
stateProvider = new FakeStateProvider(accountService);
sut = new PinService(
accountService,
cryptoFunctionService,
encryptService,
kdfConfigService,
keyGenerationService,
logService,
stateProvider,
);
});
it("should instantiate the PinService", () => {
expect(sut).not.toBeFalsy();
});
describe("userId validation", () => {
it("should throw an error if a userId is not provided", async () => {
await expect(sut.getPinKeyEncryptedUserKeyPersistent(undefined)).rejects.toThrow(
"User ID is required. Cannot get pinKeyEncryptedUserKeyPersistent.",
);
await expect(sut.getPinKeyEncryptedUserKeyEphemeral(undefined)).rejects.toThrow(
"User ID is required. Cannot get pinKeyEncryptedUserKeyEphemeral.",
);
await expect(sut.clearPinKeyEncryptedUserKeyPersistent(undefined)).rejects.toThrow(
"User ID is required. Cannot clear pinKeyEncryptedUserKeyPersistent.",
);
await expect(sut.clearPinKeyEncryptedUserKeyEphemeral(undefined)).rejects.toThrow(
"User ID is required. Cannot clear pinKeyEncryptedUserKeyEphemeral.",
);
await expect(
sut.createPinKeyEncryptedUserKey(mockPin, mockUserKey, undefined),
).rejects.toThrow("User ID is required. Cannot create pinKeyEncryptedUserKey.");
await expect(sut.getUserKeyEncryptedPin(undefined)).rejects.toThrow(
"User ID is required. Cannot get userKeyEncryptedPin.",
);
await expect(sut.setUserKeyEncryptedPin(mockUserKeyEncryptedPin, undefined)).rejects.toThrow(
"User ID is required. Cannot set userKeyEncryptedPin.",
);
await expect(sut.clearUserKeyEncryptedPin(undefined)).rejects.toThrow(
"User ID is required. Cannot clear userKeyEncryptedPin.",
);
await expect(
sut.createPinKeyEncryptedUserKey(mockPin, mockUserKey, undefined),
).rejects.toThrow("User ID is required. Cannot create pinKeyEncryptedUserKey.");
await expect(sut.getPinLockType(undefined)).rejects.toThrow("Cannot get PinLockType.");
await expect(sut.isPinSet(undefined)).rejects.toThrow(
"User ID is required. Cannot determine if PIN is set.",
);
});
});
describe("get/clear/create/store pinKeyEncryptedUserKey methods", () => {
describe("getPinKeyEncryptedUserKeyPersistent()", () => {
it("should get the pinKeyEncryptedUserKey of the specified userId", async () => {
await sut.getPinKeyEncryptedUserKeyPersistent(mockUserId);
expect(stateProvider.mock.getUserState$).toHaveBeenCalledWith(
PIN_KEY_ENCRYPTED_USER_KEY_PERSISTENT,
mockUserId,
);
});
});
describe("clearPinKeyEncryptedUserKeyPersistent()", () => {
it("should clear the pinKeyEncryptedUserKey of the specified userId", async () => {
await sut.clearPinKeyEncryptedUserKeyPersistent(mockUserId);
expect(stateProvider.mock.setUserState).toHaveBeenCalledWith(
PIN_KEY_ENCRYPTED_USER_KEY_PERSISTENT,
null,
mockUserId,
);
});
});
describe("getPinKeyEncryptedUserKeyEphemeral()", () => {
it("should get the pinKeyEncrypterUserKeyEphemeral of the specified userId", async () => {
await sut.getPinKeyEncryptedUserKeyEphemeral(mockUserId);
expect(stateProvider.mock.getUserState$).toHaveBeenCalledWith(
PIN_KEY_ENCRYPTED_USER_KEY_EPHEMERAL,
mockUserId,
);
});
});
describe("clearPinKeyEncryptedUserKeyEphemeral()", () => {
it("should clear the pinKeyEncryptedUserKey of the specified userId", async () => {
await sut.clearPinKeyEncryptedUserKeyEphemeral(mockUserId);
expect(stateProvider.mock.setUserState).toHaveBeenCalledWith(
PIN_KEY_ENCRYPTED_USER_KEY_EPHEMERAL,
null,
mockUserId,
);
});
});
describe("createPinKeyEncryptedUserKey()", () => {
it("should throw an error if a userKey is not provided", async () => {
await expect(
sut.createPinKeyEncryptedUserKey(mockPin, undefined, mockUserId),
).rejects.toThrow("No UserKey provided. Cannot create pinKeyEncryptedUserKey.");
});
it("should create a pinKeyEncryptedUserKey", async () => {
// Arrange
sut.makePinKey = jest.fn().mockResolvedValue(mockPinKey);
// Act
await sut.createPinKeyEncryptedUserKey(mockPin, mockUserKey, mockUserId);
// Assert
expect(encryptService.wrapSymmetricKey).toHaveBeenCalledWith(mockUserKey, mockPinKey);
});
});
describe("storePinKeyEncryptedUserKey", () => {
it("should store a pinKeyEncryptedUserKey (persistent version) when 'storeAsEphemeral' is false", async () => {
// Arrange
const storeAsEphemeral = false;
// Act
await sut.storePinKeyEncryptedUserKey(
pinKeyEncryptedUserKeyPersistant,
storeAsEphemeral,
mockUserId,
);
// Assert
expect(stateProvider.mock.setUserState).toHaveBeenCalledWith(
PIN_KEY_ENCRYPTED_USER_KEY_PERSISTENT,
pinKeyEncryptedUserKeyPersistant.encryptedString,
mockUserId,
);
});
it("should store a pinKeyEncryptedUserKeyEphemeral when 'storeAsEphemeral' is true", async () => {
// Arrange
const storeAsEphemeral = true;
// Act
await sut.storePinKeyEncryptedUserKey(
pinKeyEncryptedUserKeyEphemeral,
storeAsEphemeral,
mockUserId,
);
// Assert
expect(stateProvider.mock.setUserState).toHaveBeenCalledWith(
PIN_KEY_ENCRYPTED_USER_KEY_EPHEMERAL,
pinKeyEncryptedUserKeyEphemeral.encryptedString,
mockUserId,
);
});
});
});
describe("userKeyEncryptedPin methods", () => {
describe("getUserKeyEncryptedPin()", () => {
it("should get the userKeyEncryptedPin of the specified userId", async () => {
await sut.getUserKeyEncryptedPin(mockUserId);
expect(stateProvider.mock.getUserState$).toHaveBeenCalledWith(
USER_KEY_ENCRYPTED_PIN,
mockUserId,
);
});
});
describe("setUserKeyEncryptedPin()", () => {
it("should set the userKeyEncryptedPin of the specified userId", async () => {
await sut.setUserKeyEncryptedPin(mockUserKeyEncryptedPin, mockUserId);
expect(stateProvider.mock.setUserState).toHaveBeenCalledWith(
USER_KEY_ENCRYPTED_PIN,
mockUserKeyEncryptedPin.encryptedString,
mockUserId,
);
});
});
describe("clearUserKeyEncryptedPin()", () => {
it("should clear the pinKeyEncryptedUserKey of the specified userId", async () => {
await sut.clearUserKeyEncryptedPin(mockUserId);
expect(stateProvider.mock.setUserState).toHaveBeenCalledWith(
USER_KEY_ENCRYPTED_PIN,
null,
mockUserId,
);
});
});
describe("createUserKeyEncryptedPin()", () => {
it("should throw an error if a userKey is not provided", async () => {
await expect(sut.createUserKeyEncryptedPin(mockPin, undefined)).rejects.toThrow(
"No UserKey provided. Cannot create userKeyEncryptedPin.",
);
});
it("should create a userKeyEncryptedPin from the provided PIN and userKey", async () => {
encryptService.encryptString.mockResolvedValue(mockUserKeyEncryptedPin);
const result = await sut.createUserKeyEncryptedPin(mockPin, mockUserKey);
expect(encryptService.encryptString).toHaveBeenCalledWith(mockPin, mockUserKey);
expect(result).toEqual(mockUserKeyEncryptedPin);
});
});
});
describe("makePinKey()", () => {
it("should make a PinKey", async () => {
// Arrange
keyGenerationService.deriveKeyFromPassword.mockResolvedValue(mockPinKey);
// Act
await sut.makePinKey(mockPin, mockUserEmail, DEFAULT_KDF_CONFIG);
// Assert
expect(keyGenerationService.deriveKeyFromPassword).toHaveBeenCalledWith(
mockPin,
mockUserEmail,
DEFAULT_KDF_CONFIG,
);
expect(keyGenerationService.stretchKey).toHaveBeenCalledWith(mockPinKey);
});
});
describe("getPinLockType()", () => {
it("should return 'PERSISTENT' if a pinKeyEncryptedUserKey (persistent version) is found", async () => {
// Arrange
sut.getUserKeyEncryptedPin = jest.fn().mockResolvedValue(null);
sut.getPinKeyEncryptedUserKeyPersistent = jest
.fn()
.mockResolvedValue(pinKeyEncryptedUserKeyPersistant);
// Act
const result = await sut.getPinLockType(mockUserId);
// Assert
expect(result).toBe("PERSISTENT");
});
it("should return 'EPHEMERAL' if a pinKeyEncryptedUserKey (persistent version) is not found but a userKeyEncryptedPin is found", async () => {
// Arrange
sut.getUserKeyEncryptedPin = jest.fn().mockResolvedValue(mockUserKeyEncryptedPin);
sut.getPinKeyEncryptedUserKeyPersistent = jest.fn().mockResolvedValue(null);
// Act
const result = await sut.getPinLockType(mockUserId);
// Assert
expect(result).toBe("EPHEMERAL");
});
it("should return 'DISABLED' if both of these are NOT found: userKeyEncryptedPin, pinKeyEncryptedUserKey (persistent version)", async () => {
// Arrange
sut.getUserKeyEncryptedPin = jest.fn().mockResolvedValue(null);
sut.getPinKeyEncryptedUserKeyPersistent = jest.fn().mockResolvedValue(null);
// Act
const result = await sut.getPinLockType(mockUserId);
// Assert
expect(result).toBe("DISABLED");
});
});
describe("isPinSet()", () => {
it.each(["PERSISTENT", "EPHEMERAL"])(
"should return true if the user PinLockType is '%s'",
async () => {
// Arrange
sut.getPinLockType = jest.fn().mockResolvedValue("PERSISTENT");
// Act
const result = await sut.isPinSet(mockUserId);
// Assert
expect(result).toEqual(true);
},
);
it("should return false if the user PinLockType is 'DISABLED'", async () => {
// Arrange
sut.getPinLockType = jest.fn().mockResolvedValue("DISABLED");
// Act
const result = await sut.isPinSet(mockUserId);
// Assert
expect(result).toEqual(false);
});
});
describe("isPinDecryptionAvailable()", () => {
it("should return false if pinLockType is DISABLED", async () => {
// Arrange
sut.getPinLockType = jest.fn().mockResolvedValue("DISABLED");
// Act
const result = await sut.isPinDecryptionAvailable(mockUserId);
// Assert
expect(result).toBe(false);
});
it("should return true if pinLockType is PERSISTENT", async () => {
// Arrange
sut.getPinLockType = jest.fn().mockResolvedValue("PERSISTENT");
// Act
const result = await sut.isPinDecryptionAvailable(mockUserId);
// Assert
expect(result).toBe(true);
});
it("should return true if pinLockType is EPHEMERAL and we have an ephemeral PIN key encrypted user key", async () => {
// Arrange
sut.getPinLockType = jest.fn().mockResolvedValue("EPHEMERAL");
sut.getPinKeyEncryptedUserKeyEphemeral = jest
.fn()
.mockResolvedValue(pinKeyEncryptedUserKeyEphemeral);
// Act
const result = await sut.isPinDecryptionAvailable(mockUserId);
// Assert
expect(result).toBe(true);
});
it("should return false if pinLockType is EPHEMERAL and we do not have an ephemeral PIN key encrypted user key", async () => {
// Arrange
sut.getPinLockType = jest.fn().mockResolvedValue("EPHEMERAL");
sut.getPinKeyEncryptedUserKeyEphemeral = jest.fn().mockResolvedValue(null);
// Act
const result = await sut.isPinDecryptionAvailable(mockUserId);
// Assert
expect(result).toBe(false);
});
it("should throw an error if an unexpected pinLockType is returned", async () => {
// Arrange
sut.getPinLockType = jest.fn().mockResolvedValue("UNKNOWN");
// Act & Assert
await expect(sut.isPinDecryptionAvailable(mockUserId)).rejects.toThrow(
"Unexpected pinLockType: UNKNOWN",
);
});
});
describe("decryptUserKeyWithPin()", () => {
async function setupDecryptUserKeyWithPinMocks(pinLockType: PinLockType) {
sut.getPinLockType = jest.fn().mockResolvedValue(pinLockType);
mockPinEncryptedKeyDataByPinLockType(pinLockType);
kdfConfigService.getKdfConfig.mockResolvedValue(DEFAULT_KDF_CONFIG);
mockDecryptUserKeyFn();
sut.getUserKeyEncryptedPin = jest.fn().mockResolvedValue(mockUserKeyEncryptedPin);
encryptService.decryptString.mockResolvedValue(mockPin);
cryptoFunctionService.compareFast.calledWith(mockPin, "1234").mockResolvedValue(true);
}
function mockDecryptUserKeyFn() {
sut.getPinKeyEncryptedUserKeyPersistent = jest
.fn()
.mockResolvedValue(pinKeyEncryptedUserKeyPersistant);
sut.makePinKey = jest.fn().mockResolvedValue(mockPinKey);
encryptService.unwrapSymmetricKey.mockResolvedValue(mockUserKey);
}
function mockPinEncryptedKeyDataByPinLockType(pinLockType: PinLockType) {
switch (pinLockType) {
case "PERSISTENT":
sut.getPinKeyEncryptedUserKeyPersistent = jest
.fn()
.mockResolvedValue(pinKeyEncryptedUserKeyPersistant);
break;
case "EPHEMERAL":
sut.getPinKeyEncryptedUserKeyEphemeral = jest
.fn()
.mockResolvedValue(pinKeyEncryptedUserKeyEphemeral);
break;
case "DISABLED":
// no mocking required. Error should be thrown
break;
}
}
const testCases: { pinLockType: PinLockType }[] = [
{ pinLockType: "PERSISTENT" },
{ pinLockType: "EPHEMERAL" },
];
testCases.forEach(({ pinLockType }) => {
describe(`given a ${pinLockType} PIN)`, () => {
it(`should successfully decrypt and return user key when using a valid PIN`, async () => {
// Arrange
await setupDecryptUserKeyWithPinMocks(pinLockType);
// Act
const result = await sut.decryptUserKeyWithPin(mockPin, mockUserId);
// Assert
expect(result).toEqual(mockUserKey);
});
it(`should return null when PIN is incorrect and user key cannot be decrypted`, async () => {
// Arrange
await setupDecryptUserKeyWithPinMocks(pinLockType);
sut.decryptUserKeyWithPin = jest.fn().mockResolvedValue(null);
// Act
const result = await sut.decryptUserKeyWithPin(mockPin, mockUserId);
// Assert
expect(result).toBeNull();
});
// not sure if this is a realistic scenario but going to test it anyway
it(`should return null when PIN doesn't match after successful user key decryption`, async () => {
// Arrange
await setupDecryptUserKeyWithPinMocks(pinLockType);
encryptService.decryptString.mockResolvedValue("9999"); // non matching PIN
// Act
const result = await sut.decryptUserKeyWithPin(mockPin, mockUserId);
// Assert
expect(result).toBeNull();
});
});
});
it(`should return null when pin is disabled`, async () => {
// Arrange
await setupDecryptUserKeyWithPinMocks("DISABLED");
// Act
const result = await sut.decryptUserKeyWithPin(mockPin, mockUserId);
// Assert
expect(result).toBeNull();
});
});
});
// Test helpers
function randomBytes(length: number): Uint8Array {
return new Uint8Array(Array.from({ length }, (_, k) => k % 255));
}

View File

@@ -2,9 +2,6 @@
// @ts-strict-ignore
import { firstValueFrom, map, timeout } from "rxjs";
// This import has been flagged as unallowed for this class. It may be involved in a circular dependency loop.
// eslint-disable-next-line no-restricted-imports
import { PinServiceAbstraction } from "@bitwarden/auth/common";
// This import has been flagged as unallowed for this class. It may be involved in a circular dependency loop.
// eslint-disable-next-line no-restricted-imports
import { BiometricStateService } from "@bitwarden/key-management";
@@ -20,6 +17,7 @@ import { LogService } from "../../platform/abstractions/log.service";
import { MessagingService } from "../../platform/abstractions/messaging.service";
import { UserId } from "../../types/guid";
import { ProcessReloadServiceAbstraction } from "../abstractions/process-reload.service";
import { PinServiceAbstraction } from "../pin/pin.service.abstraction";
export class DefaultProcessReloadService implements ProcessReloadServiceAbstraction {
private reloadInterval: any = null;

View File

@@ -6,7 +6,6 @@ import { BehaviorSubject, firstValueFrom, map, of } from "rxjs";
// This import has been flagged as unallowed for this class. It may be involved in a circular dependency loop.
// eslint-disable-next-line no-restricted-imports
import {
PinServiceAbstraction,
FakeUserDecryptionOptions as UserDecryptionOptions,
UserDecryptionOptionsServiceAbstraction,
} from "@bitwarden/auth/common";
@@ -21,6 +20,7 @@ import { TokenService } from "../../../auth/services/token.service";
import { LogService } from "../../../platform/abstractions/log.service";
import { Utils } from "../../../platform/misc/utils";
import { UserId } from "../../../types/guid";
import { PinServiceAbstraction } from "../../pin/pin.service.abstraction";
import { VaultTimeoutSettingsService as VaultTimeoutSettingsServiceAbstraction } from "../abstractions/vault-timeout-settings.service";
import { VaultTimeoutAction } from "../enums/vault-timeout-action.enum";
import { VaultTimeout, VaultTimeoutStringType } from "../types/vault-timeout.type";

View File

@@ -16,10 +16,7 @@ import {
// This import has been flagged as unallowed for this class. It may be involved in a circular dependency loop.
// eslint-disable-next-line no-restricted-imports
import {
PinServiceAbstraction,
UserDecryptionOptionsServiceAbstraction,
} from "@bitwarden/auth/common";
import { UserDecryptionOptionsServiceAbstraction } from "@bitwarden/auth/common";
// This import has been flagged as unallowed for this class. It may be involved in a circular dependency loop.
// eslint-disable-next-line no-restricted-imports
import { BiometricStateService, KeyService } from "@bitwarden/key-management";
@@ -33,6 +30,7 @@ import { TokenService } from "../../../auth/abstractions/token.service";
import { LogService } from "../../../platform/abstractions/log.service";
import { StateProvider } from "../../../platform/state";
import { UserId } from "../../../types/guid";
import { PinServiceAbstraction } from "../../pin/pin.service.abstraction";
import { VaultTimeoutSettingsService as VaultTimeoutSettingsServiceAbstraction } from "../abstractions/vault-timeout-settings.service";
import { VaultTimeoutAction } from "../enums/vault-timeout-action.enum";
import { VaultTimeout, VaultTimeoutStringType } from "../types/vault-timeout.type";

View File

@@ -143,10 +143,6 @@ export class VaultTimeoutService implements VaultTimeoutServiceAbstraction {
),
);
if (userId == null || userId === currentUserId) {
await this.collectionService.clearActiveUserCache();
}
await this.searchService.clearIndex(lockingUserId);
await this.folderService.clearDecryptedFolderState(lockingUserId);

View File

@@ -13,9 +13,9 @@ export const getById = <TId, T extends { id: TId }>(id: TId) =>
* @param id The IDs of the objects to return.
* @returns An array containing objects with matching IDs, or an empty array if there are no matching objects.
*/
export const getByIds = <TId, T extends { id: TId }>(ids: TId[]) => {
const idSet = new Set(ids);
export const getByIds = <TId, T extends { id: TId | undefined }>(ids: TId[]) => {
const idSet = new Set(ids.filter((id) => id != null));
return map<T[], T[]>((objects) => {
return objects.filter((o) => idSet.has(o.id));
return objects.filter((o) => o.id && idSet.has(o.id));
});
};

View File

@@ -252,6 +252,7 @@ export class Utils {
}
// ref: http://stackoverflow.com/a/2117523/1090359
/** @deprecated Use newGuid from @bitwarden/guid instead */
static newGuid(): string {
return "xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx".replace(/[xy]/g, (c) => {
const r = (Math.random() * 16) | 0;
@@ -260,8 +261,10 @@ export class Utils {
});
}
/** @deprecated Use guidRegex from @bitwarden/guid instead */
static guidRegex = /^[0-9a-f]{8}-(?:[0-9a-f]{4}-){3}[0-9a-f]{12}$/;
/** @deprecated Use isGuid from @bitwarden/guid instead */
static isGuid(id: string) {
return RegExp(Utils.guidRegex, "i").test(id);
}

View File

@@ -1,131 +0,0 @@
import { mock, MockProxy } from "jest-mock-extended";
import { makeEncString, makeSymmetricCryptoKey } from "../../../../spec";
import { EncryptService } from "../../../key-management/crypto/abstractions/encrypt.service";
import { EncString } from "../../../key-management/crypto/models/enc-string";
import Domain from "./domain-base";
class TestDomain extends Domain {
plainText: string;
encToString: EncString;
encString2: EncString;
}
describe("DomainBase", () => {
let encryptService: MockProxy<EncryptService>;
const key = makeSymmetricCryptoKey(64);
beforeEach(() => {
encryptService = mock<EncryptService>();
});
function setUpCryptography() {
encryptService.encryptString.mockImplementation((value) =>
Promise.resolve(makeEncString(value)),
);
encryptService.decryptString.mockImplementation((value) => {
return Promise.resolve(value.data);
});
}
describe("decryptWithKey", () => {
it("domain property types are decryptable", async () => {
const domain = new TestDomain();
await domain["decryptObjWithKey"](
// @ts-expect-error -- clear is not of type EncString
["plainText"],
makeSymmetricCryptoKey(64),
mock<EncryptService>(),
);
await domain["decryptObjWithKey"](
// @ts-expect-error -- Clear is not of type EncString
["encToString", "encString2", "plainText"],
makeSymmetricCryptoKey(64),
mock<EncryptService>(),
);
const decrypted = await domain["decryptObjWithKey"](
["encToString"],
makeSymmetricCryptoKey(64),
mock<EncryptService>(),
);
// @ts-expect-error -- encString2 was not decrypted
// FIXME: Remove when updating file. Eslint update
// eslint-disable-next-line @typescript-eslint/no-unused-expressions
decrypted as { encToString: string; encString2: string; plainText: string };
// encString2 was not decrypted, so it's still an EncString
// FIXME: Remove when updating file. Eslint update
// eslint-disable-next-line @typescript-eslint/no-unused-expressions
decrypted as { encToString: string; encString2: EncString; plainText: string };
});
it("decrypts the encrypted properties", async () => {
setUpCryptography();
const domain = new TestDomain();
domain.encToString = await encryptService.encryptString("string", key);
const decrypted = await domain["decryptObjWithKey"](["encToString"], key, encryptService);
expect(decrypted).toEqual({
encToString: "string",
});
});
it("decrypts multiple encrypted properties", async () => {
setUpCryptography();
const domain = new TestDomain();
domain.encToString = await encryptService.encryptString("string", key);
domain.encString2 = await encryptService.encryptString("string2", key);
const decrypted = await domain["decryptObjWithKey"](
["encToString", "encString2"],
key,
encryptService,
);
expect(decrypted).toEqual({
encToString: "string",
encString2: "string2",
});
});
it("does not decrypt properties that are not encrypted", async () => {
const domain = new TestDomain();
domain.plainText = "clear";
const decrypted = await domain["decryptObjWithKey"]([], key, encryptService);
expect(decrypted).toEqual({
plainText: "clear",
});
});
it("does not decrypt properties that were not requested to be decrypted", async () => {
setUpCryptography();
const domain = new TestDomain();
domain.plainText = "clear";
domain.encToString = makeEncString("string");
domain.encString2 = makeEncString("string2");
const decrypted = await domain["decryptObjWithKey"]([], key, encryptService);
expect(decrypted).toEqual({
plainText: "clear",
encToString: makeEncString("string"),
encString2: makeEncString("string2"),
});
});
});
});

View File

@@ -1,6 +1,5 @@
import { ConditionalExcept, ConditionalKeys, Constructor } from "type-fest";
import { ConditionalExcept, ConditionalKeys } from "type-fest";
import { EncryptService } from "../../../key-management/crypto/abstractions/encrypt.service";
import { EncString } from "../../../key-management/crypto/models/enc-string";
import { View } from "../../../models/view/view";
@@ -14,7 +13,7 @@ export type DecryptedObject<
> = Record<TDecryptedKeys, string> & Omit<TEncryptedObject, TDecryptedKeys>;
// extracts shared keys from the domain and view types
type EncryptableKeys<D extends Domain, V extends View> = (keyof D &
export type EncryptableKeys<D extends Domain, V extends View> = (keyof D &
ConditionalKeys<D, EncString | null>) &
(keyof V & ConditionalKeys<V, string | null>);
@@ -89,66 +88,4 @@ export default class Domain {
return viewModel as V;
}
/**
* Decrypts the requested properties of the domain object with the provided key and encrypt service.
*
* If a property is null, the result will be null.
* @see {@link EncString.decryptWithKey} for more details on decryption behavior.
*
* @param encryptedProperties The properties to decrypt. Type restricted to EncString properties of the domain object.
* @param key The key to use for decryption.
* @param encryptService The encryption service to use for decryption.
* @param _ The constructor of the domain object. Used for type inference if the domain object is not automatically inferred.
* @returns An object with the requested properties decrypted and the rest of the domain object untouched.
*/
protected async decryptObjWithKey<
TThis extends Domain,
const TEncryptedKeys extends EncStringKeys<TThis>,
>(
this: TThis,
encryptedProperties: TEncryptedKeys[],
key: SymmetricCryptoKey,
encryptService: EncryptService,
_: Constructor<TThis> = this.constructor as Constructor<TThis>,
objectContext: string = "No Domain Context",
): Promise<DecryptedObject<TThis, TEncryptedKeys>> {
const decryptedObjects = [];
for (const prop of encryptedProperties) {
const value = this[prop] as EncString;
const decrypted = await this.decryptProperty(
prop,
value,
key,
encryptService,
`Property: ${prop.toString()}; ObjectContext: ${objectContext}`,
);
decryptedObjects.push(decrypted);
}
const decryptedObject = decryptedObjects.reduce(
(acc, obj) => {
return { ...acc, ...obj };
},
{ ...this },
);
return decryptedObject as DecryptedObject<TThis, TEncryptedKeys>;
}
private async decryptProperty<const TEncryptedKeys extends EncStringKeys<this>>(
propertyKey: TEncryptedKeys,
value: EncString,
key: SymmetricCryptoKey,
encryptService: EncryptService,
decryptTrace: string,
) {
let decrypted: string | null = null;
if (value) {
decrypted = await value.decryptWithKey(key, encryptService, decryptTrace);
}
return {
[propertyKey]: decrypted,
};
}
}

View File

@@ -1,8 +1,9 @@
import { mock } from "jest-mock-extended";
import { MigrationHelper } from "@bitwarden/state";
import { FakeStorageService } from "../../../spec/fake-storage.service";
import { ClientType } from "../../enums";
import { MigrationHelper } from "../../state-migrations/migration-helper";
import { MigrationBuilderService } from "./migration-builder.service";

View File

@@ -1,7 +1,7 @@
import { CURRENT_VERSION, currentVersion, MigrationHelper } from "@bitwarden/state";
import { ClientType } from "../../enums";
import { waitForMigrations } from "../../state-migrations";
import { CURRENT_VERSION, currentVersion } from "../../state-migrations/migrate";
import { MigrationHelper } from "../../state-migrations/migration-helper";
import { LogService } from "../abstractions/log.service";
import { AbstractStorageService } from "../abstractions/storage.service";

View File

@@ -22,6 +22,7 @@ import {
ClientSettings,
DeviceType as SdkDeviceType,
TokenProvider,
UnsignedSharedKey,
} from "@bitwarden/sdk-internal";
import { EncryptedOrganizationKeyData } from "../../../admin-console/models/data/encrypted-organization-key.data";
@@ -237,7 +238,7 @@ export class DefaultSdkService implements SdkService {
organizationKeys: new Map(
Object.entries(orgKeys ?? {})
.filter(([_, v]) => v.type === "organization")
.map(([k, v]) => [k, v.key]),
.map(([k, v]) => [k, v.key as UnsignedSharedKey]),
),
});
}

View File

@@ -0,0 +1,11 @@
import { Observable } from "rxjs";
import { UserId } from "@bitwarden/user-core";
export abstract class ActiveUserAccessor {
/**
* Returns a stream of the current active user for the application. The stream either emits the user id for that account
* or returns null if there is no current active user.
*/
abstract activeUserId$: Observable<UserId | null>;
}

View File

@@ -1,42 +0,0 @@
import { DeriveDefinition } from "./derive-definition";
import { KeyDefinition } from "./key-definition";
import { StateDefinition } from "./state-definition";
const derive: () => any = () => null;
const deserializer: any = (obj: any) => obj;
const STATE_DEFINITION = new StateDefinition("test", "disk");
const TEST_KEY = new KeyDefinition(STATE_DEFINITION, "test", {
deserializer,
});
const TEST_DERIVE = new DeriveDefinition(STATE_DEFINITION, "test", {
derive,
deserializer,
});
describe("DeriveDefinition", () => {
describe("from", () => {
it("should create a new DeriveDefinition from a KeyDefinition", () => {
const result = DeriveDefinition.from(TEST_KEY, {
derive,
deserializer,
});
expect(result).toEqual(TEST_DERIVE);
});
it("should create a new DeriveDefinition from a DeriveDefinition", () => {
const result = DeriveDefinition.from([TEST_DERIVE, "newDerive"], {
derive,
deserializer,
});
expect(result).toEqual(
new DeriveDefinition(STATE_DEFINITION, "newDerive", {
derive,
deserializer,
}),
);
});
});
});

View File

@@ -1,196 +1 @@
// FIXME: Update this file to be type safe and remove this and next line
// @ts-strict-ignore
import { Jsonify } from "type-fest";
import { UserId } from "../../types/guid";
import { DerivedStateDependencies, StorageKey } from "../../types/state";
import { KeyDefinition } from "./key-definition";
import { StateDefinition } from "./state-definition";
import { UserKeyDefinition } from "./user-key-definition";
declare const depShapeMarker: unique symbol;
/**
* A set of options for customizing the behavior of a {@link DeriveDefinition}
*/
type DeriveDefinitionOptions<TFrom, TTo, TDeps extends DerivedStateDependencies = never> = {
/**
* A function to use to convert values from TFrom to TTo. This is called on each emit of the parent state observable
* and the resulting value will be emitted from the derived state observable.
*
* @param from Populated with the latest emission from the parent state observable.
* @param deps Populated with the dependencies passed into the constructor of the derived state.
* These are constant for the lifetime of the derived state.
* @returns The derived state value or a Promise that resolves to the derived state value.
*/
derive: (from: TFrom, deps: TDeps) => TTo | Promise<TTo>;
/**
* A function to use to safely convert your type from json to your expected type.
*
* **Important:** Your data may be serialized/deserialized at any time and this
* callback needs to be able to faithfully re-initialize from the JSON object representation of your type.
*
* @param jsonValue The JSON object representation of your state.
* @returns The fully typed version of your state.
*/
deserializer: (serialized: Jsonify<TTo>) => TTo;
/**
* An object defining the dependencies of the derive function. The keys of the object are the names of the dependencies
* and the values are the types of the dependencies.
*
* for example:
* ```
* {
* myService: MyService,
* myOtherService: MyOtherService,
* }
* ```
*/
[depShapeMarker]?: TDeps;
/**
* The number of milliseconds to wait before cleaning up the state after the last subscriber has unsubscribed.
* Defaults to 1000ms.
*/
cleanupDelayMs?: number;
/**
* Whether or not to clear the derived state when cleanup occurs. Defaults to true.
*/
clearOnCleanup?: boolean;
};
/**
* DeriveDefinitions describe state derived from another observable, the value type of which is given by `TFrom`.
*
* The StateDefinition is used to describe the domain of the state, and the DeriveDefinition
* sub-divides that domain into specific keys. These keys are used to cache data in memory and enables derived state to
* be calculated once regardless of multiple execution contexts.
*/
export class DeriveDefinition<TFrom, TTo, TDeps extends DerivedStateDependencies> {
/**
* Creates a new instance of a DeriveDefinition. Derived state is always stored in memory, so the storage location
* defined in @link{StateDefinition} is ignored.
*
* @param stateDefinition The state definition for which this key belongs to.
* @param uniqueDerivationName The name of the key, this should be unique per domain.
* @param options A set of options to customize the behavior of {@link DeriveDefinition}.
* @param options.derive A function to use to convert values from TFrom to TTo. This is called on each emit of the parent state observable
* and the resulting value will be emitted from the derived state observable.
* @param options.cleanupDelayMs The number of milliseconds to wait before cleaning up the state after the last subscriber has unsubscribed.
* Defaults to 1000ms.
* @param options.dependencyShape An object defining the dependencies of the derive function. The keys of the object are the names of the dependencies
* and the values are the types of the dependencies.
* for example:
* ```
* {
* myService: MyService,
* myOtherService: MyOtherService,
* }
* ```
*
* @param options.deserializer A function to use to safely convert your type from json to your expected type.
* Your data may be serialized/deserialized at any time and this needs callback needs to be able to faithfully re-initialize
* from the JSON object representation of your type.
*/
constructor(
readonly stateDefinition: StateDefinition,
readonly uniqueDerivationName: string,
readonly options: DeriveDefinitionOptions<TFrom, TTo, TDeps>,
) {}
/**
* Factory that produces a {@link DeriveDefinition} from a {@link KeyDefinition} or {@link DeriveDefinition} and new name.
*
* If a `KeyDefinition` is passed in, the returned definition will have the same key as the given key definition, but
* will not collide with it in storage, even if they both reside in memory.
*
* If a `DeriveDefinition` is passed in, the returned definition will instead use the name given in the second position
* of the tuple. It is up to you to ensure this is unique within the domain of derived state.
*
* @param options A set of options to customize the behavior of {@link DeriveDefinition}.
* @param options.derive A function to use to convert values from TFrom to TTo. This is called on each emit of the parent state observable
* and the resulting value will be emitted from the derived state observable.
* @param options.cleanupDelayMs The number of milliseconds to wait before cleaning up the state after the last subscriber has unsubscribed.
* Defaults to 1000ms.
* @param options.dependencyShape An object defining the dependencies of the derive function. The keys of the object are the names of the dependencies
* and the values are the types of the dependencies.
* for example:
* ```
* {
* myService: MyService,
* myOtherService: MyOtherService,
* }
* ```
*
* @param options.deserializer A function to use to safely convert your type from json to your expected type.
* Your data may be serialized/deserialized at any time and this needs callback needs to be able to faithfully re-initialize
* from the JSON object representation of your type.
* @param definition
* @param options
* @returns
*/
static from<TFrom, TTo, TDeps extends DerivedStateDependencies = never>(
definition:
| KeyDefinition<TFrom>
| UserKeyDefinition<TFrom>
| [DeriveDefinition<unknown, TFrom, DerivedStateDependencies>, string],
options: DeriveDefinitionOptions<TFrom, TTo, TDeps>,
) {
if (isFromDeriveDefinition(definition)) {
return new DeriveDefinition(definition[0].stateDefinition, definition[1], options);
} else {
return new DeriveDefinition(definition.stateDefinition, definition.key, options);
}
}
static fromWithUserId<TKeyDef, TTo, TDeps extends DerivedStateDependencies = never>(
definition:
| KeyDefinition<TKeyDef>
| UserKeyDefinition<TKeyDef>
| [DeriveDefinition<unknown, TKeyDef, DerivedStateDependencies>, string],
options: DeriveDefinitionOptions<[UserId, TKeyDef], TTo, TDeps>,
) {
if (isFromDeriveDefinition(definition)) {
return new DeriveDefinition(definition[0].stateDefinition, definition[1], options);
} else {
return new DeriveDefinition(definition.stateDefinition, definition.key, options);
}
}
get derive() {
return this.options.derive;
}
deserialize(serialized: Jsonify<TTo>): TTo {
return this.options.deserializer(serialized);
}
get cleanupDelayMs() {
return this.options.cleanupDelayMs < 0 ? 0 : (this.options.cleanupDelayMs ?? 1000);
}
get clearOnCleanup() {
return this.options.clearOnCleanup ?? true;
}
buildCacheKey(): string {
return `derived_${this.stateDefinition.name}_${this.uniqueDerivationName}`;
}
/**
* Creates a {@link StorageKey} that points to the data for the given derived definition.
* @returns A key that is ready to be used in a storage service to get data.
*/
get storageKey(): StorageKey {
return `derived_${this.stateDefinition.name}_${this.uniqueDerivationName}` as StorageKey;
}
}
function isFromDeriveDefinition(
definition:
| KeyDefinition<unknown>
| UserKeyDefinition<unknown>
| [DeriveDefinition<unknown, unknown, DerivedStateDependencies>, string],
): definition is [DeriveDefinition<unknown, unknown, DerivedStateDependencies>, string] {
return Array.isArray(definition);
}
export { DeriveDefinition } from "@bitwarden/state";

View File

@@ -1,25 +1 @@
import { Observable } from "rxjs";
import { DerivedStateDependencies } from "../../types/state";
import { DeriveDefinition } from "./derive-definition";
import { DerivedState } from "./derived-state";
/**
* State derived from an observable and a derive function
*/
export abstract class DerivedStateProvider {
/**
* Creates a derived state observable from a parent state observable, a deriveDefinition, and the dependencies
* required by the deriveDefinition
* @param parentState$ The parent state observable
* @param deriveDefinition The deriveDefinition that defines conversion from the parent state to the derived state as
* well as some memory persistent information.
* @param dependencies The dependencies of the derive function
*/
abstract get<TFrom, TTo, TDeps extends DerivedStateDependencies>(
parentState$: Observable<TFrom>,
deriveDefinition: DeriveDefinition<TFrom, TTo, TDeps>,
dependencies: TDeps,
): DerivedState<TTo>;
}
export { DerivedStateProvider } from "@bitwarden/state";

View File

@@ -1,23 +1 @@
import { Observable } from "rxjs";
export type StateConverter<TFrom extends Array<unknown>, TTo> = (...args: TFrom) => TTo;
/**
* State derived from an observable and a converter function
*
* Derived state is cached and persisted to memory for sychronization across execution contexts.
* For clients with multiple execution contexts, the derived state will be executed only once in the background process.
*/
export interface DerivedState<T> {
/**
* The derived state observable
*/
state$: Observable<T>;
/**
* Forces the derived state to a given value.
*
* Useful for setting an in-memory value as a side effect of some event, such as emptying state as a result of a lock.
* @param value The value to force the derived state to
*/
forceValue(value: T): Promise<T>;
}
export { DerivedState } from "@bitwarden/state";

View File

@@ -1,25 +0,0 @@
import { record } from "./deserialization-helpers";
describe("deserialization helpers", () => {
describe("record", () => {
it("deserializes a record when keys are strings", () => {
const deserializer = record((value: number) => value);
const input = {
a: 1,
b: 2,
};
const output = deserializer(input);
expect(output).toEqual(input);
});
it("deserializes a record when keys are numbers", () => {
const deserializer = record((value: number) => value);
const input = {
1: 1,
2: 2,
};
const output = deserializer(input);
expect(output).toEqual(input);
});
});
});

View File

@@ -1,40 +0,0 @@
// FIXME: Update this file to be type safe and remove this and next line
// @ts-strict-ignore
import { Jsonify } from "type-fest";
/**
*
* @param elementDeserializer
* @returns
*/
export function array<T>(
elementDeserializer: (element: Jsonify<T>) => T,
): (array: Jsonify<T[]>) => T[] {
return (array) => {
if (array == null) {
return null;
}
return array.map((element) => elementDeserializer(element));
};
}
/**
*
* @param valueDeserializer
*/
export function record<T, TKey extends string | number = string>(
valueDeserializer: (value: Jsonify<T>) => T,
): (record: Jsonify<Record<TKey, T>>) => Record<TKey, T> {
return (jsonValue: Jsonify<Record<TKey, T> | null>) => {
if (jsonValue == null) {
return null;
}
const output: Record<TKey, T> = {} as any;
Object.entries(jsonValue).forEach(([key, value]) => {
output[key as TKey] = valueDeserializer(value);
});
return output;
};
}

View File

@@ -1,13 +1 @@
import { GlobalState } from "./global-state";
import { KeyDefinition } from "./key-definition";
/**
* A provider for getting an implementation of global state scoped to the given key.
*/
export abstract class GlobalStateProvider {
/**
* Gets a {@link GlobalState} scoped to the given {@link KeyDefinition}
* @param keyDefinition - The {@link KeyDefinition} for which you want the state for.
*/
abstract get<T>(keyDefinition: KeyDefinition<T>): GlobalState<T>;
}
export { GlobalStateProvider } from "@bitwarden/state";

View File

@@ -1,30 +1 @@
import { Observable } from "rxjs";
import { StateUpdateOptions } from "./state-update-options";
/**
* A helper object for interacting with state that is scoped to a specific domain
* but is not scoped to a user. This is application wide storage.
*/
export interface GlobalState<T> {
/**
* Method for allowing you to manipulate state in an additive way.
* @param configureState callback for how you want to manipulate this section of state
* @param options Defaults given by @see {module:state-update-options#DEFAULT_OPTIONS}
* @param options.shouldUpdate A callback for determining if you want to update state. Defaults to () => true
* @param options.combineLatestWith An observable that you want to combine with the current state for callbacks. Defaults to null
* @param options.msTimeout A timeout for how long you are willing to wait for a `combineLatestWith` option to complete. Defaults to 1000ms. Only applies if `combineLatestWith` is set.
* @returns A promise that must be awaited before your next action to ensure the update has been written to state.
* Resolves to the new state. If `shouldUpdate` returns false, the promise will resolve to the current state.
*/
update: <TCombine>(
configureState: (state: T | null, dependency: TCombine) => T | null,
options?: StateUpdateOptions<T, TCombine>,
) => Promise<T | null>;
/**
* An observable stream of this state, the first emission of this will be the current state on disk
* and subsequent updates will be from an update to that state.
*/
state$: Observable<T | null>;
}
export { GlobalState } from "@bitwarden/state";

View File

@@ -1,37 +0,0 @@
import { mock } from "jest-mock-extended";
import { mockAccountServiceWith, trackEmissions } from "../../../../spec";
import { UserId } from "../../../types/guid";
import { SingleUserStateProvider } from "../user-state.provider";
import { DefaultActiveUserStateProvider } from "./default-active-user-state.provider";
describe("DefaultActiveUserStateProvider", () => {
const singleUserStateProvider = mock<SingleUserStateProvider>();
const userId = "userId" as UserId;
const accountInfo = {
id: userId,
name: "name",
email: "email",
emailVerified: false,
};
const accountService = mockAccountServiceWith(userId, accountInfo);
let sut: DefaultActiveUserStateProvider;
beforeEach(() => {
sut = new DefaultActiveUserStateProvider(accountService, singleUserStateProvider);
});
afterEach(() => {
jest.resetAllMocks();
});
it("should track the active User id from account service", () => {
const emissions = trackEmissions(sut.activeUserId$);
accountService.activeAccountSubject.next(undefined);
accountService.activeAccountSubject.next(accountInfo);
expect(emissions).toEqual([userId, undefined, userId]);
});
});

View File

@@ -1,9 +1,9 @@
// FIXME: Update this file to be type safe and remove this and next line
// @ts-strict-ignore
import { Observable, distinctUntilChanged, map } from "rxjs";
import { Observable, distinctUntilChanged } from "rxjs";
import { AccountService } from "../../../auth/abstractions/account.service";
import { UserId } from "../../../types/guid";
import { ActiveUserAccessor } from "../active-user.accessor";
import { UserKeyDefinition } from "../user-key-definition";
import { ActiveUserState } from "../user-state";
import { ActiveUserStateProvider, SingleUserStateProvider } from "../user-state.provider";
@@ -14,11 +14,10 @@ export class DefaultActiveUserStateProvider implements ActiveUserStateProvider {
activeUserId$: Observable<UserId | undefined>;
constructor(
private readonly accountService: AccountService,
private readonly activeAccountAccessor: ActiveUserAccessor,
private readonly singleUserStateProvider: SingleUserStateProvider,
) {
this.activeUserId$ = this.accountService.activeAccount$.pipe(
map((account) => account?.id),
this.activeUserId$ = this.activeAccountAccessor.activeUserId$.pipe(
// To avoid going to storage when we don't need to, only get updates when there is a true change.
distinctUntilChanged((a, b) => (a == null || b == null ? a == b : a === b)), // Treat null and undefined as equal
);

View File

@@ -1,772 +0,0 @@
/**
* need to update test environment so trackEmissions works appropriately
* @jest-environment ../shared/test.environment.ts
*/
import { any, mock } from "jest-mock-extended";
import { BehaviorSubject, firstValueFrom, map, of, timeout } from "rxjs";
import { Jsonify } from "type-fest";
import { StorageServiceProvider } from "@bitwarden/storage-core";
import { awaitAsync, trackEmissions } from "../../../../spec";
import { FakeStorageService } from "../../../../spec/fake-storage.service";
import { Account } from "../../../auth/abstractions/account.service";
import { UserId } from "../../../types/guid";
import { LogService } from "../../abstractions/log.service";
import { StateDefinition } from "../state-definition";
import { StateEventRegistrarService } from "../state-event-registrar.service";
import { UserKeyDefinition } from "../user-key-definition";
import { DefaultActiveUserState } from "./default-active-user-state";
import { DefaultSingleUserStateProvider } from "./default-single-user-state.provider";
class TestState {
date: Date;
array: string[];
static fromJSON(jsonState: Jsonify<TestState>) {
if (jsonState == null) {
return null;
}
return Object.assign(new TestState(), jsonState, {
date: new Date(jsonState.date),
});
}
}
const testStateDefinition = new StateDefinition("fake", "disk");
const cleanupDelayMs = 15;
const testKeyDefinition = new UserKeyDefinition<TestState>(testStateDefinition, "fake", {
deserializer: TestState.fromJSON,
cleanupDelayMs,
clearOn: [],
});
describe("DefaultActiveUserState", () => {
let diskStorageService: FakeStorageService;
const storageServiceProvider = mock<StorageServiceProvider>();
const stateEventRegistrarService = mock<StateEventRegistrarService>();
const logService = mock<LogService>();
let activeAccountSubject: BehaviorSubject<Account | null>;
let singleUserStateProvider: DefaultSingleUserStateProvider;
let userState: DefaultActiveUserState<TestState>;
beforeEach(() => {
diskStorageService = new FakeStorageService();
storageServiceProvider.get.mockReturnValue(["disk", diskStorageService]);
singleUserStateProvider = new DefaultSingleUserStateProvider(
storageServiceProvider,
stateEventRegistrarService,
logService,
);
activeAccountSubject = new BehaviorSubject<Account | null>(null);
userState = new DefaultActiveUserState(
testKeyDefinition,
activeAccountSubject.asObservable().pipe(map((a) => a?.id)),
singleUserStateProvider,
);
});
afterEach(() => {
jest.resetAllMocks();
});
const makeUserId = (id: string) => {
return id != null ? (`00000000-0000-1000-a000-00000000000${id}` as UserId) : undefined;
};
const changeActiveUser = async (id: string) => {
const userId = makeUserId(id);
activeAccountSubject.next({
id: userId,
email: `test${id}@example.com`,
emailVerified: false,
name: `Test User ${id}`,
});
await awaitAsync();
};
afterEach(() => {
jest.resetAllMocks();
});
it("emits updates for each user switch and update", async () => {
const user1 = "user_00000000-0000-1000-a000-000000000001_fake_fake";
const user2 = "user_00000000-0000-1000-a000-000000000002_fake_fake";
const state1 = {
date: new Date(2021, 0),
array: ["user1"],
};
const state2 = {
date: new Date(2022, 0),
array: ["user2"],
};
const initialState: Record<string, TestState> = {};
initialState[user1] = state1;
initialState[user2] = state2;
diskStorageService.internalUpdateStore(initialState);
const emissions = trackEmissions(userState.state$);
// User signs in
await changeActiveUser("1");
// Service does an update
const updatedState = {
date: new Date(2023, 0),
array: ["user1-update"],
};
await userState.update(() => updatedState);
await awaitAsync();
// Emulate an account switch
await changeActiveUser("2");
// #1 initial state from user1
// #2 updated state for user1
// #3 switched state to initial state for user2
expect(emissions).toEqual([state1, updatedState, state2]);
// Should be called 4 time to get state, update state for user, emitting update, and switching users
expect(diskStorageService.mock.get).toHaveBeenCalledTimes(4);
// Initial subscribe to state$
expect(diskStorageService.mock.get).toHaveBeenNthCalledWith(
1,
"user_00000000-0000-1000-a000-000000000001_fake_fake",
any(), // options
);
// The updating of state for user1
expect(diskStorageService.mock.get).toHaveBeenNthCalledWith(
2,
"user_00000000-0000-1000-a000-000000000001_fake_fake",
any(), // options
);
// The emission from being actively subscribed to user1
expect(diskStorageService.mock.get).toHaveBeenNthCalledWith(
3,
"user_00000000-0000-1000-a000-000000000001_fake_fake",
any(), // options
);
// Switch to user2
expect(diskStorageService.mock.get).toHaveBeenNthCalledWith(
4,
"user_00000000-0000-1000-a000-000000000002_fake_fake",
any(), // options
);
// Should only have saved data for the first user
expect(diskStorageService.mock.save).toHaveBeenCalledTimes(1);
expect(diskStorageService.mock.save).toHaveBeenNthCalledWith(
1,
"user_00000000-0000-1000-a000-000000000001_fake_fake",
updatedState,
any(), // options
);
});
it("will not emit any value if there isn't an active user", async () => {
let resolvedValue: TestState | undefined = undefined;
let rejectedError: Error | undefined = undefined;
const promise = firstValueFrom(userState.state$.pipe(timeout(20)))
.then((value) => {
resolvedValue = value;
})
.catch((err) => {
rejectedError = err;
});
await promise;
expect(diskStorageService.mock.get).not.toHaveBeenCalled();
expect(resolvedValue).toBe(undefined);
expect(rejectedError).toBeTruthy();
expect(rejectedError.message).toBe("Timeout has occurred");
});
it("will emit value for a new active user after subscription started", async () => {
let resolvedValue: TestState | undefined = undefined;
let rejectedError: Error | undefined = undefined;
diskStorageService.internalUpdateStore({
"user_00000000-0000-1000-a000-000000000001_fake_fake": {
date: new Date(2020, 0),
array: ["testValue"],
} as TestState,
});
const promise = firstValueFrom(userState.state$.pipe(timeout(20)))
.then((value) => {
resolvedValue = value;
})
.catch((err) => {
rejectedError = err;
});
await changeActiveUser("1");
await promise;
expect(diskStorageService.mock.get).toHaveBeenCalledTimes(1);
expect(resolvedValue).toBeTruthy();
expect(resolvedValue.array).toHaveLength(1);
expect(resolvedValue.date.getFullYear()).toBe(2020);
expect(rejectedError).toBeFalsy();
});
it("should not emit a previous users value if that user is no longer active", async () => {
const user1Data: Jsonify<TestState> = {
date: "2020-09-21T13:14:17.648Z",
// NOTE: `as any` is here until we migrate to Nx: https://bitwarden.atlassian.net/browse/PM-6493
array: ["value"] as any,
};
const user2Data: Jsonify<TestState> = {
date: "2020-09-21T13:14:17.648Z",
array: [],
};
diskStorageService.internalUpdateStore({
"user_00000000-0000-1000-a000-000000000001_fake_fake": user1Data,
"user_00000000-0000-1000-a000-000000000002_fake_fake": user2Data,
});
// This starts one subscription on the observable for tracking emissions throughout
// the whole test.
const emissions = trackEmissions(userState.state$);
// Change to a user with data
await changeActiveUser("1");
// This should always return a value right await
const value = await firstValueFrom(
userState.state$.pipe(
timeout({
first: 20,
with: () => {
throw new Error("Did not emit data from newly active user.");
},
}),
),
);
expect(value).toEqual(user1Data);
// Make it such that there is no active user
await changeActiveUser(undefined);
let resolvedValue: TestState | undefined = undefined;
let rejectedError: Error | undefined = undefined;
// Even if the observable has previously emitted a value it shouldn't have
// a value for the user subscribing to it because there isn't an active user
// to get data for.
await firstValueFrom(userState.state$.pipe(timeout(20)))
.then((value) => {
resolvedValue = value;
})
.catch((err) => {
rejectedError = err;
});
expect(resolvedValue).toBeUndefined();
expect(rejectedError).not.toBeUndefined();
expect(rejectedError.message).toBe("Timeout has occurred");
// We need to figure out if something should be emitted
// when there becomes no active user, if we don't want that to emit
// this value is correct.
expect(emissions).toEqual([user1Data]);
});
it("should not emit twice if there are two listeners", async () => {
await changeActiveUser("1");
const emissions = trackEmissions(userState.state$);
const emissions2 = trackEmissions(userState.state$);
await awaitAsync();
expect(emissions).toEqual([
null, // Initial value
]);
expect(emissions2).toEqual([
null, // Initial value
]);
});
describe("update", () => {
const newData = { date: new Date(), array: ["test"] };
beforeEach(async () => {
await changeActiveUser("1");
});
it("should save on update", async () => {
const [setUserId, result] = await userState.update((state, dependencies) => {
return newData;
});
expect(diskStorageService.mock.save).toHaveBeenCalledTimes(1);
expect(result).toEqual(newData);
expect(setUserId).toEqual("00000000-0000-1000-a000-000000000001");
});
it("should emit once per update", async () => {
const emissions = trackEmissions(userState.state$);
await awaitAsync(); // Need to await for the initial value to be emitted
await userState.update((state, dependencies) => {
return newData;
});
await awaitAsync();
expect(emissions).toEqual([
null, // initial value
newData,
]);
});
it("should provide combined dependencies", async () => {
const emissions = trackEmissions(userState.state$);
await awaitAsync(); // Need to await for the initial value to be emitted
const combinedDependencies = { date: new Date() };
await userState.update(
(state, dependencies) => {
expect(dependencies).toEqual(combinedDependencies);
return newData;
},
{
combineLatestWith: of(combinedDependencies),
},
);
await awaitAsync();
expect(emissions).toEqual([
null, // initial value
newData,
]);
});
it("should not update if shouldUpdate returns false", async () => {
const emissions = trackEmissions(userState.state$);
await awaitAsync(); // Need to await for the initial value to be emitted
const [userIdResult, result] = await userState.update(
(state, dependencies) => {
return newData;
},
{
shouldUpdate: () => false,
},
);
await awaitAsync();
expect(diskStorageService.mock.save).not.toHaveBeenCalled();
expect(userIdResult).toEqual("00000000-0000-1000-a000-000000000001");
expect(result).toBeNull();
expect(emissions).toEqual([null]);
});
it("should provide the current state to the update callback", async () => {
const emissions = trackEmissions(userState.state$);
await awaitAsync(); // Need to await for the initial value to be emitted
// Seed with interesting data
const initialData = { date: new Date(2020, 0), array: ["value1", "value2"] };
await userState.update((state, dependencies) => {
return initialData;
});
await awaitAsync();
await userState.update((state, dependencies) => {
expect(state).toEqual(initialData);
return newData;
});
await awaitAsync();
expect(emissions).toEqual([
null, // Initial value
initialData,
newData,
]);
});
it("should throw on an attempted update when there is no active user", async () => {
await changeActiveUser(undefined);
await expect(async () => await userState.update(() => null)).rejects.toThrow(
"No active user at this time.",
);
});
it("should throw on an attempted update where there is no active user even if there used to be one", async () => {
// Arrange
diskStorageService.internalUpdateStore({
"user_00000000-0000-1000-a000-000000000001_fake_fake": {
date: new Date(2019, 1),
array: [],
},
});
const [userId, state] = await firstValueFrom(userState.combinedState$);
expect(userId).toBe("00000000-0000-1000-a000-000000000001");
expect(state.date.getUTCFullYear()).toBe(2019);
await changeActiveUser(undefined);
// Act
await expect(async () => await userState.update(() => null)).rejects.toThrow(
"No active user at this time.",
);
});
it.each([null, undefined])(
"should register user key definition when state transitions from null-ish (%s) to non-null",
async (startingValue: TestState | null) => {
diskStorageService.internalUpdateStore({
"user_00000000-0000-1000-a000-000000000001_fake_fake": startingValue,
});
await userState.update(() => ({ array: ["one"], date: new Date() }));
expect(stateEventRegistrarService.registerEvents).toHaveBeenCalledWith(testKeyDefinition);
},
);
it("should not register user key definition when state has preexisting value", async () => {
diskStorageService.internalUpdateStore({
"user_00000000-0000-1000-a000-000000000001_fake_fake": {
date: new Date(2019, 1),
array: [],
},
});
await userState.update(() => ({ array: ["one"], date: new Date() }));
expect(stateEventRegistrarService.registerEvents).not.toHaveBeenCalled();
});
it.each([null, undefined])(
"should not register user key definition when setting value to null-ish (%s) value",
async (updatedValue: TestState | null) => {
diskStorageService.internalUpdateStore({
"user_00000000-0000-1000-a000-000000000001_fake_fake": {
date: new Date(2019, 1),
array: [],
},
});
await userState.update(() => updatedValue);
expect(stateEventRegistrarService.registerEvents).not.toHaveBeenCalled();
},
);
});
describe("update races", () => {
const newData = { date: new Date(), array: ["test"] };
const userId = makeUserId("1");
beforeEach(async () => {
await changeActiveUser("1");
await awaitAsync();
});
test("subscriptions during an update should receive the current and latest", async () => {
const oldData = { date: new Date(2019, 1, 1), array: ["oldValue1"] };
await userState.update(() => {
return oldData;
});
const initialData = { date: new Date(2020, 1, 1), array: ["value1", "value2"] };
await userState.update(() => {
return initialData;
});
await awaitAsync();
const emissions = trackEmissions(userState.state$);
await awaitAsync();
expect(emissions).toEqual([initialData]);
let emissions2: TestState[];
const originalSave = diskStorageService.save.bind(diskStorageService);
diskStorageService.save = jest.fn().mockImplementation(async (key: string, obj: any) => {
emissions2 = trackEmissions(userState.state$);
await originalSave(key, obj);
});
const [userIdResult, val] = await userState.update(() => {
return newData;
});
await awaitAsync(10);
expect(userIdResult).toEqual(userId);
expect(val).toEqual(newData);
expect(emissions).toEqual([initialData, newData]);
expect(emissions2).toEqual([initialData, newData]);
});
test("subscription during an aborted update should receive the last value", async () => {
// Seed with interesting data
const initialData = { date: new Date(2020, 1, 1), array: ["value1", "value2"] };
await userState.update(() => {
return initialData;
});
await awaitAsync();
const emissions = trackEmissions(userState.state$);
await awaitAsync();
expect(emissions).toEqual([initialData]);
let emissions2: TestState[];
const [userIdResult, val] = await userState.update(
(state) => {
return newData;
},
{
shouldUpdate: () => {
emissions2 = trackEmissions(userState.state$);
return false;
},
},
);
await awaitAsync();
expect(userIdResult).toEqual(userId);
expect(val).toEqual(initialData);
expect(emissions).toEqual([initialData]);
expect(emissions2).toEqual([initialData]);
});
test("updates should wait until previous update is complete", async () => {
trackEmissions(userState.state$);
await awaitAsync(); // storage updates are behind a promise
const originalSave = diskStorageService.save.bind(diskStorageService);
diskStorageService.save = jest
.fn()
.mockImplementationOnce(async (key: string, obj: any) => {
let resolved = false;
await Promise.race([
userState.update(() => {
// deadlocks
resolved = true;
return newData;
}),
awaitAsync(100), // limit test to 100ms
]);
expect(resolved).toBe(false);
})
.mockImplementation((...args) => {
return originalSave(...args);
});
await userState.update(() => {
return newData;
});
});
test("updates with FAKE_DEFAULT initial value should resolve correctly", async () => {
expect(diskStorageService["updatesSubject"]["observers"]).toHaveLength(0);
const [userIdResult, val] = await userState.update((state) => {
return newData;
});
expect(userIdResult).toEqual(userId);
expect(val).toEqual(newData);
const call = diskStorageService.mock.save.mock.calls[0];
expect(call[0]).toEqual(`user_${userId}_fake_fake`);
expect(call[1]).toEqual(newData);
});
it("does not await updates if the active user changes", async () => {
const initialUserId = (await firstValueFrom(activeAccountSubject)).id;
expect(initialUserId).toBe(userId);
trackEmissions(userState.state$);
await awaitAsync(); // storage updates are behind a promise
const originalSave = diskStorageService.save.bind(diskStorageService);
diskStorageService.save = jest
.fn()
.mockImplementationOnce(async (key: string, obj: any) => {
let resolved = false;
await changeActiveUser("2");
await Promise.race([
userState.update(() => {
// should not deadlock because we updated the user
resolved = true;
return newData;
}),
awaitAsync(100), // limit test to 100ms
]);
expect(resolved).toBe(true);
})
.mockImplementation((...args) => {
return originalSave(...args);
});
await userState.update(() => {
return newData;
});
});
it("stores updates for users in the correct place when active user changes mid-update", async () => {
trackEmissions(userState.state$);
await awaitAsync(); // storage updates are behind a promise
const user2Data = { date: new Date(), array: ["user 2 data"] };
const originalSave = diskStorageService.save.bind(diskStorageService);
diskStorageService.save = jest
.fn()
.mockImplementationOnce(async (key: string, obj: any) => {
let resolved = false;
await changeActiveUser("2");
await Promise.race([
userState.update(() => {
// should not deadlock because we updated the user
resolved = true;
return user2Data;
}),
awaitAsync(100), // limit test to 100ms
]);
expect(resolved).toBe(true);
await originalSave(key, obj);
})
.mockImplementation((...args) => {
return originalSave(...args);
});
await userState.update(() => {
return newData;
});
await awaitAsync();
expect(diskStorageService.mock.save).toHaveBeenCalledTimes(2);
const innerCall = diskStorageService.mock.save.mock.calls[0];
expect(innerCall[0]).toEqual(`user_${makeUserId("2")}_fake_fake`);
expect(innerCall[1]).toEqual(user2Data);
const outerCall = diskStorageService.mock.save.mock.calls[1];
expect(outerCall[0]).toEqual(`user_${makeUserId("1")}_fake_fake`);
expect(outerCall[1]).toEqual(newData);
});
});
describe("cleanup", () => {
const newData = { date: new Date(), array: ["test"] };
const userId = makeUserId("1");
let userKey: string;
beforeEach(async () => {
await changeActiveUser("1");
userKey = testKeyDefinition.buildKey(userId);
});
function assertClean() {
expect(activeAccountSubject["observers"]).toHaveLength(0);
expect(diskStorageService["updatesSubject"]["observers"]).toHaveLength(0);
}
it("should cleanup after last subscriber", async () => {
const subscription = userState.state$.subscribe();
await awaitAsync(); // storage updates are behind a promise
subscription.unsubscribe();
expect(diskStorageService["updatesSubject"]["observers"]).toHaveLength(1);
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
assertClean();
});
it("should not cleanup if there are still subscribers", async () => {
const subscription1 = userState.state$.subscribe();
const sub2Emissions: TestState[] = [];
const subscription2 = userState.state$.subscribe((v) => sub2Emissions.push(v));
await awaitAsync(); // storage updates are behind a promise
subscription1.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
expect(diskStorageService["updatesSubject"]["observers"]).toHaveLength(1);
// Still be listening to storage updates
// FIXME: Verify that this floating promise is intentional. If it is, add an explanatory comment and ensure there is proper error handling.
// eslint-disable-next-line @typescript-eslint/no-floating-promises
diskStorageService.save(userKey, newData);
await awaitAsync(); // storage updates are behind a promise
expect(sub2Emissions).toEqual([null, newData]);
subscription2.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
assertClean();
});
it("can re-initialize after cleanup", async () => {
const subscription = userState.state$.subscribe();
await awaitAsync();
subscription.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
const emissions = trackEmissions(userState.state$);
await awaitAsync();
await diskStorageService.save(userKey, newData);
await awaitAsync();
expect(emissions).toEqual([null, newData]);
});
it("should not cleanup if a subscriber joins during the cleanup delay", async () => {
const subscription = userState.state$.subscribe();
await awaitAsync();
await diskStorageService.save(userKey, newData);
await awaitAsync();
subscription.unsubscribe();
// Do not wait long enough for cleanup
await awaitAsync(cleanupDelayMs / 2);
const state = await firstValueFrom(userState.state$);
expect(state).toEqual(newData); // digging in to check that it hasn't been cleared
// Should be called once for the initial subscription and once from the save
// but should NOT be called for the second subscription from the `firstValueFrom`
expect(diskStorageService.mock.get).toHaveBeenCalledTimes(2);
});
it("state$ observables are durable to cleanup", async () => {
const observable = userState.state$;
let subscription = observable.subscribe();
await diskStorageService.save(userKey, newData);
await awaitAsync();
subscription.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
subscription = observable.subscribe();
await diskStorageService.save(userKey, newData);
await awaitAsync();
expect(await firstValueFrom(observable)).toEqual(newData);
});
});
});

View File

@@ -1,64 +1 @@
// FIXME: Update this file to be type safe and remove this and next line
// @ts-strict-ignore
import { Observable, map, switchMap, firstValueFrom, timeout, throwError, NEVER } from "rxjs";
import { UserId } from "../../../types/guid";
import { StateUpdateOptions } from "../state-update-options";
import { UserKeyDefinition } from "../user-key-definition";
import { ActiveUserState, CombinedState, activeMarker } from "../user-state";
import { SingleUserStateProvider } from "../user-state.provider";
export class DefaultActiveUserState<T> implements ActiveUserState<T> {
[activeMarker]: true;
combinedState$: Observable<CombinedState<T>>;
state$: Observable<T>;
constructor(
protected keyDefinition: UserKeyDefinition<T>,
private activeUserId$: Observable<UserId | null>,
private singleUserStateProvider: SingleUserStateProvider,
) {
this.combinedState$ = this.activeUserId$.pipe(
switchMap((userId) =>
userId != null
? this.singleUserStateProvider.get(userId, this.keyDefinition).combinedState$
: NEVER,
),
);
// State should just be combined state without the user id
this.state$ = this.combinedState$.pipe(map(([_userId, state]) => state));
}
async update<TCombine>(
configureState: (state: T, dependency: TCombine) => T,
options: StateUpdateOptions<T, TCombine> = {},
): Promise<[UserId, T]> {
const userId = await firstValueFrom(
this.activeUserId$.pipe(
timeout({
first: 1000,
with: () =>
throwError(
() =>
new Error(
`Timeout while retrieving active user for key ${this.keyDefinition.fullName}.`,
),
),
}),
),
);
if (userId == null) {
throw new Error(
`Error storing ${this.keyDefinition.fullName} for the active user: No active user at this time.`,
);
}
return [
userId,
await this.singleUserStateProvider
.get(userId, this.keyDefinition)
.update(configureState, options),
];
}
}
export { DefaultActiveUserState } from "@bitwarden/state";

View File

@@ -1,53 +1 @@
import { Observable } from "rxjs";
import { DerivedStateDependencies } from "../../../types/state";
import { DeriveDefinition } from "../derive-definition";
import { DerivedState } from "../derived-state";
import { DerivedStateProvider } from "../derived-state.provider";
import { DefaultDerivedState } from "./default-derived-state";
export class DefaultDerivedStateProvider implements DerivedStateProvider {
/**
* The cache uses a WeakMap to maintain separate derived states per user.
* Each user's state Observable acts as a unique key, without needing to
* pass around `userId`. Also, when a user's state Observable is cleaned up
* (like during an account swap) their cache is automatically garbage
* collected.
*/
private cache = new WeakMap<Observable<unknown>, Record<string, DerivedState<unknown>>>();
constructor() {}
get<TFrom, TTo, TDeps extends DerivedStateDependencies>(
parentState$: Observable<TFrom>,
deriveDefinition: DeriveDefinition<TFrom, TTo, TDeps>,
dependencies: TDeps,
): DerivedState<TTo> {
let stateCache = this.cache.get(parentState$);
if (!stateCache) {
stateCache = {};
this.cache.set(parentState$, stateCache);
}
const cacheKey = deriveDefinition.buildCacheKey();
const existingDerivedState = stateCache[cacheKey];
if (existingDerivedState != null) {
// I have to cast out of the unknown generic but this should be safe if rules
// around domain token are made
return existingDerivedState as DefaultDerivedState<TFrom, TTo, TDeps>;
}
const newDerivedState = this.buildDerivedState(parentState$, deriveDefinition, dependencies);
stateCache[cacheKey] = newDerivedState;
return newDerivedState;
}
protected buildDerivedState<TFrom, TTo, TDeps extends DerivedStateDependencies>(
parentState$: Observable<TFrom>,
deriveDefinition: DeriveDefinition<TFrom, TTo, TDeps>,
dependencies: TDeps,
): DerivedState<TTo> {
return new DefaultDerivedState<TFrom, TTo, TDeps>(parentState$, deriveDefinition, dependencies);
}
}
export { DefaultDerivedStateProvider } from "@bitwarden/state";

View File

@@ -1,211 +0,0 @@
/**
* need to update test environment so trackEmissions works appropriately
* @jest-environment ../shared/test.environment.ts
*/
import { Subject, firstValueFrom } from "rxjs";
import { awaitAsync, trackEmissions } from "../../../../spec";
import { DeriveDefinition } from "../derive-definition";
import { StateDefinition } from "../state-definition";
import { DefaultDerivedState } from "./default-derived-state";
import { DefaultDerivedStateProvider } from "./default-derived-state.provider";
let callCount = 0;
const cleanupDelayMs = 10;
const stateDefinition = new StateDefinition("test", "memory");
const deriveDefinition = new DeriveDefinition<string, Date, { date: Date }>(
stateDefinition,
"test",
{
derive: (dateString: string) => {
callCount++;
return new Date(dateString);
},
deserializer: (dateString: string) => new Date(dateString),
cleanupDelayMs,
},
);
describe("DefaultDerivedState", () => {
let parentState$: Subject<string>;
let sut: DefaultDerivedState<string, Date, { date: Date }>;
const deps = {
date: new Date(),
};
beforeEach(() => {
callCount = 0;
parentState$ = new Subject();
sut = new DefaultDerivedState(parentState$, deriveDefinition, deps);
});
afterEach(() => {
parentState$.complete();
jest.resetAllMocks();
});
it("should derive the state", async () => {
const dateString = "2020-01-01";
const emissions = trackEmissions(sut.state$);
parentState$.next(dateString);
await awaitAsync();
expect(emissions).toEqual([new Date(dateString)]);
});
it("should derive the state once", async () => {
const dateString = "2020-01-01";
trackEmissions(sut.state$);
parentState$.next(dateString);
expect(callCount).toBe(1);
});
describe("forceValue", () => {
const initialParentValue = "2020-01-01";
const forced = new Date("2020-02-02");
let emissions: Date[];
beforeEach(async () => {
emissions = trackEmissions(sut.state$);
parentState$.next(initialParentValue);
await awaitAsync();
});
it("should force the value", async () => {
await sut.forceValue(forced);
expect(emissions).toEqual([new Date(initialParentValue), forced]);
});
it("should only force the value once", async () => {
await sut.forceValue(forced);
parentState$.next(initialParentValue);
await awaitAsync();
expect(emissions).toEqual([
new Date(initialParentValue),
forced,
new Date(initialParentValue),
]);
});
});
describe("cleanup", () => {
const newDate = "2020-02-02";
it("should cleanup after last subscriber", async () => {
const subscription = sut.state$.subscribe();
await awaitAsync();
subscription.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
expect(parentState$.observed).toBe(false);
});
it("should not cleanup if there are still subscribers", async () => {
const subscription1 = sut.state$.subscribe();
const sub2Emissions: Date[] = [];
const subscription2 = sut.state$.subscribe((v) => sub2Emissions.push(v));
await awaitAsync();
subscription1.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
// Still be listening to parent updates
parentState$.next(newDate);
await awaitAsync();
expect(sub2Emissions).toEqual([new Date(newDate)]);
subscription2.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
expect(parentState$.observed).toBe(false);
});
it("can re-initialize after cleanup", async () => {
const subscription = sut.state$.subscribe();
await awaitAsync();
subscription.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
const emissions = trackEmissions(sut.state$);
await awaitAsync();
parentState$.next(newDate);
await awaitAsync();
expect(emissions).toEqual([new Date(newDate)]);
});
it("should not cleanup if a subscriber joins during the cleanup delay", async () => {
const subscription = sut.state$.subscribe();
await awaitAsync();
await parentState$.next(newDate);
await awaitAsync();
subscription.unsubscribe();
// Do not wait long enough for cleanup
await awaitAsync(cleanupDelayMs / 2);
expect(parentState$.observed).toBe(true); // still listening to parent
const emissions = trackEmissions(sut.state$);
expect(emissions).toEqual([new Date(newDate)]); // we didn't lose our buffered value
});
it("state$ observables are durable to cleanup", async () => {
const observable = sut.state$;
let subscription = observable.subscribe();
await parentState$.next(newDate);
await awaitAsync();
subscription.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
subscription = observable.subscribe();
await parentState$.next(newDate);
await awaitAsync();
expect(await firstValueFrom(observable)).toEqual(new Date(newDate));
});
});
describe("account switching", () => {
let provider: DefaultDerivedStateProvider;
beforeEach(() => {
provider = new DefaultDerivedStateProvider();
});
it("should provide a dedicated cache for each account", async () => {
const user1State$ = new Subject<string>();
const user1Derived = provider.get(user1State$, deriveDefinition, deps);
const user1Emissions = trackEmissions(user1Derived.state$);
const user2State$ = new Subject<string>();
const user2Derived = provider.get(user2State$, deriveDefinition, deps);
const user2Emissions = trackEmissions(user2Derived.state$);
user1State$.next("2015-12-30");
user2State$.next("2020-12-29");
await awaitAsync();
expect(user1Emissions).toEqual([new Date("2015-12-30")]);
expect(user2Emissions).toEqual([new Date("2020-12-29")]);
});
});
});

View File

@@ -1,50 +1 @@
import { Observable, ReplaySubject, Subject, concatMap, merge, share, timer } from "rxjs";
import { DerivedStateDependencies } from "../../../types/state";
import { DeriveDefinition } from "../derive-definition";
import { DerivedState } from "../derived-state";
/**
* Default derived state
*/
export class DefaultDerivedState<TFrom, TTo, TDeps extends DerivedStateDependencies>
implements DerivedState<TTo>
{
private readonly storageKey: string;
private forcedValueSubject = new Subject<TTo>();
state$: Observable<TTo>;
constructor(
private parentState$: Observable<TFrom>,
protected deriveDefinition: DeriveDefinition<TFrom, TTo, TDeps>,
private dependencies: TDeps,
) {
this.storageKey = deriveDefinition.storageKey;
const derivedState$ = this.parentState$.pipe(
concatMap(async (state) => {
let derivedStateOrPromise = this.deriveDefinition.derive(state, this.dependencies);
if (derivedStateOrPromise instanceof Promise) {
derivedStateOrPromise = await derivedStateOrPromise;
}
const derivedState = derivedStateOrPromise;
return derivedState;
}),
);
this.state$ = merge(this.forcedValueSubject, derivedState$).pipe(
share({
connector: () => {
return new ReplaySubject<TTo>(1);
},
resetOnRefCountZero: () => timer(this.deriveDefinition.cleanupDelayMs),
}),
);
}
async forceValue(value: TTo) {
this.forcedValueSubject.next(value);
return value;
}
}
export { DefaultDerivedState } from "@bitwarden/state";

View File

@@ -1,46 +1 @@
// FIXME: Update this file to be type safe and remove this and next line
// @ts-strict-ignore
import { StorageServiceProvider } from "@bitwarden/storage-core";
import { LogService } from "../../abstractions/log.service";
import { GlobalState } from "../global-state";
import { GlobalStateProvider } from "../global-state.provider";
import { KeyDefinition } from "../key-definition";
import { DefaultGlobalState } from "./default-global-state";
export class DefaultGlobalStateProvider implements GlobalStateProvider {
private globalStateCache: Record<string, GlobalState<unknown>> = {};
constructor(
private storageServiceProvider: StorageServiceProvider,
private readonly logService: LogService,
) {}
get<T>(keyDefinition: KeyDefinition<T>): GlobalState<T> {
const [location, storageService] = this.storageServiceProvider.get(
keyDefinition.stateDefinition.defaultStorageLocation,
keyDefinition.stateDefinition.storageLocationOverrides,
);
const cacheKey = this.buildCacheKey(location, keyDefinition);
const existingGlobalState = this.globalStateCache[cacheKey];
if (existingGlobalState != null) {
// The cast into the actual generic is safe because of rules around key definitions
// being unique.
return existingGlobalState as DefaultGlobalState<T>;
}
const newGlobalState = new DefaultGlobalState<T>(
keyDefinition,
storageService,
this.logService,
);
this.globalStateCache[cacheKey] = newGlobalState;
return newGlobalState;
}
private buildCacheKey(location: string, keyDefinition: KeyDefinition<unknown>) {
return `${location}_${keyDefinition.fullName}`;
}
}
export { DefaultGlobalStateProvider } from "@bitwarden/state";

View File

@@ -1,411 +0,0 @@
/**
* need to update test environment so trackEmissions works appropriately
* @jest-environment ../shared/test.environment.ts
*/
import { mock } from "jest-mock-extended";
import { firstValueFrom, of } from "rxjs";
import { Jsonify } from "type-fest";
import { trackEmissions, awaitAsync } from "../../../../spec";
import { FakeStorageService } from "../../../../spec/fake-storage.service";
import { LogService } from "../../abstractions/log.service";
import { KeyDefinition, globalKeyBuilder } from "../key-definition";
import { StateDefinition } from "../state-definition";
import { DefaultGlobalState } from "./default-global-state";
class TestState {
date: Date;
static fromJSON(jsonState: Jsonify<TestState>) {
if (jsonState == null) {
return null;
}
return Object.assign(new TestState(), jsonState, {
date: new Date(jsonState.date),
});
}
}
const testStateDefinition = new StateDefinition("fake", "disk");
const cleanupDelayMs = 10;
const testKeyDefinition = new KeyDefinition<TestState>(testStateDefinition, "fake", {
deserializer: TestState.fromJSON,
cleanupDelayMs,
});
const globalKey = globalKeyBuilder(testKeyDefinition);
describe("DefaultGlobalState", () => {
let diskStorageService: FakeStorageService;
let globalState: DefaultGlobalState<TestState>;
const logService = mock<LogService>();
const newData = { date: new Date() };
beforeEach(() => {
diskStorageService = new FakeStorageService();
globalState = new DefaultGlobalState(testKeyDefinition, diskStorageService, logService);
});
afterEach(() => {
jest.resetAllMocks();
});
describe("state$", () => {
it("should emit when storage updates", async () => {
const emissions = trackEmissions(globalState.state$);
await diskStorageService.save(globalKey, newData);
await awaitAsync();
expect(emissions).toEqual([
null, // Initial value
newData,
]);
});
it("should not emit when update key does not match", async () => {
const emissions = trackEmissions(globalState.state$);
await diskStorageService.save("wrong_key", newData);
expect(emissions).toHaveLength(0);
});
it("should emit initial storage value on first subscribe", async () => {
const initialStorage: Record<string, TestState> = {};
initialStorage[globalKey] = TestState.fromJSON({
date: "2022-09-21T13:14:17.648Z",
});
diskStorageService.internalUpdateStore(initialStorage);
const state = await firstValueFrom(globalState.state$);
expect(diskStorageService.mock.get).toHaveBeenCalledTimes(1);
expect(diskStorageService.mock.get).toHaveBeenCalledWith("global_fake_fake", undefined);
expect(state).toBeTruthy();
});
it("should not emit twice if there are two listeners", async () => {
const emissions = trackEmissions(globalState.state$);
const emissions2 = trackEmissions(globalState.state$);
await awaitAsync();
expect(emissions).toEqual([
null, // Initial value
]);
expect(emissions2).toEqual([
null, // Initial value
]);
});
});
describe("update", () => {
it("should save on update", async () => {
const result = await globalState.update((state) => {
return newData;
});
expect(diskStorageService.mock.save).toHaveBeenCalledTimes(1);
expect(result).toEqual(newData);
});
it("should emit once per update", async () => {
const emissions = trackEmissions(globalState.state$);
await awaitAsync(); // storage updates are behind a promise
await globalState.update((state) => {
return newData;
});
await awaitAsync();
expect(emissions).toEqual([
null, // Initial value
newData,
]);
});
it("should provided combined dependencies", async () => {
const emissions = trackEmissions(globalState.state$);
await awaitAsync(); // storage updates are behind a promise
const combinedDependencies = { date: new Date() };
await globalState.update(
(state, dependencies) => {
expect(dependencies).toEqual(combinedDependencies);
return newData;
},
{
combineLatestWith: of(combinedDependencies),
},
);
await awaitAsync();
expect(emissions).toEqual([
null, // Initial value
newData,
]);
});
it("should not update if shouldUpdate returns false", async () => {
const emissions = trackEmissions(globalState.state$);
await awaitAsync(); // storage updates are behind a promise
const result = await globalState.update(
(state) => {
return newData;
},
{
shouldUpdate: () => false,
},
);
expect(diskStorageService.mock.save).not.toHaveBeenCalled();
expect(emissions).toEqual([null]); // Initial value
expect(result).toBeNull();
});
it("should provide the update callback with the current State", async () => {
const emissions = trackEmissions(globalState.state$);
await awaitAsync(); // storage updates are behind a promise
// Seed with interesting data
const initialData = { date: new Date(2020, 1, 1) };
await globalState.update((state, dependencies) => {
return initialData;
});
await awaitAsync();
await globalState.update((state) => {
expect(state).toEqual(initialData);
return newData;
});
await awaitAsync();
expect(emissions).toEqual([
null, // Initial value
initialData,
newData,
]);
});
it("should give initial state for update call", async () => {
const initialStorage: Record<string, TestState> = {};
const initialState = TestState.fromJSON({
date: "2022-09-21T13:14:17.648Z",
});
initialStorage[globalKey] = initialState;
diskStorageService.internalUpdateStore(initialStorage);
const emissions = trackEmissions(globalState.state$);
await awaitAsync(); // storage updates are behind a promise
const newState = {
...initialState,
date: new Date(initialState.date.getFullYear(), initialState.date.getMonth() + 1),
};
const actual = await globalState.update((existingState) => newState);
await awaitAsync();
expect(actual).toEqual(newState);
expect(emissions).toHaveLength(2);
expect(emissions).toEqual(expect.arrayContaining([initialState, newState]));
});
});
describe("update races", () => {
test("subscriptions during an update should receive the current and latest data", async () => {
const oldData = { date: new Date(2019, 1, 1) };
await globalState.update(() => {
return oldData;
});
const initialData = { date: new Date(2020, 1, 1) };
await globalState.update(() => {
return initialData;
});
await awaitAsync();
const emissions = trackEmissions(globalState.state$);
await awaitAsync();
expect(emissions).toEqual([initialData]);
let emissions2: TestState[];
const originalSave = diskStorageService.save.bind(diskStorageService);
diskStorageService.save = jest.fn().mockImplementation(async (key: string, obj: any) => {
emissions2 = trackEmissions(globalState.state$);
await originalSave(key, obj);
});
const val = await globalState.update(() => {
return newData;
});
await awaitAsync(10);
expect(val).toEqual(newData);
expect(emissions).toEqual([initialData, newData]);
expect(emissions2).toEqual([initialData, newData]);
});
test("subscription during an aborted update should receive the last value", async () => {
// Seed with interesting data
const initialData = { date: new Date(2020, 1, 1) };
await globalState.update(() => {
return initialData;
});
await awaitAsync();
const emissions = trackEmissions(globalState.state$);
await awaitAsync();
expect(emissions).toEqual([initialData]);
let emissions2: TestState[];
const val = await globalState.update(
() => {
return newData;
},
{
shouldUpdate: () => {
emissions2 = trackEmissions(globalState.state$);
return false;
},
},
);
await awaitAsync();
expect(val).toEqual(initialData);
expect(emissions).toEqual([initialData]);
expect(emissions2).toEqual([initialData]);
});
test("updates should wait until previous update is complete", async () => {
trackEmissions(globalState.state$);
await awaitAsync(); // storage updates are behind a promise
const originalSave = diskStorageService.save.bind(diskStorageService);
diskStorageService.save = jest
.fn()
.mockImplementationOnce(async () => {
let resolved = false;
await Promise.race([
globalState.update(() => {
// deadlocks
resolved = true;
return newData;
}),
awaitAsync(100), // limit test to 100ms
]);
expect(resolved).toBe(false);
})
.mockImplementation(originalSave);
await globalState.update((state) => {
return newData;
});
});
});
describe("cleanup", () => {
function assertClean() {
expect(diskStorageService["updatesSubject"]["observers"]).toHaveLength(0);
}
it("should cleanup after last subscriber", async () => {
const subscription = globalState.state$.subscribe();
await awaitAsync(); // storage updates are behind a promise
subscription.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
assertClean();
});
it("should not cleanup if there are still subscribers", async () => {
const subscription1 = globalState.state$.subscribe();
const sub2Emissions: TestState[] = [];
const subscription2 = globalState.state$.subscribe((v) => sub2Emissions.push(v));
await awaitAsync(); // storage updates are behind a promise
subscription1.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
expect(diskStorageService["updatesSubject"]["observers"]).toHaveLength(1);
// Still be listening to storage updates
// FIXME: Verify that this floating promise is intentional. If it is, add an explanatory comment and ensure there is proper error handling.
// eslint-disable-next-line @typescript-eslint/no-floating-promises
diskStorageService.save(globalKey, newData);
await awaitAsync(); // storage updates are behind a promise
expect(sub2Emissions).toEqual([null, newData]);
subscription2.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
assertClean();
});
it("can re-initialize after cleanup", async () => {
const subscription = globalState.state$.subscribe();
await awaitAsync();
subscription.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
const emissions = trackEmissions(globalState.state$);
await awaitAsync();
// FIXME: Verify that this floating promise is intentional. If it is, add an explanatory comment and ensure there is proper error handling.
// eslint-disable-next-line @typescript-eslint/no-floating-promises
diskStorageService.save(globalKey, newData);
await awaitAsync();
expect(emissions).toEqual([null, newData]);
});
it("should not cleanup if a subscriber joins during the cleanup delay", async () => {
const subscription = globalState.state$.subscribe();
await awaitAsync();
await diskStorageService.save(globalKey, newData);
await awaitAsync();
subscription.unsubscribe();
expect(diskStorageService["updatesSubject"]["observers"]).toHaveLength(1);
// Do not wait long enough for cleanup
await awaitAsync(cleanupDelayMs / 2);
expect(diskStorageService["updatesSubject"]["observers"]).toHaveLength(1);
});
it("state$ observables are durable to cleanup", async () => {
const observable = globalState.state$;
let subscription = observable.subscribe();
await diskStorageService.save(globalKey, newData);
await awaitAsync();
subscription.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
subscription = observable.subscribe();
await diskStorageService.save(globalKey, newData);
await awaitAsync();
expect(await firstValueFrom(observable)).toEqual(newData);
});
});
});

View File

@@ -1,20 +1 @@
import { AbstractStorageService, ObservableStorageService } from "@bitwarden/storage-core";
import { LogService } from "../../abstractions/log.service";
import { GlobalState } from "../global-state";
import { KeyDefinition, globalKeyBuilder } from "../key-definition";
import { StateBase } from "./state-base";
export class DefaultGlobalState<T>
extends StateBase<T, KeyDefinition<T>>
implements GlobalState<T>
{
constructor(
keyDefinition: KeyDefinition<T>,
chosenLocation: AbstractStorageService & ObservableStorageService,
logService: LogService,
) {
super(globalKeyBuilder(keyDefinition), chosenLocation, keyDefinition, logService);
}
}
export { DefaultGlobalState } from "@bitwarden/state";

View File

@@ -1,54 +1 @@
// FIXME: Update this file to be type safe and remove this and next line
// @ts-strict-ignore
import { StorageServiceProvider } from "@bitwarden/storage-core";
import { UserId } from "../../../types/guid";
import { LogService } from "../../abstractions/log.service";
import { StateEventRegistrarService } from "../state-event-registrar.service";
import { UserKeyDefinition } from "../user-key-definition";
import { SingleUserState } from "../user-state";
import { SingleUserStateProvider } from "../user-state.provider";
import { DefaultSingleUserState } from "./default-single-user-state";
export class DefaultSingleUserStateProvider implements SingleUserStateProvider {
private cache: Record<string, SingleUserState<unknown>> = {};
constructor(
private readonly storageServiceProvider: StorageServiceProvider,
private readonly stateEventRegistrarService: StateEventRegistrarService,
private readonly logService: LogService,
) {}
get<T>(userId: UserId, keyDefinition: UserKeyDefinition<T>): SingleUserState<T> {
const [location, storageService] = this.storageServiceProvider.get(
keyDefinition.stateDefinition.defaultStorageLocation,
keyDefinition.stateDefinition.storageLocationOverrides,
);
const cacheKey = this.buildCacheKey(location, userId, keyDefinition);
const existingUserState = this.cache[cacheKey];
if (existingUserState != null) {
// I have to cast out of the unknown generic but this should be safe if rules
// around domain token are made
return existingUserState as SingleUserState<T>;
}
const newUserState = new DefaultSingleUserState<T>(
userId,
keyDefinition,
storageService,
this.stateEventRegistrarService,
this.logService,
);
this.cache[cacheKey] = newUserState;
return newUserState;
}
private buildCacheKey(
location: string,
userId: UserId,
keyDefinition: UserKeyDefinition<unknown>,
) {
return `${location}_${keyDefinition.fullName}_${userId}`;
}
}
export { DefaultSingleUserStateProvider } from "@bitwarden/state";

View File

@@ -1,596 +0,0 @@
/**
* need to update test environment so trackEmissions works appropriately
* @jest-environment ../shared/test.environment.ts
*/
import { mock } from "jest-mock-extended";
import { firstValueFrom, of } from "rxjs";
import { Jsonify } from "type-fest";
import { trackEmissions, awaitAsync } from "../../../../spec";
import { FakeStorageService } from "../../../../spec/fake-storage.service";
import { UserId } from "../../../types/guid";
import { LogService } from "../../abstractions/log.service";
import { Utils } from "../../misc/utils";
import { StateDefinition } from "../state-definition";
import { StateEventRegistrarService } from "../state-event-registrar.service";
import { UserKeyDefinition } from "../user-key-definition";
import { DefaultSingleUserState } from "./default-single-user-state";
class TestState {
date: Date;
static fromJSON(jsonState: Jsonify<TestState>) {
if (jsonState == null) {
return null;
}
return Object.assign(new TestState(), jsonState, {
date: new Date(jsonState.date),
});
}
}
const testStateDefinition = new StateDefinition("fake", "disk");
const cleanupDelayMs = 10;
const testKeyDefinition = new UserKeyDefinition<TestState>(testStateDefinition, "fake", {
deserializer: TestState.fromJSON,
cleanupDelayMs,
clearOn: [],
});
const userId = Utils.newGuid() as UserId;
const userKey = testKeyDefinition.buildKey(userId);
describe("DefaultSingleUserState", () => {
let diskStorageService: FakeStorageService;
let userState: DefaultSingleUserState<TestState>;
const stateEventRegistrarService = mock<StateEventRegistrarService>();
const logService = mock<LogService>();
const newData = { date: new Date() };
beforeEach(() => {
diskStorageService = new FakeStorageService();
userState = new DefaultSingleUserState(
userId,
testKeyDefinition,
diskStorageService,
stateEventRegistrarService,
logService,
);
});
afterEach(() => {
jest.resetAllMocks();
});
describe("state$", () => {
it("should emit when storage updates", async () => {
const emissions = trackEmissions(userState.state$);
await diskStorageService.save(userKey, newData);
await awaitAsync();
expect(emissions).toEqual([
null, // Initial value
newData,
]);
});
it("should not emit when update key does not match", async () => {
const emissions = trackEmissions(userState.state$);
await diskStorageService.save("wrong_key", newData);
// Give userState a chance to emit it's initial value
// as well as wrongly emit the different key.
await awaitAsync();
// Just the initial value
expect(emissions).toEqual([null]);
});
it("should emit initial storage value on first subscribe", async () => {
const initialStorage: Record<string, TestState> = {};
initialStorage[userKey] = TestState.fromJSON({
date: "2022-09-21T13:14:17.648Z",
});
diskStorageService.internalUpdateStore(initialStorage);
const state = await firstValueFrom(userState.state$);
expect(diskStorageService.mock.get).toHaveBeenCalledTimes(1);
expect(diskStorageService.mock.get).toHaveBeenCalledWith(
`user_${userId}_fake_fake`,
undefined,
);
expect(state).toBeTruthy();
});
it("should go to disk each subscription if a cleanupDelayMs of 0 is given", async () => {
const state = new DefaultSingleUserState(
userId,
new UserKeyDefinition(testStateDefinition, "test", {
cleanupDelayMs: 0,
deserializer: TestState.fromJSON,
clearOn: [],
debug: {
enableRetrievalLogging: true,
},
}),
diskStorageService,
stateEventRegistrarService,
logService,
);
await firstValueFrom(state.state$);
await firstValueFrom(state.state$);
expect(diskStorageService.mock.get).toHaveBeenCalledTimes(2);
expect(logService.info).toHaveBeenCalledTimes(2);
expect(logService.info).toHaveBeenCalledWith(
`Retrieving 'user_${userId}_fake_test' from storage, value is null`,
);
});
});
describe("combinedState$", () => {
it("should emit when storage updates", async () => {
const emissions = trackEmissions(userState.combinedState$);
await diskStorageService.save(userKey, newData);
await awaitAsync();
expect(emissions).toEqual([
[userId, null], // Initial value
[userId, newData],
]);
});
it("should not emit when update key does not match", async () => {
const emissions = trackEmissions(userState.combinedState$);
await diskStorageService.save("wrong_key", newData);
// Give userState a chance to emit it's initial value
// as well as wrongly emit the different key.
await awaitAsync();
// Just the initial value
expect(emissions).toHaveLength(1);
});
it("should emit initial storage value on first subscribe", async () => {
const initialStorage: Record<string, TestState> = {};
initialStorage[userKey] = TestState.fromJSON({
date: "2022-09-21T13:14:17.648Z",
});
diskStorageService.internalUpdateStore(initialStorage);
const combinedState = await firstValueFrom(userState.combinedState$);
expect(diskStorageService.mock.get).toHaveBeenCalledTimes(1);
expect(diskStorageService.mock.get).toHaveBeenCalledWith(
`user_${userId}_fake_fake`,
undefined,
);
expect(combinedState).toBeTruthy();
const [stateUserId, state] = combinedState;
expect(stateUserId).toBe(userId);
expect(state).toBe(initialStorage[userKey]);
});
});
describe("update", () => {
it("should save on update", async () => {
const result = await userState.update((state) => {
return newData;
});
expect(diskStorageService.mock.save).toHaveBeenCalledTimes(1);
expect(result).toEqual(newData);
});
it("should emit once per update", async () => {
const emissions = trackEmissions(userState.state$);
await awaitAsync(); // storage updates are behind a promise
await userState.update((state) => {
return newData;
});
await awaitAsync();
expect(emissions).toEqual([
null, // Initial value
newData,
]);
});
it("should provided combined dependencies", async () => {
const emissions = trackEmissions(userState.state$);
await awaitAsync(); // storage updates are behind a promise
const combinedDependencies = { date: new Date() };
await userState.update(
(state, dependencies) => {
expect(dependencies).toEqual(combinedDependencies);
return newData;
},
{
combineLatestWith: of(combinedDependencies),
},
);
await awaitAsync();
expect(emissions).toEqual([
null, // Initial value
newData,
]);
});
it("should not update if shouldUpdate returns false", async () => {
const emissions = trackEmissions(userState.state$);
await awaitAsync(); // storage updates are behind a promise
const result = await userState.update(
(state) => {
return newData;
},
{
shouldUpdate: () => false,
},
);
expect(diskStorageService.mock.save).not.toHaveBeenCalled();
expect(emissions).toEqual([null]); // Initial value
expect(result).toBeNull();
});
it("should provide the update callback with the current State", async () => {
const emissions = trackEmissions(userState.state$);
await awaitAsync(); // storage updates are behind a promise
// Seed with interesting data
const initialData = { date: new Date(2020, 1, 1) };
await userState.update((state, dependencies) => {
return initialData;
});
await awaitAsync();
await userState.update((state) => {
expect(state).toEqual(initialData);
return newData;
});
await awaitAsync();
expect(emissions).toEqual([
null, // Initial value
initialData,
newData,
]);
});
it("should give initial state for update call", async () => {
const initialStorage: Record<string, TestState> = {};
const initialState = TestState.fromJSON({
date: "2022-09-21T13:14:17.648Z",
});
initialStorage[userKey] = initialState;
diskStorageService.internalUpdateStore(initialStorage);
const emissions = trackEmissions(userState.state$);
await awaitAsync(); // storage updates are behind a promise
const newState = {
...initialState,
date: new Date(initialState.date.getFullYear(), initialState.date.getMonth() + 1),
};
const actual = await userState.update((existingState) => newState);
await awaitAsync();
expect(actual).toEqual(newState);
expect(emissions).toHaveLength(2);
expect(emissions).toEqual(expect.arrayContaining([initialState, newState]));
});
it.each([null, undefined])(
"should register user key definition when state transitions from null-ish (%s) to non-null",
async (startingValue: TestState | null) => {
const initialState: Record<string, TestState> = {};
initialState[userKey] = startingValue;
diskStorageService.internalUpdateStore(initialState);
await userState.update(() => ({ array: ["one"], date: new Date() }));
expect(stateEventRegistrarService.registerEvents).toHaveBeenCalledWith(testKeyDefinition);
},
);
it("should not register user key definition when state has preexisting value", async () => {
const initialState: Record<string, TestState> = {};
initialState[userKey] = {
date: new Date(2019, 1),
};
diskStorageService.internalUpdateStore(initialState);
await userState.update(() => ({ array: ["one"], date: new Date() }));
expect(stateEventRegistrarService.registerEvents).not.toHaveBeenCalled();
});
it.each([null, undefined])(
"should not register user key definition when setting value to null-ish (%s) value",
async (updatedValue: TestState | null) => {
const initialState: Record<string, TestState> = {};
initialState[userKey] = {
date: new Date(2019, 1),
};
diskStorageService.internalUpdateStore(initialState);
await userState.update(() => updatedValue);
expect(stateEventRegistrarService.registerEvents).not.toHaveBeenCalled();
},
);
const logCases: { startingValue: TestState; updateValue: TestState; phrase: string }[] = [
{
startingValue: null,
updateValue: null,
phrase: "null to null",
},
{
startingValue: null,
updateValue: new TestState(),
phrase: "null to non-null",
},
{
startingValue: new TestState(),
updateValue: null,
phrase: "non-null to null",
},
{
startingValue: new TestState(),
updateValue: new TestState(),
phrase: "non-null to non-null",
},
];
it.each(logCases)(
"should log meta info about the update",
async ({ startingValue, updateValue, phrase }) => {
diskStorageService.internalUpdateStore({
[`user_${userId}_fake_fake`]: startingValue,
});
const state = new DefaultSingleUserState(
userId,
new UserKeyDefinition<TestState>(testStateDefinition, "fake", {
deserializer: TestState.fromJSON,
clearOn: [],
debug: {
enableUpdateLogging: true,
},
}),
diskStorageService,
stateEventRegistrarService,
logService,
);
await state.update(() => updateValue);
expect(logService.info).toHaveBeenCalledWith(
`Updating 'user_${userId}_fake_fake' from ${phrase}`,
);
},
);
});
describe("update races", () => {
test("subscriptions during an update should receive the current and latest data", async () => {
const oldData = { date: new Date(2019, 1, 1) };
await userState.update(() => {
return oldData;
});
const initialData = { date: new Date(2020, 1, 1) };
await userState.update(() => {
return initialData;
});
await awaitAsync();
const emissions = trackEmissions(userState.state$);
await awaitAsync();
expect(emissions).toEqual([initialData]);
let emissions2: TestState[];
const originalSave = diskStorageService.save.bind(diskStorageService);
diskStorageService.save = jest.fn().mockImplementation(async (key: string, obj: any) => {
emissions2 = trackEmissions(userState.state$);
await originalSave(key, obj);
});
const val = await userState.update(() => {
return newData;
});
await awaitAsync(10);
expect(val).toEqual(newData);
expect(emissions).toEqual([initialData, newData]);
expect(emissions2).toEqual([initialData, newData]);
});
test("subscription during an aborted update should receive the last value", async () => {
// Seed with interesting data
const initialData = { date: new Date(2020, 1, 1) };
await userState.update(() => {
return initialData;
});
await awaitAsync();
const emissions = trackEmissions(userState.state$);
await awaitAsync();
expect(emissions).toEqual([initialData]);
let emissions2: TestState[];
const val = await userState.update(
(state) => {
return newData;
},
{
shouldUpdate: () => {
emissions2 = trackEmissions(userState.state$);
return false;
},
},
);
await awaitAsync();
expect(val).toEqual(initialData);
expect(emissions).toEqual([initialData]);
expect(emissions2).toEqual([initialData]);
});
test("updates should wait until previous update is complete", async () => {
trackEmissions(userState.state$);
await awaitAsync(); // storage updates are behind a promise
const originalSave = diskStorageService.save.bind(diskStorageService);
diskStorageService.save = jest
.fn()
.mockImplementationOnce(async () => {
let resolved = false;
await Promise.race([
userState.update(() => {
// deadlocks
resolved = true;
return newData;
}),
awaitAsync(100), // limit test to 100ms
]);
expect(resolved).toBe(false);
})
.mockImplementation(originalSave);
await userState.update((state) => {
return newData;
});
});
test("updates with FAKE_DEFAULT initial value should resolve correctly", async () => {
const val = await userState.update((state) => {
return newData;
});
expect(val).toEqual(newData);
const call = diskStorageService.mock.save.mock.calls[0];
expect(call[0]).toEqual(`user_${userId}_fake_fake`);
expect(call[1]).toEqual(newData);
});
});
describe("cleanup", () => {
function assertClean() {
expect(diskStorageService["updatesSubject"]["observers"]).toHaveLength(0);
}
it("should cleanup after last subscriber", async () => {
const subscription = userState.state$.subscribe();
await awaitAsync(); // storage updates are behind a promise
subscription.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
assertClean();
});
it("should not cleanup if there are still subscribers", async () => {
const subscription1 = userState.state$.subscribe();
const sub2Emissions: TestState[] = [];
const subscription2 = userState.state$.subscribe((v) => sub2Emissions.push(v));
await awaitAsync(); // storage updates are behind a promise
subscription1.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
expect(diskStorageService["updatesSubject"]["observers"]).toHaveLength(1);
// Still be listening to storage updates
// FIXME: Verify that this floating promise is intentional. If it is, add an explanatory comment and ensure there is proper error handling.
// eslint-disable-next-line @typescript-eslint/no-floating-promises
diskStorageService.save(userKey, newData);
await awaitAsync(); // storage updates are behind a promise
expect(sub2Emissions).toEqual([null, newData]);
subscription2.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
assertClean();
});
it("can re-initialize after cleanup", async () => {
const subscription = userState.state$.subscribe();
await awaitAsync();
subscription.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
const emissions = trackEmissions(userState.state$);
await awaitAsync();
// FIXME: Verify that this floating promise is intentional. If it is, add an explanatory comment and ensure there is proper error handling.
// eslint-disable-next-line @typescript-eslint/no-floating-promises
diskStorageService.save(userKey, newData);
await awaitAsync();
expect(emissions).toEqual([null, newData]);
});
it("should not cleanup if a subscriber joins during the cleanup delay", async () => {
const subscription = userState.state$.subscribe();
await awaitAsync();
await diskStorageService.save(userKey, newData);
await awaitAsync();
subscription.unsubscribe();
// Do not wait long enough for cleanup
await awaitAsync(cleanupDelayMs / 2);
const value = await firstValueFrom(userState.state$);
expect(value).toEqual(newData);
// Should be called once for the initial subscription and a second time during the save
// but should not be called for a second subscription if the cleanup hasn't happened yet.
expect(diskStorageService.mock.get).toHaveBeenCalledTimes(2);
});
it("state$ observables are durable to cleanup", async () => {
const observable = userState.state$;
let subscription = observable.subscribe();
await diskStorageService.save(userKey, newData);
await awaitAsync();
subscription.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
subscription = observable.subscribe();
await diskStorageService.save(userKey, newData);
await awaitAsync();
expect(await firstValueFrom(observable)).toEqual(newData);
});
});
});

View File

@@ -1,36 +1 @@
import { Observable, combineLatest, of } from "rxjs";
import { AbstractStorageService, ObservableStorageService } from "@bitwarden/storage-core";
import { UserId } from "../../../types/guid";
import { LogService } from "../../abstractions/log.service";
import { StateEventRegistrarService } from "../state-event-registrar.service";
import { UserKeyDefinition } from "../user-key-definition";
import { CombinedState, SingleUserState } from "../user-state";
import { StateBase } from "./state-base";
export class DefaultSingleUserState<T>
extends StateBase<T, UserKeyDefinition<T>>
implements SingleUserState<T>
{
readonly combinedState$: Observable<CombinedState<T | null>>;
constructor(
readonly userId: UserId,
keyDefinition: UserKeyDefinition<T>,
chosenLocation: AbstractStorageService & ObservableStorageService,
private stateEventRegistrarService: StateEventRegistrarService,
logService: LogService,
) {
super(keyDefinition.buildKey(userId), chosenLocation, keyDefinition, logService);
this.combinedState$ = combineLatest([of(userId), this.state$]);
}
protected override async doStorageSave(newState: T, oldState: T): Promise<void> {
await super.doStorageSave(newState, oldState);
if (newState != null && oldState == null) {
await this.stateEventRegistrarService.registerEvents(this.keyDefinition);
}
}
}
export { DefaultSingleUserState } from "@bitwarden/state";

View File

@@ -1,265 +0,0 @@
/**
* need to update test environment so structuredClone works appropriately
* @jest-environment ../shared/test.environment.ts
*/
import { Observable, of } from "rxjs";
import { awaitAsync, trackEmissions } from "../../../../spec";
import { FakeAccountService, mockAccountServiceWith } from "../../../../spec/fake-account-service";
import {
FakeActiveUserStateProvider,
FakeDerivedStateProvider,
FakeGlobalStateProvider,
FakeSingleUserStateProvider,
} from "../../../../spec/fake-state-provider";
import { AuthenticationStatus } from "../../../auth/enums/authentication-status";
import { UserId } from "../../../types/guid";
import { DeriveDefinition } from "../derive-definition";
import { KeyDefinition } from "../key-definition";
import { StateDefinition } from "../state-definition";
import { UserKeyDefinition } from "../user-key-definition";
import { DefaultStateProvider } from "./default-state.provider";
describe("DefaultStateProvider", () => {
let sut: DefaultStateProvider;
let activeUserStateProvider: FakeActiveUserStateProvider;
let singleUserStateProvider: FakeSingleUserStateProvider;
let globalStateProvider: FakeGlobalStateProvider;
let derivedStateProvider: FakeDerivedStateProvider;
let accountService: FakeAccountService;
const userId = "fakeUserId" as UserId;
beforeEach(() => {
accountService = mockAccountServiceWith(userId);
activeUserStateProvider = new FakeActiveUserStateProvider(accountService);
singleUserStateProvider = new FakeSingleUserStateProvider();
globalStateProvider = new FakeGlobalStateProvider();
derivedStateProvider = new FakeDerivedStateProvider();
sut = new DefaultStateProvider(
activeUserStateProvider,
singleUserStateProvider,
globalStateProvider,
derivedStateProvider,
);
});
describe("activeUserId$", () => {
it("should track the active User id from active user state provider", () => {
expect(sut.activeUserId$).toBe(activeUserStateProvider.activeUserId$);
});
});
describe.each([
[
"getUserState$",
(keyDefinition: UserKeyDefinition<string>, userId?: UserId) =>
sut.getUserState$(keyDefinition, userId),
],
[
"getUserStateOrDefault$",
(keyDefinition: UserKeyDefinition<string>, userId?: UserId) =>
sut.getUserStateOrDefault$(keyDefinition, { userId: userId }),
],
])(
"Shared behavior for %s",
(
_testName: string,
methodUnderTest: (
keyDefinition: UserKeyDefinition<string>,
userId?: UserId,
) => Observable<string>,
) => {
const accountInfo = {
email: "email",
emailVerified: false,
name: "name",
status: AuthenticationStatus.LoggedOut,
};
const keyDefinition = new UserKeyDefinition<string>(
new StateDefinition("test", "disk"),
"test",
{
deserializer: (s) => s,
clearOn: [],
},
);
it("should follow the specified user if userId is provided", async () => {
const state = singleUserStateProvider.getFake(userId, keyDefinition);
state.nextState("value");
const emissions = trackEmissions(methodUnderTest(keyDefinition, userId));
state.nextState("value2");
state.nextState("value3");
expect(emissions).toEqual(["value", "value2", "value3"]);
});
it("should follow the current active user if no userId is provided", async () => {
accountService.activeAccountSubject.next({ id: userId, ...accountInfo });
const state = singleUserStateProvider.getFake(userId, keyDefinition);
state.nextState("value");
const emissions = trackEmissions(methodUnderTest(keyDefinition));
state.nextState("value2");
state.nextState("value3");
expect(emissions).toEqual(["value", "value2", "value3"]);
});
it("should continue to follow the state of the user that was active when called, even if active user changes", async () => {
const state = singleUserStateProvider.getFake(userId, keyDefinition);
state.nextState("value");
const emissions = trackEmissions(methodUnderTest(keyDefinition));
accountService.activeAccountSubject.next({ id: "newUserId" as UserId, ...accountInfo });
const newUserEmissions = trackEmissions(sut.getUserState$(keyDefinition));
state.nextState("value2");
state.nextState("value3");
expect(emissions).toEqual(["value", "value2", "value3"]);
expect(newUserEmissions).toEqual([null]);
});
},
);
describe("getUserState$", () => {
const accountInfo = {
email: "email",
emailVerified: false,
name: "name",
status: AuthenticationStatus.LoggedOut,
};
const keyDefinition = new UserKeyDefinition<string>(
new StateDefinition("test", "disk"),
"test",
{
deserializer: (s) => s,
clearOn: [],
},
);
it("should not emit any values until a truthy user id is supplied", async () => {
accountService.activeAccountSubject.next(null);
const state = singleUserStateProvider.getFake(userId, keyDefinition);
state.nextState("value");
const emissions = trackEmissions(sut.getUserState$(keyDefinition));
await awaitAsync();
expect(emissions).toHaveLength(0);
accountService.activeAccountSubject.next({ id: userId, ...accountInfo });
await awaitAsync();
expect(emissions).toEqual(["value"]);
});
});
describe("getUserStateOrDefault$", () => {
const keyDefinition = new UserKeyDefinition<string>(
new StateDefinition("test", "disk"),
"test",
{
deserializer: (s) => s,
clearOn: [],
},
);
it("should emit default value if no userId supplied and first active user id emission in falsy", async () => {
accountService.activeAccountSubject.next(null);
const emissions = trackEmissions(
sut.getUserStateOrDefault$(keyDefinition, {
userId: undefined,
defaultValue: "I'm default!",
}),
);
expect(emissions).toEqual(["I'm default!"]);
});
});
describe("setUserState", () => {
const keyDefinition = new UserKeyDefinition<string>(
new StateDefinition("test", "disk"),
"test",
{
deserializer: (s) => s,
clearOn: [],
},
);
it("should set the state for the active user if no userId is provided", async () => {
const value = "value";
await sut.setUserState(keyDefinition, value);
const state = activeUserStateProvider.getFake(keyDefinition);
expect(state.nextMock).toHaveBeenCalledWith([expect.any(String), value]);
});
it("should not set state for a single user if no userId is provided", async () => {
const value = "value";
await sut.setUserState(keyDefinition, value);
const state = singleUserStateProvider.getFake(userId, keyDefinition);
expect(state.nextMock).not.toHaveBeenCalled();
});
it("should set the state for the provided userId", async () => {
const value = "value";
await sut.setUserState(keyDefinition, value, userId);
const state = singleUserStateProvider.getFake(userId, keyDefinition);
expect(state.nextMock).toHaveBeenCalledWith(value);
});
it("should not set the active user state if userId is provided", async () => {
const value = "value";
await sut.setUserState(keyDefinition, value, userId);
const state = activeUserStateProvider.getFake(keyDefinition);
expect(state.nextMock).not.toHaveBeenCalled();
});
});
it("should bind the activeUserStateProvider", () => {
const keyDefinition = new UserKeyDefinition(new StateDefinition("test", "disk"), "test", {
deserializer: () => null,
clearOn: [],
});
const existing = activeUserStateProvider.get(keyDefinition);
const actual = sut.getActive(keyDefinition);
expect(actual).toBe(existing);
});
it("should bind the singleUserStateProvider", () => {
const userId = "user" as UserId;
const keyDefinition = new UserKeyDefinition(new StateDefinition("test", "disk"), "test", {
deserializer: () => null,
clearOn: [],
});
const existing = singleUserStateProvider.get(userId, keyDefinition);
const actual = sut.getUser(userId, keyDefinition);
expect(actual).toBe(existing);
});
it("should bind the globalStateProvider", () => {
const keyDefinition = new KeyDefinition(new StateDefinition("test", "disk"), "test", {
deserializer: () => null,
});
const existing = globalStateProvider.get(keyDefinition);
const actual = sut.getGlobal(keyDefinition);
expect(actual).toBe(existing);
});
it("should bind the derivedStateProvider", () => {
const derivedDefinition = new DeriveDefinition(new StateDefinition("test", "disk"), "test", {
derive: () => null,
deserializer: () => null,
});
const parentState$ = of(null);
const existing = derivedStateProvider.get(parentState$, derivedDefinition, {});
const actual = sut.getDerived(parentState$, derivedDefinition, {});
expect(actual).toBe(existing);
});
});

View File

@@ -1,79 +1 @@
// FIXME: Update this file to be type safe and remove this and next line
// @ts-strict-ignore
import { Observable, filter, of, switchMap, take } from "rxjs";
import { UserId } from "../../../types/guid";
import { DerivedStateDependencies } from "../../../types/state";
import { DeriveDefinition } from "../derive-definition";
import { DerivedState } from "../derived-state";
import { DerivedStateProvider } from "../derived-state.provider";
import { GlobalStateProvider } from "../global-state.provider";
import { StateProvider } from "../state.provider";
import { UserKeyDefinition } from "../user-key-definition";
import { ActiveUserStateProvider, SingleUserStateProvider } from "../user-state.provider";
export class DefaultStateProvider implements StateProvider {
activeUserId$: Observable<UserId>;
constructor(
private readonly activeUserStateProvider: ActiveUserStateProvider,
private readonly singleUserStateProvider: SingleUserStateProvider,
private readonly globalStateProvider: GlobalStateProvider,
private readonly derivedStateProvider: DerivedStateProvider,
) {
this.activeUserId$ = this.activeUserStateProvider.activeUserId$;
}
getUserState$<T>(userKeyDefinition: UserKeyDefinition<T>, userId?: UserId): Observable<T> {
if (userId) {
return this.getUser<T>(userId, userKeyDefinition).state$;
} else {
return this.activeUserId$.pipe(
filter((userId) => userId != null), // Filter out null-ish user ids since we can't get state for a null user id
take(1),
switchMap((userId) => this.getUser<T>(userId, userKeyDefinition).state$),
);
}
}
getUserStateOrDefault$<T>(
userKeyDefinition: UserKeyDefinition<T>,
config: { userId: UserId | undefined; defaultValue?: T },
): Observable<T> {
const { userId, defaultValue = null } = config;
if (userId) {
return this.getUser<T>(userId, userKeyDefinition).state$;
} else {
return this.activeUserId$.pipe(
take(1),
switchMap((userId) =>
userId != null ? this.getUser<T>(userId, userKeyDefinition).state$ : of(defaultValue),
),
);
}
}
async setUserState<T>(
userKeyDefinition: UserKeyDefinition<T>,
value: T | null,
userId?: UserId,
): Promise<[UserId, T | null]> {
if (userId) {
return [userId, await this.getUser<T>(userId, userKeyDefinition).update(() => value)];
} else {
return await this.getActive<T>(userKeyDefinition).update(() => value);
}
}
getActive: InstanceType<typeof ActiveUserStateProvider>["get"] =
this.activeUserStateProvider.get.bind(this.activeUserStateProvider);
getUser: InstanceType<typeof SingleUserStateProvider>["get"] =
this.singleUserStateProvider.get.bind(this.singleUserStateProvider);
getGlobal: InstanceType<typeof GlobalStateProvider>["get"] = this.globalStateProvider.get.bind(
this.globalStateProvider,
);
getDerived: <TFrom, TTo, TDeps extends DerivedStateDependencies>(
parentState$: Observable<TFrom>,
deriveDefinition: DeriveDefinition<unknown, TTo, TDeps>,
dependencies: TDeps,
) => DerivedState<TTo> = this.derivedStateProvider.get.bind(this.derivedStateProvider);
}
export { DefaultStateProvider } from "@bitwarden/state";

View File

@@ -1,62 +0,0 @@
import { Subject, firstValueFrom } from "rxjs";
import { DeriveDefinition } from "../derive-definition";
import { StateDefinition } from "../state-definition";
import { InlineDerivedState } from "./inline-derived-state";
describe("InlineDerivedState", () => {
const syncDeriveDefinition = new DeriveDefinition<boolean, boolean, Record<string, unknown>>(
new StateDefinition("test", "disk"),
"test",
{
derive: (value, deps) => !value,
deserializer: (value) => value,
},
);
const asyncDeriveDefinition = new DeriveDefinition<boolean, boolean, Record<string, unknown>>(
new StateDefinition("test", "disk"),
"test",
{
derive: async (value, deps) => Promise.resolve(!value),
deserializer: (value) => value,
},
);
const parentState = new Subject<boolean>();
describe("state", () => {
const cases = [
{
it: "works when derive function is sync",
definition: syncDeriveDefinition,
},
{
it: "works when derive function is async",
definition: asyncDeriveDefinition,
},
];
it.each(cases)("$it", async ({ definition }) => {
const sut = new InlineDerivedState(parentState.asObservable(), definition, {});
const valuePromise = firstValueFrom(sut.state$);
parentState.next(true);
const value = await valuePromise;
expect(value).toBe(false);
});
});
describe("forceValue", () => {
it("returns the force value back to the caller", async () => {
const sut = new InlineDerivedState(parentState.asObservable(), syncDeriveDefinition, {});
const value = await sut.forceValue(true);
expect(value).toBe(true);
});
});
});

View File

@@ -1,37 +1 @@
import { Observable, concatMap } from "rxjs";
import { DerivedStateDependencies } from "../../../types/state";
import { DeriveDefinition } from "../derive-definition";
import { DerivedState } from "../derived-state";
import { DerivedStateProvider } from "../derived-state.provider";
export class InlineDerivedStateProvider implements DerivedStateProvider {
get<TFrom, TTo, TDeps extends DerivedStateDependencies>(
parentState$: Observable<TFrom>,
deriveDefinition: DeriveDefinition<TFrom, TTo, TDeps>,
dependencies: TDeps,
): DerivedState<TTo> {
return new InlineDerivedState(parentState$, deriveDefinition, dependencies);
}
}
export class InlineDerivedState<TFrom, TTo, TDeps extends DerivedStateDependencies>
implements DerivedState<TTo>
{
constructor(
parentState$: Observable<TFrom>,
deriveDefinition: DeriveDefinition<TFrom, TTo, TDeps>,
dependencies: TDeps,
) {
this.state$ = parentState$.pipe(
concatMap(async (value) => await deriveDefinition.derive(value, dependencies)),
);
}
state$: Observable<TTo>;
forceValue(value: TTo): Promise<TTo> {
// No need to force anything, we don't keep a cache
return Promise.resolve(value);
}
}
export { InlineDerivedState, InlineDerivedStateProvider } from "@bitwarden/state";

View File

@@ -1,177 +0,0 @@
import { mock } from "jest-mock-extended";
import { StorageServiceProvider } from "@bitwarden/storage-core";
import { mockAccountServiceWith } from "../../../../spec/fake-account-service";
import { FakeStorageService } from "../../../../spec/fake-storage.service";
import { UserId } from "../../../types/guid";
import { LogService } from "../../abstractions/log.service";
import { KeyDefinition } from "../key-definition";
import { StateDefinition } from "../state-definition";
import { StateEventRegistrarService } from "../state-event-registrar.service";
import { UserKeyDefinition } from "../user-key-definition";
import { DefaultActiveUserState } from "./default-active-user-state";
import { DefaultActiveUserStateProvider } from "./default-active-user-state.provider";
import { DefaultGlobalState } from "./default-global-state";
import { DefaultGlobalStateProvider } from "./default-global-state.provider";
import { DefaultSingleUserState } from "./default-single-user-state";
import { DefaultSingleUserStateProvider } from "./default-single-user-state.provider";
describe("Specific State Providers", () => {
const storageServiceProvider = mock<StorageServiceProvider>();
const stateEventRegistrarService = mock<StateEventRegistrarService>();
const logService = mock<LogService>();
let singleSut: DefaultSingleUserStateProvider;
let activeSut: DefaultActiveUserStateProvider;
let globalSut: DefaultGlobalStateProvider;
const fakeUser1 = "00000000-0000-1000-a000-000000000001" as UserId;
beforeEach(() => {
storageServiceProvider.get.mockImplementation((location) => {
return [location, new FakeStorageService()];
});
singleSut = new DefaultSingleUserStateProvider(
storageServiceProvider,
stateEventRegistrarService,
logService,
);
activeSut = new DefaultActiveUserStateProvider(mockAccountServiceWith(null), singleSut);
globalSut = new DefaultGlobalStateProvider(storageServiceProvider, logService);
});
const fakeDiskStateDefinition = new StateDefinition("fake", "disk");
const fakeAlternateDiskStateDefinition = new StateDefinition("fakeAlternate", "disk");
const fakeMemoryStateDefinition = new StateDefinition("fake", "memory");
const makeKeyDefinition = (stateDefinition: StateDefinition, key: string) =>
new KeyDefinition<boolean>(stateDefinition, key, {
deserializer: (b) => b,
});
const makeUserKeyDefinition = (stateDefinition: StateDefinition, key: string) =>
new UserKeyDefinition<boolean>(stateDefinition, key, {
deserializer: (b) => b,
clearOn: [],
});
const keyDefinitions = {
disk: {
keyDefinition: makeKeyDefinition(fakeDiskStateDefinition, "fake"),
userKeyDefinition: makeUserKeyDefinition(fakeDiskStateDefinition, "fake"),
altKeyDefinition: makeKeyDefinition(fakeDiskStateDefinition, "fakeAlternate"),
altUserKeyDefinition: makeUserKeyDefinition(fakeDiskStateDefinition, "fakeAlternate"),
},
memory: {
keyDefinition: makeKeyDefinition(fakeMemoryStateDefinition, "fake"),
userKeyDefinition: makeUserKeyDefinition(fakeMemoryStateDefinition, "fake"),
},
alternateDisk: {
keyDefinition: makeKeyDefinition(fakeAlternateDiskStateDefinition, "fake"),
userKeyDefinition: makeUserKeyDefinition(fakeAlternateDiskStateDefinition, "fake"),
},
};
describe("active provider", () => {
it("returns a DefaultActiveUserState", () => {
const state = activeSut.get(keyDefinitions.disk.userKeyDefinition);
expect(state).toBeInstanceOf(DefaultActiveUserState);
});
it("returns different instances when the storage location differs", () => {
const stateDisk = activeSut.get(keyDefinitions.disk.userKeyDefinition);
const stateMemory = activeSut.get(keyDefinitions.memory.userKeyDefinition);
expect(stateDisk).not.toStrictEqual(stateMemory);
});
it("returns different instances when the state name differs", () => {
const state = activeSut.get(keyDefinitions.disk.userKeyDefinition);
const stateAlt = activeSut.get(keyDefinitions.alternateDisk.userKeyDefinition);
expect(state).not.toStrictEqual(stateAlt);
});
it("returns different instances when the key differs", () => {
const state = activeSut.get(keyDefinitions.disk.userKeyDefinition);
const stateAlt = activeSut.get(keyDefinitions.disk.altUserKeyDefinition);
expect(state).not.toStrictEqual(stateAlt);
});
});
describe("single provider", () => {
it("returns a DefaultSingleUserState", () => {
const state = singleSut.get(fakeUser1, keyDefinitions.disk.userKeyDefinition);
expect(state).toBeInstanceOf(DefaultSingleUserState);
});
it("returns different instances when the storage location differs", () => {
const stateDisk = singleSut.get(fakeUser1, keyDefinitions.disk.userKeyDefinition);
const stateMemory = singleSut.get(fakeUser1, keyDefinitions.memory.userKeyDefinition);
expect(stateDisk).not.toStrictEqual(stateMemory);
});
it("returns different instances when the state name differs", () => {
const state = singleSut.get(fakeUser1, keyDefinitions.disk.userKeyDefinition);
const stateAlt = singleSut.get(fakeUser1, keyDefinitions.alternateDisk.userKeyDefinition);
expect(state).not.toStrictEqual(stateAlt);
});
it("returns different instances when the key differs", () => {
const state = singleSut.get(fakeUser1, keyDefinitions.disk.userKeyDefinition);
const stateAlt = singleSut.get(fakeUser1, keyDefinitions.disk.altUserKeyDefinition);
expect(state).not.toStrictEqual(stateAlt);
});
const fakeUser2 = "00000000-0000-1000-a000-000000000002" as UserId;
it("returns different instances when the user id differs", () => {
const user1State = singleSut.get(fakeUser1, keyDefinitions.disk.userKeyDefinition);
const user2State = singleSut.get(fakeUser2, keyDefinitions.disk.userKeyDefinition);
expect(user1State).not.toStrictEqual(user2State);
});
it("returns an instance with the userId property corresponding to the user id passed in", () => {
const userState = singleSut.get(fakeUser1, keyDefinitions.disk.userKeyDefinition);
expect(userState.userId).toBe(fakeUser1);
});
it("returns cached instance on repeated request", () => {
const stateFirst = singleSut.get(fakeUser1, keyDefinitions.disk.userKeyDefinition);
const stateCached = singleSut.get(fakeUser1, keyDefinitions.disk.userKeyDefinition);
expect(stateFirst).toStrictEqual(stateCached);
});
});
describe("global provider", () => {
it("returns a DefaultGlobalState", () => {
const state = globalSut.get(keyDefinitions.disk.keyDefinition);
expect(state).toBeInstanceOf(DefaultGlobalState);
});
it("returns different instances when the storage location differs", () => {
const stateDisk = globalSut.get(keyDefinitions.disk.keyDefinition);
const stateMemory = globalSut.get(keyDefinitions.memory.keyDefinition);
expect(stateDisk).not.toStrictEqual(stateMemory);
});
it("returns different instances when the state name differs", () => {
const state = globalSut.get(keyDefinitions.disk.keyDefinition);
const stateAlt = globalSut.get(keyDefinitions.alternateDisk.keyDefinition);
expect(state).not.toStrictEqual(stateAlt);
});
it("returns different instances when the key differs", () => {
const state = globalSut.get(keyDefinitions.disk.keyDefinition);
const stateAlt = globalSut.get(keyDefinitions.disk.altKeyDefinition);
expect(state).not.toStrictEqual(stateAlt);
});
it("returns cached instance on repeated request", () => {
const stateFirst = globalSut.get(keyDefinitions.disk.keyDefinition);
const stateCached = globalSut.get(keyDefinitions.disk.keyDefinition);
expect(stateFirst).toStrictEqual(stateCached);
});
});
});

View File

@@ -1,137 +1 @@
// FIXME: Update this file to be type safe and remove this and next line
// @ts-strict-ignore
import {
defer,
filter,
firstValueFrom,
merge,
Observable,
ReplaySubject,
share,
switchMap,
tap,
timeout,
timer,
} from "rxjs";
import { Jsonify } from "type-fest";
import { AbstractStorageService, ObservableStorageService } from "@bitwarden/storage-core";
import { StorageKey } from "../../../types/state";
import { LogService } from "../../abstractions/log.service";
import { DebugOptions } from "../key-definition";
import { populateOptionsWithDefault, StateUpdateOptions } from "../state-update-options";
import { getStoredValue } from "./util";
// The parts of a KeyDefinition this class cares about to make it work
type KeyDefinitionRequirements<T> = {
deserializer: (jsonState: Jsonify<T>) => T | null;
cleanupDelayMs: number;
debug: Required<DebugOptions>;
};
export abstract class StateBase<T, KeyDef extends KeyDefinitionRequirements<T>> {
private updatePromise: Promise<T>;
readonly state$: Observable<T | null>;
constructor(
protected readonly key: StorageKey,
protected readonly storageService: AbstractStorageService & ObservableStorageService,
protected readonly keyDefinition: KeyDef,
protected readonly logService: LogService,
) {
const storageUpdate$ = storageService.updates$.pipe(
filter((storageUpdate) => storageUpdate.key === key),
switchMap(async (storageUpdate) => {
if (storageUpdate.updateType === "remove") {
return null;
}
return await getStoredValue(key, storageService, keyDefinition.deserializer);
}),
);
let state$ = merge(
defer(() => getStoredValue(key, storageService, keyDefinition.deserializer)),
storageUpdate$,
);
if (keyDefinition.debug.enableRetrievalLogging) {
state$ = state$.pipe(
tap({
next: (v) => {
this.logService.info(
`Retrieving '${key}' from storage, value is ${v == null ? "null" : "non-null"}`,
);
},
}),
);
}
// If 0 cleanup is chosen, treat this as absolutely no cache
if (keyDefinition.cleanupDelayMs !== 0) {
state$ = state$.pipe(
share({
connector: () => new ReplaySubject(1),
resetOnRefCountZero: () => timer(keyDefinition.cleanupDelayMs),
}),
);
}
this.state$ = state$;
}
async update<TCombine>(
configureState: (state: T | null, dependency: TCombine) => T | null,
options: StateUpdateOptions<T, TCombine> = {},
): Promise<T | null> {
options = populateOptionsWithDefault(options);
if (this.updatePromise != null) {
await this.updatePromise;
}
try {
this.updatePromise = this.internalUpdate(configureState, options);
return await this.updatePromise;
} finally {
this.updatePromise = null;
}
}
private async internalUpdate<TCombine>(
configureState: (state: T | null, dependency: TCombine) => T | null,
options: StateUpdateOptions<T, TCombine>,
): Promise<T | null> {
const currentState = await this.getStateForUpdate();
const combinedDependencies =
options.combineLatestWith != null
? await firstValueFrom(options.combineLatestWith.pipe(timeout(options.msTimeout)))
: null;
if (!options.shouldUpdate(currentState, combinedDependencies)) {
return currentState;
}
const newState = configureState(currentState, combinedDependencies);
await this.doStorageSave(newState, currentState);
return newState;
}
protected async doStorageSave(newState: T | null, oldState: T) {
if (this.keyDefinition.debug.enableUpdateLogging) {
this.logService.info(
`Updating '${this.key}' from ${oldState == null ? "null" : "non-null"} to ${newState == null ? "null" : "non-null"}`,
);
}
await this.storageService.save(this.key, newState);
}
/** For use in update methods, does not wait for update to complete before yielding state.
* The expectation is that that await is already done
*/
private async getStateForUpdate() {
return await getStoredValue(this.key, this.storageService, this.keyDefinition.deserializer);
}
}
export { StateBase } from "@bitwarden/state";

View File

@@ -1,56 +0,0 @@
import { FakeStorageService } from "../../../../spec/fake-storage.service";
import { getStoredValue } from "./util";
describe("getStoredValue", () => {
const key = "key";
const deserializedValue = { value: 1 };
const value = JSON.stringify(deserializedValue);
const deserializer = (v: string) => JSON.parse(v);
let storageService: FakeStorageService;
beforeEach(() => {
storageService = new FakeStorageService();
});
describe("when the storage service requires deserialization", () => {
beforeEach(() => {
storageService.internalUpdateValuesRequireDeserialization(true);
});
it("should deserialize", async () => {
// FIXME: Verify that this floating promise is intentional. If it is, add an explanatory comment and ensure there is proper error handling.
// eslint-disable-next-line @typescript-eslint/no-floating-promises
storageService.save(key, value);
const result = await getStoredValue(key, storageService, deserializer);
expect(result).toEqual(deserializedValue);
});
});
describe("when the storage service does not require deserialization", () => {
beforeEach(() => {
storageService.internalUpdateValuesRequireDeserialization(false);
});
it("should not deserialize", async () => {
// FIXME: Verify that this floating promise is intentional. If it is, add an explanatory comment and ensure there is proper error handling.
// eslint-disable-next-line @typescript-eslint/no-floating-promises
storageService.save(key, value);
const result = await getStoredValue(key, storageService, deserializer);
expect(result).toEqual(value);
});
it("should convert undefined to null", async () => {
// FIXME: Verify that this floating promise is intentional. If it is, add an explanatory comment and ensure there is proper error handling.
// eslint-disable-next-line @typescript-eslint/no-floating-promises
storageService.save(key, undefined);
const result = await getStoredValue(key, storageService, deserializer);
expect(result).toEqual(null);
});
});
});

View File

@@ -1,17 +0,0 @@
import { Jsonify } from "type-fest";
import { AbstractStorageService } from "@bitwarden/storage-core";
export async function getStoredValue<T>(
key: string,
storage: AbstractStorageService,
deserializer: (jsonValue: Jsonify<T>) => T | null,
) {
if (storage.valuesRequireDeserialization) {
const jsonValue = await storage.get<Jsonify<T>>(key);
return deserializer(jsonValue);
} else {
const value = await storage.get<T>(key);
return value ?? null;
}
}

View File

@@ -1,14 +1 @@
export { DeriveDefinition } from "./derive-definition";
export { DerivedStateProvider } from "./derived-state.provider";
export { DerivedState } from "./derived-state";
export { GlobalState } from "./global-state";
export { StateProvider } from "./state.provider";
export { GlobalStateProvider } from "./global-state.provider";
export { ActiveUserState, SingleUserState, CombinedState } from "./user-state";
export { ActiveUserStateProvider, SingleUserStateProvider } from "./user-state.provider";
export { KeyDefinition, KeyDefinitionOptions } from "./key-definition";
export { StateUpdateOptions } from "./state-update-options";
export { UserKeyDefinitionOptions, UserKeyDefinition } from "./user-key-definition";
export { StateEventRunnerService } from "./state-event-runner.service";
export * from "./state-definitions";
export * from "@bitwarden/state";

View File

@@ -1,204 +0,0 @@
import { Opaque } from "type-fest";
import { DebugOptions, KeyDefinition } from "./key-definition";
import { StateDefinition } from "./state-definition";
const fakeStateDefinition = new StateDefinition("fake", "disk");
type FancyString = Opaque<string, "FancyString">;
describe("KeyDefinition", () => {
describe("constructor", () => {
it("throws on undefined deserializer", () => {
expect(() => {
new KeyDefinition<boolean>(fakeStateDefinition, "fake", {
deserializer: undefined,
});
});
});
it("normalizes debug options set to undefined", () => {
const keyDefinition = new KeyDefinition(fakeStateDefinition, "fake", {
deserializer: (v) => v,
debug: undefined,
});
expect(keyDefinition.debug.enableUpdateLogging).toBe(false);
});
it("normalizes no debug options", () => {
const keyDefinition = new KeyDefinition(fakeStateDefinition, "fake", {
deserializer: (v) => v,
});
expect(keyDefinition.debug.enableUpdateLogging).toBe(false);
});
const cases: {
debug: DebugOptions | undefined;
expectedEnableUpdateLogging: boolean;
expectedEnableRetrievalLogging: boolean;
}[] = [
{
debug: undefined,
expectedEnableUpdateLogging: false,
expectedEnableRetrievalLogging: false,
},
{
debug: {},
expectedEnableUpdateLogging: false,
expectedEnableRetrievalLogging: false,
},
{
debug: {
enableUpdateLogging: false,
},
expectedEnableUpdateLogging: false,
expectedEnableRetrievalLogging: false,
},
{
debug: {
enableRetrievalLogging: false,
},
expectedEnableUpdateLogging: false,
expectedEnableRetrievalLogging: false,
},
{
debug: {
enableUpdateLogging: true,
},
expectedEnableUpdateLogging: true,
expectedEnableRetrievalLogging: false,
},
{
debug: {
enableRetrievalLogging: true,
},
expectedEnableUpdateLogging: false,
expectedEnableRetrievalLogging: true,
},
{
debug: {
enableRetrievalLogging: false,
enableUpdateLogging: false,
},
expectedEnableUpdateLogging: false,
expectedEnableRetrievalLogging: false,
},
{
debug: {
enableRetrievalLogging: true,
enableUpdateLogging: true,
},
expectedEnableUpdateLogging: true,
expectedEnableRetrievalLogging: true,
},
];
it.each(cases)(
"normalizes debug options to correct values when given $debug",
({ debug, expectedEnableUpdateLogging, expectedEnableRetrievalLogging }) => {
const keyDefinition = new KeyDefinition(fakeStateDefinition, "fake", {
deserializer: (v) => v,
debug: debug,
});
expect(keyDefinition.debug.enableUpdateLogging).toBe(expectedEnableUpdateLogging);
expect(keyDefinition.debug.enableRetrievalLogging).toBe(expectedEnableRetrievalLogging);
},
);
});
describe("cleanupDelayMs", () => {
it("defaults to 1000ms", () => {
const keyDefinition = new KeyDefinition<boolean>(fakeStateDefinition, "fake", {
deserializer: (value) => value,
});
expect(keyDefinition).toBeTruthy();
expect(keyDefinition.cleanupDelayMs).toBe(1000);
});
it("can be overridden", () => {
const keyDefinition = new KeyDefinition<boolean>(fakeStateDefinition, "fake", {
deserializer: (value) => value,
cleanupDelayMs: 500,
});
expect(keyDefinition).toBeTruthy();
expect(keyDefinition.cleanupDelayMs).toBe(500);
});
it("throws on negative", () => {
expect(
() =>
new KeyDefinition<boolean>(fakeStateDefinition, "fake", {
deserializer: (value) => value,
cleanupDelayMs: -1,
}),
).toThrow();
});
});
describe("record", () => {
it("runs custom deserializer for each record value", () => {
const recordDefinition = KeyDefinition.record<boolean>(fakeStateDefinition, "fake", {
// Intentionally negate the value for testing
deserializer: (value) => !value,
});
expect(recordDefinition).toBeTruthy();
expect(recordDefinition.deserializer).toBeTruthy();
const deserializedValue = recordDefinition.deserializer({
test1: false,
test2: true,
});
expect(Object.keys(deserializedValue)).toHaveLength(2);
// Values should have swapped from their initial value
expect(deserializedValue["test1"]).toBeTruthy();
expect(deserializedValue["test2"]).toBeFalsy();
});
it("can handle fancy string type", () => {
// This test is more of a test that I got the typescript typing correctly than actually testing any business logic
const recordDefinition = KeyDefinition.record<boolean, FancyString>(
fakeStateDefinition,
"fake",
{
deserializer: (value) => !value,
},
);
const fancyRecord = recordDefinition.deserializer(
JSON.parse(`{ "myKey": false, "mySecondKey": true }`),
);
expect(fancyRecord).toBeTruthy();
expect(Object.keys(fancyRecord)).toHaveLength(2);
expect(fancyRecord["myKey" as FancyString]).toBeTruthy();
expect(fancyRecord["mySecondKey" as FancyString]).toBeFalsy();
});
});
describe("array", () => {
it("run custom deserializer for each array element", () => {
const arrayDefinition = KeyDefinition.array<boolean>(fakeStateDefinition, "fake", {
deserializer: (value) => !value,
});
expect(arrayDefinition).toBeTruthy();
expect(arrayDefinition.deserializer).toBeTruthy();
// NOTE: `as any` is here until we migrate to Nx: https://bitwarden.atlassian.net/browse/PM-6493
const deserializedValue = arrayDefinition.deserializer([false, true] as any);
expect(deserializedValue).toBeTruthy();
expect(deserializedValue).toHaveLength(2);
expect(deserializedValue[0]).toBeTruthy();
expect(deserializedValue[1]).toBeFalsy();
});
});
});

View File

@@ -1,182 +1 @@
// FIXME: Update this file to be type safe and remove this and next line
// @ts-strict-ignore
import { Jsonify } from "type-fest";
import { StorageKey } from "../../types/state";
import { array, record } from "./deserialization-helpers";
import { StateDefinition } from "./state-definition";
export type DebugOptions = {
/**
* When true, logs will be written that look like the following:
*
* ```
* "Updating 'global_myState_myKey' from null to non-null"
* "Updating 'user_32265eda-62ff-4797-9ead-22214772f888_myState_myKey' from non-null to null."
* ```
*
* It does not include the value of the data, only whether it is null or non-null.
*/
enableUpdateLogging?: boolean;
/**
* When true, logs will be written that look like the following everytime a value is retrieved from storage.
*
* "Retrieving 'global_myState_myKey' from storage, value is null."
* "Retrieving 'user_32265eda-62ff-4797-9ead-22214772f888_myState_myKey' from storage, value is non-null."
*/
enableRetrievalLogging?: boolean;
};
/**
* A set of options for customizing the behavior of a {@link KeyDefinition}
*/
export type KeyDefinitionOptions<T> = {
/**
* A function to use to safely convert your type from json to your expected type.
*
* **Important:** Your data may be serialized/deserialized at any time and this
* callback needs to be able to faithfully re-initialize from the JSON object representation of your type.
*
* @param jsonValue The JSON object representation of your state.
* @returns The fully typed version of your state.
*/
readonly deserializer: (jsonValue: Jsonify<T>) => T | null;
/**
* The number of milliseconds to wait before cleaning up the state after the last subscriber has unsubscribed.
* Defaults to 1000ms.
*/
readonly cleanupDelayMs?: number;
/**
* Options for configuring the debugging behavior, see individual options for more info.
*/
readonly debug?: DebugOptions;
};
/**
* KeyDefinitions describe the precise location to store data for a given piece of state.
* The StateDefinition is used to describe the domain of the state, and the KeyDefinition
* sub-divides that domain into specific keys.
*/
export class KeyDefinition<T> {
readonly debug: Required<DebugOptions>;
/**
* Creates a new instance of a KeyDefinition
* @param stateDefinition The state definition for which this key belongs to.
* @param key The name of the key, this should be unique per domain.
* @param options A set of options to customize the behavior of {@link KeyDefinition}. All options are required.
* @param options.deserializer A function to use to safely convert your type from json to your expected type.
* Your data may be serialized/deserialized at any time and this needs callback needs to be able to faithfully re-initialize
* from the JSON object representation of your type.
*/
constructor(
readonly stateDefinition: StateDefinition,
readonly key: string,
private readonly options: KeyDefinitionOptions<T>,
) {
if (options.deserializer == null) {
throw new Error(`'deserializer' is a required property on key ${this.errorKeyName}`);
}
if (options.cleanupDelayMs < 0) {
throw new Error(
`'cleanupDelayMs' must be greater than or equal to 0. Value of ${options.cleanupDelayMs} passed to key ${this.errorKeyName} `,
);
}
// Normalize optional debug options
const { enableUpdateLogging = false, enableRetrievalLogging = false } = options.debug ?? {};
this.debug = {
enableUpdateLogging,
enableRetrievalLogging,
};
}
/**
* Gets the deserializer configured for this {@link KeyDefinition}
*/
get deserializer() {
return this.options.deserializer;
}
/**
* Gets the number of milliseconds to wait before cleaning up the state after the last subscriber has unsubscribed.
*/
get cleanupDelayMs() {
return this.options.cleanupDelayMs < 0 ? 0 : (this.options.cleanupDelayMs ?? 1000);
}
/**
* Creates a {@link KeyDefinition} for state that is an array.
* @param stateDefinition The state definition to be added to the KeyDefinition
* @param key The key to be added to the KeyDefinition
* @param options The options to customize the final {@link KeyDefinition}.
* @returns A {@link KeyDefinition} initialized for arrays, the options run
* the deserializer on the provided options for each element of an array.
*
* @example
* ```typescript
* const MY_KEY = KeyDefinition.array<MyArrayElement>(MY_STATE, "key", {
* deserializer: (myJsonElement) => convertToElement(myJsonElement),
* });
* ```
*/
static array<T>(
stateDefinition: StateDefinition,
key: string,
// We have them provide options for the element of the array, depending on future options we add, this could get a little weird.
options: KeyDefinitionOptions<T>, // The array helper forces an initialValue of an empty array
) {
return new KeyDefinition<T[]>(stateDefinition, key, {
...options,
deserializer: array((e) => options.deserializer(e)),
});
}
/**
* Creates a {@link KeyDefinition} for state that is a record.
* @param stateDefinition The state definition to be added to the KeyDefinition
* @param key The key to be added to the KeyDefinition
* @param options The options to customize the final {@link KeyDefinition}.
* @returns A {@link KeyDefinition} that contains a serializer that will run the provided deserializer for each
* value in a record and returns every key as a string.
*
* @example
* ```typescript
* const MY_KEY = KeyDefinition.record<MyRecordValue>(MY_STATE, "key", {
* deserializer: (myJsonValue) => convertToValue(myJsonValue),
* });
* ```
*/
static record<T, TKey extends string | number = string>(
stateDefinition: StateDefinition,
key: string,
// We have them provide options for the value of the record, depending on future options we add, this could get a little weird.
options: KeyDefinitionOptions<T>, // The array helper forces an initialValue of an empty record
) {
return new KeyDefinition<Record<TKey, T>>(stateDefinition, key, {
...options,
deserializer: record((v) => options.deserializer(v)),
});
}
get fullName() {
return `${this.stateDefinition.name}_${this.key}`;
}
protected get errorKeyName() {
return `${this.stateDefinition.name} > ${this.key}`;
}
}
/**
* Creates a {@link StorageKey}
* @param keyDefinition The key definition of which data the key should point to.
* @returns A key that is ready to be used in a storage service to get data.
*/
export function globalKeyBuilder(keyDefinition: KeyDefinition<unknown>): StorageKey {
return `global_${keyDefinition.stateDefinition.name}_${keyDefinition.key}` as StorageKey;
}
export { KeyDefinition, KeyDefinitionOptions } from "@bitwarden/state";

View File

@@ -1,24 +1,4 @@
import { StorageLocation, ClientLocations } from "@bitwarden/storage-core";
export { StateDefinition } from "@bitwarden/state";
// To be removed once references are updated to point to @bitwarden/storage-core
export { StorageLocation, ClientLocations };
/**
* Defines the base location and instruction of where this state is expected to be located.
*/
export class StateDefinition {
readonly storageLocationOverrides: Partial<ClientLocations>;
/**
* Creates a new instance of {@link StateDefinition}, the creation of which is owned by the platform team.
* @param name The name of the state, this needs to be unique from all other {@link StateDefinition}'s.
* @param defaultStorageLocation The location of where this state should be stored.
*/
constructor(
readonly name: string,
readonly defaultStorageLocation: StorageLocation,
storageLocationOverrides?: Partial<ClientLocations>,
) {
this.storageLocationOverrides = storageLocationOverrides ?? {};
}
}
export { StorageLocation, ClientLocations } from "@bitwarden/storage-core";

View File

@@ -1,60 +0,0 @@
import { ClientLocations, StateDefinition } from "./state-definition";
import * as stateDefinitionsRecord from "./state-definitions";
describe.each(["web", "cli", "desktop", "browser"])(
"state definitions follow rules for client %s",
(clientType: keyof ClientLocations) => {
const trackedNames: [string, string][] = [];
test.each(Object.entries(stateDefinitionsRecord))(
"that export %s follows all rules",
(exportName, stateDefinition) => {
// All exports from state-definitions are expected to be StateDefinition's
if (!(stateDefinition instanceof StateDefinition)) {
throw new Error(`export ${exportName} is expected to be a StateDefinition`);
}
const storageLocation =
stateDefinition.storageLocationOverrides[clientType] ??
stateDefinition.defaultStorageLocation;
const fullName = `${stateDefinition.name}_${storageLocation}`;
const exactConflictingExport = trackedNames.find(
([_, trackedName]) => trackedName === fullName,
);
if (exactConflictingExport !== undefined) {
const [conflictingExportName] = exactConflictingExport;
throw new Error(
`The export '${exportName}' has a conflicting state name and storage location with export ` +
`'${conflictingExportName}' please ensure that you choose a unique name and location for all clients.`,
);
}
const roughConflictingExport = trackedNames.find(
([_, trackedName]) => trackedName.toLowerCase() === fullName.toLowerCase(),
);
if (roughConflictingExport !== undefined) {
const [conflictingExportName] = roughConflictingExport;
throw new Error(
`The export '${exportName}' differs its state name and storage location ` +
`only by casing with export '${conflictingExportName}' please ensure it differs by more than casing.`,
);
}
const name = stateDefinition.name;
expect(name).not.toBeUndefined(); // undefined in an invalid name
expect(name).not.toBeNull(); // null is in invalid name
expect(name.length).toBeGreaterThan(3); // A 3 characters or less name is not descriptive enough
expect(name[0]).toEqual(name[0].toLowerCase()); // First character should be lower case since camelCase is required
expect(name).not.toContain(" "); // There should be no spaces in a state name
expect(name).not.toContain("_"); // We should not be doing snake_case for state name
// NOTE: We could expect some details about the export name as well
trackedNames.push([exportName, fullName]);
},
);
},
);

View File

@@ -1,215 +0,0 @@
import { StateDefinition } from "./state-definition";
/**
* `StateDefinition`s comes with some rules, to facilitate a quick review from
* platform of this file, ensure you follow these rules, the ones marked with (tested)
* have unit tests that you can run locally.
*
* 1. (tested) Names should not be null or undefined
* 2. (tested) Name and storage location should be unique
* 3. (tested) Name and storage location can't differ from another export by only casing
* 4. (tested) Name should be longer than 3 characters. It should be descriptive, but brief.
* 5. (tested) Name should not contain spaces or underscores
* 6. Name should be human readable
* 7. Name should be in camelCase format (unit tests ensure the first character is lowercase)
* 8. Teams should only use state definitions they have created
* 9. StateDefinitions should only be used for keys relating to the state name they chose
*
*/
// Admin Console
export const ORGANIZATIONS_DISK = new StateDefinition("organizations", "disk");
export const POLICIES_DISK = new StateDefinition("policies", "disk");
export const PROVIDERS_DISK = new StateDefinition("providers", "disk");
export const ORGANIZATION_MANAGEMENT_PREFERENCES_DISK = new StateDefinition(
"organizationManagementPreferences",
"disk",
{
web: "disk-local",
},
);
export const DELETE_MANAGED_USER_WARNING = new StateDefinition(
"showDeleteManagedUserWarning",
"disk",
{
web: "disk-local",
},
);
// Billing
export const BILLING_DISK = new StateDefinition("billing", "disk");
// Auth
export const ACCOUNT_DISK = new StateDefinition("account", "disk");
export const ACCOUNT_MEMORY = new StateDefinition("account", "memory");
export const AUTH_REQUEST_DISK_LOCAL = new StateDefinition("authRequestLocal", "disk", {
web: "disk-local",
});
export const AVATAR_DISK = new StateDefinition("avatar", "disk", { web: "disk-local" });
export const DEVICE_TRUST_DISK_LOCAL = new StateDefinition("deviceTrust", "disk", {
web: "disk-local",
browser: "disk-backup-local-storage",
});
export const KDF_CONFIG_DISK = new StateDefinition("kdfConfig", "disk");
export const KEY_CONNECTOR_DISK = new StateDefinition("keyConnector", "disk");
export const LOGIN_EMAIL_DISK = new StateDefinition("loginEmail", "disk", {
web: "disk-local",
});
export const LOGIN_EMAIL_MEMORY = new StateDefinition("loginEmail", "memory");
export const LOGIN_STRATEGY_MEMORY = new StateDefinition("loginStrategy", "memory");
export const MASTER_PASSWORD_DISK = new StateDefinition("masterPassword", "disk");
export const MASTER_PASSWORD_MEMORY = new StateDefinition("masterPassword", "memory");
export const PIN_DISK = new StateDefinition("pinUnlock", "disk");
export const PIN_MEMORY = new StateDefinition("pinUnlock", "memory");
export const ROUTER_DISK = new StateDefinition("router", "disk");
export const SSO_DISK = new StateDefinition("ssoLogin", "disk");
export const TOKEN_DISK = new StateDefinition("token", "disk");
export const TOKEN_DISK_LOCAL = new StateDefinition("tokenDiskLocal", "disk", {
web: "disk-local",
});
export const TOKEN_MEMORY = new StateDefinition("token", "memory");
export const TWO_FACTOR_MEMORY = new StateDefinition("twoFactor", "memory");
export const USER_DECRYPTION_OPTIONS_DISK = new StateDefinition("userDecryptionOptions", "disk");
export const ORGANIZATION_INVITE_DISK = new StateDefinition("organizationInvite", "disk");
export const VAULT_TIMEOUT_SETTINGS_DISK_LOCAL = new StateDefinition(
"vaultTimeoutSettings",
"disk",
{
web: "disk-local",
},
);
// Autofill
export const BADGE_SETTINGS_DISK = new StateDefinition("badgeSettings", "disk");
export const USER_NOTIFICATION_SETTINGS_DISK = new StateDefinition(
"userNotificationSettings",
"disk",
);
export const DOMAIN_SETTINGS_DISK = new StateDefinition("domainSettings", "disk");
export const AUTOFILL_SETTINGS_DISK = new StateDefinition("autofillSettings", "disk");
export const AUTOFILL_SETTINGS_DISK_LOCAL = new StateDefinition("autofillSettingsLocal", "disk", {
web: "disk-local",
});
export const AUTOTYPE_SETTINGS_DISK = new StateDefinition("autotypeSettings", "disk");
// Components
export const NEW_WEB_LAYOUT_BANNER_DISK = new StateDefinition("newWebLayoutBanner", "disk", {
web: "disk-local",
});
// Platform
export const APPLICATION_ID_DISK = new StateDefinition("applicationId", "disk", {
web: "disk-local",
});
export const BADGE_MEMORY = new StateDefinition("badge", "memory", {
browser: "memory-large-object",
});
export const BIOMETRIC_SETTINGS_DISK = new StateDefinition("biometricSettings", "disk");
export const CLEAR_EVENT_DISK = new StateDefinition("clearEvent", "disk");
export const CONFIG_DISK = new StateDefinition("config", "disk", {
web: "disk-local",
});
export const CRYPTO_DISK = new StateDefinition("crypto", "disk");
export const CRYPTO_MEMORY = new StateDefinition("crypto", "memory");
export const DESKTOP_SETTINGS_DISK = new StateDefinition("desktopSettings", "disk");
export const ENVIRONMENT_DISK = new StateDefinition("environment", "disk");
export const ENVIRONMENT_MEMORY = new StateDefinition("environment", "memory");
export const POPUP_VIEW_MEMORY = new StateDefinition("popupView", "memory", {
browser: "memory-large-object",
});
export const SYNC_DISK = new StateDefinition("sync", "disk", { web: "memory" });
export const THEMING_DISK = new StateDefinition("theming", "disk", { web: "disk-local" });
export const TRANSLATION_DISK = new StateDefinition("translation", "disk", { web: "disk-local" });
export const ANIMATION_DISK = new StateDefinition("animation", "disk");
export const TASK_SCHEDULER_DISK = new StateDefinition("taskScheduler", "disk");
export const EXTENSION_INITIAL_INSTALL_DISK = new StateDefinition(
"extensionInitialInstall",
"disk",
);
export const WEB_PUSH_SUBSCRIPTION = new StateDefinition("webPushSubscription", "disk", {
web: "disk-local",
});
// Design System
export const POPUP_STYLE_DISK = new StateDefinition("popupStyle", "disk");
// Secrets Manager
export const SM_ONBOARDING_DISK = new StateDefinition("smOnboarding", "disk", {
web: "disk-local",
});
// Tools
export const EXTENSION_DISK = new StateDefinition("extension", "disk");
export const GENERATOR_DISK = new StateDefinition("generator", "disk");
export const GENERATOR_MEMORY = new StateDefinition("generator", "memory");
export const BROWSER_SEND_MEMORY = new StateDefinition("sendBrowser", "memory");
export const EVENT_COLLECTION_DISK = new StateDefinition("eventCollection", "disk");
export const SEND_DISK = new StateDefinition("encryptedSend", "disk", {
web: "memory",
});
export const SEND_MEMORY = new StateDefinition("decryptedSend", "memory", {
browser: "memory-large-object",
});
export const SEND_ACCESS_AUTH_MEMORY = new StateDefinition("sendAccessAuth", "memory");
// Vault
export const COLLECTION_DATA = new StateDefinition("collection", "disk", {
web: "memory",
});
export const FOLDER_DISK = new StateDefinition("folder", "disk", { web: "memory" });
export const FOLDER_MEMORY = new StateDefinition("decryptedFolders", "memory", {
browser: "memory-large-object",
});
export const VAULT_FILTER_DISK = new StateDefinition("vaultFilter", "disk", {
web: "disk-local",
});
export const VAULT_ONBOARDING = new StateDefinition("vaultOnboarding", "disk", {
web: "disk-local",
});
export const VAULT_SETTINGS_DISK = new StateDefinition("vaultSettings", "disk", {
web: "disk-local",
});
export const VAULT_BROWSER_MEMORY = new StateDefinition("vaultBrowser", "memory", {
browser: "memory-large-object",
});
export const VAULT_SEARCH_MEMORY = new StateDefinition("vaultSearch", "memory", {
browser: "memory-large-object",
});
export const CIPHERS_DISK = new StateDefinition("ciphers", "disk", { web: "memory" });
export const CIPHERS_DISK_LOCAL = new StateDefinition("ciphersLocal", "disk", {
web: "disk-local",
});
export const CIPHERS_MEMORY = new StateDefinition("ciphersMemory", "memory", {
browser: "memory-large-object",
});
export const PREMIUM_BANNER_DISK_LOCAL = new StateDefinition("premiumBannerReprompt", "disk", {
web: "disk-local",
});
export const BANNERS_DISMISSED_DISK = new StateDefinition("bannersDismissed", "disk");
export const VAULT_APPEARANCE = new StateDefinition("vaultAppearance", "disk");
export const SECURITY_TASKS_DISK = new StateDefinition("securityTasks", "disk");
export const AT_RISK_PASSWORDS_PAGE_DISK = new StateDefinition("atRiskPasswordsPage", "disk");
export const NOTIFICATION_DISK = new StateDefinition("notifications", "disk");
export const NUDGES_DISK = new StateDefinition("nudges", "disk", { web: "disk-local" });
export const SETUP_EXTENSION_DISMISSED_DISK = new StateDefinition(
"setupExtensionDismissed",
"disk",
{
web: "disk-local",
},
);
export const VAULT_BROWSER_INTRO_CAROUSEL = new StateDefinition(
"vaultBrowserIntroCarousel",
"disk",
);

View File

@@ -1,89 +0,0 @@
import { mock } from "jest-mock-extended";
import {
AbstractStorageService,
ObservableStorageService,
StorageServiceProvider,
} from "@bitwarden/storage-core";
import { FakeGlobalStateProvider } from "../../../spec";
import { StateDefinition } from "./state-definition";
import { STATE_LOCK_EVENT, StateEventRegistrarService } from "./state-event-registrar.service";
import { UserKeyDefinition } from "./user-key-definition";
describe("StateEventRegistrarService", () => {
const globalStateProvider = new FakeGlobalStateProvider();
const lockState = globalStateProvider.getFake(STATE_LOCK_EVENT);
const storageServiceProvider = mock<StorageServiceProvider>();
const sut = new StateEventRegistrarService(globalStateProvider, storageServiceProvider);
describe("registerEvents", () => {
const fakeKeyDefinition = new UserKeyDefinition<boolean>(
new StateDefinition("fakeState", "disk"),
"fakeKey",
{
deserializer: (s) => s,
clearOn: ["lock"],
},
);
beforeEach(() => {
jest.resetAllMocks();
});
it("adds event on null storage", async () => {
storageServiceProvider.get.mockReturnValue([
"disk",
mock<AbstractStorageService & ObservableStorageService>(),
]);
await sut.registerEvents(fakeKeyDefinition);
expect(lockState.nextMock).toHaveBeenCalledWith([
{
key: "fakeKey",
location: "disk",
state: "fakeState",
},
]);
});
it("adds event on empty array in storage", async () => {
lockState.stateSubject.next([]);
storageServiceProvider.get.mockReturnValue([
"disk",
mock<AbstractStorageService & ObservableStorageService>(),
]);
await sut.registerEvents(fakeKeyDefinition);
expect(lockState.nextMock).toHaveBeenCalledWith([
{
key: "fakeKey",
location: "disk",
state: "fakeState",
},
]);
});
it("doesn't add a duplicate", async () => {
lockState.stateSubject.next([
{
key: "fakeKey",
location: "disk",
state: "fakeState",
},
]);
storageServiceProvider.get.mockReturnValue([
"disk",
mock<AbstractStorageService & ObservableStorageService>(),
]);
await sut.registerEvents(fakeKeyDefinition);
expect(lockState.nextMock).not.toHaveBeenCalled();
});
});
});

View File

@@ -1,76 +1,6 @@
import { PossibleLocation, StorageServiceProvider } from "../services/storage-service.provider";
import { GlobalState } from "./global-state";
import { GlobalStateProvider } from "./global-state.provider";
import { KeyDefinition } from "./key-definition";
import { CLEAR_EVENT_DISK } from "./state-definitions";
import { ClearEvent, UserKeyDefinition } from "./user-key-definition";
export type StateEventInfo = {
state: string;
key: string;
location: PossibleLocation;
};
export const STATE_LOCK_EVENT = KeyDefinition.array<StateEventInfo>(CLEAR_EVENT_DISK, "lock", {
deserializer: (e) => e,
});
export const STATE_LOGOUT_EVENT = KeyDefinition.array<StateEventInfo>(CLEAR_EVENT_DISK, "logout", {
deserializer: (e) => e,
});
export class StateEventRegistrarService {
private readonly stateEventStateMap: { [Prop in ClearEvent]: GlobalState<StateEventInfo[]> };
constructor(
globalStateProvider: GlobalStateProvider,
private storageServiceProvider: StorageServiceProvider,
) {
this.stateEventStateMap = {
lock: globalStateProvider.get(STATE_LOCK_EVENT),
logout: globalStateProvider.get(STATE_LOGOUT_EVENT),
};
}
async registerEvents(keyDefinition: UserKeyDefinition<unknown>) {
for (const clearEvent of keyDefinition.clearOn) {
const eventState = this.stateEventStateMap[clearEvent];
// Determine the storage location for this
const [storageLocation] = this.storageServiceProvider.get(
keyDefinition.stateDefinition.defaultStorageLocation,
keyDefinition.stateDefinition.storageLocationOverrides,
);
const newEvent: StateEventInfo = {
state: keyDefinition.stateDefinition.name,
key: keyDefinition.key,
location: storageLocation,
};
// Only update the event state if the existing list doesn't have a matching entry
await eventState.update(
(existingTickets) => {
existingTickets ??= [];
existingTickets.push(newEvent);
return existingTickets;
},
{
shouldUpdate: (currentTickets) => {
return (
// If the current tickets are null, then it will for sure be added
currentTickets == null ||
// If an existing match couldn't be found, we also need to add one
currentTickets.findIndex(
(e) =>
e.state === newEvent.state &&
e.key === newEvent.key &&
e.location === newEvent.location,
) === -1
);
},
},
);
}
}
}
export {
StateEventRegistrarService,
StateEventInfo,
STATE_LOCK_EVENT,
STATE_LOGOUT_EVENT,
} from "@bitwarden/state";

View File

@@ -1,73 +0,0 @@
import { mock } from "jest-mock-extended";
import {
AbstractStorageService,
ObservableStorageService,
StorageServiceProvider,
} from "@bitwarden/storage-core";
import { FakeGlobalStateProvider } from "../../../spec";
import { UserId } from "../../types/guid";
import { STATE_LOCK_EVENT } from "./state-event-registrar.service";
import { StateEventRunnerService } from "./state-event-runner.service";
describe("EventRunnerService", () => {
const fakeGlobalStateProvider = new FakeGlobalStateProvider();
const lockState = fakeGlobalStateProvider.getFake(STATE_LOCK_EVENT);
const storageServiceProvider = mock<StorageServiceProvider>();
const sut = new StateEventRunnerService(fakeGlobalStateProvider, storageServiceProvider);
describe("handleEvent", () => {
it("does nothing if there are no events in state", async () => {
const mockStorageService = mock<AbstractStorageService & ObservableStorageService>();
storageServiceProvider.get.mockReturnValue(["disk", mockStorageService]);
await sut.handleEvent("lock", "bff09d3c-762a-4551-9275-45b137b2f073" as UserId);
expect(lockState.nextMock).not.toHaveBeenCalled();
});
it("loops through and acts on all events", async () => {
const mockDiskStorageService = mock<AbstractStorageService & ObservableStorageService>();
const mockMemoryStorageService = mock<AbstractStorageService & ObservableStorageService>();
lockState.stateSubject.next([
{
state: "fakeState1",
key: "fakeKey1",
location: "disk",
},
{
state: "fakeState2",
key: "fakeKey2",
location: "memory",
},
]);
storageServiceProvider.get.mockImplementation((defaultLocation, overrides) => {
if (defaultLocation === "disk") {
return [defaultLocation, mockDiskStorageService];
} else if (defaultLocation === "memory") {
return [defaultLocation, mockMemoryStorageService];
}
});
mockMemoryStorageService.get.mockResolvedValue("something");
await sut.handleEvent("lock", "bff09d3c-762a-4551-9275-45b137b2f073" as UserId);
expect(mockDiskStorageService.get).toHaveBeenCalledTimes(1);
expect(mockDiskStorageService.get).toHaveBeenCalledWith(
"user_bff09d3c-762a-4551-9275-45b137b2f073_fakeState1_fakeKey1",
);
expect(mockMemoryStorageService.get).toHaveBeenCalledTimes(1);
expect(mockMemoryStorageService.get).toHaveBeenCalledWith(
"user_bff09d3c-762a-4551-9275-45b137b2f073_fakeState2_fakeKey2",
);
expect(mockMemoryStorageService.remove).toHaveBeenCalledTimes(1);
});
});
});

View File

@@ -1,83 +1 @@
// FIXME: Update this file to be type safe and remove this and next line
// @ts-strict-ignore
import { firstValueFrom } from "rxjs";
import { StorageServiceProvider } from "@bitwarden/storage-core";
import { UserId } from "../../types/guid";
import { GlobalState } from "./global-state";
import { GlobalStateProvider } from "./global-state.provider";
import { StateDefinition, StorageLocation } from "./state-definition";
import {
STATE_LOCK_EVENT,
STATE_LOGOUT_EVENT,
StateEventInfo,
} from "./state-event-registrar.service";
import { ClearEvent, UserKeyDefinition } from "./user-key-definition";
export class StateEventRunnerService {
private readonly stateEventMap: { [Prop in ClearEvent]: GlobalState<StateEventInfo[]> };
constructor(
globalStateProvider: GlobalStateProvider,
private storageServiceProvider: StorageServiceProvider,
) {
this.stateEventMap = {
lock: globalStateProvider.get(STATE_LOCK_EVENT),
logout: globalStateProvider.get(STATE_LOGOUT_EVENT),
};
}
async handleEvent(event: ClearEvent, userId: UserId) {
let tickets = await firstValueFrom(this.stateEventMap[event].state$);
tickets ??= [];
const failures: string[] = [];
for (const ticket of tickets) {
try {
const [, service] = this.storageServiceProvider.get(
ticket.location,
{}, // The storage location is already the computed storage location for this client
);
const ticketStorageKey = this.storageKeyFor(userId, ticket);
// Evaluate current value so we can avoid writing to state if we don't need to
const currentValue = await service.get(ticketStorageKey);
if (currentValue != null) {
await service.remove(ticketStorageKey);
}
} catch (err: unknown) {
let errorMessage = "Unknown Error";
if (typeof err === "object" && "message" in err && typeof err.message === "string") {
errorMessage = err.message;
}
failures.push(
`${errorMessage} in ${ticket.state} > ${ticket.key} located ${ticket.location}`,
);
}
}
if (failures.length > 0) {
// Throw aggregated error
throw new Error(
`One or more errors occurred while handling event '${event}' for user ${userId}.\n${failures.join("\n")}`,
);
}
}
private storageKeyFor(userId: UserId, ticket: StateEventInfo) {
const userKey = new UserKeyDefinition<unknown>(
new StateDefinition(ticket.state, ticket.location as unknown as StorageLocation),
ticket.key,
{
deserializer: (v) => v,
clearOn: [],
},
);
return userKey.buildKey(userId);
}
}
export { StateEventRunnerService } from "@bitwarden/state";

View File

@@ -1,28 +0,0 @@
// FIXME: Update this file to be type safe and remove this and next line
// @ts-strict-ignore
import { Observable } from "rxjs";
export const DEFAULT_OPTIONS = {
shouldUpdate: () => true,
combineLatestWith: null as Observable<unknown>,
msTimeout: 1000,
};
type DefinitelyTypedDefault<T, TCombine> = Omit<
typeof DEFAULT_OPTIONS,
"shouldUpdate" | "combineLatestWith"
> & {
shouldUpdate: (state: T, dependency: TCombine) => boolean;
combineLatestWith?: Observable<TCombine>;
};
export type StateUpdateOptions<T, TCombine> = Partial<DefinitelyTypedDefault<T, TCombine>>;
export function populateOptionsWithDefault<T, TCombine>(
options: StateUpdateOptions<T, TCombine>,
): StateUpdateOptions<T, TCombine> {
return {
...(DEFAULT_OPTIONS as StateUpdateOptions<T, TCombine>),
...options,
};
}

View File

@@ -1,80 +1 @@
import { Observable } from "rxjs";
import { UserId } from "../../types/guid";
import { DerivedStateDependencies } from "../../types/state";
import { DeriveDefinition } from "./derive-definition";
import { DerivedState } from "./derived-state";
import { GlobalState } from "./global-state";
// eslint-disable-next-line @typescript-eslint/no-unused-vars -- used in docs
import { GlobalStateProvider } from "./global-state.provider";
import { KeyDefinition } from "./key-definition";
import { UserKeyDefinition } from "./user-key-definition";
import { ActiveUserState, SingleUserState } from "./user-state";
// eslint-disable-next-line @typescript-eslint/no-unused-vars -- used in docs
import { ActiveUserStateProvider, SingleUserStateProvider } from "./user-state.provider";
/** Convenience wrapper class for {@link ActiveUserStateProvider}, {@link SingleUserStateProvider},
* and {@link GlobalStateProvider}.
*/
export abstract class StateProvider {
/** @see{@link ActiveUserStateProvider.activeUserId$} */
abstract activeUserId$: Observable<UserId | undefined>;
/**
* Gets a state observable for a given key and userId.
*
* @remarks If userId is falsy the observable returned will attempt to point to the currently active user _and not update if the active user changes_.
* This is different to how `getActive` works and more similar to `getUser` for whatever user happens to be active at the time of the call.
* If no user happens to be active at the time this method is called with a falsy userId then this observable will not emit a value until
* a user becomes active. If you are not confident a user is active at the time this method is called, you may want to pipe a call to `timeout`
* or instead call {@link getUserStateOrDefault$} and supply a value you would rather have given in the case of no passed in userId and no active user.
*
* @param keyDefinition - The key definition for the state you want to get.
* @param userId - The userId for which you want the state for. If not provided, the state for the currently active user will be returned.
*/
abstract getUserState$<T>(keyDefinition: UserKeyDefinition<T>, userId?: UserId): Observable<T>;
/**
* Gets a state observable for a given key and userId
*
* @remarks If userId is falsy the observable return will first attempt to point to the currently active user but will not follow subsequent active user changes,
* if there is no immediately available active user, then it will fallback to returning a default value in an observable that immediately completes.
*
* @param keyDefinition - The key definition for the state you want to get.
* @param config.userId - The userId for which you want the state for. If not provided, the state for the currently active user will be returned.
* @param config.defaultValue - The default value that should be wrapped in an observable if no active user is immediately available and no truthy userId is passed in.
*/
abstract getUserStateOrDefault$<T>(
keyDefinition: UserKeyDefinition<T>,
config: { userId: UserId | undefined; defaultValue?: T },
): Observable<T>;
/**
* Sets the state for a given key and userId.
*
* @overload
* @param keyDefinition - The key definition for the state you want to set.
* @param value - The value to set the state to.
* @param userId - The userId for which you want to set the state for. If not provided, the state for the currently active user will be set.
*/
abstract setUserState<T>(
keyDefinition: UserKeyDefinition<T>,
value: T | null,
userId?: UserId,
): Promise<[UserId, T | null]>;
/** @see{@link ActiveUserStateProvider.get} */
abstract getActive<T>(userKeyDefinition: UserKeyDefinition<T>): ActiveUserState<T>;
/** @see{@link SingleUserStateProvider.get} */
abstract getUser<T>(userId: UserId, userKeyDefinition: UserKeyDefinition<T>): SingleUserState<T>;
/** @see{@link GlobalStateProvider.get} */
abstract getGlobal<T>(keyDefinition: KeyDefinition<T>): GlobalState<T>;
abstract getDerived<TFrom, TTo, TDeps extends DerivedStateDependencies>(
parentState$: Observable<TFrom>,
deriveDefinition: DeriveDefinition<TFrom, TTo, TDeps>,
dependencies: TDeps,
): DerivedState<TTo>;
}
export { StateProvider } from "@bitwarden/state";

View File

@@ -1,142 +1 @@
// FIXME: Update this file to be type safe and remove this and next line
// @ts-strict-ignore
import { UserId } from "../../types/guid";
import { StorageKey } from "../../types/state";
import { Utils } from "../misc/utils";
import { array, record } from "./deserialization-helpers";
import { DebugOptions, KeyDefinitionOptions } from "./key-definition";
import { StateDefinition } from "./state-definition";
export type ClearEvent = "lock" | "logout";
export type UserKeyDefinitionOptions<T> = KeyDefinitionOptions<T> & {
clearOn: ClearEvent[];
};
const USER_KEY_DEFINITION_MARKER: unique symbol = Symbol("UserKeyDefinition");
export class UserKeyDefinition<T> {
readonly [USER_KEY_DEFINITION_MARKER] = true;
/**
* A unique array of events that the state stored at this key should be cleared on.
*/
readonly clearOn: ClearEvent[];
/**
* Normalized options used for debugging purposes.
*/
readonly debug: Required<DebugOptions>;
constructor(
readonly stateDefinition: StateDefinition,
readonly key: string,
private readonly options: UserKeyDefinitionOptions<T>,
) {
if (options.deserializer == null) {
throw new Error(`'deserializer' is a required property on key ${this.errorKeyName}`);
}
if (options.cleanupDelayMs < 0) {
throw new Error(
`'cleanupDelayMs' must be greater than or equal to 0. Value of ${options.cleanupDelayMs} passed to key ${this.errorKeyName} `,
);
}
// Filter out repeat values
this.clearOn = Array.from(new Set(options.clearOn));
// Normalize optional debug options
const { enableUpdateLogging = false, enableRetrievalLogging = false } = options.debug ?? {};
this.debug = {
enableUpdateLogging,
enableRetrievalLogging,
};
}
/**
* Gets the deserializer configured for this {@link KeyDefinition}
*/
get deserializer() {
return this.options.deserializer;
}
/**
* Gets the number of milliseconds to wait before cleaning up the state after the last subscriber has unsubscribed.
*/
get cleanupDelayMs() {
return this.options.cleanupDelayMs < 0 ? 0 : (this.options.cleanupDelayMs ?? 1000);
}
/**
* Creates a {@link UserKeyDefinition} for state that is an array.
* @param stateDefinition The state definition to be added to the UserKeyDefinition
* @param key The key to be added to the KeyDefinition
* @param options The options to customize the final {@link UserKeyDefinition}.
* @returns A {@link UserKeyDefinition} initialized for arrays, the options run
* the deserializer on the provided options for each element of an array
* **unless that array is null, in which case it will return an empty list.**
*
* @example
* ```typescript
* const MY_KEY = UserKeyDefinition.array<MyArrayElement>(MY_STATE, "key", {
* deserializer: (myJsonElement) => convertToElement(myJsonElement),
* });
* ```
*/
static array<T>(
stateDefinition: StateDefinition,
key: string,
// We have them provide options for the element of the array, depending on future options we add, this could get a little weird.
options: UserKeyDefinitionOptions<T>,
) {
return new UserKeyDefinition<T[]>(stateDefinition, key, {
...options,
deserializer: array((e) => options.deserializer(e)),
});
}
/**
* Creates a {@link UserKeyDefinition} for state that is a record.
* @param stateDefinition The state definition to be added to the UserKeyDefinition
* @param key The key to be added to the KeyDefinition
* @param options The options to customize the final {@link UserKeyDefinition}.
* @returns A {@link UserKeyDefinition} that contains a serializer that will run the provided deserializer for each
* value in a record and returns every key as a string **unless that record is null, in which case it will return an record.**
*
* @example
* ```typescript
* const MY_KEY = UserKeyDefinition.record<MyRecordValue>(MY_STATE, "key", {
* deserializer: (myJsonValue) => convertToValue(myJsonValue),
* });
* ```
*/
static record<T, TKey extends string | number = string>(
stateDefinition: StateDefinition,
key: string,
// We have them provide options for the value of the record, depending on future options we add, this could get a little weird.
options: UserKeyDefinitionOptions<T>, // The array helper forces an initialValue of an empty record
) {
return new UserKeyDefinition<Record<TKey, T>>(stateDefinition, key, {
...options,
deserializer: record((v) => options.deserializer(v)),
});
}
get fullName() {
return `${this.stateDefinition.name}_${this.key}`;
}
buildKey(userId: UserId) {
if (!Utils.isGuid(userId)) {
throw new Error(
`You cannot build a user key without a valid UserId, building for key ${this.fullName}`,
);
}
return `user_${userId}_${this.stateDefinition.name}_${this.key}` as StorageKey;
}
private get errorKeyName() {
return `${this.stateDefinition.name} > ${this.key}`;
}
}
export { UserKeyDefinition, UserKeyDefinitionOptions } from "@bitwarden/state";

View File

@@ -1,35 +1 @@
import { Observable } from "rxjs";
import { UserId } from "../../types/guid";
import { UserKeyDefinition } from "./user-key-definition";
import { ActiveUserState, SingleUserState } from "./user-state";
/** A provider for getting an implementation of state scoped to a given key and userId */
export abstract class SingleUserStateProvider {
/**
* Gets a {@link SingleUserState} scoped to the given {@link UserKeyDefinition} and {@link UserId}
*
* @param userId - The {@link UserId} for which you want the user state for.
* @param userKeyDefinition - The {@link UserKeyDefinition} for which you want the user state for.
*/
abstract get<T>(userId: UserId, userKeyDefinition: UserKeyDefinition<T>): SingleUserState<T>;
}
/** A provider for getting an implementation of state scoped to a given key, but always pointing
* to the currently active user
*/
export abstract class ActiveUserStateProvider {
/**
* Convenience re-emission of active user ID from {@link AccountService.activeAccount$}
*/
abstract activeUserId$: Observable<UserId | undefined>;
/**
* Gets a {@link ActiveUserState} scoped to the given {@link KeyDefinition}, but updates when active user changes such
* that the emitted values always represents the state for the currently active user.
*
* @param keyDefinition - The {@link UserKeyDefinition} for which you want the user state for.
*/
abstract get<T>(userKeyDefinition: UserKeyDefinition<T>): ActiveUserState<T>;
}
export { ActiveUserStateProvider, SingleUserStateProvider } from "@bitwarden/state";

View File

@@ -1,64 +1 @@
import { Observable } from "rxjs";
import { UserId } from "../../types/guid";
import { StateUpdateOptions } from "./state-update-options";
export type CombinedState<T> = readonly [userId: UserId, state: T];
/** A helper object for interacting with state that is scoped to a specific user. */
export interface UserState<T> {
/** Emits a stream of data. Emits null if the user does not have specified state. */
readonly state$: Observable<T | null>;
/** Emits a stream of tuples, with the first element being a user id and the second element being the data for that user. */
readonly combinedState$: Observable<CombinedState<T | null>>;
}
export const activeMarker: unique symbol = Symbol("active");
export interface ActiveUserState<T> extends UserState<T> {
readonly [activeMarker]: true;
/**
* Emits a stream of data. Emits null if the user does not have specified state.
* Note: Will not emit if there is no active user.
*/
readonly state$: Observable<T | null>;
/**
* Updates backing stores for the active user.
* @param configureState function that takes the current state and returns the new state
* @param options Defaults to @see {module:state-update-options#DEFAULT_OPTIONS}
* @param options.shouldUpdate A callback for determining if you want to update state. Defaults to () => true
* @param options.combineLatestWith An observable that you want to combine with the current state for callbacks. Defaults to null
* @param options.msTimeout A timeout for how long you are willing to wait for a `combineLatestWith` option to complete. Defaults to 1000ms. Only applies if `combineLatestWith` is set.
*
* @returns A promise that must be awaited before your next action to ensure the update has been written to state.
* Resolves to the new state. If `shouldUpdate` returns false, the promise will resolve to the current state.
*/
readonly update: <TCombine>(
configureState: (state: T | null, dependencies: TCombine) => T | null,
options?: StateUpdateOptions<T, TCombine>,
) => Promise<[UserId, T | null]>;
}
export interface SingleUserState<T> extends UserState<T> {
readonly userId: UserId;
/**
* Updates backing stores for the active user.
* @param configureState function that takes the current state and returns the new state
* @param options Defaults to @see {module:state-update-options#DEFAULT_OPTIONS}
* @param options.shouldUpdate A callback for determining if you want to update state. Defaults to () => true
* @param options.combineLatestWith An observable that you want to combine with the current state for callbacks. Defaults to null
* @param options.msTimeout A timeout for how long you are willing to wait for a `combineLatestWith` option to complete. Defaults to 1000ms. Only applies if `combineLatestWith` is set.
*
* @returns A promise that must be awaited before your next action to ensure the update has been written to state.
* Resolves to the new state. If `shouldUpdate` returns false, the promise will resolve to the current state.
*/
readonly update: <TCombine>(
configureState: (state: T | null, dependencies: TCombine) => T | null,
options?: StateUpdateOptions<T, TCombine>,
) => Promise<T | null>;
}
export { ActiveUserState, SingleUserState, CombinedState } from "@bitwarden/state";

View File

@@ -172,7 +172,11 @@ export abstract class CoreSyncService implements SyncService {
notification.collectionIds != null &&
notification.collectionIds.length > 0
) {
const collections = await this.collectionService.getAll();
const collections = await firstValueFrom(
this.collectionService
.encryptedCollections$(userId)
.pipe(map((collections) => collections ?? [])),
);
if (collections != null) {
for (let i = 0; i < collections.length; i++) {
if (notification.collectionIds.indexOf(collections[i].id) > -1) {

View File

@@ -1 +1,2 @@
export { createMigrationBuilder, waitForMigrations, CURRENT_VERSION } from "./migrate";
// Compatibility re-export for @bitwarden/common/state-migrations
export * from "@bitwarden/state";

View File

@@ -1,40 +0,0 @@
import { mock, MockProxy } from "jest-mock-extended";
// eslint-disable-next-line import/no-restricted-paths -- Needed to print log messages
import { LogService } from "../platform/abstractions/log.service";
// eslint-disable-next-line import/no-restricted-paths -- Needed to interface with storage locations
import { AbstractStorageService } from "../platform/abstractions/storage.service";
import { currentVersion } from "./migrate";
describe("currentVersion", () => {
let storage: MockProxy<AbstractStorageService>;
let logService: MockProxy<LogService>;
beforeEach(() => {
storage = mock();
logService = mock();
});
it("should return -1 if no version", async () => {
storage.get.mockReturnValueOnce(null);
expect(await currentVersion(storage, logService)).toEqual(-1);
});
it("should return version", async () => {
storage.get.calledWith("stateVersion").mockReturnValueOnce(1 as any);
expect(await currentVersion(storage, logService)).toEqual(1);
});
it("should return version from global", async () => {
storage.get.calledWith("stateVersion").mockReturnValueOnce(null);
storage.get.calledWith("global").mockReturnValueOnce({ stateVersion: 1 } as any);
expect(await currentVersion(storage, logService)).toEqual(1);
});
it("should prefer root version to global", async () => {
storage.get.calledWith("stateVersion").mockReturnValue(1 as any);
storage.get.calledWith("global").mockReturnValue({ stateVersion: 2 } as any);
expect(await currentVersion(storage, logService)).toEqual(1);
});
});

View File

@@ -1,218 +0,0 @@
// eslint-disable-next-line import/no-restricted-paths -- Needed to print log messages
import { LogService } from "../platform/abstractions/log.service";
// eslint-disable-next-line import/no-restricted-paths -- Needed to interface with storage locations
import { AbstractStorageService } from "../platform/abstractions/storage.service";
import { MigrationBuilder } from "./migration-builder";
import { EverHadUserKeyMigrator } from "./migrations/10-move-ever-had-user-key-to-state-providers";
import { OrganizationKeyMigrator } from "./migrations/11-move-org-keys-to-state-providers";
import { MoveEnvironmentStateToProviders } from "./migrations/12-move-environment-state-to-providers";
import { ProviderKeyMigrator } from "./migrations/13-move-provider-keys-to-state-providers";
import { MoveBiometricClientKeyHalfToStateProviders } from "./migrations/14-move-biometric-client-key-half-state-to-providers";
import { FolderMigrator } from "./migrations/15-move-folder-state-to-state-provider";
import { LastSyncMigrator } from "./migrations/16-move-last-sync-to-state-provider";
import { EnablePasskeysMigrator } from "./migrations/17-move-enable-passkeys-to-state-providers";
import { AutofillSettingsKeyMigrator } from "./migrations/18-move-autofill-settings-to-state-providers";
import { RequirePasswordOnStartMigrator } from "./migrations/19-migrate-require-password-on-start";
import { PrivateKeyMigrator } from "./migrations/20-move-private-key-to-state-providers";
import { CollectionMigrator } from "./migrations/21-move-collections-state-to-state-provider";
import { CollapsedGroupingsMigrator } from "./migrations/22-move-collapsed-groupings-to-state-provider";
import { MoveBiometricPromptsToStateProviders } from "./migrations/23-move-biometric-prompts-to-state-providers";
import { SmOnboardingTasksMigrator } from "./migrations/24-move-sm-onboarding-key-to-state-providers";
import { ClearClipboardDelayMigrator } from "./migrations/25-move-clear-clipboard-to-autofill-settings-state-provider";
import { RevertLastSyncMigrator } from "./migrations/26-revert-move-last-sync-to-state-provider";
import { BadgeSettingsMigrator } from "./migrations/27-move-badge-settings-to-state-providers";
import { MoveBiometricUnlockToStateProviders } from "./migrations/28-move-biometric-unlock-to-state-providers";
import { UserNotificationSettingsKeyMigrator } from "./migrations/29-move-user-notification-settings-to-state-provider";
import { PolicyMigrator } from "./migrations/30-move-policy-state-to-state-provider";
import { EnableContextMenuMigrator } from "./migrations/31-move-enable-context-menu-to-autofill-settings-state-provider";
import { PreferredLanguageMigrator } from "./migrations/32-move-preferred-language";
import { AppIdMigrator } from "./migrations/33-move-app-id-to-state-providers";
import { DomainSettingsMigrator } from "./migrations/34-move-domain-settings-to-state-providers";
import { MoveThemeToStateProviderMigrator } from "./migrations/35-move-theme-to-state-providers";
import { VaultSettingsKeyMigrator } from "./migrations/36-move-show-card-and-identity-to-state-provider";
import { AvatarColorMigrator } from "./migrations/37-move-avatar-color-to-state-providers";
import { TokenServiceStateProviderMigrator } from "./migrations/38-migrate-token-svc-to-state-provider";
import { MoveBillingAccountProfileMigrator } from "./migrations/39-move-billing-account-profile-to-state-providers";
import { RemoveEverBeenUnlockedMigrator } from "./migrations/4-remove-ever-been-unlocked";
import { OrganizationMigrator } from "./migrations/40-move-organization-state-to-state-provider";
import { EventCollectionMigrator } from "./migrations/41-move-event-collection-to-state-provider";
import { EnableFaviconMigrator } from "./migrations/42-move-enable-favicon-to-domain-settings-state-provider";
import { AutoConfirmFingerPrintsMigrator } from "./migrations/43-move-auto-confirm-finger-prints-to-state-provider";
import { UserDecryptionOptionsMigrator } from "./migrations/44-move-user-decryption-options-to-state-provider";
import { MergeEnvironmentState } from "./migrations/45-merge-environment-state";
import { DeleteBiometricPromptCancelledData } from "./migrations/46-delete-orphaned-biometric-prompt-data";
import { MoveDesktopSettingsMigrator } from "./migrations/47-move-desktop-settings";
import { MoveDdgToStateProviderMigrator } from "./migrations/48-move-ddg-to-state-provider";
import { AccountServerConfigMigrator } from "./migrations/49-move-account-server-configs";
import { AddKeyTypeToOrgKeysMigrator } from "./migrations/5-add-key-type-to-org-keys";
import { KeyConnectorMigrator } from "./migrations/50-move-key-connector-to-state-provider";
import { RememberedEmailMigrator } from "./migrations/51-move-remembered-email-to-state-providers";
import { DeleteInstalledVersion } from "./migrations/52-delete-installed-version";
import { DeviceTrustServiceStateProviderMigrator } from "./migrations/53-migrate-device-trust-svc-to-state-providers";
import { SendMigrator } from "./migrations/54-move-encrypted-sends";
import { MoveMasterKeyStateToProviderMigrator } from "./migrations/55-move-master-key-state-to-provider";
import { AuthRequestMigrator } from "./migrations/56-move-auth-requests";
import { CipherServiceMigrator } from "./migrations/57-move-cipher-service-to-state-provider";
import { RemoveRefreshTokenMigratedFlagMigrator } from "./migrations/58-remove-refresh-token-migrated-state-provider-flag";
import { KdfConfigMigrator } from "./migrations/59-move-kdf-config-to-state-provider";
import { RemoveLegacyEtmKeyMigrator } from "./migrations/6-remove-legacy-etm-key";
import { KnownAccountsMigrator } from "./migrations/60-known-accounts";
import { PinStateMigrator } from "./migrations/61-move-pin-state-to-providers";
import { VaultTimeoutSettingsServiceStateProviderMigrator } from "./migrations/62-migrate-vault-timeout-settings-svc-to-state-provider";
import { PasswordOptionsMigrator } from "./migrations/63-migrate-password-settings";
import { GeneratorHistoryMigrator } from "./migrations/64-migrate-generator-history";
import { ForwarderOptionsMigrator } from "./migrations/65-migrate-forwarder-settings";
import { MoveFinalDesktopSettingsMigrator } from "./migrations/66-move-final-desktop-settings";
import { RemoveUnassignedItemsBannerDismissed } from "./migrations/67-remove-unassigned-items-banner-dismissed";
import { MoveLastSyncDate } from "./migrations/68-move-last-sync-date";
import { MigrateIncorrectFolderKey } from "./migrations/69-migrate-incorrect-folder-key";
import { MoveBiometricAutoPromptToAccount } from "./migrations/7-move-biometric-auto-prompt-to-account";
import { RemoveAcBannersDismissed } from "./migrations/70-remove-ac-banner-dismissed";
import { RemoveNewCustomizationOptionsCalloutDismissed } from "./migrations/71-remove-new-customization-options-callout-dismissed";
import { RemoveAccountDeprovisioningBannerDismissed } from "./migrations/72-remove-account-deprovisioning-banner-dismissed";
import { MoveStateVersionMigrator } from "./migrations/8-move-state-version";
import { MoveBrowserSettingsToGlobal } from "./migrations/9-move-browser-settings-to-global";
import { MinVersionMigrator } from "./migrations/min-version";
export const MIN_VERSION = 3;
export const CURRENT_VERSION = 72;
export type MinVersion = typeof MIN_VERSION;
export function createMigrationBuilder() {
return MigrationBuilder.create()
.with(MinVersionMigrator)
.with(RemoveEverBeenUnlockedMigrator, 3, 4)
.with(AddKeyTypeToOrgKeysMigrator, 4, 5)
.with(RemoveLegacyEtmKeyMigrator, 5, 6)
.with(MoveBiometricAutoPromptToAccount, 6, 7)
.with(MoveStateVersionMigrator, 7, 8)
.with(MoveBrowserSettingsToGlobal, 8, 9)
.with(EverHadUserKeyMigrator, 9, 10)
.with(OrganizationKeyMigrator, 10, 11)
.with(MoveEnvironmentStateToProviders, 11, 12)
.with(ProviderKeyMigrator, 12, 13)
.with(MoveBiometricClientKeyHalfToStateProviders, 13, 14)
.with(FolderMigrator, 14, 15)
.with(LastSyncMigrator, 15, 16)
.with(EnablePasskeysMigrator, 16, 17)
.with(AutofillSettingsKeyMigrator, 17, 18)
.with(RequirePasswordOnStartMigrator, 18, 19)
.with(PrivateKeyMigrator, 19, 20)
.with(CollectionMigrator, 20, 21)
.with(CollapsedGroupingsMigrator, 21, 22)
.with(MoveBiometricPromptsToStateProviders, 22, 23)
.with(SmOnboardingTasksMigrator, 23, 24)
.with(ClearClipboardDelayMigrator, 24, 25)
.with(RevertLastSyncMigrator, 25, 26)
.with(BadgeSettingsMigrator, 26, 27)
.with(MoveBiometricUnlockToStateProviders, 27, 28)
.with(UserNotificationSettingsKeyMigrator, 28, 29)
.with(PolicyMigrator, 29, 30)
.with(EnableContextMenuMigrator, 30, 31)
.with(PreferredLanguageMigrator, 31, 32)
.with(AppIdMigrator, 32, 33)
.with(DomainSettingsMigrator, 33, 34)
.with(MoveThemeToStateProviderMigrator, 34, 35)
.with(VaultSettingsKeyMigrator, 35, 36)
.with(AvatarColorMigrator, 36, 37)
.with(TokenServiceStateProviderMigrator, 37, 38)
.with(MoveBillingAccountProfileMigrator, 38, 39)
.with(OrganizationMigrator, 39, 40)
.with(EventCollectionMigrator, 40, 41)
.with(EnableFaviconMigrator, 41, 42)
.with(AutoConfirmFingerPrintsMigrator, 42, 43)
.with(UserDecryptionOptionsMigrator, 43, 44)
.with(MergeEnvironmentState, 44, 45)
.with(DeleteBiometricPromptCancelledData, 45, 46)
.with(MoveDesktopSettingsMigrator, 46, 47)
.with(MoveDdgToStateProviderMigrator, 47, 48)
.with(AccountServerConfigMigrator, 48, 49)
.with(KeyConnectorMigrator, 49, 50)
.with(RememberedEmailMigrator, 50, 51)
.with(DeleteInstalledVersion, 51, 52)
.with(DeviceTrustServiceStateProviderMigrator, 52, 53)
.with(SendMigrator, 53, 54)
.with(MoveMasterKeyStateToProviderMigrator, 54, 55)
.with(AuthRequestMigrator, 55, 56)
.with(CipherServiceMigrator, 56, 57)
.with(RemoveRefreshTokenMigratedFlagMigrator, 57, 58)
.with(KdfConfigMigrator, 58, 59)
.with(KnownAccountsMigrator, 59, 60)
.with(PinStateMigrator, 60, 61)
.with(VaultTimeoutSettingsServiceStateProviderMigrator, 61, 62)
.with(PasswordOptionsMigrator, 62, 63)
.with(GeneratorHistoryMigrator, 63, 64)
.with(ForwarderOptionsMigrator, 64, 65)
.with(MoveFinalDesktopSettingsMigrator, 65, 66)
.with(RemoveUnassignedItemsBannerDismissed, 66, 67)
.with(MoveLastSyncDate, 67, 68)
.with(MigrateIncorrectFolderKey, 68, 69)
.with(RemoveAcBannersDismissed, 69, 70)
.with(RemoveNewCustomizationOptionsCalloutDismissed, 70, 71)
.with(RemoveAccountDeprovisioningBannerDismissed, 71, CURRENT_VERSION);
}
export async function currentVersion(
storageService: AbstractStorageService,
logService: LogService,
) {
let state = await storageService.get<number>("stateVersion");
if (state == null) {
// Pre v8
state = (await storageService.get<{ stateVersion: number }>("global"))?.stateVersion;
}
if (state == null) {
logService.info("No state version found, assuming empty state.");
return -1;
}
logService.info(`State version: ${state}`);
return state;
}
/**
* Waits for migrations to have a chance to run and will resolve the promise once they are.
*
* @param storageService Disk storage where the `stateVersion` will or is already saved in.
* @param logService Log service
*/
export async function waitForMigrations(
storageService: AbstractStorageService,
logService: LogService,
) {
const isReady = async () => {
const version = await currentVersion(storageService, logService);
// The saved version is what we consider the latest
// migrations should be complete, the state version
// shouldn't become larger than `CURRENT_VERSION` in
// any normal usage of the application but it is common
// enough in dev scenarios where we want to consider that
// ready as well and return true in that scenario.
return version >= CURRENT_VERSION;
};
const wait = async (time: number) => {
// Wait exponentially
const nextTime = time * 2;
if (nextTime > 8192) {
// Don't wait longer than ~8 seconds in a single wait,
// if the migrations still haven't happened. They aren't
// likely to.
return;
}
return new Promise<void>((resolve) => {
setTimeout(async () => {
if (!(await isReady())) {
logService.info(`Waiting for migrations to finish, waiting for ${nextTime}ms`);
await wait(nextTime);
}
resolve();
}, time);
});
};
if (!(await isReady())) {
// Wait for 2ms to start with
await wait(2);
}
}

View File

@@ -1,143 +0,0 @@
import { mock } from "jest-mock-extended";
// eslint-disable-next-line import/no-restricted-paths
import { ClientType } from "../enums";
import { MigrationBuilder } from "./migration-builder";
import { MigrationHelper } from "./migration-helper";
import { Migrator } from "./migrator";
describe("MigrationBuilder", () => {
class TestMigrator extends Migrator<0, 1> {
async migrate(helper: MigrationHelper): Promise<void> {
await helper.set("test", "test");
return;
}
async rollback(helper: MigrationHelper): Promise<void> {
await helper.set("test", "rollback");
return;
}
}
class TestMigratorWithInstanceMethod extends Migrator<0, 1> {
private async instanceMethod(helper: MigrationHelper, value: string) {
await helper.set("test", value);
}
async migrate(helper: MigrationHelper): Promise<void> {
await this.instanceMethod(helper, "migrate");
}
async rollback(helper: MigrationHelper): Promise<void> {
await this.instanceMethod(helper, "rollback");
}
}
let sut: MigrationBuilder<number>;
beforeEach(() => {
sut = MigrationBuilder.create();
});
class TestBadMigrator extends Migrator<1, 0> {
async migrate(helper: MigrationHelper): Promise<void> {
await helper.set("test", "test");
}
async rollback(helper: MigrationHelper): Promise<void> {
await helper.set("test", "rollback");
}
}
it("should throw if instantiated incorrectly", () => {
expect(() => MigrationBuilder.create().with(TestMigrator, null, null)).toThrow();
expect(() =>
MigrationBuilder.create().with(TestMigrator, 0, 1).with(TestBadMigrator, 1, 0),
).toThrow();
});
it("should be able to create a new MigrationBuilder", () => {
expect(sut).toBeInstanceOf(MigrationBuilder);
});
it("should be able to add a migrator", () => {
const newBuilder = sut.with(TestMigrator, 0, 1);
const migrations = newBuilder["migrations"];
expect(migrations.length).toBe(1);
expect(migrations[0]).toMatchObject({ migrator: expect.any(TestMigrator), direction: "up" });
});
it("should be able to add a rollback", () => {
const newBuilder = sut.with(TestMigrator, 0, 1).rollback(TestMigrator, 1, 0);
const migrations = newBuilder["migrations"];
expect(migrations.length).toBe(2);
expect(migrations[1]).toMatchObject({ migrator: expect.any(TestMigrator), direction: "down" });
});
const clientTypes = Object.values(ClientType);
describe.each(clientTypes)("for client %s", (clientType) => {
describe("migrate", () => {
let migrator: TestMigrator;
let rollback_migrator: TestMigrator;
beforeEach(() => {
sut = sut.with(TestMigrator, 0, 1).rollback(TestMigrator, 1, 0);
migrator = (sut as any).migrations[0].migrator;
rollback_migrator = (sut as any).migrations[1].migrator;
});
it("should migrate", async () => {
const helper = new MigrationHelper(0, mock(), mock(), "general", clientType);
const spy = jest.spyOn(migrator, "migrate");
await sut.migrate(helper);
expect(spy).toBeCalledWith(helper);
});
it("should rollback", async () => {
const helper = new MigrationHelper(1, mock(), mock(), "general", clientType);
const spy = jest.spyOn(rollback_migrator, "rollback");
await sut.migrate(helper);
expect(spy).toBeCalledWith(helper);
});
it("should update version on migrate", async () => {
const helper = new MigrationHelper(0, mock(), mock(), "general", clientType);
const spy = jest.spyOn(migrator, "updateVersion");
await sut.migrate(helper);
expect(spy).toBeCalledWith(helper, "up");
});
it("should update version on rollback", async () => {
const helper = new MigrationHelper(1, mock(), mock(), "general", clientType);
const spy = jest.spyOn(rollback_migrator, "updateVersion");
await sut.migrate(helper);
expect(spy).toBeCalledWith(helper, "down");
});
it("should not run the migrator if the current version does not match the from version", async () => {
const helper = new MigrationHelper(3, mock(), mock(), "general", clientType);
const migrate = jest.spyOn(migrator, "migrate");
const rollback = jest.spyOn(rollback_migrator, "rollback");
await sut.migrate(helper);
expect(migrate).not.toBeCalled();
expect(rollback).not.toBeCalled();
});
it("should not update version if the current version does not match the from version", async () => {
const helper = new MigrationHelper(3, mock(), mock(), "general", clientType);
const migrate = jest.spyOn(migrator, "updateVersion");
const rollback = jest.spyOn(rollback_migrator, "updateVersion");
await sut.migrate(helper);
expect(migrate).not.toBeCalled();
expect(rollback).not.toBeCalled();
});
});
it("should be able to call instance methods", async () => {
const helper = new MigrationHelper(0, mock(), mock(), "general", clientType);
await sut.with(TestMigratorWithInstanceMethod, 0, 1).migrate(helper);
});
});
});

View File

@@ -1,106 +1 @@
import { MigrationHelper } from "./migration-helper";
import { Direction, Migrator, VersionFrom, VersionTo } from "./migrator";
export class MigrationBuilder<TCurrent extends number = 0> {
/** Create a new MigrationBuilder with an empty buffer of migrations to perform.
*
* Add migrations to the buffer with {@link with} and {@link rollback}.
* @returns A new MigrationBuilder.
*/
static create(): MigrationBuilder<0> {
return new MigrationBuilder([]);
}
private constructor(
private migrations: readonly { migrator: Migrator<number, number>; direction: Direction }[],
) {}
/** Add a migrator to the MigrationBuilder. Types are updated such that the chained MigrationBuilder must currently be
* at state version equal to the from version of the migrator. Return as MigrationBuilder<TTo> where TTo is the to
* version of the migrator, so that the next migrator can be chained.
*
* @param migrate A migrator class or a tuple of a migrator class, the from version, and the to version. A tuple is
* required to instantiate version numbers unless a default constructor is defined.
* @returns A new MigrationBuilder with the to version of the migrator as the current version.
*/
with<
TMigrator extends Migrator<number, number>,
TFrom extends VersionFrom<TMigrator> & TCurrent,
TTo extends VersionTo<TMigrator>,
>(
...migrate: [new () => TMigrator] | [new (from: TFrom, to: TTo) => TMigrator, TFrom, TTo]
): MigrationBuilder<TTo> {
return this.addMigrator(migrate, "up");
}
/** Add a migrator to rollback on the MigrationBuilder's list of migrations. As with {@link with}, types of
* MigrationBuilder and Migrator must align. However, this time the migration is reversed so TCurrent of the
* MigrationBuilder must be equal to the to version of the migrator. Return as MigrationBuilder<TFrom> where TFrom
* is the from version of the migrator, so that the next migrator can be chained.
*
* @param migrate A migrator class or a tuple of a migrator class, the from version, and the to version. A tuple is
* required to instantiate version numbers unless a default constructor is defined.
* @returns A new MigrationBuilder with the from version of the migrator as the current version.
*/
rollback<
TMigrator extends Migrator<number, number>,
TFrom extends VersionFrom<TMigrator>,
TTo extends VersionTo<TMigrator> & TCurrent,
>(
...migrate: [new () => TMigrator] | [new (from: TFrom, to: TTo) => TMigrator, TTo, TFrom]
): MigrationBuilder<TFrom> {
if (migrate.length === 3) {
migrate = [migrate[0], migrate[2], migrate[1]];
}
return this.addMigrator(migrate, "down");
}
/** Execute the migrations as defined in the MigrationBuilder's migrator buffer */
migrate(helper: MigrationHelper): Promise<void> {
return this.migrations.reduce(
(promise, migrator) =>
promise.then(async () => {
await this.runMigrator(migrator.migrator, helper, migrator.direction);
}),
Promise.resolve(),
);
}
private addMigrator<
TMigrator extends Migrator<number, number>,
TFrom extends VersionFrom<TMigrator> & TCurrent,
TTo extends VersionTo<TMigrator>,
>(
migrate: [new () => TMigrator] | [new (from: TFrom, to: TTo) => TMigrator, TFrom, TTo],
direction: Direction = "up",
) {
const newMigration =
migrate.length === 1
? { migrator: new migrate[0](), direction }
: { migrator: new migrate[0](migrate[1], migrate[2]), direction };
return new MigrationBuilder<TTo>([...this.migrations, newMigration]);
}
private async runMigrator(
migrator: Migrator<number, number>,
helper: MigrationHelper,
direction: Direction,
): Promise<void> {
const shouldMigrate = await migrator.shouldMigrate(helper, direction);
helper.info(
`Migrator ${migrator.constructor.name} (to version ${migrator.toVersion}) should migrate: ${shouldMigrate} - ${direction}`,
);
if (shouldMigrate) {
const method = direction === "up" ? migrator.migrate : migrator.rollback;
await method.bind(migrator)(helper);
helper.info(
`Migrator ${migrator.constructor.name} (to version ${migrator.toVersion}) migrated - ${direction}`,
);
await migrator.updateVersion(helper, direction);
helper.info(
`Migrator ${migrator.constructor.name} (to version ${migrator.toVersion}) updated version - ${direction}`,
);
}
}
}
export { MigrationBuilder } from "@bitwarden/state";

View File

@@ -1,385 +0,0 @@
import { MockProxy, mock } from "jest-mock-extended";
import { FakeStorageService } from "../../spec/fake-storage.service";
// eslint-disable-next-line import/no-restricted-paths -- Needed client type enum
import { ClientType } from "../enums";
// eslint-disable-next-line import/no-restricted-paths -- Needed to print log messages
import { LogService } from "../platform/abstractions/log.service";
// eslint-disable-next-line import/no-restricted-paths -- Needed to interface with storage locations
import { AbstractStorageService } from "../platform/abstractions/storage.service";
// eslint-disable-next-line import/no-restricted-paths -- Needed to generate unique strings for injection
import { Utils } from "../platform/misc/utils";
import { MigrationHelper, MigrationHelperType } from "./migration-helper";
import { Migrator } from "./migrator";
const exampleJSON = {
authenticatedAccounts: [
"c493ed01-4e08-4e88-abc7-332f380ca760",
"23e61a5f-2ece-4f5e-b499-f0bc489482a9",
],
"c493ed01-4e08-4e88-abc7-332f380ca760": {
otherStuff: "otherStuff1",
},
"23e61a5f-2ece-4f5e-b499-f0bc489482a9": {
otherStuff: "otherStuff2",
},
global_serviceName_key: "global_serviceName_key",
user_userId_serviceName_key: "user_userId_serviceName_key",
global_account_accounts: {
"c493ed01-4e08-4e88-abc7-332f380ca760": {
otherStuff: "otherStuff3",
},
"23e61a5f-2ece-4f5e-b499-f0bc489482a9": {
otherStuff: "otherStuff4",
},
},
};
describe("RemoveLegacyEtmKeyMigrator", () => {
let storage: MockProxy<AbstractStorageService>;
let logService: MockProxy<LogService>;
let sut: MigrationHelper;
const clientTypes = Object.values(ClientType);
describe.each(clientTypes)("for client %s", (clientType) => {
beforeEach(() => {
logService = mock();
storage = mock();
storage.get.mockImplementation((key) => (exampleJSON as any)[key]);
sut = new MigrationHelper(0, storage, logService, "general", clientType);
});
describe("get", () => {
it("should delegate to storage.get", async () => {
await sut.get("key");
expect(storage.get).toHaveBeenCalledWith("key");
});
});
describe("set", () => {
it("should delegate to storage.save", async () => {
await sut.set("key", "value");
expect(storage.save).toHaveBeenCalledWith("key", "value");
});
});
describe("getAccounts", () => {
it("should return all accounts", async () => {
const accounts = await sut.getAccounts();
expect(accounts).toEqual([
{
userId: "c493ed01-4e08-4e88-abc7-332f380ca760",
account: { otherStuff: "otherStuff1" },
},
{
userId: "23e61a5f-2ece-4f5e-b499-f0bc489482a9",
account: { otherStuff: "otherStuff2" },
},
]);
});
it("should handle missing authenticatedAccounts", async () => {
storage.get.mockImplementation((key) =>
key === "authenticatedAccounts" ? undefined : (exampleJSON as any)[key],
);
const accounts = await sut.getAccounts();
expect(accounts).toEqual([]);
});
it("handles global scoped known accounts for version 60 and after", async () => {
sut.currentVersion = 60;
const accounts = await sut.getAccounts();
expect(accounts).toEqual([
// Note, still gets values stored in state service objects, just grabs user ids from global
{
userId: "c493ed01-4e08-4e88-abc7-332f380ca760",
account: { otherStuff: "otherStuff1" },
},
{
userId: "23e61a5f-2ece-4f5e-b499-f0bc489482a9",
account: { otherStuff: "otherStuff2" },
},
]);
});
});
describe("getKnownUserIds", () => {
it("returns all user ids", async () => {
const userIds = await sut.getKnownUserIds();
expect(userIds).toEqual([
"c493ed01-4e08-4e88-abc7-332f380ca760",
"23e61a5f-2ece-4f5e-b499-f0bc489482a9",
]);
});
it("returns all user ids when version is 60 or greater", async () => {
sut.currentVersion = 60;
const userIds = await sut.getKnownUserIds();
expect(userIds).toEqual([
"c493ed01-4e08-4e88-abc7-332f380ca760",
"23e61a5f-2ece-4f5e-b499-f0bc489482a9",
]);
});
});
describe("getFromGlobal", () => {
it("should return the correct value", async () => {
sut.currentVersion = 9;
const value = await sut.getFromGlobal({
stateDefinition: { name: "serviceName" },
key: "key",
});
expect(value).toEqual("global_serviceName_key");
});
it("should throw if the current version is less than 9", () => {
expect(() =>
sut.getFromGlobal({ stateDefinition: { name: "serviceName" }, key: "key" }),
).toThrowError("No key builder should be used for versions prior to 9.");
});
});
describe("setToGlobal", () => {
it("should set the correct value", async () => {
sut.currentVersion = 9;
await sut.setToGlobal(
{ stateDefinition: { name: "serviceName" }, key: "key" },
"new_value",
);
expect(storage.save).toHaveBeenCalledWith("global_serviceName_key", "new_value");
});
it("should throw if the current version is less than 9", () => {
expect(() =>
sut.setToGlobal(
{ stateDefinition: { name: "serviceName" }, key: "key" },
"global_serviceName_key",
),
).toThrowError("No key builder should be used for versions prior to 9.");
});
});
describe("getFromUser", () => {
it("should return the correct value", async () => {
sut.currentVersion = 9;
const value = await sut.getFromUser("userId", {
stateDefinition: { name: "serviceName" },
key: "key",
});
expect(value).toEqual("user_userId_serviceName_key");
});
it("should throw if the current version is less than 9", () => {
expect(() =>
sut.getFromUser("userId", { stateDefinition: { name: "serviceName" }, key: "key" }),
).toThrowError("No key builder should be used for versions prior to 9.");
});
});
describe("setToUser", () => {
it("should set the correct value", async () => {
sut.currentVersion = 9;
await sut.setToUser(
"userId",
{ stateDefinition: { name: "serviceName" }, key: "key" },
"new_value",
);
expect(storage.save).toHaveBeenCalledWith("user_userId_serviceName_key", "new_value");
});
it("should throw if the current version is less than 9", () => {
expect(() =>
sut.setToUser(
"userId",
{ stateDefinition: { name: "serviceName" }, key: "key" },
"new_value",
),
).toThrowError("No key builder should be used for versions prior to 9.");
});
});
});
});
/** Helper to create well-mocked migration helpers in migration tests */
export function mockMigrationHelper(
storageJson: any,
stateVersion = 0,
type: MigrationHelperType = "general",
clientType: ClientType = ClientType.Web,
): MockProxy<MigrationHelper> {
const logService: MockProxy<LogService> = mock();
const storage: MockProxy<AbstractStorageService> = mock();
storage.get.mockImplementation((key) => (storageJson as any)[key]);
storage.save.mockImplementation(async (key, value) => {
(storageJson as any)[key] = value;
});
const helper = new MigrationHelper(stateVersion, storage, logService, type, clientType);
const mockHelper = mock<MigrationHelper>();
mockHelper.get.mockImplementation((key) => helper.get(key));
mockHelper.set.mockImplementation((key, value) => helper.set(key, value));
mockHelper.getFromGlobal.mockImplementation((keyDefinition) =>
helper.getFromGlobal(keyDefinition),
);
mockHelper.setToGlobal.mockImplementation((keyDefinition, value) =>
helper.setToGlobal(keyDefinition, value),
);
mockHelper.getFromUser.mockImplementation((userId, keyDefinition) =>
helper.getFromUser(userId, keyDefinition),
);
mockHelper.setToUser.mockImplementation((userId, keyDefinition, value) =>
helper.setToUser(userId, keyDefinition, value),
);
mockHelper.getAccounts.mockImplementation(() => helper.getAccounts());
mockHelper.getKnownUserIds.mockImplementation(() => helper.getKnownUserIds());
mockHelper.removeFromGlobal.mockImplementation((keyDefinition) =>
helper.removeFromGlobal(keyDefinition),
);
mockHelper.remove.mockImplementation((key) => helper.remove(key));
mockHelper.type = helper.type;
mockHelper.clientType = helper.clientType;
return mockHelper;
}
export type InitialDataHint<TUsers extends readonly string[]> = {
/**
* A string array of the users id who are authenticated
*/
authenticatedAccounts?: TUsers;
/**
* Global data
*/
global?: unknown;
/**
* Other top level data
*/
[key: string]: unknown;
} & {
/**
* A users data
*/
[userData in TUsers[number]]?: unknown;
};
type InjectedData = {
propertyName: string;
propertyValue: string;
originalPath: string[];
};
// This is a slight lie, technically the type is `Record<string | symbol, unknown>
// but for the purposes of things in the migrations this is enough.
function isStringRecord(object: unknown | undefined): object is Record<string, unknown> {
return object && typeof object === "object" && !Array.isArray(object);
}
function injectData(data: Record<string, unknown>, path: string[]): InjectedData[] {
if (!data) {
return [];
}
const injectedData: InjectedData[] = [];
// Traverse keys for other objects
const keys = Object.keys(data);
for (const key of keys) {
const currentProperty = data[key];
if (isStringRecord(currentProperty)) {
injectedData.push(...injectData(currentProperty, [...path, key]));
}
}
const propertyName = `__injectedProperty__${Utils.newGuid()}`;
const propertyValue = `__injectedValue__${Utils.newGuid()}`;
injectedData.push({
propertyName: propertyName,
propertyValue: propertyValue,
// Track the path it was originally injected in just for a better error
originalPath: path,
});
data[propertyName] = propertyValue;
return injectedData;
}
function expectInjectedData(
data: Record<string, unknown>,
injectedData: InjectedData[],
): [data: Record<string, unknown>, leftoverInjectedData: InjectedData[]] {
const keys = Object.keys(data);
for (const key of keys) {
const propertyValue = data[key];
// Injected data does not have to be found exactly where it was injected,
// just that it exists at all.
const injectedIndex = injectedData.findIndex(
(d) =>
d.propertyName === key &&
typeof propertyValue === "string" &&
propertyValue === d.propertyValue,
);
if (injectedIndex !== -1) {
// We found something we injected, remove it
injectedData.splice(injectedIndex, 1);
delete data[key];
continue;
}
if (isStringRecord(propertyValue)) {
const [updatedData, leftoverInjectedData] = expectInjectedData(propertyValue, injectedData);
data[key] = updatedData;
injectedData = leftoverInjectedData;
}
}
return [data, injectedData];
}
/**
* Runs the {@link Migrator.migrate} method of your migrator. You may pass in your test data and get back the data after the migration.
* This also injects extra properties at every level of your state and makes sure that it can be found.
* @param migrator Your migrator to use to do the migration
* @param initalData The data to start with
* @returns State after your migration has ran.
*/
export async function runMigrator<
TMigrator extends Migrator<number, number>,
const TUsers extends readonly string[],
>(
migrator: TMigrator,
initalData?: InitialDataHint<TUsers>,
direction: "migrate" | "rollback" = "migrate",
): Promise<Record<string, unknown>> {
const clonedData = JSON.parse(JSON.stringify(initalData ?? {}));
// Inject fake data at every level of the object
const allInjectedData = injectData(clonedData, []);
const fakeStorageService = new FakeStorageService(clonedData);
const helper = new MigrationHelper(
migrator.fromVersion,
fakeStorageService,
mock(),
"general",
ClientType.Web,
);
// Run their migrations
if (direction === "rollback") {
await migrator.rollback(helper);
} else {
await migrator.migrate(helper);
}
const [data, leftoverInjectedData] = expectInjectedData(
fakeStorageService.internalStore,
allInjectedData,
);
expect(leftoverInjectedData).toHaveLength(0);
return data;
}

View File

@@ -1,261 +1 @@
// eslint-disable-next-line import/no-restricted-paths -- Needed to provide client type to migrations
import { ClientType } from "../enums";
// eslint-disable-next-line import/no-restricted-paths -- Needed to print log messages
import { LogService } from "../platform/abstractions/log.service";
// eslint-disable-next-line import/no-restricted-paths -- Needed to interface with storage locations
import { AbstractStorageService } from "../platform/abstractions/storage.service";
export type StateDefinitionLike = { name: string };
export type KeyDefinitionLike = {
stateDefinition: StateDefinitionLike;
key: string;
};
export type MigrationHelperType = "general" | "web-disk-local";
export class MigrationHelper {
constructor(
public currentVersion: number,
private storageService: AbstractStorageService,
public logService: LogService,
type: MigrationHelperType,
public clientType: ClientType,
) {
this.type = type;
}
/**
* On some clients, migrations are ran multiple times without direct action from the migration writer.
*
* All clients will run through migrations at least once, this run is referred to as `"general"`. If a migration is
* ran more than that single time, they will get a unique name if that the write can make conditional logic based on which
* migration run this is.
*
* @remarks The preferrable way of writing migrations is ALWAYS to be defensive and reflect on the data you are given back. This
* should really only be used when reflecting on the data given isn't enough.
*/
type: MigrationHelperType;
/**
* Gets a value from the storage service at the given key.
*
* This is a brute force method to just get a value from the storage service. If you can use {@link getFromGlobal} or {@link getFromUser}, you should.
* @param key location
* @returns the value at the location
*/
get<T>(key: string): Promise<T> {
return this.storageService.get<T>(key);
}
/**
* Sets a value in the storage service at the given key.
*
* This is a brute force method to just set a value in the storage service. If you can use {@link setToGlobal} or {@link setToUser}, you should.
* @param key location
* @param value the value to set
* @returns
*/
set<T>(key: string, value: T): Promise<void> {
this.logService.info(`Setting ${key}`);
return this.storageService.save(key, value);
}
/**
* Remove a value in the storage service at the given key.
*
* This is a brute force method to just remove a value in the storage service. If you can use {@link removeFromGlobal} or {@link removeFromUser}, you should.
* @param key location
* @returns void
*/
remove(key: string): Promise<void> {
this.logService.info(`Removing ${key}`);
return this.storageService.remove(key);
}
/**
* Gets a globally scoped value from a location derived through the key definition
*
* This is for use with the state providers framework, DO NOT use for values stored with {@link StateService},
* use {@link get} for those.
* @param keyDefinition unique key definition
* @returns value from store
*/
getFromGlobal<T>(keyDefinition: KeyDefinitionLike): Promise<T> {
return this.get<T>(this.getGlobalKey(keyDefinition));
}
/**
* Sets a globally scoped value to a location derived through the key definition
*
* This is for use with the state providers framework, DO NOT use for values stored with {@link StateService},
* use {@link set} for those.
* @param keyDefinition unique key definition
* @param value value to store
* @returns void
*/
setToGlobal<T>(keyDefinition: KeyDefinitionLike, value: T): Promise<void> {
return this.set(this.getGlobalKey(keyDefinition), value);
}
/**
* Remove a globally scoped location derived through the key definition
*
* This is for use with the state providers framework, DO NOT use for values stored with {@link StateService},
* use {@link remove} for those.
* @param keyDefinition unique key definition
* @returns void
*/
removeFromGlobal(keyDefinition: KeyDefinitionLike): Promise<void> {
return this.remove(this.getGlobalKey(keyDefinition));
}
/**
* Gets a user scoped value from a location derived through the user id and key definition
*
* This is for use with the state providers framework, DO NOT use for values stored with {@link StateService},
* use {@link get} for those.
* @param userId userId to use in the key
* @param keyDefinition unique key definition
* @returns value from store
*/
getFromUser<T>(userId: string, keyDefinition: KeyDefinitionLike): Promise<T> {
return this.get<T>(this.getUserKey(userId, keyDefinition));
}
/**
* Sets a user scoped value to a location derived through the user id and key definition
*
* This is for use with the state providers framework, DO NOT use for values stored with {@link StateService},
* use {@link set} for those.
* @param userId userId to use in the key
* @param keyDefinition unique key definition
* @param value value to store
* @returns void
*/
setToUser<T>(userId: string, keyDefinition: KeyDefinitionLike, value: T): Promise<void> {
return this.set(this.getUserKey(userId, keyDefinition), value);
}
/**
* Remove a user scoped location derived through the key definition
*
* This is for use with the state providers framework, DO NOT use for values stored with {@link StateService},
* use {@link remove} for those.
* @param keyDefinition unique key definition
* @returns void
*/
removeFromUser(userId: string, keyDefinition: KeyDefinitionLike): Promise<void> {
return this.remove(this.getUserKey(userId, keyDefinition));
}
info(message: string): void {
this.logService.info(message);
}
/**
* Helper method to read all Account objects stored by the State Service.
*
* This is useful from creating migrations off of this paradigm, but should not be used once a value is migrated to a state provider.
*
* @returns a list of all accounts that have been authenticated with state service, cast the expected type.
*/
async getAccounts<ExpectedAccountType>(): Promise<
{ userId: string; account: ExpectedAccountType }[]
> {
const userIds = await this.getKnownUserIds();
return Promise.all(
userIds.map(async (userId) => ({
userId,
account: await this.get<ExpectedAccountType>(userId),
})),
);
}
/**
* Helper method to read known users ids.
*/
async getKnownUserIds(): Promise<string[]> {
if (this.currentVersion < 60) {
return knownAccountUserIdsBuilderPre60(this.storageService);
} else {
return knownAccountUserIdsBuilder(this.storageService);
}
}
/**
* Builds a user storage key appropriate for the current version.
*
* @param userId userId to use in the key
* @param keyDefinition state and key to use in the key
* @returns
*/
private getUserKey(userId: string, keyDefinition: KeyDefinitionLike): string {
if (this.currentVersion < 9) {
return userKeyBuilderPre9();
} else {
return userKeyBuilder(userId, keyDefinition);
}
}
/**
* Builds a global storage key appropriate for the current version.
*
* @param keyDefinition state and key to use in the key
* @returns
*/
private getGlobalKey(keyDefinition: KeyDefinitionLike): string {
if (this.currentVersion < 9) {
return globalKeyBuilderPre9();
} else {
return globalKeyBuilder(keyDefinition);
}
}
}
/**
* When this is updated, rename this function to `userKeyBuilderXToY` where `X` is the version number it
* became relevant, and `Y` prior to the version it was updated.
*
* Be sure to update the map in `MigrationHelper` to point to the appropriate function for the current version.
* @param userId The userId of the user you want the key to be for.
* @param keyDefinition the key definition of which data the key should point to.
* @returns
*/
function userKeyBuilder(userId: string, keyDefinition: KeyDefinitionLike): string {
return `user_${userId}_${keyDefinition.stateDefinition.name}_${keyDefinition.key}`;
}
function userKeyBuilderPre9(): string {
throw Error("No key builder should be used for versions prior to 9.");
}
/**
* When this is updated, rename this function to `globalKeyBuilderXToY` where `X` is the version number
* it became relevant, and `Y` prior to the version it was updated.
*
* Be sure to update the map in `MigrationHelper` to point to the appropriate function for the current version.
* @param keyDefinition the key definition of which data the key should point to.
* @returns
*/
function globalKeyBuilder(keyDefinition: KeyDefinitionLike): string {
return `global_${keyDefinition.stateDefinition.name}_${keyDefinition.key}`;
}
function globalKeyBuilderPre9(): string {
throw Error("No key builder should be used for versions prior to 9.");
}
async function knownAccountUserIdsBuilderPre60(
storageService: AbstractStorageService,
): Promise<string[]> {
return (await storageService.get<string[]>("authenticatedAccounts")) ?? [];
}
async function knownAccountUserIdsBuilder(
storageService: AbstractStorageService,
): Promise<string[]> {
const accounts = await storageService.get<Record<string, unknown>>(
globalKeyBuilder({ stateDefinition: { name: "account" }, key: "accounts" }),
);
return Object.keys(accounts ?? {});
}
export { MigrationHelper } from "@bitwarden/state";

View File

@@ -1,161 +0,0 @@
import { MockProxy, any } from "jest-mock-extended";
import { MigrationHelper } from "../migration-helper";
import { mockMigrationHelper } from "../migration-helper.spec";
import { EverHadUserKeyMigrator } from "./10-move-ever-had-user-key-to-state-providers";
function exampleJSON() {
return {
global: {
otherStuff: "otherStuff1",
},
authenticatedAccounts: [
"c493ed01-4e08-4e88-abc7-332f380ca760",
"23e61a5f-2ece-4f5e-b499-f0bc489482a9",
"fd005ea6-a16a-45ef-ba4a-a194269bfd73",
],
"c493ed01-4e08-4e88-abc7-332f380ca760": {
profile: {
everHadUserKey: false,
otherStuff: "overStuff2",
},
otherStuff: "otherStuff3",
},
"23e61a5f-2ece-4f5e-b499-f0bc489482a9": {
profile: {
everHadUserKey: true,
otherStuff: "otherStuff4",
},
otherStuff: "otherStuff5",
},
};
}
function rollbackJSON() {
return {
"user_c493ed01-4e08-4e88-abc7-332f380ca760_crypto_everHadUserKey": false,
"user_23e61a5f-2ece-4f5e-b499-f0bc489482a9_crypto_everHadUserKey": true,
"user_fd005ea6-a16a-45ef-ba4a-a194269bfd73_crypto_everHadUserKey": false,
global: {
otherStuff: "otherStuff1",
},
authenticatedAccounts: [
"c493ed01-4e08-4e88-abc7-332f380ca760",
"23e61a5f-2ece-4f5e-b499-f0bc489482a9",
"fd005ea6-a16a-45ef-ba4a-a194269bfd73",
],
"c493ed01-4e08-4e88-abc7-332f380ca760": {
profile: {
everHadUserKey: false,
otherStuff: "overStuff2",
},
otherStuff: "otherStuff3",
},
"23e61a5f-2ece-4f5e-b499-f0bc489482a9": {
profile: {
everHadUserKey: true,
otherStuff: "otherStuff4",
},
otherStuff: "otherStuff5",
},
};
}
describe("EverHadUserKeyMigrator", () => {
let helper: MockProxy<MigrationHelper>;
let sut: EverHadUserKeyMigrator;
const keyDefinitionLike = {
key: "everHadUserKey",
stateDefinition: {
name: "crypto",
},
};
describe("migrate", () => {
beforeEach(() => {
helper = mockMigrationHelper(exampleJSON(), 9);
sut = new EverHadUserKeyMigrator(9, 10);
});
it("should remove everHadUserKey from all accounts", async () => {
await sut.migrate(helper);
expect(helper.set).toHaveBeenCalledWith("c493ed01-4e08-4e88-abc7-332f380ca760", {
profile: {
otherStuff: "overStuff2",
},
otherStuff: "otherStuff3",
});
expect(helper.set).toHaveBeenCalledWith("23e61a5f-2ece-4f5e-b499-f0bc489482a9", {
profile: {
otherStuff: "otherStuff4",
},
otherStuff: "otherStuff5",
});
});
it("should set everHadUserKey provider value for each account", async () => {
await sut.migrate(helper);
expect(helper.setToUser).toHaveBeenCalledWith(
"c493ed01-4e08-4e88-abc7-332f380ca760",
keyDefinitionLike,
false,
);
expect(helper.setToUser).toHaveBeenCalledWith(
"23e61a5f-2ece-4f5e-b499-f0bc489482a9",
keyDefinitionLike,
true,
);
expect(helper.setToUser).toHaveBeenCalledWith(
"fd005ea6-a16a-45ef-ba4a-a194269bfd73",
keyDefinitionLike,
false,
);
});
});
describe("rollback", () => {
beforeEach(() => {
helper = mockMigrationHelper(rollbackJSON(), 10);
sut = new EverHadUserKeyMigrator(9, 10);
});
it.each([
"c493ed01-4e08-4e88-abc7-332f380ca760",
"23e61a5f-2ece-4f5e-b499-f0bc489482a9",
"fd005ea6-a16a-45ef-ba4a-a194269bfd73",
])("should null out new values", async (userId) => {
await sut.rollback(helper);
expect(helper.setToUser).toHaveBeenCalledWith(userId, keyDefinitionLike, null);
});
it("should add explicit value back to accounts", async () => {
await sut.rollback(helper);
expect(helper.set).toHaveBeenCalledWith("c493ed01-4e08-4e88-abc7-332f380ca760", {
profile: {
everHadUserKey: false,
otherStuff: "overStuff2",
},
otherStuff: "otherStuff3",
});
expect(helper.set).toHaveBeenCalledWith("23e61a5f-2ece-4f5e-b499-f0bc489482a9", {
profile: {
everHadUserKey: true,
otherStuff: "otherStuff4",
},
otherStuff: "otherStuff5",
});
});
it("should not try to restore values to missing accounts", async () => {
await sut.rollback(helper);
expect(helper.set).not.toHaveBeenCalledWith("fd005ea6-a16a-45ef-ba4a-a194269bfd73", any());
});
});
});

View File

@@ -1,48 +0,0 @@
// FIXME: Update this file to be type safe and remove this and next line
// @ts-strict-ignore
import { KeyDefinitionLike, MigrationHelper } from "../migration-helper";
import { Migrator } from "../migrator";
type ExpectedAccountType = {
profile?: {
everHadUserKey?: boolean;
};
};
const USER_EVER_HAD_USER_KEY: KeyDefinitionLike = {
key: "everHadUserKey",
stateDefinition: {
name: "crypto",
},
};
export class EverHadUserKeyMigrator extends Migrator<9, 10> {
async migrate(helper: MigrationHelper): Promise<void> {
const accounts = await helper.getAccounts<ExpectedAccountType>();
async function migrateAccount(userId: string, account: ExpectedAccountType): Promise<void> {
const value = account?.profile?.everHadUserKey;
await helper.setToUser(userId, USER_EVER_HAD_USER_KEY, value ?? false);
if (value != null) {
delete account.profile.everHadUserKey;
}
await helper.set(userId, account);
}
await Promise.all([...accounts.map(({ userId, account }) => migrateAccount(userId, account))]);
}
async rollback(helper: MigrationHelper): Promise<void> {
const accounts = await helper.getAccounts<ExpectedAccountType>();
async function rollbackAccount(userId: string, account: ExpectedAccountType): Promise<void> {
const value = await helper.getFromUser(userId, USER_EVER_HAD_USER_KEY);
if (account) {
account.profile = Object.assign(account.profile ?? {}, {
everHadUserKey: value,
});
await helper.set(userId, account);
}
await helper.setToUser(userId, USER_EVER_HAD_USER_KEY, null);
}
await Promise.all([...accounts.map(({ userId, account }) => rollbackAccount(userId, account))]);
}
}

View File

@@ -1,163 +0,0 @@
import { MockProxy, any } from "jest-mock-extended";
import { MigrationHelper } from "../migration-helper";
import { mockMigrationHelper } from "../migration-helper.spec";
import { OrganizationKeyMigrator } from "./11-move-org-keys-to-state-providers";
function exampleJSON() {
return {
global: {
otherStuff: "otherStuff1",
},
authenticatedAccounts: ["user-1", "user-2", "user-3"],
"user-1": {
keys: {
organizationKeys: {
encrypted: {
"org-id-1": {
type: "organization",
key: "org-key-1",
},
"org-id-2": {
type: "provider",
key: "org-key-2",
providerId: "provider-id-2",
},
},
},
otherStuff: "overStuff2",
},
otherStuff: "otherStuff3",
},
"user-2": {
keys: {
otherStuff: "otherStuff4",
},
otherStuff: "otherStuff5",
},
};
}
function rollbackJSON() {
return {
"user_user-1_crypto_organizationKeys": {
"org-id-1": {
type: "organization",
key: "org-key-1",
},
"org-id-2": {
type: "provider",
key: "org-key-2",
providerId: "provider-id-2",
},
},
"user_user-2_crypto_organizationKeys": null as any,
global: {
otherStuff: "otherStuff1",
},
authenticatedAccounts: ["user-1", "user-2", "user-3"],
"user-1": {
keys: {
otherStuff: "overStuff2",
},
otherStuff: "otherStuff3",
},
"user-2": {
keys: {
otherStuff: "otherStuff4",
},
otherStuff: "otherStuff5",
},
};
}
describe("OrganizationKeysMigrator", () => {
let helper: MockProxy<MigrationHelper>;
let sut: OrganizationKeyMigrator;
const keyDefinitionLike = {
key: "organizationKeys",
stateDefinition: {
name: "crypto",
},
};
describe("migrate", () => {
beforeEach(() => {
helper = mockMigrationHelper(exampleJSON(), 10);
sut = new OrganizationKeyMigrator(10, 11);
});
it("should remove organizationKeys from all accounts", async () => {
await sut.migrate(helper);
expect(helper.set).toHaveBeenCalledTimes(1);
expect(helper.set).toHaveBeenCalledWith("user-1", {
keys: {
otherStuff: "overStuff2",
},
otherStuff: "otherStuff3",
});
});
it("should set organizationKeys value for each account", async () => {
await sut.migrate(helper);
expect(helper.setToUser).toHaveBeenCalledTimes(1);
expect(helper.setToUser).toHaveBeenCalledWith("user-1", keyDefinitionLike, {
"org-id-1": {
type: "organization",
key: "org-key-1",
},
"org-id-2": {
type: "provider",
key: "org-key-2",
providerId: "provider-id-2",
},
});
});
});
describe("rollback", () => {
beforeEach(() => {
helper = mockMigrationHelper(rollbackJSON(), 11);
sut = new OrganizationKeyMigrator(10, 11);
});
it.each(["user-1", "user-2", "user-3"])("should null out new values %s", async (userId) => {
await sut.rollback(helper);
expect(helper.setToUser).toHaveBeenCalledWith(userId, keyDefinitionLike, null);
});
it("should add explicit value back to accounts", async () => {
await sut.rollback(helper);
expect(helper.set).toHaveBeenCalledTimes(1);
expect(helper.set).toHaveBeenCalledWith("user-1", {
keys: {
organizationKeys: {
encrypted: {
"org-id-1": {
type: "organization",
key: "org-key-1",
},
"org-id-2": {
type: "provider",
key: "org-key-2",
providerId: "provider-id-2",
},
},
},
otherStuff: "overStuff2",
},
otherStuff: "otherStuff3",
});
});
it("should not try to restore values to missing accounts", async () => {
await sut.rollback(helper);
expect(helper.set).not.toHaveBeenCalledWith("user-3", any());
});
});
});

View File

@@ -1,61 +0,0 @@
// FIXME: Update this file to be type safe and remove this and next line
// @ts-strict-ignore
import { KeyDefinitionLike, MigrationHelper } from "../migration-helper";
import { Migrator } from "../migrator";
type OrgKeyDataType = {
type: "organization" | "provider";
key: string;
providerId?: string;
};
type ExpectedAccountType = {
keys?: {
organizationKeys?: {
encrypted?: Record<string, OrgKeyDataType>;
};
};
};
const USER_ENCRYPTED_ORGANIZATION_KEYS: KeyDefinitionLike = {
key: "organizationKeys",
stateDefinition: {
name: "crypto",
},
};
export class OrganizationKeyMigrator extends Migrator<10, 11> {
async migrate(helper: MigrationHelper): Promise<void> {
const accounts = await helper.getAccounts<ExpectedAccountType>();
async function migrateAccount(userId: string, account: ExpectedAccountType): Promise<void> {
const value = account?.keys?.organizationKeys?.encrypted;
if (value != null) {
await helper.setToUser(userId, USER_ENCRYPTED_ORGANIZATION_KEYS, value);
delete account.keys.organizationKeys;
await helper.set(userId, account);
}
}
await Promise.all([...accounts.map(({ userId, account }) => migrateAccount(userId, account))]);
}
async rollback(helper: MigrationHelper): Promise<void> {
const accounts = await helper.getAccounts<ExpectedAccountType>();
async function rollbackAccount(userId: string, account: ExpectedAccountType): Promise<void> {
const value = await helper.getFromUser<Record<string, OrgKeyDataType>>(
userId,
USER_ENCRYPTED_ORGANIZATION_KEYS,
);
if (account && value) {
account.keys = Object.assign(account.keys ?? {}, {
organizationKeys: {
encrypted: value,
},
});
await helper.set(userId, account);
}
await helper.setToUser(userId, USER_ENCRYPTED_ORGANIZATION_KEYS, null);
}
await Promise.all([...accounts.map(({ userId, account }) => rollbackAccount(userId, account))]);
}
}

View File

@@ -1,157 +0,0 @@
import { runMigrator } from "../migration-helper.spec";
import { MoveEnvironmentStateToProviders } from "./12-move-environment-state-to-providers";
describe("MoveEnvironmentStateToProviders", () => {
const migrator = new MoveEnvironmentStateToProviders(11, 12);
it("can migrate all data", async () => {
const output = await runMigrator(migrator, {
authenticatedAccounts: ["user1", "user2"] as const,
global: {
region: "US",
environmentUrls: {
base: "example.com",
},
extra: "data",
},
user1: {
extra: "data",
settings: {
extra: "data",
region: "US",
environmentUrls: {
base: "example.com",
},
},
},
user2: {
extra: "data",
settings: {
region: "EU",
environmentUrls: {
base: "other.example.com",
},
extra: "data",
},
},
extra: "data",
});
expect(output).toEqual({
authenticatedAccounts: ["user1", "user2"],
global: {
extra: "data",
},
global_environment_region: "US",
global_environment_urls: {
base: "example.com",
},
user1: {
extra: "data",
settings: {
extra: "data",
},
},
user2: {
extra: "data",
settings: {
extra: "data",
},
},
extra: "data",
user_user1_environment_region: "US",
user_user2_environment_region: "EU",
user_user1_environment_urls: {
base: "example.com",
},
user_user2_environment_urls: {
base: "other.example.com",
},
});
});
it("handles missing parts", async () => {
const output = await runMigrator(migrator, {
authenticatedAccounts: ["user1", "user2"],
global: {
extra: "data",
},
user1: {
extra: "data",
settings: {
extra: "data",
},
},
user2: null,
});
expect(output).toEqual({
authenticatedAccounts: ["user1", "user2"],
global: {
extra: "data",
},
user1: {
extra: "data",
settings: {
extra: "data",
},
},
user2: null,
});
});
it("can migrate only global data", async () => {
const output = await runMigrator(migrator, {
authenticatedAccounts: [] as const,
global: {
region: "Self-Hosted",
},
});
expect(output).toEqual({
authenticatedAccounts: [],
global_environment_region: "Self-Hosted",
global: {},
});
});
it("can migrate only user state", async () => {
const output = await runMigrator(migrator, {
authenticatedAccounts: ["user1"] as const,
global: null,
user1: {
settings: {
region: "Self-Hosted",
environmentUrls: {
base: "some-base-url",
api: "some-api-url",
identity: "some-identity-url",
icons: "some-icons-url",
notifications: "some-notifications-url",
events: "some-events-url",
webVault: "some-webVault-url",
keyConnector: "some-keyConnector-url",
},
},
},
});
expect(output).toEqual({
authenticatedAccounts: ["user1"] as const,
global: null,
user1: { settings: {} },
user_user1_environment_region: "Self-Hosted",
user_user1_environment_urls: {
base: "some-base-url",
api: "some-api-url",
identity: "some-identity-url",
icons: "some-icons-url",
notifications: "some-notifications-url",
events: "some-events-url",
webVault: "some-webVault-url",
keyConnector: "some-keyConnector-url",
},
});
});
});

Some files were not shown because too many files have changed in this diff Show More