mirror of
https://github.com/bitwarden/browser
synced 2026-02-07 12:13:45 +00:00
Merge remote-tracking branch 'origin' into auth/pm-19877/notification-processing
This commit is contained in:
@@ -89,8 +89,7 @@ export class DefaultPolicyService implements PolicyService {
|
||||
const policies$ = policies ? of(policies) : this.policies$(userId);
|
||||
return policies$.pipe(
|
||||
map((obsPolicies) => {
|
||||
// TODO: replace with this.combinePoliciesIntoMasterPasswordPolicyOptions(obsPolicies)) once
|
||||
// FeatureFlag.PM16117_ChangeExistingPasswordRefactor is removed.
|
||||
// TODO ([PM-23777]): replace with this.combinePoliciesIntoMasterPasswordPolicyOptions(obsPolicies))
|
||||
let enforcedOptions: MasterPasswordPolicyOptions | undefined = undefined;
|
||||
const filteredPolicies =
|
||||
obsPolicies.filter((p) => p.type === PolicyType.MasterPassword) ?? [];
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
import { map, Observable } from "rxjs";
|
||||
|
||||
import { UserId } from "@bitwarden/user-core";
|
||||
|
||||
import { ActiveUserAccessor } from "../../platform/state";
|
||||
import { AccountService } from "../abstractions/account.service";
|
||||
|
||||
/**
|
||||
* Implementation for Platform so they can avoid a direct dependency on AccountService. Not for general consumption.
|
||||
*/
|
||||
export class DefaultActiveUserAccessor implements ActiveUserAccessor {
|
||||
constructor(private readonly accountService: AccountService) {
|
||||
this.activeUserId$ = this.accountService.activeAccount$.pipe(
|
||||
map((a) => (a != null ? a.id : null)),
|
||||
);
|
||||
}
|
||||
|
||||
activeUserId$: Observable<UserId | null>;
|
||||
}
|
||||
@@ -4,8 +4,6 @@ import { of } from "rxjs";
|
||||
// This import has been flagged as unallowed for this class. It may be involved in a circular dependency loop.
|
||||
// eslint-disable-next-line no-restricted-imports
|
||||
import {
|
||||
PinLockType,
|
||||
PinServiceAbstraction,
|
||||
UserDecryptionOptions,
|
||||
UserDecryptionOptionsServiceAbstraction,
|
||||
} from "@bitwarden/auth/common";
|
||||
@@ -21,6 +19,8 @@ import {
|
||||
|
||||
import { FakeAccountService, mockAccountServiceWith } from "../../../../spec";
|
||||
import { InternalMasterPasswordServiceAbstraction } from "../../../key-management/master-password/abstractions/master-password.service.abstraction";
|
||||
import { PinServiceAbstraction } from "../../../key-management/pin/pin.service.abstraction";
|
||||
import { PinLockType } from "../../../key-management/pin/pin.service.implementation";
|
||||
import { VaultTimeoutSettingsService } from "../../../key-management/vault-timeout";
|
||||
import { I18nService } from "../../../platform/abstractions/i18n.service";
|
||||
import { HashPurpose } from "../../../platform/enums";
|
||||
|
||||
@@ -14,10 +14,8 @@ import {
|
||||
KeyService,
|
||||
} from "@bitwarden/key-management";
|
||||
|
||||
// FIXME: remove `src` and fix import
|
||||
// eslint-disable-next-line no-restricted-imports
|
||||
import { PinServiceAbstraction } from "../../../../../auth/src/common/abstractions/pin.service.abstraction";
|
||||
import { InternalMasterPasswordServiceAbstraction } from "../../../key-management/master-password/abstractions/master-password.service.abstraction";
|
||||
import { PinServiceAbstraction } from "../../../key-management/pin/pin.service.abstraction";
|
||||
import { I18nService } from "../../../platform/abstractions/i18n.service";
|
||||
import { HashPurpose } from "../../../platform/enums";
|
||||
import { UserId } from "../../../types/guid";
|
||||
|
||||
@@ -3,7 +3,6 @@ import { PreviewIndividualInvoiceRequest } from "../models/request/preview-indiv
|
||||
import { PreviewOrganizationInvoiceRequest } from "../models/request/preview-organization-invoice.request";
|
||||
import { PreviewTaxAmountForOrganizationTrialRequest } from "../models/request/tax";
|
||||
import { PreviewInvoiceResponse } from "../models/response/preview-invoice.response";
|
||||
import { PreviewTaxAmountResponse } from "../models/response/tax";
|
||||
|
||||
export abstract class TaxServiceAbstraction {
|
||||
abstract getCountries(): CountryListItem[];
|
||||
@@ -20,5 +19,5 @@ export abstract class TaxServiceAbstraction {
|
||||
|
||||
abstract previewTaxAmountForOrganizationTrial: (
|
||||
request: PreviewTaxAmountForOrganizationTrialRequest,
|
||||
) => Promise<PreviewTaxAmountResponse>;
|
||||
) => Promise<number>;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { PreviewTaxAmountForOrganizationTrialRequest } from "@bitwarden/common/billing/models/request/tax";
|
||||
import { PreviewTaxAmountResponse } from "@bitwarden/common/billing/models/response/tax";
|
||||
|
||||
import { ApiService } from "../../abstractions/api.service";
|
||||
import { TaxServiceAbstraction } from "../abstractions/tax.service.abstraction";
|
||||
@@ -306,13 +305,14 @@ export class TaxService implements TaxServiceAbstraction {
|
||||
|
||||
async previewTaxAmountForOrganizationTrial(
|
||||
request: PreviewTaxAmountForOrganizationTrialRequest,
|
||||
): Promise<PreviewTaxAmountResponse> {
|
||||
return await this.apiService.send(
|
||||
): Promise<number> {
|
||||
const response = await this.apiService.send(
|
||||
"POST",
|
||||
"/tax/preview-amount/organization-trial",
|
||||
request,
|
||||
true,
|
||||
true,
|
||||
);
|
||||
return response as number;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1 @@
|
||||
// FIXME: update to use a const object instead of a typescript enum
|
||||
// eslint-disable-next-line @bitwarden/platform/no-enums
|
||||
export enum ClientType {
|
||||
Web = "web",
|
||||
Browser = "browser",
|
||||
Desktop = "desktop",
|
||||
// Mobile = "mobile",
|
||||
Cli = "cli",
|
||||
// DirectoryConnector = "connector",
|
||||
}
|
||||
export { ClientType } from "@bitwarden/client-type";
|
||||
|
||||
@@ -14,8 +14,6 @@ export enum FeatureFlag {
|
||||
CreateDefaultLocation = "pm-19467-create-default-location",
|
||||
|
||||
/* Auth */
|
||||
PM16117_SetInitialPasswordRefactor = "pm-16117-set-initial-password-refactor",
|
||||
PM16117_ChangeExistingPasswordRefactor = "pm-16117-change-existing-password-refactor",
|
||||
PM14938_BrowserExtensionLoginApproval = "pm-14938-browser-extension-login-approvals",
|
||||
|
||||
/* Autofill */
|
||||
@@ -34,17 +32,19 @@ export enum FeatureFlag {
|
||||
UseOrganizationWarningsService = "use-organization-warnings-service",
|
||||
AllowTrialLengthZero = "pm-20322-allow-trial-length-0",
|
||||
PM21881_ManagePaymentDetailsOutsideCheckout = "pm-21881-manage-payment-details-outside-checkout",
|
||||
PM21821_ProviderPortalTakeover = "pm-21821-provider-portal-takeover",
|
||||
|
||||
/* Key Management */
|
||||
PrivateKeyRegeneration = "pm-12241-private-key-regeneration",
|
||||
PM4154_BulkEncryptionService = "PM-4154-bulk-encryption-service",
|
||||
UseSDKForDecryption = "use-sdk-for-decryption",
|
||||
PM17987_BlockType0 = "pm-17987-block-type-0",
|
||||
EnrollAeadOnKeyRotation = "enroll-aead-on-key-rotation",
|
||||
ForceUpdateKDFSettings = "pm-18021-force-update-kdf-settings",
|
||||
|
||||
/* Tools */
|
||||
DesktopSendUIRefresh = "desktop-send-ui-refresh",
|
||||
UseSdkPasswordGenerators = "pm-19976-use-sdk-password-generators",
|
||||
|
||||
/* DIRT */
|
||||
EventBasedOrganizationIntegrations = "event-based-organization-integrations",
|
||||
|
||||
/* Vault */
|
||||
PM8851_BrowserOnboardingNudge = "pm-8851-browser-onboarding-nudge",
|
||||
@@ -88,6 +88,10 @@ export const DefaultFeatureFlagValue = {
|
||||
|
||||
/* Tools */
|
||||
[FeatureFlag.DesktopSendUIRefresh]: FALSE,
|
||||
[FeatureFlag.UseSdkPasswordGenerators]: FALSE,
|
||||
|
||||
/* DIRT */
|
||||
[FeatureFlag.EventBasedOrganizationIntegrations]: FALSE,
|
||||
|
||||
/* Vault */
|
||||
[FeatureFlag.PM8851_BrowserOnboardingNudge]: FALSE,
|
||||
@@ -101,8 +105,6 @@ export const DefaultFeatureFlagValue = {
|
||||
[FeatureFlag.PM22136_SdkCipherEncryption]: FALSE,
|
||||
|
||||
/* Auth */
|
||||
[FeatureFlag.PM16117_SetInitialPasswordRefactor]: FALSE,
|
||||
[FeatureFlag.PM16117_ChangeExistingPasswordRefactor]: FALSE,
|
||||
[FeatureFlag.PM14938_BrowserExtensionLoginApproval]: FALSE,
|
||||
|
||||
/* Billing */
|
||||
@@ -113,12 +115,10 @@ export const DefaultFeatureFlagValue = {
|
||||
[FeatureFlag.UseOrganizationWarningsService]: FALSE,
|
||||
[FeatureFlag.AllowTrialLengthZero]: FALSE,
|
||||
[FeatureFlag.PM21881_ManagePaymentDetailsOutsideCheckout]: FALSE,
|
||||
[FeatureFlag.PM21821_ProviderPortalTakeover]: FALSE,
|
||||
|
||||
/* Key Management */
|
||||
[FeatureFlag.PrivateKeyRegeneration]: FALSE,
|
||||
[FeatureFlag.PM4154_BulkEncryptionService]: FALSE,
|
||||
[FeatureFlag.UseSDKForDecryption]: FALSE,
|
||||
[FeatureFlag.PM17987_BlockType0]: FALSE,
|
||||
[FeatureFlag.EnrollAeadOnKeyRotation]: FALSE,
|
||||
[FeatureFlag.ForceUpdateKDFSettings]: FALSE,
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
import { ServerConfig } from "../../../platform/abstractions/config/server-config";
|
||||
import { Decryptable } from "../../../platform/interfaces/decryptable.interface";
|
||||
import { InitializerMetadata } from "../../../platform/interfaces/initializer-metadata.interface";
|
||||
import { SymmetricCryptoKey } from "../../../platform/models/domain/symmetric-crypto-key";
|
||||
|
||||
export abstract class BulkEncryptService {
|
||||
abstract decryptItems<T extends InitializerMetadata>(
|
||||
items: Decryptable<T>[],
|
||||
key: SymmetricCryptoKey,
|
||||
): Promise<T[]>;
|
||||
|
||||
abstract onServerConfigChange(newConfig: ServerConfig): void;
|
||||
}
|
||||
@@ -6,12 +6,20 @@ import { SymmetricCryptoKey } from "../../../platform/models/domain/symmetric-cr
|
||||
import { CsprngArray } from "../../../types/csprng";
|
||||
|
||||
export abstract class CryptoFunctionService {
|
||||
/**
|
||||
* @deprecated HAZMAT WARNING: DO NOT USE THIS FOR NEW CODE. Implement low-level crypto operations
|
||||
* in the SDK instead. Further, you should probably never find yourself using this low-level crypto function.
|
||||
*/
|
||||
abstract pbkdf2(
|
||||
password: string | Uint8Array,
|
||||
salt: string | Uint8Array,
|
||||
algorithm: "sha256" | "sha512",
|
||||
iterations: number,
|
||||
): Promise<Uint8Array>;
|
||||
/**
|
||||
* @deprecated HAZMAT WARNING: DO NOT USE THIS FOR NEW CODE. Implement low-level crypto operations
|
||||
* in the SDK instead. Further, you should probably never find yourself using this low-level crypto function.
|
||||
*/
|
||||
abstract hkdf(
|
||||
ikm: Uint8Array,
|
||||
salt: string | Uint8Array,
|
||||
@@ -19,51 +27,76 @@ export abstract class CryptoFunctionService {
|
||||
outputByteSize: number,
|
||||
algorithm: "sha256" | "sha512",
|
||||
): Promise<Uint8Array>;
|
||||
/**
|
||||
* @deprecated HAZMAT WARNING: DO NOT USE THIS FOR NEW CODE. Implement low-level crypto operations
|
||||
* in the SDK instead. Further, you should probably never find yourself using this low-level crypto function.
|
||||
*/
|
||||
abstract hkdfExpand(
|
||||
prk: Uint8Array,
|
||||
info: string | Uint8Array,
|
||||
outputByteSize: number,
|
||||
algorithm: "sha256" | "sha512",
|
||||
): Promise<Uint8Array>;
|
||||
/**
|
||||
* @deprecated HAZMAT WARNING: DO NOT USE THIS FOR NEW CODE. Implement low-level crypto operations
|
||||
* in the SDK instead. Further, you should probably never find yourself using this low-level crypto function.
|
||||
*/
|
||||
abstract hash(
|
||||
value: string | Uint8Array,
|
||||
algorithm: "sha1" | "sha256" | "sha512" | "md5",
|
||||
): Promise<Uint8Array>;
|
||||
abstract hmac(
|
||||
value: Uint8Array,
|
||||
key: Uint8Array,
|
||||
algorithm: "sha1" | "sha256" | "sha512",
|
||||
): Promise<Uint8Array>;
|
||||
abstract compare(a: Uint8Array, b: Uint8Array): Promise<boolean>;
|
||||
/**
|
||||
* @deprecated HAZMAT WARNING: DO NOT USE THIS FOR NEW CODE. Implement low-level crypto operations
|
||||
* in the SDK instead. Further, you should probably never find yourself using this low-level crypto function.
|
||||
*/
|
||||
abstract hmacFast(
|
||||
value: Uint8Array | string,
|
||||
key: Uint8Array | string,
|
||||
algorithm: "sha1" | "sha256" | "sha512",
|
||||
): Promise<Uint8Array | string>;
|
||||
abstract compareFast(a: Uint8Array | string, b: Uint8Array | string): Promise<boolean>;
|
||||
/**
|
||||
* @deprecated HAZMAT WARNING: DO NOT USE THIS FOR NEW CODE. Implement low-level crypto operations
|
||||
* in the SDK instead. Further, you should probably never find yourself using this low-level crypto function.
|
||||
*/
|
||||
abstract aesDecryptFastParameters(
|
||||
data: string,
|
||||
iv: string,
|
||||
mac: string,
|
||||
key: SymmetricCryptoKey,
|
||||
): CbcDecryptParameters<Uint8Array | string>;
|
||||
/**
|
||||
* @deprecated HAZMAT WARNING: DO NOT USE THIS FOR NEW CODE. Implement low-level crypto operations
|
||||
* in the SDK instead. Further, you should probably never find yourself using this low-level crypto function.
|
||||
*/
|
||||
abstract aesDecryptFast({
|
||||
mode,
|
||||
parameters,
|
||||
}:
|
||||
| { mode: "cbc"; parameters: CbcDecryptParameters<Uint8Array | string> }
|
||||
| { mode: "ecb"; parameters: EcbDecryptParameters<Uint8Array | string> }): Promise<string>;
|
||||
/**
|
||||
* @deprecated HAZMAT WARNING: DO NOT USE THIS FOR NEW CODE. Only used by DDG integration until DDG uses PKCS#7 padding, and by lastpass importer.
|
||||
*/
|
||||
abstract aesDecrypt(
|
||||
data: Uint8Array,
|
||||
iv: Uint8Array,
|
||||
key: Uint8Array,
|
||||
mode: "cbc" | "ecb",
|
||||
): Promise<Uint8Array>;
|
||||
/**
|
||||
* @deprecated HAZMAT WARNING: DO NOT USE THIS FOR NEW CODE. Implement low-level crypto operations
|
||||
* in the SDK instead. Further, you should probably never find yourself using this low-level crypto function.
|
||||
*/
|
||||
abstract rsaEncrypt(
|
||||
data: Uint8Array,
|
||||
publicKey: Uint8Array,
|
||||
algorithm: "sha1" | "sha256",
|
||||
): Promise<Uint8Array>;
|
||||
/**
|
||||
* @deprecated HAZMAT WARNING: DO NOT USE THIS FOR NEW CODE. Implement low-level crypto operations
|
||||
* in the SDK instead. Further, you should probably never find yourself using this low-level crypto function.
|
||||
*/
|
||||
abstract rsaDecrypt(
|
||||
data: Uint8Array,
|
||||
privateKey: Uint8Array,
|
||||
@@ -77,7 +110,6 @@ export abstract class CryptoFunctionService {
|
||||
abstract aesGenerateKey(bitLength: 128 | 192 | 256 | 512): Promise<CsprngArray>;
|
||||
/**
|
||||
* Generates a random array of bytes of the given length. Uses a cryptographically secure random number generator.
|
||||
*
|
||||
* Do not use this for generating encryption keys. Use aesGenerateKey or rsaGenerateKeyPair instead.
|
||||
*/
|
||||
abstract randomBytes(length: number): Promise<CsprngArray>;
|
||||
|
||||
@@ -1,51 +1,8 @@
|
||||
import { ServerConfig } from "../../../platform/abstractions/config/server-config";
|
||||
import { Decryptable } from "../../../platform/interfaces/decryptable.interface";
|
||||
import { Encrypted } from "../../../platform/interfaces/encrypted";
|
||||
import { InitializerMetadata } from "../../../platform/interfaces/initializer-metadata.interface";
|
||||
import { EncArrayBuffer } from "../../../platform/models/domain/enc-array-buffer";
|
||||
import { SymmetricCryptoKey } from "../../../platform/models/domain/symmetric-crypto-key";
|
||||
import { EncString } from "../models/enc-string";
|
||||
|
||||
export abstract class EncryptService {
|
||||
/**
|
||||
* @deprecated
|
||||
* Decrypts an EncString to a string
|
||||
* @param encString - The EncString to decrypt
|
||||
* @param key - The key to decrypt the EncString with
|
||||
* @param decryptTrace - A string to identify the context of the object being decrypted. This can include: field name, encryption type, cipher id, key type, but should not include
|
||||
* sensitive information like encryption keys or data. This is used for logging when decryption errors occur in order to identify what failed to decrypt
|
||||
* @returns The decrypted string
|
||||
*/
|
||||
abstract decryptToUtf8(
|
||||
encString: EncString,
|
||||
key: SymmetricCryptoKey,
|
||||
decryptTrace?: string,
|
||||
): Promise<string>;
|
||||
/**
|
||||
* @deprecated
|
||||
* Decrypts an Encrypted object to a Uint8Array
|
||||
* @param encThing - The Encrypted object to decrypt
|
||||
* @param key - The key to decrypt the Encrypted object with
|
||||
* @param decryptTrace - A string to identify the context of the object being decrypted. This can include: field name, encryption type, cipher id, key type, but should not include
|
||||
* sensitive information like encryption keys or data. This is used for logging when decryption errors occur in order to identify what failed to decrypt
|
||||
* @returns The decrypted Uint8Array
|
||||
*/
|
||||
abstract decryptToBytes(
|
||||
encThing: Encrypted,
|
||||
key: SymmetricCryptoKey,
|
||||
decryptTrace?: string,
|
||||
): Promise<Uint8Array | null>;
|
||||
|
||||
/**
|
||||
* @deprecated Replaced by BulkEncryptService, remove once the feature is tested and the featureflag PM-4154-multi-worker-encryption-service is removed
|
||||
* @param items The items to decrypt
|
||||
* @param key The key to decrypt the items with
|
||||
*/
|
||||
abstract decryptItems<T extends InitializerMetadata>(
|
||||
items: Decryptable<T>[],
|
||||
key: SymmetricCryptoKey,
|
||||
): Promise<T[]>;
|
||||
|
||||
/**
|
||||
* Encrypts a string to an EncString
|
||||
* @param plainValue - The value to encrypt
|
||||
@@ -188,12 +145,6 @@ export abstract class EncryptService {
|
||||
decapsulationKey: Uint8Array,
|
||||
): Promise<SymmetricCryptoKey>;
|
||||
|
||||
/**
|
||||
* @deprecated Use @see {@link encapsulateKeyUnsigned} instead
|
||||
* @param data - The data to encrypt
|
||||
* @param publicKey - The public key to encrypt with
|
||||
*/
|
||||
abstract rsaEncrypt(data: Uint8Array, publicKey: Uint8Array): Promise<EncString>;
|
||||
/**
|
||||
* @deprecated Use @see {@link decapsulateKeyUnsigned} instead
|
||||
* @param data - The ciphertext to decrypt
|
||||
@@ -210,6 +161,4 @@ export abstract class EncryptService {
|
||||
value: string | Uint8Array,
|
||||
algorithm: "sha1" | "sha256" | "sha512",
|
||||
): Promise<string>;
|
||||
|
||||
abstract onServerConfigChange(newConfig: ServerConfig): void;
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import { mock, MockProxy } from "jest-mock-extended";
|
||||
// eslint-disable-next-line no-restricted-imports
|
||||
import { KeyService } from "@bitwarden/key-management";
|
||||
|
||||
import { makeEncString, makeStaticByteArray } from "../../../../spec";
|
||||
import { makeStaticByteArray } from "../../../../spec";
|
||||
import { EncryptionType } from "../../../platform/enums";
|
||||
import { SymmetricCryptoKey } from "../../../platform/models/domain/symmetric-crypto-key";
|
||||
import { ContainerService } from "../../../platform/services/container.service";
|
||||
@@ -83,7 +83,7 @@ describe("EncString", () => {
|
||||
|
||||
const keyService = mock<KeyService>();
|
||||
keyService.hasUserKey.mockResolvedValue(true);
|
||||
keyService.getUserKeyWithLegacySupport.mockResolvedValue(
|
||||
keyService.getUserKey.mockResolvedValue(
|
||||
new SymmetricCryptoKey(makeStaticByteArray(32)) as UserKey,
|
||||
);
|
||||
|
||||
@@ -114,67 +114,6 @@ describe("EncString", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("decryptWithKey", () => {
|
||||
const encString = new EncString(EncryptionType.Rsa2048_OaepSha256_B64, "data");
|
||||
|
||||
const keyService = mock<KeyService>();
|
||||
const encryptService = mock<EncryptService>();
|
||||
encryptService.decryptString
|
||||
.calledWith(encString, expect.anything())
|
||||
.mockResolvedValue("decrypted");
|
||||
|
||||
function setupEncryption() {
|
||||
encryptService.encryptString.mockImplementation(async (data, key) => {
|
||||
return makeEncString(data);
|
||||
});
|
||||
encryptService.decryptString.mockImplementation(async (encString, key) => {
|
||||
return encString.data;
|
||||
});
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
(window as any).bitwardenContainerService = new ContainerService(keyService, encryptService);
|
||||
});
|
||||
|
||||
it("decrypts using the provided key and encryptService", async () => {
|
||||
setupEncryption();
|
||||
|
||||
const key = new SymmetricCryptoKey(makeStaticByteArray(32));
|
||||
await encString.decryptWithKey(key, encryptService);
|
||||
|
||||
expect(encryptService.decryptString).toHaveBeenCalledWith(encString, key);
|
||||
});
|
||||
|
||||
it("fails to decrypt when key is null", async () => {
|
||||
const decrypted = await encString.decryptWithKey(null, encryptService);
|
||||
|
||||
expect(decrypted).toBe("[error: cannot decrypt]");
|
||||
expect(encString.decryptedValue).toBe("[error: cannot decrypt]");
|
||||
});
|
||||
|
||||
it("fails to decrypt when encryptService is null", async () => {
|
||||
const decrypted = await encString.decryptWithKey(
|
||||
new SymmetricCryptoKey(makeStaticByteArray(32)),
|
||||
null,
|
||||
);
|
||||
|
||||
expect(decrypted).toBe("[error: cannot decrypt]");
|
||||
expect(encString.decryptedValue).toBe("[error: cannot decrypt]");
|
||||
});
|
||||
|
||||
it("fails to decrypt when encryptService throws", async () => {
|
||||
encryptService.decryptString.mockRejectedValue("error");
|
||||
|
||||
const decrypted = await encString.decryptWithKey(
|
||||
new SymmetricCryptoKey(makeStaticByteArray(32)),
|
||||
encryptService,
|
||||
);
|
||||
|
||||
expect(decrypted).toBe("[error: cannot decrypt]");
|
||||
expect(encString.decryptedValue).toBe("[error: cannot decrypt]");
|
||||
});
|
||||
});
|
||||
|
||||
describe("AesCbc256_B64", () => {
|
||||
it("constructor", () => {
|
||||
const encString = new EncString(EncryptionType.AesCbc256_B64, "data", "iv");
|
||||
@@ -343,7 +282,7 @@ describe("EncString", () => {
|
||||
|
||||
await encString.decrypt(null, key);
|
||||
|
||||
expect(keyService.getUserKeyWithLegacySupport).not.toHaveBeenCalled();
|
||||
expect(keyService.getUserKey).not.toHaveBeenCalled();
|
||||
expect(encryptService.decryptString).toHaveBeenCalledWith(encString, key);
|
||||
});
|
||||
|
||||
@@ -361,11 +300,11 @@ describe("EncString", () => {
|
||||
it("gets the user's decryption key if required", async () => {
|
||||
const userKey = mock<UserKey>();
|
||||
|
||||
keyService.getUserKeyWithLegacySupport.mockResolvedValue(userKey);
|
||||
keyService.getUserKey.mockResolvedValue(userKey);
|
||||
|
||||
await encString.decrypt(null, null);
|
||||
|
||||
expect(keyService.getUserKeyWithLegacySupport).toHaveBeenCalledWith();
|
||||
expect(keyService.getUserKey).toHaveBeenCalledWith();
|
||||
expect(encryptService.decryptString).toHaveBeenCalledWith(encString, userKey);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
// FIXME: Update this file to be type safe and remove this and next line
|
||||
// @ts-strict-ignore
|
||||
import { Jsonify, Opaque } from "type-fest";
|
||||
import { Jsonify } from "type-fest";
|
||||
|
||||
import { EncString as SdkEncString } from "@bitwarden/sdk-internal";
|
||||
|
||||
import { EncryptionType, EXPECTED_NUM_PARTS_BY_ENCRYPTION_TYPE } from "../../../platform/enums";
|
||||
import { Encrypted } from "../../../platform/interfaces/encrypted";
|
||||
import { Utils } from "../../../platform/misc/utils";
|
||||
import { SymmetricCryptoKey } from "../../../platform/models/domain/symmetric-crypto-key";
|
||||
import { EncryptService } from "../abstractions/encrypt.service";
|
||||
|
||||
export const DECRYPT_ERROR = "[error: cannot decrypt]";
|
||||
|
||||
export class EncString implements Encrypted {
|
||||
encryptedString?: EncryptedString;
|
||||
encryptedString?: SdkEncString;
|
||||
encryptionType?: EncryptionType;
|
||||
decryptedValue?: string;
|
||||
data?: string;
|
||||
@@ -43,7 +44,11 @@ export class EncString implements Encrypted {
|
||||
return this.data == null ? null : Utils.fromB64ToArray(this.data);
|
||||
}
|
||||
|
||||
toJSON() {
|
||||
toSdk(): SdkEncString {
|
||||
return this.encryptedString;
|
||||
}
|
||||
|
||||
toJSON(): string {
|
||||
return this.encryptedString as string;
|
||||
}
|
||||
|
||||
@@ -57,14 +62,14 @@ export class EncString implements Encrypted {
|
||||
|
||||
private initFromData(encType: EncryptionType, data: string, iv: string, mac: string) {
|
||||
if (iv != null) {
|
||||
this.encryptedString = (encType + "." + iv + "|" + data) as EncryptedString;
|
||||
this.encryptedString = (encType + "." + iv + "|" + data) as SdkEncString;
|
||||
} else {
|
||||
this.encryptedString = (encType + "." + data) as EncryptedString;
|
||||
this.encryptedString = (encType + "." + data) as SdkEncString;
|
||||
}
|
||||
|
||||
// mac
|
||||
if (mac != null) {
|
||||
this.encryptedString = (this.encryptedString + "|" + mac) as EncryptedString;
|
||||
this.encryptedString = (this.encryptedString + "|" + mac) as SdkEncString;
|
||||
}
|
||||
|
||||
this.encryptionType = encType;
|
||||
@@ -74,7 +79,7 @@ export class EncString implements Encrypted {
|
||||
}
|
||||
|
||||
private initFromEncryptedString(encryptedString: string) {
|
||||
this.encryptedString = encryptedString as EncryptedString;
|
||||
this.encryptedString = encryptedString as SdkEncString;
|
||||
if (!this.encryptedString) {
|
||||
return;
|
||||
}
|
||||
@@ -184,31 +189,14 @@ export class EncString implements Encrypted {
|
||||
return this.decryptedValue;
|
||||
}
|
||||
|
||||
async decryptWithKey(
|
||||
key: SymmetricCryptoKey,
|
||||
encryptService: EncryptService,
|
||||
decryptTrace: string = "domain-withkey",
|
||||
): Promise<string> {
|
||||
try {
|
||||
if (key == null) {
|
||||
throw new Error("No key to decrypt EncString");
|
||||
}
|
||||
|
||||
this.decryptedValue = await encryptService.decryptString(this, key);
|
||||
// FIXME: Remove when updating file. Eslint update
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
} catch (e) {
|
||||
this.decryptedValue = DECRYPT_ERROR;
|
||||
}
|
||||
|
||||
return this.decryptedValue;
|
||||
}
|
||||
private async getKeyForDecryption(orgId: string) {
|
||||
const keyService = Utils.getContainerService().getKeyService();
|
||||
return orgId != null
|
||||
? await keyService.getOrgKey(orgId)
|
||||
: await keyService.getUserKeyWithLegacySupport();
|
||||
return orgId != null ? await keyService.getOrgKey(orgId) : await keyService.getUserKey();
|
||||
}
|
||||
}
|
||||
|
||||
export type EncryptedString = Opaque<string, "EncString">;
|
||||
/**
|
||||
* Temporary type mapping until consumers are moved over.
|
||||
* @deprecated - Use SdkEncString directly
|
||||
*/
|
||||
export type EncryptedString = SdkEncString;
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
import { BulkEncryptService } from "@bitwarden/common/key-management/crypto/abstractions/bulk-encrypt.service";
|
||||
import { CryptoFunctionService } from "@bitwarden/common/key-management/crypto/abstractions/crypto-function.service";
|
||||
import { LogService } from "@bitwarden/common/platform/abstractions/log.service";
|
||||
import { Decryptable } from "@bitwarden/common/platform/interfaces/decryptable.interface";
|
||||
import { InitializerMetadata } from "@bitwarden/common/platform/interfaces/initializer-metadata.interface";
|
||||
import { SymmetricCryptoKey } from "@bitwarden/common/platform/models/domain/symmetric-crypto-key";
|
||||
|
||||
import { DefaultFeatureFlagValue, FeatureFlag } from "../../../enums/feature-flag.enum";
|
||||
import { ServerConfig } from "../../../platform/abstractions/config/server-config";
|
||||
|
||||
/**
|
||||
* @deprecated Will be deleted in an immediate subsequent PR
|
||||
*/
|
||||
export class BulkEncryptServiceImplementation implements BulkEncryptService {
|
||||
protected useSDKForDecryption: boolean = DefaultFeatureFlagValue[FeatureFlag.UseSDKForDecryption];
|
||||
|
||||
constructor(
|
||||
protected cryptoFunctionService: CryptoFunctionService,
|
||||
protected logService: LogService,
|
||||
) {}
|
||||
|
||||
/**
|
||||
* Decrypts items using a web worker if the environment supports it.
|
||||
* Will fall back to the main thread if the window object is not available.
|
||||
*/
|
||||
async decryptItems<T extends InitializerMetadata>(
|
||||
items: Decryptable<T>[],
|
||||
key: SymmetricCryptoKey,
|
||||
): Promise<T[]> {
|
||||
if (key == null) {
|
||||
throw new Error("No encryption key provided.");
|
||||
}
|
||||
|
||||
if (items == null || items.length < 1) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const results = [];
|
||||
for (let i = 0; i < items.length; i++) {
|
||||
results.push(await items[i].decrypt(key));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
onServerConfigChange(newConfig: ServerConfig): void {}
|
||||
}
|
||||
@@ -5,15 +5,11 @@ import { EncString } from "@bitwarden/common/key-management/crypto/models/enc-st
|
||||
import { LogService } from "@bitwarden/common/platform/abstractions/log.service";
|
||||
import { SdkLoadService } from "@bitwarden/common/platform/abstractions/sdk/sdk-load.service";
|
||||
import { EncryptionType } from "@bitwarden/common/platform/enums";
|
||||
import { Decryptable } from "@bitwarden/common/platform/interfaces/decryptable.interface";
|
||||
import { Encrypted } from "@bitwarden/common/platform/interfaces/encrypted";
|
||||
import { InitializerMetadata } from "@bitwarden/common/platform/interfaces/initializer-metadata.interface";
|
||||
import { Utils } from "@bitwarden/common/platform/misc/utils";
|
||||
import { EncArrayBuffer } from "@bitwarden/common/platform/models/domain/enc-array-buffer";
|
||||
import { SymmetricCryptoKey } from "@bitwarden/common/platform/models/domain/symmetric-crypto-key";
|
||||
import { PureCrypto } from "@bitwarden/sdk-internal";
|
||||
|
||||
import { ServerConfig } from "../../../platform/abstractions/config/server-config";
|
||||
import { EncryptService } from "../abstractions/encrypt.service";
|
||||
|
||||
export class EncryptServiceImplementation implements EncryptService {
|
||||
@@ -23,7 +19,6 @@ export class EncryptServiceImplementation implements EncryptService {
|
||||
protected logMacFailures: boolean,
|
||||
) {}
|
||||
|
||||
// Proxy functions; Their implementation are temporary before moving at this level to the SDK
|
||||
async encryptString(plainValue: string, key: SymmetricCryptoKey): Promise<EncString> {
|
||||
if (plainValue == null) {
|
||||
this.logService.warning(
|
||||
@@ -171,36 +166,6 @@ export class EncryptServiceImplementation implements EncryptService {
|
||||
return Utils.fromBufferToB64(hashArray);
|
||||
}
|
||||
|
||||
// Handle updating private properties to turn on/off feature flags.
|
||||
onServerConfigChange(newConfig: ServerConfig): void {}
|
||||
|
||||
async decryptToUtf8(
|
||||
encString: EncString,
|
||||
key: SymmetricCryptoKey,
|
||||
_decryptContext: string = "no context",
|
||||
): Promise<string> {
|
||||
await SdkLoadService.Ready;
|
||||
return PureCrypto.symmetric_decrypt(encString.encryptedString, key.toEncoded());
|
||||
}
|
||||
|
||||
async decryptToBytes(
|
||||
encThing: Encrypted,
|
||||
key: SymmetricCryptoKey,
|
||||
_decryptContext: string = "no context",
|
||||
): Promise<Uint8Array | null> {
|
||||
if (encThing.encryptionType == null || encThing.ivBytes == null || encThing.dataBytes == null) {
|
||||
throw new Error("Cannot decrypt, missing type, IV, or data bytes.");
|
||||
}
|
||||
const buffer = EncArrayBuffer.fromParts(
|
||||
encThing.encryptionType,
|
||||
encThing.ivBytes,
|
||||
encThing.dataBytes,
|
||||
encThing.macBytes,
|
||||
).buffer;
|
||||
await SdkLoadService.Ready;
|
||||
return PureCrypto.symmetric_decrypt_array_buffer(buffer, key.toEncoded());
|
||||
}
|
||||
|
||||
async encapsulateKeyUnsigned(
|
||||
sharedKey: SymmetricCryptoKey,
|
||||
encapsulationKey: Uint8Array,
|
||||
@@ -228,45 +193,14 @@ export class EncryptServiceImplementation implements EncryptService {
|
||||
throw new Error("No decapsulationKey provided for decapsulation");
|
||||
}
|
||||
|
||||
await SdkLoadService.Ready;
|
||||
const keyBytes = PureCrypto.decapsulate_key_unsigned(
|
||||
encryptedSharedKey.encryptedString,
|
||||
decapsulationKey,
|
||||
);
|
||||
await SdkLoadService.Ready;
|
||||
return new SymmetricCryptoKey(keyBytes);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Replaced by BulkEncryptService (PM-4154)
|
||||
*/
|
||||
async decryptItems<T extends InitializerMetadata>(
|
||||
items: Decryptable<T>[],
|
||||
key: SymmetricCryptoKey,
|
||||
): Promise<T[]> {
|
||||
if (items == null || items.length < 1) {
|
||||
return [];
|
||||
}
|
||||
|
||||
// don't use promise.all because this task is not io bound
|
||||
const results = [];
|
||||
for (let i = 0; i < items.length; i++) {
|
||||
results.push(await items[i].decrypt(key));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
async rsaEncrypt(data: Uint8Array, publicKey: Uint8Array): Promise<EncString> {
|
||||
if (data == null) {
|
||||
throw new Error("No data provided for encryption.");
|
||||
}
|
||||
|
||||
if (publicKey == null) {
|
||||
throw new Error("No public key provided for encryption.");
|
||||
}
|
||||
const encrypted = await this.cryptoFunctionService.rsaEncrypt(data, publicKey, "sha1");
|
||||
return new EncString(EncryptionType.Rsa2048_OaepSha1_B64, Utils.fromBufferToB64(encrypted));
|
||||
}
|
||||
|
||||
async rsaDecrypt(data: EncString, privateKey: Uint8Array): Promise<Uint8Array> {
|
||||
if (data == null) {
|
||||
throw new Error("[Encrypt service] rsaDecrypt: No data provided for decryption.");
|
||||
|
||||
@@ -303,12 +303,6 @@ describe("EncryptService", () => {
|
||||
const actual = await encryptService.encapsulateKeyUnsigned(testKey, publicKey);
|
||||
expect(actual).toEqual(new EncString("encapsulated_key_unsigned"));
|
||||
});
|
||||
|
||||
it("throws if no data was provided", () => {
|
||||
return expect(encryptService.rsaEncrypt(null, new Uint8Array(32))).rejects.toThrow(
|
||||
"No data provided for encryption",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("decapsulateKeyUnsigned", () => {
|
||||
@@ -338,23 +332,4 @@ describe("EncryptService", () => {
|
||||
expect(cryptoFunctionService.hash).toHaveBeenCalledWith("test", "sha256");
|
||||
});
|
||||
});
|
||||
|
||||
describe("decryptItems", () => {
|
||||
it("returns empty array if no items are provided", async () => {
|
||||
const key = mock<SymmetricCryptoKey>();
|
||||
const actual = await encryptService.decryptItems(null, key);
|
||||
expect(actual).toEqual([]);
|
||||
});
|
||||
|
||||
it("returns items decrypted with provided key", async () => {
|
||||
const key = mock<SymmetricCryptoKey>();
|
||||
const decryptable = {
|
||||
decrypt: jest.fn().mockResolvedValue("decrypted"),
|
||||
};
|
||||
const items = [decryptable];
|
||||
const actual = await encryptService.decryptItems(items as any, key);
|
||||
expect(actual).toEqual(["decrypted"]);
|
||||
expect(decryptable.decrypt).toHaveBeenCalledWith(key);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
// FIXME: Update this file to be type safe and remove this and next line
|
||||
// @ts-strict-ignore
|
||||
import { Jsonify } from "type-fest";
|
||||
|
||||
import { ServerConfig } from "../../../platform/abstractions/config/server-config";
|
||||
import { LogService } from "../../../platform/abstractions/log.service";
|
||||
import { Decryptable } from "../../../platform/interfaces/decryptable.interface";
|
||||
import { SymmetricCryptoKey } from "../../../platform/models/domain/symmetric-crypto-key";
|
||||
import { ConsoleLogService } from "../../../platform/services/console-log.service";
|
||||
import { ContainerService } from "../../../platform/services/container.service";
|
||||
import { getClassInitializer } from "../../../platform/services/cryptography/get-class-initializer";
|
||||
import {
|
||||
DECRYPT_COMMAND,
|
||||
SET_CONFIG_COMMAND,
|
||||
ParsedDecryptCommandData,
|
||||
} from "../types/worker-command.type";
|
||||
|
||||
import { EncryptServiceImplementation } from "./encrypt.service.implementation";
|
||||
import { WebCryptoFunctionService } from "./web-crypto-function.service";
|
||||
|
||||
const workerApi: Worker = self as any;
|
||||
|
||||
let inited = false;
|
||||
let encryptService: EncryptServiceImplementation;
|
||||
let logService: LogService;
|
||||
|
||||
/**
|
||||
* Bootstrap the worker environment with services required for decryption
|
||||
*/
|
||||
export function init() {
|
||||
const cryptoFunctionService = new WebCryptoFunctionService(self);
|
||||
logService = new ConsoleLogService(false);
|
||||
encryptService = new EncryptServiceImplementation(cryptoFunctionService, logService, true);
|
||||
|
||||
const bitwardenContainerService = new ContainerService(null, encryptService);
|
||||
bitwardenContainerService.attachToGlobal(self);
|
||||
|
||||
inited = true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Listen for messages and decrypt their contents
|
||||
*/
|
||||
workerApi.addEventListener("message", async (event: { data: string }) => {
|
||||
if (!inited) {
|
||||
init();
|
||||
}
|
||||
|
||||
const request: {
|
||||
command: string;
|
||||
} = JSON.parse(event.data);
|
||||
|
||||
switch (request.command) {
|
||||
case DECRYPT_COMMAND:
|
||||
return await handleDecrypt(request as unknown as ParsedDecryptCommandData);
|
||||
case SET_CONFIG_COMMAND: {
|
||||
const newConfig = (request as unknown as { newConfig: Jsonify<ServerConfig> }).newConfig;
|
||||
return await handleSetConfig(newConfig);
|
||||
}
|
||||
default:
|
||||
logService.error(`[EncryptWorker] unknown worker command`, request.command, request);
|
||||
}
|
||||
});
|
||||
|
||||
async function handleDecrypt(request: ParsedDecryptCommandData) {
|
||||
const key = SymmetricCryptoKey.fromJSON(request.key);
|
||||
const items = request.items.map((jsonItem) => {
|
||||
const initializer = getClassInitializer<Decryptable<any>>(jsonItem.initializerKey);
|
||||
return initializer(jsonItem);
|
||||
});
|
||||
const result = await encryptService.decryptItems(items, key);
|
||||
|
||||
workerApi.postMessage({
|
||||
id: request.id,
|
||||
items: JSON.stringify(result),
|
||||
});
|
||||
}
|
||||
|
||||
async function handleSetConfig(newConfig: Jsonify<ServerConfig>) {
|
||||
encryptService.onServerConfigChange(ServerConfig.fromJSON(newConfig));
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
// FIXME: Update this file to be type safe and remove this and next line
|
||||
// @ts-strict-ignore
|
||||
import { BulkEncryptService } from "@bitwarden/common/key-management/crypto/abstractions/bulk-encrypt.service";
|
||||
import { Decryptable } from "@bitwarden/common/platform/interfaces/decryptable.interface";
|
||||
import { InitializerMetadata } from "@bitwarden/common/platform/interfaces/initializer-metadata.interface";
|
||||
import { SymmetricCryptoKey } from "@bitwarden/common/platform/models/domain/symmetric-crypto-key";
|
||||
|
||||
import { ServerConfig } from "../../../platform/abstractions/config/server-config";
|
||||
import { EncryptService } from "../abstractions/encrypt.service";
|
||||
|
||||
/**
|
||||
* @deprecated Will be deleted in an immediate subsequent PR
|
||||
*/
|
||||
export class FallbackBulkEncryptService implements BulkEncryptService {
|
||||
private featureFlagEncryptService: BulkEncryptService;
|
||||
private currentServerConfig: ServerConfig | undefined = undefined;
|
||||
|
||||
constructor(protected encryptService: EncryptService) {}
|
||||
|
||||
/**
|
||||
* Decrypts items using a web worker if the environment supports it.
|
||||
* Will fall back to the main thread if the window object is not available.
|
||||
*/
|
||||
async decryptItems<T extends InitializerMetadata>(
|
||||
items: Decryptable<T>[],
|
||||
key: SymmetricCryptoKey,
|
||||
): Promise<T[]> {
|
||||
return await this.encryptService.decryptItems(items, key);
|
||||
}
|
||||
|
||||
async setFeatureFlagEncryptService(featureFlagEncryptService: BulkEncryptService) {}
|
||||
|
||||
onServerConfigChange(newConfig: ServerConfig): void {}
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
import { Decryptable } from "@bitwarden/common/platform/interfaces/decryptable.interface";
|
||||
import { InitializerMetadata } from "@bitwarden/common/platform/interfaces/initializer-metadata.interface";
|
||||
import { SymmetricCryptoKey } from "@bitwarden/common/platform/models/domain/symmetric-crypto-key";
|
||||
|
||||
import { ServerConfig } from "../../../platform/abstractions/config/server-config";
|
||||
|
||||
import { EncryptServiceImplementation } from "./encrypt.service.implementation";
|
||||
|
||||
/**
|
||||
* @deprecated Will be deleted in an immediate subsequent PR
|
||||
*/
|
||||
export class MultithreadEncryptServiceImplementation extends EncryptServiceImplementation {
|
||||
protected useSDKForDecryption: boolean = true;
|
||||
|
||||
/**
|
||||
* Sends items to a web worker to decrypt them.
|
||||
* This utilises multithreading to decrypt items faster without interrupting other operations (e.g. updating UI).
|
||||
*/
|
||||
async decryptItems<T extends InitializerMetadata>(
|
||||
items: Decryptable<T>[],
|
||||
key: SymmetricCryptoKey,
|
||||
): Promise<T[]> {
|
||||
return await super.decryptItems(items, key);
|
||||
}
|
||||
|
||||
override onServerConfigChange(newConfig: ServerConfig): void {}
|
||||
}
|
||||
@@ -154,46 +154,6 @@ describe("WebCrypto Function Service", () => {
|
||||
testHmac("sha512", Sha512Mac);
|
||||
});
|
||||
|
||||
describe("compare", () => {
|
||||
it("should successfully compare two of the same values", async () => {
|
||||
const cryptoFunctionService = getWebCryptoFunctionService();
|
||||
const a = new Uint8Array(2);
|
||||
a[0] = 1;
|
||||
a[1] = 2;
|
||||
const equal = await cryptoFunctionService.compare(a, a);
|
||||
expect(equal).toBe(true);
|
||||
});
|
||||
|
||||
it("should successfully compare two different values of the same length", async () => {
|
||||
const cryptoFunctionService = getWebCryptoFunctionService();
|
||||
const a = new Uint8Array(2);
|
||||
a[0] = 1;
|
||||
a[1] = 2;
|
||||
const b = new Uint8Array(2);
|
||||
b[0] = 3;
|
||||
b[1] = 4;
|
||||
const equal = await cryptoFunctionService.compare(a, b);
|
||||
expect(equal).toBe(false);
|
||||
});
|
||||
|
||||
it("should successfully compare two different values of different lengths", async () => {
|
||||
const cryptoFunctionService = getWebCryptoFunctionService();
|
||||
const a = new Uint8Array(2);
|
||||
a[0] = 1;
|
||||
a[1] = 2;
|
||||
const b = new Uint8Array(2);
|
||||
b[0] = 3;
|
||||
const equal = await cryptoFunctionService.compare(a, b);
|
||||
expect(equal).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("hmacFast", () => {
|
||||
testHmacFast("sha1", Sha1Mac);
|
||||
testHmacFast("sha256", Sha256Mac);
|
||||
testHmacFast("sha512", Sha512Mac);
|
||||
});
|
||||
|
||||
describe("compareFast", () => {
|
||||
it("should successfully compare two of the same values", async () => {
|
||||
const cryptoFunctionService = getWebCryptoFunctionService();
|
||||
@@ -523,20 +483,6 @@ function testHmac(algorithm: "sha1" | "sha256" | "sha512", mac: string) {
|
||||
});
|
||||
}
|
||||
|
||||
function testHmacFast(algorithm: "sha1" | "sha256" | "sha512", mac: string) {
|
||||
it("should create valid " + algorithm + " hmac", async () => {
|
||||
const cryptoFunctionService = getWebCryptoFunctionService();
|
||||
const keyByteString = Utils.fromBufferToByteString(Utils.fromUtf8ToArray("secretkey"));
|
||||
const dataByteString = Utils.fromBufferToByteString(Utils.fromUtf8ToArray("SignMe!!"));
|
||||
const computedMac = await cryptoFunctionService.hmacFast(
|
||||
dataByteString,
|
||||
keyByteString,
|
||||
algorithm,
|
||||
);
|
||||
expect(Utils.fromBufferToHex(Utils.fromByteStringToArray(computedMac))).toBe(mac);
|
||||
});
|
||||
}
|
||||
|
||||
function testRsaGenerateKeyPair(length: 1024 | 2048 | 4096) {
|
||||
it(
|
||||
"should successfully generate a " + length + " bit key pair",
|
||||
|
||||
@@ -146,34 +146,6 @@ export class WebCryptoFunctionService implements CryptoFunctionService {
|
||||
return new Uint8Array(buffer);
|
||||
}
|
||||
|
||||
// Safely compare two values in a way that protects against timing attacks (Double HMAC Verification).
|
||||
// ref: https://www.nccgroup.trust/us/about-us/newsroom-and-events/blog/2011/february/double-hmac-verification/
|
||||
// ref: https://paragonie.com/blog/2015/11/preventing-timing-attacks-on-string-comparison-with-double-hmac-strategy
|
||||
async compare(a: Uint8Array, b: Uint8Array): Promise<boolean> {
|
||||
const macKey = await this.randomBytes(32);
|
||||
const signingAlgorithm = {
|
||||
name: "HMAC",
|
||||
hash: { name: "SHA-256" },
|
||||
};
|
||||
const impKey = await this.subtle.importKey("raw", macKey, signingAlgorithm, false, ["sign"]);
|
||||
const mac1 = await this.subtle.sign(signingAlgorithm, impKey, a);
|
||||
const mac2 = await this.subtle.sign(signingAlgorithm, impKey, b);
|
||||
|
||||
if (mac1.byteLength !== mac2.byteLength) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const arr1 = new Uint8Array(mac1);
|
||||
const arr2 = new Uint8Array(mac2);
|
||||
for (let i = 0; i < arr2.length; i++) {
|
||||
if (arr1[i] !== arr2[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
hmacFast(value: string, key: string, algorithm: "sha1" | "sha256" | "sha512"): Promise<string> {
|
||||
const hmac = forge.hmac.create();
|
||||
hmac.start(algorithm, key);
|
||||
@@ -182,6 +154,9 @@ export class WebCryptoFunctionService implements CryptoFunctionService {
|
||||
return Promise.resolve(bytes);
|
||||
}
|
||||
|
||||
// Safely compare two values in a way that protects against timing attacks (Double HMAC Verification).
|
||||
// ref: https://www.nccgroup.trust/us/about-us/newsroom-and-events/blog/2011/february/double-hmac-verification/
|
||||
// ref: https://paragonie.com/blog/2015/11/preventing-timing-attacks-on-string-comparison-with-double-hmac-strategy
|
||||
async compareFast(a: string, b: string): Promise<boolean> {
|
||||
const rand = await this.randomBytes(32);
|
||||
const bytes = new Uint32Array(rand);
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
import { mock } from "jest-mock-extended";
|
||||
|
||||
import { makeStaticByteArray } from "../../../../spec";
|
||||
import { ServerConfig } from "../../../platform/abstractions/config/server-config";
|
||||
import { Decryptable } from "../../../platform/interfaces/decryptable.interface";
|
||||
import { SymmetricCryptoKey } from "../../../platform/models/domain/symmetric-crypto-key";
|
||||
|
||||
import {
|
||||
DECRYPT_COMMAND,
|
||||
DecryptCommandData,
|
||||
SET_CONFIG_COMMAND,
|
||||
buildDecryptMessage,
|
||||
buildSetConfigMessage,
|
||||
} from "./worker-command.type";
|
||||
|
||||
describe("Worker command types", () => {
|
||||
describe("buildDecryptMessage", () => {
|
||||
it("builds a message with the correct command", () => {
|
||||
const commandData = createDecryptCommandData();
|
||||
|
||||
const result = buildDecryptMessage(commandData);
|
||||
|
||||
const parsedResult = JSON.parse(result);
|
||||
expect(parsedResult.command).toBe(DECRYPT_COMMAND);
|
||||
});
|
||||
|
||||
it("includes the provided data in the message", () => {
|
||||
const mockItems = [{ encrypted: "test-encrypted" } as unknown as Decryptable<any>];
|
||||
const commandData = createDecryptCommandData(mockItems);
|
||||
|
||||
const result = buildDecryptMessage(commandData);
|
||||
|
||||
const parsedResult = JSON.parse(result);
|
||||
expect(parsedResult.command).toBe(DECRYPT_COMMAND);
|
||||
expect(parsedResult.id).toBe("test-id");
|
||||
expect(parsedResult.items).toEqual(mockItems);
|
||||
expect(SymmetricCryptoKey.fromJSON(parsedResult.key)).toEqual(commandData.key);
|
||||
});
|
||||
});
|
||||
|
||||
describe("buildSetConfigMessage", () => {
|
||||
it("builds a message with the correct command", () => {
|
||||
const result = buildSetConfigMessage({ newConfig: mock<ServerConfig>() });
|
||||
|
||||
const parsedResult = JSON.parse(result);
|
||||
expect(parsedResult.command).toBe(SET_CONFIG_COMMAND);
|
||||
});
|
||||
|
||||
it("includes the provided data in the message", () => {
|
||||
const serverConfig = { version: "test-version" } as unknown as ServerConfig;
|
||||
|
||||
const result = buildSetConfigMessage({ newConfig: serverConfig });
|
||||
|
||||
const parsedResult = JSON.parse(result);
|
||||
expect(parsedResult.command).toBe(SET_CONFIG_COMMAND);
|
||||
expect(ServerConfig.fromJSON(parsedResult.newConfig).version).toEqual(serverConfig.version);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
function createDecryptCommandData(items?: Decryptable<any>[]): DecryptCommandData {
|
||||
return {
|
||||
id: "test-id",
|
||||
items: items ?? [],
|
||||
key: new SymmetricCryptoKey(makeStaticByteArray(64)),
|
||||
};
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
import { Jsonify } from "type-fest";
|
||||
|
||||
import { ServerConfig } from "../../../platform/abstractions/config/server-config";
|
||||
import { Decryptable } from "../../../platform/interfaces/decryptable.interface";
|
||||
import { SymmetricCryptoKey } from "../../../platform/models/domain/symmetric-crypto-key";
|
||||
|
||||
export const DECRYPT_COMMAND = "decrypt";
|
||||
export const SET_CONFIG_COMMAND = "updateConfig";
|
||||
|
||||
export type DecryptCommandData = {
|
||||
id: string;
|
||||
items: Decryptable<any>[];
|
||||
key: SymmetricCryptoKey;
|
||||
};
|
||||
|
||||
export type ParsedDecryptCommandData = {
|
||||
id: string;
|
||||
items: Jsonify<Decryptable<any>>[];
|
||||
key: Jsonify<SymmetricCryptoKey>;
|
||||
};
|
||||
|
||||
type SetConfigCommandData = { newConfig: ServerConfig };
|
||||
|
||||
export function buildDecryptMessage(data: DecryptCommandData): string {
|
||||
return JSON.stringify({
|
||||
command: DECRYPT_COMMAND,
|
||||
...data,
|
||||
});
|
||||
}
|
||||
|
||||
export function buildSetConfigMessage(data: SetConfigCommandData): string {
|
||||
return JSON.stringify({
|
||||
command: SET_CONFIG_COMMAND,
|
||||
...data,
|
||||
});
|
||||
}
|
||||
@@ -1,9 +1,17 @@
|
||||
import { Observable } from "rxjs";
|
||||
|
||||
// eslint-disable-next-line no-restricted-imports
|
||||
import { KdfConfig } from "@bitwarden/key-management";
|
||||
|
||||
import { ForceSetPasswordReason } from "../../../auth/models/domain/force-set-password-reason";
|
||||
import { UserId } from "../../../types/guid";
|
||||
import { MasterKey, UserKey } from "../../../types/key";
|
||||
import { EncString } from "../../crypto/models/enc-string";
|
||||
import {
|
||||
MasterPasswordAuthenticationData,
|
||||
MasterPasswordSalt,
|
||||
MasterPasswordUnlockData,
|
||||
} from "../types/master-password.types";
|
||||
|
||||
export abstract class MasterPasswordServiceAbstraction {
|
||||
/**
|
||||
@@ -12,14 +20,23 @@ export abstract class MasterPasswordServiceAbstraction {
|
||||
* @throws If the user ID is missing.
|
||||
*/
|
||||
abstract forceSetPasswordReason$: (userId: UserId) => Observable<ForceSetPasswordReason>;
|
||||
/**
|
||||
* An observable that emits the master password salt for the user.
|
||||
* @param userId The user ID.
|
||||
* @throws If the user ID is missing.
|
||||
* @throws If the user ID is provided, but the user is not found.
|
||||
*/
|
||||
abstract saltForUser$: (userId: UserId) => Observable<MasterPasswordSalt>;
|
||||
/**
|
||||
* An observable that emits the master key for the user.
|
||||
* @deprecated Interacting with the master-key directly is deprecated. Please use {@link makeMasterPasswordUnlockData}, {@link makeMasterPasswordAuthenticationData} or {@link unwrapUserKeyFromMasterPasswordUnlockData} instead.
|
||||
* @param userId The user ID.
|
||||
* @throws If the user ID is missing.
|
||||
*/
|
||||
abstract masterKey$: (userId: UserId) => Observable<MasterKey>;
|
||||
/**
|
||||
* An observable that emits the master key hash for the user.
|
||||
* @deprecated Interacting with the master-key directly is deprecated. Please use {@link makeMasterPasswordAuthenticationData}.
|
||||
* @param userId The user ID.
|
||||
* @throws If the user ID is missing.
|
||||
*/
|
||||
@@ -32,6 +49,7 @@ export abstract class MasterPasswordServiceAbstraction {
|
||||
abstract getMasterKeyEncryptedUserKey: (userId: UserId) => Promise<EncString>;
|
||||
/**
|
||||
* Decrypts the user key with the provided master key
|
||||
* @deprecated Interacting with the master-key directly is deprecated. Please use {@link unwrapUserKeyFromMasterPasswordUnlockData} instead.
|
||||
* @param masterKey The user's master key
|
||||
* * @param userId The desired user
|
||||
* @param userKey The user's encrypted symmetric key
|
||||
@@ -44,12 +62,52 @@ export abstract class MasterPasswordServiceAbstraction {
|
||||
userId: string,
|
||||
userKey?: EncString,
|
||||
) => Promise<UserKey | null>;
|
||||
|
||||
/**
|
||||
* Makes the authentication hash for authenticating to the server with the master password.
|
||||
* @param password The master password.
|
||||
* @param kdf The KDF configuration.
|
||||
* @param salt The master password salt to use. See {@link saltForUser$} for current salt.
|
||||
* @throws If password, KDF or salt are null or undefined.
|
||||
*/
|
||||
abstract makeMasterPasswordAuthenticationData: (
|
||||
password: string,
|
||||
kdf: KdfConfig,
|
||||
salt: MasterPasswordSalt,
|
||||
) => Promise<MasterPasswordAuthenticationData>;
|
||||
|
||||
/**
|
||||
* Creates a MasterPasswordUnlockData bundle that encrypts the user-key with a key derived from the password. The
|
||||
* bundle also contains the KDF settings and salt used to derive the key, which are required to decrypt the user-key later.
|
||||
* @param password The master password.
|
||||
* @param kdf The KDF configuration.
|
||||
* @param salt The master password salt to use. See {@link saltForUser$} for current salt.
|
||||
* @param userKey The user's userKey to encrypt.
|
||||
* @throws If password, KDF, salt, or userKey are null or undefined.
|
||||
*/
|
||||
abstract makeMasterPasswordUnlockData: (
|
||||
password: string,
|
||||
kdf: KdfConfig,
|
||||
salt: MasterPasswordSalt,
|
||||
userKey: UserKey,
|
||||
) => Promise<MasterPasswordUnlockData>;
|
||||
|
||||
/**
|
||||
* Unwraps a user-key that was wrapped with a password provided KDF settings. The same KDF settings and salt must be provided to unwrap the user-key, otherwise it will fail to decrypt.
|
||||
* @throws If the encryption type is not supported.
|
||||
* @throws If the password, KDF, or salt don't match the original wrapping parameters.
|
||||
*/
|
||||
abstract unwrapUserKeyFromMasterPasswordUnlockData: (
|
||||
password: string,
|
||||
masterPasswordUnlockData: MasterPasswordUnlockData,
|
||||
) => Promise<UserKey>;
|
||||
}
|
||||
|
||||
export abstract class InternalMasterPasswordServiceAbstraction extends MasterPasswordServiceAbstraction {
|
||||
/**
|
||||
* Set the master key for the user.
|
||||
* Note: Use {@link clearMasterKey} to clear the master key.
|
||||
* @deprecated Interacting with the master-key directly is deprecated.
|
||||
* @param masterKey The master key.
|
||||
* @param userId The user ID.
|
||||
* @throws If the user ID or master key is missing.
|
||||
@@ -57,6 +115,7 @@ export abstract class InternalMasterPasswordServiceAbstraction extends MasterPas
|
||||
abstract setMasterKey: (masterKey: MasterKey, userId: UserId) => Promise<void>;
|
||||
/**
|
||||
* Clear the master key for the user.
|
||||
* @deprecated Interacting with the master-key directly is deprecated.
|
||||
* @param userId The user ID.
|
||||
* @throws If the user ID is missing.
|
||||
*/
|
||||
@@ -64,6 +123,7 @@ export abstract class InternalMasterPasswordServiceAbstraction extends MasterPas
|
||||
/**
|
||||
* Set the master key hash for the user.
|
||||
* Note: Use {@link clearMasterKeyHash} to clear the master key hash.
|
||||
* @deprecated Interacting with the master-key directly is deprecated.
|
||||
* @param masterKeyHash The master key hash.
|
||||
* @param userId The user ID.
|
||||
* @throws If the user ID or master key hash is missing.
|
||||
@@ -71,6 +131,7 @@ export abstract class InternalMasterPasswordServiceAbstraction extends MasterPas
|
||||
abstract setMasterKeyHash: (masterKeyHash: string, userId: UserId) => Promise<void>;
|
||||
/**
|
||||
* Clear the master key hash for the user.
|
||||
* @deprecated Interacting with the master-key directly is deprecated.
|
||||
* @param userId The user ID.
|
||||
* @throws If the user ID is missing.
|
||||
*/
|
||||
|
||||
@@ -3,11 +3,20 @@
|
||||
import { mock } from "jest-mock-extended";
|
||||
import { ReplaySubject, Observable } from "rxjs";
|
||||
|
||||
// FIXME: Update this file to be type safe and remove this and next line
|
||||
// eslint-disable-next-line no-restricted-imports
|
||||
import { KdfConfig } from "@bitwarden/key-management";
|
||||
|
||||
import { ForceSetPasswordReason } from "../../../auth/models/domain/force-set-password-reason";
|
||||
import { UserId } from "../../../types/guid";
|
||||
import { MasterKey, UserKey } from "../../../types/key";
|
||||
import { EncString } from "../../crypto/models/enc-string";
|
||||
import { InternalMasterPasswordServiceAbstraction } from "../abstractions/master-password.service.abstraction";
|
||||
import {
|
||||
MasterPasswordAuthenticationData,
|
||||
MasterPasswordSalt,
|
||||
MasterPasswordUnlockData,
|
||||
} from "../types/master-password.types";
|
||||
|
||||
export class FakeMasterPasswordService implements InternalMasterPasswordServiceAbstraction {
|
||||
mock = mock<InternalMasterPasswordServiceAbstraction>();
|
||||
@@ -24,6 +33,10 @@ export class FakeMasterPasswordService implements InternalMasterPasswordServiceA
|
||||
this.masterKeyHashSubject.next(initialMasterKeyHash);
|
||||
}
|
||||
|
||||
saltForUser$(userId: UserId): Observable<MasterPasswordSalt> {
|
||||
return this.mock.saltForUser$(userId);
|
||||
}
|
||||
|
||||
masterKey$(userId: UserId): Observable<MasterKey> {
|
||||
return this.masterKeySubject.asObservable();
|
||||
}
|
||||
@@ -71,4 +84,28 @@ export class FakeMasterPasswordService implements InternalMasterPasswordServiceA
|
||||
): Promise<UserKey> {
|
||||
return this.mock.decryptUserKeyWithMasterKey(masterKey, userId, userKey);
|
||||
}
|
||||
|
||||
makeMasterPasswordAuthenticationData(
|
||||
password: string,
|
||||
kdf: KdfConfig,
|
||||
salt: MasterPasswordSalt,
|
||||
): Promise<MasterPasswordAuthenticationData> {
|
||||
return this.mock.makeMasterPasswordAuthenticationData(password, kdf, salt);
|
||||
}
|
||||
|
||||
makeMasterPasswordUnlockData(
|
||||
password: string,
|
||||
kdf: KdfConfig,
|
||||
salt: MasterPasswordSalt,
|
||||
userKey: UserKey,
|
||||
): Promise<MasterPasswordUnlockData> {
|
||||
return this.mock.makeMasterPasswordUnlockData(password, kdf, salt, userKey);
|
||||
}
|
||||
|
||||
unwrapUserKeyFromMasterPasswordUnlockData(
|
||||
password: string,
|
||||
masterPasswordUnlockData: MasterPasswordUnlockData,
|
||||
): Promise<UserKey> {
|
||||
return this.mock.unwrapUserKeyFromMasterPasswordUnlockData(password, masterPasswordUnlockData);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,17 @@
|
||||
import { mock, MockProxy } from "jest-mock-extended";
|
||||
import { of } from "rxjs";
|
||||
import { firstValueFrom, of } from "rxjs";
|
||||
import * as rxjs from "rxjs";
|
||||
|
||||
import { makeSymmetricCryptoKey } from "../../../../spec";
|
||||
import { SdkLoadService } from "@bitwarden/common/platform/abstractions/sdk/sdk-load.service";
|
||||
import { Utils } from "@bitwarden/common/platform/misc/utils";
|
||||
// eslint-disable-next-line no-restricted-imports
|
||||
import { KdfConfig, PBKDF2KdfConfig } from "@bitwarden/key-management";
|
||||
|
||||
import {
|
||||
FakeAccountService,
|
||||
makeSymmetricCryptoKey,
|
||||
mockAccountServiceWith,
|
||||
} from "../../../../spec";
|
||||
import { ForceSetPasswordReason } from "../../../auth/models/domain/force-set-password-reason";
|
||||
import { KeyGenerationService } from "../../../platform/abstractions/key-generation.service";
|
||||
import { LogService } from "../../../platform/abstractions/log.service";
|
||||
@@ -10,9 +19,11 @@ import { StateService } from "../../../platform/abstractions/state.service";
|
||||
import { SymmetricCryptoKey } from "../../../platform/models/domain/symmetric-crypto-key";
|
||||
import { StateProvider } from "../../../platform/state";
|
||||
import { UserId } from "../../../types/guid";
|
||||
import { MasterKey } from "../../../types/key";
|
||||
import { MasterKey, UserKey } from "../../../types/key";
|
||||
import { CryptoFunctionService } from "../../crypto/abstractions/crypto-function.service";
|
||||
import { EncryptService } from "../../crypto/abstractions/encrypt.service";
|
||||
import { EncString } from "../../crypto/models/enc-string";
|
||||
import { MasterPasswordSalt } from "../types/master-password.types";
|
||||
|
||||
import { MasterPasswordService } from "./master-password.service";
|
||||
|
||||
@@ -24,8 +35,10 @@ describe("MasterPasswordService", () => {
|
||||
let keyGenerationService: MockProxy<KeyGenerationService>;
|
||||
let encryptService: MockProxy<EncryptService>;
|
||||
let logService: MockProxy<LogService>;
|
||||
let cryptoFunctionService: MockProxy<CryptoFunctionService>;
|
||||
let accountService: FakeAccountService;
|
||||
|
||||
const userId = "user-id" as UserId;
|
||||
const userId = "00000000-0000-0000-0000-000000000000" as UserId;
|
||||
const mockUserState = {
|
||||
state$: of(null),
|
||||
update: jest.fn().mockResolvedValue(null),
|
||||
@@ -45,6 +58,8 @@ describe("MasterPasswordService", () => {
|
||||
keyGenerationService = mock<KeyGenerationService>();
|
||||
encryptService = mock<EncryptService>();
|
||||
logService = mock<LogService>();
|
||||
cryptoFunctionService = mock<CryptoFunctionService>();
|
||||
accountService = mockAccountServiceWith(userId);
|
||||
|
||||
stateProvider.getUser.mockReturnValue(mockUserState as any);
|
||||
|
||||
@@ -56,10 +71,33 @@ describe("MasterPasswordService", () => {
|
||||
keyGenerationService,
|
||||
encryptService,
|
||||
logService,
|
||||
cryptoFunctionService,
|
||||
accountService,
|
||||
);
|
||||
|
||||
encryptService.unwrapSymmetricKey.mockResolvedValue(makeSymmetricCryptoKey(64, 1));
|
||||
keyGenerationService.stretchKey.mockResolvedValue(makeSymmetricCryptoKey(64, 3));
|
||||
Object.defineProperty(SdkLoadService, "Ready", {
|
||||
value: Promise.resolve(),
|
||||
configurable: true,
|
||||
});
|
||||
});
|
||||
|
||||
describe("saltForUser$", () => {
|
||||
it("throws when userid not present", async () => {
|
||||
expect(() => {
|
||||
sut.saltForUser$(null as unknown as UserId);
|
||||
}).toThrow("userId is null or undefined.");
|
||||
});
|
||||
it("throws when userid present but not in account service", async () => {
|
||||
await expect(
|
||||
firstValueFrom(sut.saltForUser$("00000000-0000-0000-0000-000000000001" as UserId)),
|
||||
).rejects.toThrow("Cannot read properties of undefined (reading 'email')");
|
||||
});
|
||||
it("returns salt", async () => {
|
||||
const salt = await firstValueFrom(sut.saltForUser$(userId));
|
||||
expect(salt).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("setForceSetPasswordReason", () => {
|
||||
@@ -190,4 +228,97 @@ describe("MasterPasswordService", () => {
|
||||
expect(updateFn(null)).toEqual(encryptedKey.toJSON());
|
||||
});
|
||||
});
|
||||
|
||||
describe("makeMasterPasswordAuthenticationData", () => {
|
||||
const password = "test-password";
|
||||
const kdf: KdfConfig = new PBKDF2KdfConfig(600_000);
|
||||
const salt = "test@bitwarden.com" as MasterPasswordSalt;
|
||||
const masterKey = makeSymmetricCryptoKey(32, 2);
|
||||
const masterKeyHash = makeSymmetricCryptoKey(32, 3).toEncoded();
|
||||
|
||||
beforeEach(() => {
|
||||
keyGenerationService.deriveKeyFromPassword.mockResolvedValue(masterKey);
|
||||
cryptoFunctionService.pbkdf2.mockResolvedValue(masterKeyHash);
|
||||
});
|
||||
|
||||
it("derives master key and creates authentication hash", async () => {
|
||||
const result = await sut.makeMasterPasswordAuthenticationData(password, kdf, salt);
|
||||
|
||||
expect(keyGenerationService.deriveKeyFromPassword).toHaveBeenCalledWith(password, salt, kdf);
|
||||
expect(cryptoFunctionService.pbkdf2).toHaveBeenCalledWith(
|
||||
masterKey.toEncoded(),
|
||||
password,
|
||||
"sha256",
|
||||
1,
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
kdf,
|
||||
salt,
|
||||
masterPasswordAuthenticationHash: Utils.fromBufferToB64(masterKeyHash),
|
||||
});
|
||||
});
|
||||
|
||||
it("throws if password is null", async () => {
|
||||
await expect(
|
||||
sut.makeMasterPasswordAuthenticationData(null as unknown as string, kdf, salt),
|
||||
).rejects.toThrow();
|
||||
});
|
||||
it("throws if kdf is null", async () => {
|
||||
await expect(
|
||||
sut.makeMasterPasswordAuthenticationData(password, null as unknown as KdfConfig, salt),
|
||||
).rejects.toThrow();
|
||||
});
|
||||
it("throws if salt is null", async () => {
|
||||
await expect(
|
||||
sut.makeMasterPasswordAuthenticationData(
|
||||
password,
|
||||
kdf,
|
||||
null as unknown as MasterPasswordSalt,
|
||||
),
|
||||
).rejects.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe("wrapUnwrapUserKeyWithPassword", () => {
|
||||
const password = "test-password";
|
||||
const kdf: KdfConfig = new PBKDF2KdfConfig(600_000);
|
||||
const salt = "test@bitwarden.com" as MasterPasswordSalt;
|
||||
const userKey = makeSymmetricCryptoKey(64, 2) as UserKey;
|
||||
|
||||
it("wraps and unwraps user key with password", async () => {
|
||||
const unlockData = await sut.makeMasterPasswordUnlockData(password, kdf, salt, userKey);
|
||||
const unwrappedUserkey = await sut.unwrapUserKeyFromMasterPasswordUnlockData(
|
||||
password,
|
||||
unlockData,
|
||||
);
|
||||
expect(unwrappedUserkey).toEqual(userKey);
|
||||
});
|
||||
|
||||
it("throws if password is null", async () => {
|
||||
await expect(
|
||||
sut.makeMasterPasswordUnlockData(null as unknown as string, kdf, salt, userKey),
|
||||
).rejects.toThrow();
|
||||
});
|
||||
it("throws if kdf is null", async () => {
|
||||
await expect(
|
||||
sut.makeMasterPasswordUnlockData(password, null as unknown as KdfConfig, salt, userKey),
|
||||
).rejects.toThrow();
|
||||
});
|
||||
it("throws if salt is null", async () => {
|
||||
await expect(
|
||||
sut.makeMasterPasswordUnlockData(
|
||||
password,
|
||||
kdf,
|
||||
null as unknown as MasterPasswordSalt,
|
||||
userKey,
|
||||
),
|
||||
).rejects.toThrow();
|
||||
});
|
||||
it("throws if userKey is null", async () => {
|
||||
await expect(
|
||||
sut.makeMasterPasswordUnlockData(password, kdf, salt, null as unknown as UserKey),
|
||||
).rejects.toThrow();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -2,6 +2,14 @@
|
||||
// @ts-strict-ignore
|
||||
import { firstValueFrom, map, Observable } from "rxjs";
|
||||
|
||||
import { AccountService } from "@bitwarden/common/auth/abstractions/account.service";
|
||||
import { assertNonNullish } from "@bitwarden/common/auth/utils";
|
||||
import { SdkLoadService } from "@bitwarden/common/platform/abstractions/sdk/sdk-load.service";
|
||||
import { Utils } from "@bitwarden/common/platform/misc/utils";
|
||||
// eslint-disable-next-line no-restricted-imports
|
||||
import { KdfConfig } from "@bitwarden/key-management";
|
||||
import { PureCrypto } from "@bitwarden/sdk-internal";
|
||||
|
||||
import { ForceSetPasswordReason } from "../../../auth/models/domain/force-set-password-reason";
|
||||
import { KeyGenerationService } from "../../../platform/abstractions/key-generation.service";
|
||||
import { LogService } from "../../../platform/abstractions/log.service";
|
||||
@@ -16,9 +24,17 @@ import {
|
||||
} from "../../../platform/state";
|
||||
import { UserId } from "../../../types/guid";
|
||||
import { MasterKey, UserKey } from "../../../types/key";
|
||||
import { CryptoFunctionService } from "../../crypto/abstractions/crypto-function.service";
|
||||
import { EncryptService } from "../../crypto/abstractions/encrypt.service";
|
||||
import { EncryptedString, EncString } from "../../crypto/models/enc-string";
|
||||
import { InternalMasterPasswordServiceAbstraction } from "../abstractions/master-password.service.abstraction";
|
||||
import {
|
||||
MasterKeyWrappedUserKey,
|
||||
MasterPasswordAuthenticationData,
|
||||
MasterPasswordAuthenticationHash,
|
||||
MasterPasswordSalt,
|
||||
MasterPasswordUnlockData,
|
||||
} from "../types/master-password.types";
|
||||
|
||||
/** Memory since master key shouldn't be available on lock */
|
||||
const MASTER_KEY = new UserKeyDefinition<MasterKey>(MASTER_PASSWORD_MEMORY, "masterKey", {
|
||||
@@ -59,8 +75,18 @@ export class MasterPasswordService implements InternalMasterPasswordServiceAbstr
|
||||
private keyGenerationService: KeyGenerationService,
|
||||
private encryptService: EncryptService,
|
||||
private logService: LogService,
|
||||
private cryptoFunctionService: CryptoFunctionService,
|
||||
private accountService: AccountService,
|
||||
) {}
|
||||
|
||||
saltForUser$(userId: UserId): Observable<MasterPasswordSalt> {
|
||||
assertNonNullish(userId, "userId");
|
||||
return this.accountService.accounts$.pipe(
|
||||
map((accounts) => accounts[userId].email),
|
||||
map((email) => this.emailToSalt(email)),
|
||||
);
|
||||
}
|
||||
|
||||
masterKey$(userId: UserId): Observable<MasterKey> {
|
||||
if (userId == null) {
|
||||
throw new Error("User ID is required.");
|
||||
@@ -95,6 +121,10 @@ export class MasterPasswordService implements InternalMasterPasswordServiceAbstr
|
||||
return EncString.fromJSON(key);
|
||||
}
|
||||
|
||||
private emailToSalt(email: string): MasterPasswordSalt {
|
||||
return email.toLowerCase().trim() as MasterPasswordSalt;
|
||||
}
|
||||
|
||||
async setMasterKey(masterKey: MasterKey, userId: UserId): Promise<void> {
|
||||
if (masterKey == null) {
|
||||
throw new Error("Master key is required.");
|
||||
@@ -202,4 +232,89 @@ export class MasterPasswordService implements InternalMasterPasswordServiceAbstr
|
||||
|
||||
return decUserKey as UserKey;
|
||||
}
|
||||
|
||||
async makeMasterPasswordAuthenticationData(
|
||||
password: string,
|
||||
kdf: KdfConfig,
|
||||
salt: MasterPasswordSalt,
|
||||
): Promise<MasterPasswordAuthenticationData> {
|
||||
assertNonNullish(password, "password");
|
||||
assertNonNullish(kdf, "kdf");
|
||||
assertNonNullish(salt, "salt");
|
||||
|
||||
// We don't trust callers to use masterpasswordsalt correctly. They may type assert incorrectly.
|
||||
salt = salt.toLowerCase().trim() as MasterPasswordSalt;
|
||||
|
||||
const SERVER_AUTHENTICATION_HASH_ITERATIONS = 1;
|
||||
|
||||
const masterKey = (await this.keyGenerationService.deriveKeyFromPassword(
|
||||
password,
|
||||
salt,
|
||||
kdf,
|
||||
)) as MasterKey;
|
||||
|
||||
const masterPasswordAuthenticationHash = Utils.fromBufferToB64(
|
||||
await this.cryptoFunctionService.pbkdf2(
|
||||
masterKey.toEncoded(),
|
||||
password,
|
||||
"sha256",
|
||||
SERVER_AUTHENTICATION_HASH_ITERATIONS,
|
||||
),
|
||||
) as MasterPasswordAuthenticationHash;
|
||||
|
||||
return {
|
||||
salt,
|
||||
kdf,
|
||||
masterPasswordAuthenticationHash,
|
||||
} as MasterPasswordAuthenticationData;
|
||||
}
|
||||
|
||||
async makeMasterPasswordUnlockData(
|
||||
password: string,
|
||||
kdf: KdfConfig,
|
||||
salt: MasterPasswordSalt,
|
||||
userKey: UserKey,
|
||||
): Promise<MasterPasswordUnlockData> {
|
||||
assertNonNullish(password, "password");
|
||||
assertNonNullish(kdf, "kdf");
|
||||
assertNonNullish(salt, "salt");
|
||||
assertNonNullish(userKey, "userKey");
|
||||
|
||||
// We don't trust callers to use masterpasswordsalt correctly. They may type assert incorrectly.
|
||||
salt = salt.toLowerCase().trim() as MasterPasswordSalt;
|
||||
|
||||
await SdkLoadService.Ready;
|
||||
const masterKeyWrappedUserKey = new EncString(
|
||||
PureCrypto.encrypt_user_key_with_master_password(
|
||||
userKey.toEncoded(),
|
||||
password,
|
||||
salt,
|
||||
kdf.toSdkConfig(),
|
||||
),
|
||||
) as MasterKeyWrappedUserKey;
|
||||
return {
|
||||
salt,
|
||||
kdf,
|
||||
masterKeyWrappedUserKey,
|
||||
};
|
||||
}
|
||||
|
||||
async unwrapUserKeyFromMasterPasswordUnlockData(
|
||||
password: string,
|
||||
masterPasswordUnlockData: MasterPasswordUnlockData,
|
||||
): Promise<UserKey> {
|
||||
assertNonNullish(password, "password");
|
||||
assertNonNullish(masterPasswordUnlockData, "masterPasswordUnlockData");
|
||||
|
||||
await SdkLoadService.Ready;
|
||||
const userKey = new SymmetricCryptoKey(
|
||||
PureCrypto.decrypt_user_key_with_master_password(
|
||||
masterPasswordUnlockData.masterKeyWrappedUserKey.encryptedString,
|
||||
password,
|
||||
masterPasswordUnlockData.salt,
|
||||
masterPasswordUnlockData.kdf.toSdkConfig(),
|
||||
),
|
||||
);
|
||||
return userKey as UserKey;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
import { Opaque } from "type-fest";
|
||||
|
||||
// eslint-disable-next-line no-restricted-imports
|
||||
import { KdfConfig } from "@bitwarden/key-management";
|
||||
|
||||
import { EncString } from "../../crypto/models/enc-string";
|
||||
|
||||
/**
|
||||
* The Base64-encoded master password authentication hash, that is sent to the server for authentication.
|
||||
*/
|
||||
export type MasterPasswordAuthenticationHash = Opaque<string, "MasterPasswordAuthenticationHash">;
|
||||
/**
|
||||
* You MUST obtain this through the emailToSalt function in MasterPasswordService
|
||||
*/
|
||||
export type MasterPasswordSalt = Opaque<string, "MasterPasswordSalt">;
|
||||
export type MasterKeyWrappedUserKey = Opaque<EncString, "MasterPasswordSalt">;
|
||||
|
||||
/**
|
||||
* The data required to unlock with the master password.
|
||||
*/
|
||||
export type MasterPasswordUnlockData = {
|
||||
salt: MasterPasswordSalt;
|
||||
kdf: KdfConfig;
|
||||
masterKeyWrappedUserKey: MasterKeyWrappedUserKey;
|
||||
};
|
||||
|
||||
/**
|
||||
* The data required to authenticate with the master password.
|
||||
*/
|
||||
export type MasterPasswordAuthenticationData = {
|
||||
salt: MasterPasswordSalt;
|
||||
kdf: KdfConfig;
|
||||
masterPasswordAuthenticationHash: MasterPasswordAuthenticationHash;
|
||||
};
|
||||
128
libs/common/src/key-management/pin/pin.service.abstraction.ts
Normal file
128
libs/common/src/key-management/pin/pin.service.abstraction.ts
Normal file
@@ -0,0 +1,128 @@
|
||||
// eslint-disable-next-line no-restricted-imports
|
||||
import { KdfConfig } from "@bitwarden/key-management";
|
||||
|
||||
import { EncString } from "../../key-management/crypto/models/enc-string";
|
||||
import { UserId } from "../../types/guid";
|
||||
import { PinKey, UserKey } from "../../types/key";
|
||||
|
||||
import { PinLockType } from "./pin.service.implementation";
|
||||
|
||||
/**
|
||||
* The PinService is used for PIN-based unlocks. Below is a very basic overview of the PIN flow:
|
||||
*
|
||||
* -- Setting the PIN via {@link SetPinComponent} --
|
||||
*
|
||||
* When the user submits the setPinForm:
|
||||
|
||||
* 1. We encrypt the PIN with the UserKey and store it on disk as `userKeyEncryptedPin`.
|
||||
*
|
||||
* 2. We create a PinKey from the PIN, and then use that PinKey to encrypt the UserKey, resulting in
|
||||
* a `pinKeyEncryptedUserKey`, which can be stored in one of two ways depending on what the user selects
|
||||
* for the `requireMasterPasswordOnClientReset` checkbox.
|
||||
*
|
||||
* If `requireMasterPasswordOnClientReset` is:
|
||||
* - TRUE, store in memory as `pinKeyEncryptedUserKeyEphemeral` (does NOT persist through a client reset)
|
||||
* - FALSE, store on disk as `pinKeyEncryptedUserKeyPersistent` (persists through a client reset)
|
||||
*
|
||||
* -- Unlocking with the PIN via {@link LockComponent} --
|
||||
*
|
||||
* When the user enters their PIN, we decrypt their UserKey with the PIN and set that UserKey to state.
|
||||
*/
|
||||
export abstract class PinServiceAbstraction {
|
||||
/**
|
||||
* Gets the persistent (stored on disk) version of the UserKey, encrypted by the PinKey.
|
||||
*/
|
||||
abstract getPinKeyEncryptedUserKeyPersistent: (userId: UserId) => Promise<EncString | null>;
|
||||
|
||||
/**
|
||||
* Clears the persistent (stored on disk) version of the UserKey, encrypted by the PinKey.
|
||||
*/
|
||||
abstract clearPinKeyEncryptedUserKeyPersistent(userId: UserId): Promise<void>;
|
||||
|
||||
/**
|
||||
* Gets the ephemeral (stored in memory) version of the UserKey, encrypted by the PinKey.
|
||||
*/
|
||||
abstract getPinKeyEncryptedUserKeyEphemeral: (userId: UserId) => Promise<EncString | null>;
|
||||
|
||||
/**
|
||||
* Clears the ephemeral (stored in memory) version of the UserKey, encrypted by the PinKey.
|
||||
*/
|
||||
abstract clearPinKeyEncryptedUserKeyEphemeral(userId: UserId): Promise<void>;
|
||||
|
||||
/**
|
||||
* Creates a pinKeyEncryptedUserKey from the provided PIN and UserKey.
|
||||
*/
|
||||
abstract createPinKeyEncryptedUserKey: (
|
||||
pin: string,
|
||||
userKey: UserKey,
|
||||
userId: UserId,
|
||||
) => Promise<EncString>;
|
||||
|
||||
/**
|
||||
* Stores the UserKey, encrypted by the PinKey.
|
||||
* @param storeEphemeralVersion If true, stores an ephemeral version via the private {@link setPinKeyEncryptedUserKeyEphemeral} method.
|
||||
* If false, stores a persistent version via the private {@link setPinKeyEncryptedUserKeyPersistent} method.
|
||||
*/
|
||||
abstract storePinKeyEncryptedUserKey: (
|
||||
pinKeyEncryptedUserKey: EncString,
|
||||
storeEphemeralVersion: boolean,
|
||||
userId: UserId,
|
||||
) => Promise<void>;
|
||||
|
||||
/**
|
||||
* Gets the user's PIN, encrypted by the UserKey.
|
||||
*/
|
||||
abstract getUserKeyEncryptedPin: (userId: UserId) => Promise<EncString | null>;
|
||||
|
||||
/**
|
||||
* Sets the user's PIN, encrypted by the UserKey.
|
||||
*/
|
||||
abstract setUserKeyEncryptedPin: (
|
||||
userKeyEncryptedPin: EncString,
|
||||
userId: UserId,
|
||||
) => Promise<void>;
|
||||
|
||||
/**
|
||||
* Creates a PIN, encrypted by the UserKey.
|
||||
*/
|
||||
abstract createUserKeyEncryptedPin: (pin: string, userKey: UserKey) => Promise<EncString>;
|
||||
|
||||
/**
|
||||
* Clears the user's PIN, encrypted by the UserKey.
|
||||
*/
|
||||
abstract clearUserKeyEncryptedPin(userId: UserId): Promise<void>;
|
||||
|
||||
/**
|
||||
* Makes a PinKey from the provided PIN.
|
||||
*/
|
||||
abstract makePinKey: (pin: string, salt: string, kdfConfig: KdfConfig) => Promise<PinKey>;
|
||||
|
||||
/**
|
||||
* Gets the user's PinLockType {@link PinLockType}.
|
||||
*/
|
||||
abstract getPinLockType: (userId: UserId) => Promise<PinLockType>;
|
||||
|
||||
/**
|
||||
* Declares whether or not the user has a PIN set (either persistent or ephemeral).
|
||||
* Note: for ephemeral, this does not check if we actual have an ephemeral PIN-encrypted UserKey stored in memory.
|
||||
* Decryption might not be possible even if this returns true. Use {@link isPinDecryptionAvailable} if decryption is required.
|
||||
*/
|
||||
abstract isPinSet: (userId: UserId) => Promise<boolean>;
|
||||
|
||||
/**
|
||||
* Checks if PIN-encrypted keys are stored for the user.
|
||||
* Used for unlock / user verification scenarios where we will need to decrypt the UserKey with the PIN.
|
||||
*/
|
||||
abstract isPinDecryptionAvailable: (userId: UserId) => Promise<boolean>;
|
||||
|
||||
/**
|
||||
* Decrypts the UserKey with the provided PIN.
|
||||
*
|
||||
* @remarks - If the user has an old pinKeyEncryptedMasterKey (formerly called `pinProtected`), the UserKey
|
||||
* will be obtained via the private {@link decryptAndMigrateOldPinKeyEncryptedMasterKey} method.
|
||||
* - If the user does not have an old pinKeyEncryptedMasterKey, the UserKey will be obtained via the
|
||||
* private {@link decryptUserKey} method.
|
||||
* @returns UserKey
|
||||
*/
|
||||
abstract decryptUserKeyWithPin: (pin: string, userId: UserId) => Promise<UserKey | null>;
|
||||
}
|
||||
390
libs/common/src/key-management/pin/pin.service.implementation.ts
Normal file
390
libs/common/src/key-management/pin/pin.service.implementation.ts
Normal file
@@ -0,0 +1,390 @@
|
||||
// FIXME: Update this file to be type safe and remove this and next line
|
||||
// @ts-strict-ignore
|
||||
import { firstValueFrom, map } from "rxjs";
|
||||
|
||||
// eslint-disable-next-line no-restricted-imports
|
||||
import { KdfConfig, KdfConfigService } from "@bitwarden/key-management";
|
||||
|
||||
import { AccountService } from "../../auth/abstractions/account.service";
|
||||
import { CryptoFunctionService } from "../../key-management/crypto/abstractions/crypto-function.service";
|
||||
import { EncryptService } from "../../key-management/crypto/abstractions/encrypt.service";
|
||||
import { EncString, EncryptedString } from "../../key-management/crypto/models/enc-string";
|
||||
import { KeyGenerationService } from "../../platform/abstractions/key-generation.service";
|
||||
import { LogService } from "../../platform/abstractions/log.service";
|
||||
import { PIN_DISK, PIN_MEMORY, StateProvider, UserKeyDefinition } from "../../platform/state";
|
||||
import { UserId } from "../../types/guid";
|
||||
import { PinKey, UserKey } from "../../types/key";
|
||||
|
||||
import { PinServiceAbstraction } from "./pin.service.abstraction";
|
||||
|
||||
/**
|
||||
* - DISABLED : No PIN set.
|
||||
* - PERSISTENT : PIN is set and persists through client reset.
|
||||
* - EPHEMERAL : PIN is set, but does NOT persist through client reset. This means that
|
||||
* after client reset the master password is required to unlock.
|
||||
*/
|
||||
export type PinLockType = "DISABLED" | "PERSISTENT" | "EPHEMERAL";
|
||||
|
||||
/**
|
||||
* The persistent (stored on disk) version of the UserKey, encrypted by the PinKey.
|
||||
*
|
||||
* @remarks Persists through a client reset. Used when `requireMasterPasswordOnClientRestart` is disabled.
|
||||
* @see SetPinComponent.setPinForm.requireMasterPasswordOnClientRestart
|
||||
*/
|
||||
export const PIN_KEY_ENCRYPTED_USER_KEY_PERSISTENT = new UserKeyDefinition<EncryptedString>(
|
||||
PIN_DISK,
|
||||
"pinKeyEncryptedUserKeyPersistent",
|
||||
{
|
||||
deserializer: (jsonValue) => jsonValue,
|
||||
clearOn: ["logout"],
|
||||
},
|
||||
);
|
||||
|
||||
/**
|
||||
* The ephemeral (stored in memory) version of the UserKey, encrypted by the PinKey.
|
||||
*
|
||||
* @remarks Does NOT persist through a client reset. Used when `requireMasterPasswordOnClientRestart` is enabled.
|
||||
* @see SetPinComponent.setPinForm.requireMasterPasswordOnClientRestart
|
||||
*/
|
||||
export const PIN_KEY_ENCRYPTED_USER_KEY_EPHEMERAL = new UserKeyDefinition<EncryptedString>(
|
||||
PIN_MEMORY,
|
||||
"pinKeyEncryptedUserKeyEphemeral",
|
||||
{
|
||||
deserializer: (jsonValue) => jsonValue,
|
||||
clearOn: ["logout"],
|
||||
},
|
||||
);
|
||||
|
||||
/**
|
||||
* The PIN, encrypted by the UserKey.
|
||||
*/
|
||||
export const USER_KEY_ENCRYPTED_PIN = new UserKeyDefinition<EncryptedString>(
|
||||
PIN_DISK,
|
||||
"userKeyEncryptedPin",
|
||||
{
|
||||
deserializer: (jsonValue) => jsonValue,
|
||||
clearOn: ["logout"],
|
||||
},
|
||||
);
|
||||
|
||||
export class PinService implements PinServiceAbstraction {
|
||||
constructor(
|
||||
private accountService: AccountService,
|
||||
private cryptoFunctionService: CryptoFunctionService,
|
||||
private encryptService: EncryptService,
|
||||
private kdfConfigService: KdfConfigService,
|
||||
private keyGenerationService: KeyGenerationService,
|
||||
private logService: LogService,
|
||||
private stateProvider: StateProvider,
|
||||
) {}
|
||||
|
||||
async getPinKeyEncryptedUserKeyPersistent(userId: UserId): Promise<EncString | null> {
|
||||
this.validateUserId(userId, "Cannot get pinKeyEncryptedUserKeyPersistent.");
|
||||
|
||||
return EncString.fromJSON(
|
||||
await firstValueFrom(
|
||||
this.stateProvider.getUserState$(PIN_KEY_ENCRYPTED_USER_KEY_PERSISTENT, userId),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the persistent (stored on disk) version of the UserKey, encrypted by the PinKey.
|
||||
*/
|
||||
private async setPinKeyEncryptedUserKeyPersistent(
|
||||
pinKeyEncryptedUserKey: EncString,
|
||||
userId: UserId,
|
||||
): Promise<void> {
|
||||
this.validateUserId(userId, "Cannot set pinKeyEncryptedUserKeyPersistent.");
|
||||
|
||||
if (pinKeyEncryptedUserKey == null) {
|
||||
throw new Error(
|
||||
"No pinKeyEncryptedUserKey provided. Cannot set pinKeyEncryptedUserKeyPersistent.",
|
||||
);
|
||||
}
|
||||
|
||||
await this.stateProvider.setUserState(
|
||||
PIN_KEY_ENCRYPTED_USER_KEY_PERSISTENT,
|
||||
pinKeyEncryptedUserKey?.encryptedString,
|
||||
userId,
|
||||
);
|
||||
}
|
||||
|
||||
async clearPinKeyEncryptedUserKeyPersistent(userId: UserId): Promise<void> {
|
||||
this.validateUserId(userId, "Cannot clear pinKeyEncryptedUserKeyPersistent.");
|
||||
|
||||
await this.stateProvider.setUserState(PIN_KEY_ENCRYPTED_USER_KEY_PERSISTENT, null, userId);
|
||||
}
|
||||
|
||||
async getPinKeyEncryptedUserKeyEphemeral(userId: UserId): Promise<EncString | null> {
|
||||
this.validateUserId(userId, "Cannot get pinKeyEncryptedUserKeyEphemeral.");
|
||||
|
||||
return EncString.fromJSON(
|
||||
await firstValueFrom(
|
||||
this.stateProvider.getUserState$(PIN_KEY_ENCRYPTED_USER_KEY_EPHEMERAL, userId),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the ephemeral (stored in memory) version of the UserKey, encrypted by the PinKey.
|
||||
*/
|
||||
private async setPinKeyEncryptedUserKeyEphemeral(
|
||||
pinKeyEncryptedUserKey: EncString,
|
||||
userId: UserId,
|
||||
): Promise<void> {
|
||||
this.validateUserId(userId, "Cannot set pinKeyEncryptedUserKeyEphemeral.");
|
||||
|
||||
if (pinKeyEncryptedUserKey == null) {
|
||||
throw new Error(
|
||||
"No pinKeyEncryptedUserKey provided. Cannot set pinKeyEncryptedUserKeyEphemeral.",
|
||||
);
|
||||
}
|
||||
|
||||
await this.stateProvider.setUserState(
|
||||
PIN_KEY_ENCRYPTED_USER_KEY_EPHEMERAL,
|
||||
pinKeyEncryptedUserKey?.encryptedString,
|
||||
userId,
|
||||
);
|
||||
}
|
||||
|
||||
async clearPinKeyEncryptedUserKeyEphemeral(userId: UserId): Promise<void> {
|
||||
this.validateUserId(userId, "Cannot clear pinKeyEncryptedUserKeyEphemeral.");
|
||||
|
||||
await this.stateProvider.setUserState(PIN_KEY_ENCRYPTED_USER_KEY_EPHEMERAL, null, userId);
|
||||
}
|
||||
|
||||
async createPinKeyEncryptedUserKey(
|
||||
pin: string,
|
||||
userKey: UserKey,
|
||||
userId: UserId,
|
||||
): Promise<EncString> {
|
||||
this.validateUserId(userId, "Cannot create pinKeyEncryptedUserKey.");
|
||||
|
||||
if (!userKey) {
|
||||
throw new Error("No UserKey provided. Cannot create pinKeyEncryptedUserKey.");
|
||||
}
|
||||
|
||||
const email = await firstValueFrom(
|
||||
this.accountService.accounts$.pipe(map((accounts) => accounts[userId].email)),
|
||||
);
|
||||
const kdfConfig = await this.kdfConfigService.getKdfConfig(userId);
|
||||
const pinKey = await this.makePinKey(pin, email, kdfConfig);
|
||||
|
||||
return await this.encryptService.wrapSymmetricKey(userKey, pinKey);
|
||||
}
|
||||
|
||||
async storePinKeyEncryptedUserKey(
|
||||
pinKeyEncryptedUserKey: EncString,
|
||||
storeAsEphemeral: boolean,
|
||||
userId: UserId,
|
||||
): Promise<void> {
|
||||
this.validateUserId(userId, "Cannot store pinKeyEncryptedUserKey.");
|
||||
|
||||
if (storeAsEphemeral) {
|
||||
await this.setPinKeyEncryptedUserKeyEphemeral(pinKeyEncryptedUserKey, userId);
|
||||
} else {
|
||||
await this.setPinKeyEncryptedUserKeyPersistent(pinKeyEncryptedUserKey, userId);
|
||||
}
|
||||
}
|
||||
|
||||
async getUserKeyEncryptedPin(userId: UserId): Promise<EncString | null> {
|
||||
this.validateUserId(userId, "Cannot get userKeyEncryptedPin.");
|
||||
|
||||
return EncString.fromJSON(
|
||||
await firstValueFrom(this.stateProvider.getUserState$(USER_KEY_ENCRYPTED_PIN, userId)),
|
||||
);
|
||||
}
|
||||
|
||||
async setUserKeyEncryptedPin(userKeyEncryptedPin: EncString, userId: UserId): Promise<void> {
|
||||
this.validateUserId(userId, "Cannot set userKeyEncryptedPin.");
|
||||
|
||||
await this.stateProvider.setUserState(
|
||||
USER_KEY_ENCRYPTED_PIN,
|
||||
userKeyEncryptedPin?.encryptedString,
|
||||
userId,
|
||||
);
|
||||
}
|
||||
|
||||
async clearUserKeyEncryptedPin(userId: UserId): Promise<void> {
|
||||
this.validateUserId(userId, "Cannot clear userKeyEncryptedPin.");
|
||||
|
||||
await this.stateProvider.setUserState(USER_KEY_ENCRYPTED_PIN, null, userId);
|
||||
}
|
||||
|
||||
async createUserKeyEncryptedPin(pin: string, userKey: UserKey): Promise<EncString> {
|
||||
if (!userKey) {
|
||||
throw new Error("No UserKey provided. Cannot create userKeyEncryptedPin.");
|
||||
}
|
||||
|
||||
return await this.encryptService.encryptString(pin, userKey);
|
||||
}
|
||||
|
||||
async makePinKey(pin: string, salt: string, kdfConfig: KdfConfig): Promise<PinKey> {
|
||||
const start = Date.now();
|
||||
const pinKey = await this.keyGenerationService.deriveKeyFromPassword(pin, salt, kdfConfig);
|
||||
this.logService.info(`[Pin Service] deriving pin key took ${Date.now() - start}ms`);
|
||||
|
||||
return (await this.keyGenerationService.stretchKey(pinKey)) as PinKey;
|
||||
}
|
||||
|
||||
async getPinLockType(userId: UserId): Promise<PinLockType> {
|
||||
this.validateUserId(userId, "Cannot get PinLockType.");
|
||||
|
||||
const aUserKeyEncryptedPinIsSet = !!(await this.getUserKeyEncryptedPin(userId));
|
||||
const aPinKeyEncryptedUserKeyPersistentIsSet =
|
||||
!!(await this.getPinKeyEncryptedUserKeyPersistent(userId));
|
||||
|
||||
if (aPinKeyEncryptedUserKeyPersistentIsSet) {
|
||||
return "PERSISTENT";
|
||||
} else if (aUserKeyEncryptedPinIsSet && !aPinKeyEncryptedUserKeyPersistentIsSet) {
|
||||
return "EPHEMERAL";
|
||||
} else {
|
||||
return "DISABLED";
|
||||
}
|
||||
}
|
||||
|
||||
async isPinSet(userId: UserId): Promise<boolean> {
|
||||
this.validateUserId(userId, "Cannot determine if PIN is set.");
|
||||
|
||||
return (await this.getPinLockType(userId)) !== "DISABLED";
|
||||
}
|
||||
|
||||
async isPinDecryptionAvailable(userId: UserId): Promise<boolean> {
|
||||
this.validateUserId(userId, "Cannot determine if decryption of user key via PIN is available.");
|
||||
|
||||
const pinLockType = await this.getPinLockType(userId);
|
||||
|
||||
switch (pinLockType) {
|
||||
case "DISABLED":
|
||||
return false;
|
||||
case "PERSISTENT":
|
||||
// The above getPinLockType call ensures that we have either a PinKeyEncryptedUserKey set.
|
||||
return true;
|
||||
case "EPHEMERAL": {
|
||||
// The above getPinLockType call ensures that we have a UserKeyEncryptedPin set.
|
||||
// However, we must additively check to ensure that we have a set PinKeyEncryptedUserKeyEphemeral b/c otherwise
|
||||
// we cannot take a PIN, derive a PIN key, and decrypt the ephemeral UserKey.
|
||||
const pinKeyEncryptedUserKeyEphemeral =
|
||||
await this.getPinKeyEncryptedUserKeyEphemeral(userId);
|
||||
return Boolean(pinKeyEncryptedUserKeyEphemeral);
|
||||
}
|
||||
|
||||
default: {
|
||||
// Compile-time check for exhaustive switch
|
||||
const _exhaustiveCheck: never = pinLockType;
|
||||
throw new Error(`Unexpected pinLockType: ${_exhaustiveCheck}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async decryptUserKeyWithPin(pin: string, userId: UserId): Promise<UserKey | null> {
|
||||
this.validateUserId(userId, "Cannot decrypt user key with PIN.");
|
||||
|
||||
try {
|
||||
const pinLockType = await this.getPinLockType(userId);
|
||||
|
||||
const pinKeyEncryptedUserKey = await this.getPinKeyEncryptedKeys(pinLockType, userId);
|
||||
|
||||
const email = await firstValueFrom(
|
||||
this.accountService.accounts$.pipe(map((accounts) => accounts[userId].email)),
|
||||
);
|
||||
const kdfConfig = await this.kdfConfigService.getKdfConfig(userId);
|
||||
|
||||
const userKey: UserKey = await this.decryptUserKey(
|
||||
userId,
|
||||
pin,
|
||||
email,
|
||||
kdfConfig,
|
||||
pinKeyEncryptedUserKey,
|
||||
);
|
||||
if (!userKey) {
|
||||
this.logService.warning(`User key null after pin key decryption.`);
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!(await this.validatePin(userKey, pin, userId))) {
|
||||
this.logService.warning(`Pin key decryption successful but pin validation failed.`);
|
||||
return null;
|
||||
}
|
||||
|
||||
return userKey;
|
||||
} catch (error) {
|
||||
this.logService.error(`Error decrypting user key with pin: ${error}`);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Decrypts the UserKey with the provided PIN.
|
||||
*/
|
||||
private async decryptUserKey(
|
||||
userId: UserId,
|
||||
pin: string,
|
||||
salt: string,
|
||||
kdfConfig: KdfConfig,
|
||||
pinKeyEncryptedUserKey?: EncString,
|
||||
): Promise<UserKey> {
|
||||
this.validateUserId(userId, "Cannot decrypt user key.");
|
||||
|
||||
pinKeyEncryptedUserKey ||= await this.getPinKeyEncryptedUserKeyPersistent(userId);
|
||||
pinKeyEncryptedUserKey ||= await this.getPinKeyEncryptedUserKeyEphemeral(userId);
|
||||
|
||||
if (!pinKeyEncryptedUserKey) {
|
||||
throw new Error("No pinKeyEncryptedUserKey found.");
|
||||
}
|
||||
|
||||
const pinKey = await this.makePinKey(pin, salt, kdfConfig);
|
||||
const userKey = await this.encryptService.unwrapSymmetricKey(pinKeyEncryptedUserKey, pinKey);
|
||||
|
||||
return userKey as UserKey;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the user's `pinKeyEncryptedUserKey` (persistent or ephemeral)
|
||||
* (if one exists) based on the user's PinLockType.
|
||||
*
|
||||
* @throws If PinLockType is 'DISABLED' or if userId is not provided
|
||||
*/
|
||||
private async getPinKeyEncryptedKeys(
|
||||
pinLockType: PinLockType,
|
||||
userId: UserId,
|
||||
): Promise<EncString> {
|
||||
this.validateUserId(userId, "Cannot get PinKey encrypted keys.");
|
||||
|
||||
switch (pinLockType) {
|
||||
case "PERSISTENT": {
|
||||
return await this.getPinKeyEncryptedUserKeyPersistent(userId);
|
||||
}
|
||||
case "EPHEMERAL": {
|
||||
return await this.getPinKeyEncryptedUserKeyEphemeral(userId);
|
||||
}
|
||||
case "DISABLED":
|
||||
throw new Error("Pin is disabled");
|
||||
default: {
|
||||
// Compile-time check for exhaustive switch
|
||||
const _exhaustiveCheck: never = pinLockType;
|
||||
return _exhaustiveCheck;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async validatePin(userKey: UserKey, pin: string, userId: UserId): Promise<boolean> {
|
||||
this.validateUserId(userId, "Cannot validate PIN.");
|
||||
|
||||
const userKeyEncryptedPin = await this.getUserKeyEncryptedPin(userId);
|
||||
const decryptedPin = await this.encryptService.decryptString(userKeyEncryptedPin, userKey);
|
||||
|
||||
const isPinValid = this.cryptoFunctionService.compareFast(decryptedPin, pin);
|
||||
return isPinValid;
|
||||
}
|
||||
|
||||
/**
|
||||
* Throws a custom error message if user ID is not provided.
|
||||
*/
|
||||
private validateUserId(userId: UserId, errorMessage: string = "") {
|
||||
if (!userId) {
|
||||
throw new Error(`User ID is required. ${errorMessage}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
518
libs/common/src/key-management/pin/pin.service.spec.ts
Normal file
518
libs/common/src/key-management/pin/pin.service.spec.ts
Normal file
@@ -0,0 +1,518 @@
|
||||
import { mock } from "jest-mock-extended";
|
||||
|
||||
// eslint-disable-next-line no-restricted-imports
|
||||
import { DEFAULT_KDF_CONFIG, KdfConfigService } from "@bitwarden/key-management";
|
||||
|
||||
import { FakeAccountService, FakeStateProvider, mockAccountServiceWith } from "../../../spec";
|
||||
import { KeyGenerationService } from "../../platform/abstractions/key-generation.service";
|
||||
import { LogService } from "../../platform/abstractions/log.service";
|
||||
import { Utils } from "../../platform/misc/utils";
|
||||
import { SymmetricCryptoKey } from "../../platform/models/domain/symmetric-crypto-key";
|
||||
import { UserId } from "../../types/guid";
|
||||
import { PinKey, UserKey } from "../../types/key";
|
||||
import { CryptoFunctionService } from "../crypto/abstractions/crypto-function.service";
|
||||
import { EncryptService } from "../crypto/abstractions/encrypt.service";
|
||||
import { EncString } from "../crypto/models/enc-string";
|
||||
|
||||
import {
|
||||
PinService,
|
||||
PIN_KEY_ENCRYPTED_USER_KEY_PERSISTENT,
|
||||
PIN_KEY_ENCRYPTED_USER_KEY_EPHEMERAL,
|
||||
USER_KEY_ENCRYPTED_PIN,
|
||||
PinLockType,
|
||||
} from "./pin.service.implementation";
|
||||
|
||||
describe("PinService", () => {
|
||||
let sut: PinService;
|
||||
|
||||
let accountService: FakeAccountService;
|
||||
let stateProvider: FakeStateProvider;
|
||||
|
||||
const cryptoFunctionService = mock<CryptoFunctionService>();
|
||||
const encryptService = mock<EncryptService>();
|
||||
const kdfConfigService = mock<KdfConfigService>();
|
||||
const keyGenerationService = mock<KeyGenerationService>();
|
||||
const logService = mock<LogService>();
|
||||
|
||||
const mockUserId = Utils.newGuid() as UserId;
|
||||
const mockUserKey = new SymmetricCryptoKey(randomBytes(64)) as UserKey;
|
||||
const mockPinKey = new SymmetricCryptoKey(randomBytes(32)) as PinKey;
|
||||
const mockUserEmail = "user@example.com";
|
||||
const mockPin = "1234";
|
||||
const mockUserKeyEncryptedPin = new EncString("userKeyEncryptedPin");
|
||||
|
||||
// Note: both pinKeyEncryptedUserKeys use encryptionType: 2 (AesCbc256_HmacSha256_B64)
|
||||
const pinKeyEncryptedUserKeyEphemeral = new EncString(
|
||||
"2.gbauOANURUHqvhLTDnva1A==|nSW+fPumiuTaDB/s12+JO88uemV6rhwRSR+YR1ZzGr5j6Ei3/h+XEli2Unpz652NlZ9NTuRpHxeOqkYYJtp7J+lPMoclgteXuAzUu9kqlRc=|DeUFkhIwgkGdZA08bDnDqMMNmZk21D+H5g8IostPKAY=",
|
||||
);
|
||||
const pinKeyEncryptedUserKeyPersistant = new EncString(
|
||||
"2.fb5kOEZvh9zPABbP8WRmSQ==|Yi6ZAJY+UtqCKMUSqp1ahY9Kf8QuneKXs6BMkpNsakLVOzTYkHHlilyGABMF7GzUO8QHyZi7V/Ovjjg+Naf3Sm8qNhxtDhibITv4k8rDnM0=|TFkq3h2VNTT1z5BFbebm37WYuxyEHXuRo0DZJI7TQnw=",
|
||||
);
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
accountService = mockAccountServiceWith(mockUserId, { email: mockUserEmail });
|
||||
stateProvider = new FakeStateProvider(accountService);
|
||||
|
||||
sut = new PinService(
|
||||
accountService,
|
||||
cryptoFunctionService,
|
||||
encryptService,
|
||||
kdfConfigService,
|
||||
keyGenerationService,
|
||||
logService,
|
||||
stateProvider,
|
||||
);
|
||||
});
|
||||
|
||||
it("should instantiate the PinService", () => {
|
||||
expect(sut).not.toBeFalsy();
|
||||
});
|
||||
|
||||
describe("userId validation", () => {
|
||||
it("should throw an error if a userId is not provided", async () => {
|
||||
await expect(sut.getPinKeyEncryptedUserKeyPersistent(undefined)).rejects.toThrow(
|
||||
"User ID is required. Cannot get pinKeyEncryptedUserKeyPersistent.",
|
||||
);
|
||||
await expect(sut.getPinKeyEncryptedUserKeyEphemeral(undefined)).rejects.toThrow(
|
||||
"User ID is required. Cannot get pinKeyEncryptedUserKeyEphemeral.",
|
||||
);
|
||||
await expect(sut.clearPinKeyEncryptedUserKeyPersistent(undefined)).rejects.toThrow(
|
||||
"User ID is required. Cannot clear pinKeyEncryptedUserKeyPersistent.",
|
||||
);
|
||||
await expect(sut.clearPinKeyEncryptedUserKeyEphemeral(undefined)).rejects.toThrow(
|
||||
"User ID is required. Cannot clear pinKeyEncryptedUserKeyEphemeral.",
|
||||
);
|
||||
await expect(
|
||||
sut.createPinKeyEncryptedUserKey(mockPin, mockUserKey, undefined),
|
||||
).rejects.toThrow("User ID is required. Cannot create pinKeyEncryptedUserKey.");
|
||||
await expect(sut.getUserKeyEncryptedPin(undefined)).rejects.toThrow(
|
||||
"User ID is required. Cannot get userKeyEncryptedPin.",
|
||||
);
|
||||
await expect(sut.setUserKeyEncryptedPin(mockUserKeyEncryptedPin, undefined)).rejects.toThrow(
|
||||
"User ID is required. Cannot set userKeyEncryptedPin.",
|
||||
);
|
||||
await expect(sut.clearUserKeyEncryptedPin(undefined)).rejects.toThrow(
|
||||
"User ID is required. Cannot clear userKeyEncryptedPin.",
|
||||
);
|
||||
await expect(
|
||||
sut.createPinKeyEncryptedUserKey(mockPin, mockUserKey, undefined),
|
||||
).rejects.toThrow("User ID is required. Cannot create pinKeyEncryptedUserKey.");
|
||||
await expect(sut.getPinLockType(undefined)).rejects.toThrow("Cannot get PinLockType.");
|
||||
await expect(sut.isPinSet(undefined)).rejects.toThrow(
|
||||
"User ID is required. Cannot determine if PIN is set.",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("get/clear/create/store pinKeyEncryptedUserKey methods", () => {
|
||||
describe("getPinKeyEncryptedUserKeyPersistent()", () => {
|
||||
it("should get the pinKeyEncryptedUserKey of the specified userId", async () => {
|
||||
await sut.getPinKeyEncryptedUserKeyPersistent(mockUserId);
|
||||
|
||||
expect(stateProvider.mock.getUserState$).toHaveBeenCalledWith(
|
||||
PIN_KEY_ENCRYPTED_USER_KEY_PERSISTENT,
|
||||
mockUserId,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("clearPinKeyEncryptedUserKeyPersistent()", () => {
|
||||
it("should clear the pinKeyEncryptedUserKey of the specified userId", async () => {
|
||||
await sut.clearPinKeyEncryptedUserKeyPersistent(mockUserId);
|
||||
|
||||
expect(stateProvider.mock.setUserState).toHaveBeenCalledWith(
|
||||
PIN_KEY_ENCRYPTED_USER_KEY_PERSISTENT,
|
||||
null,
|
||||
mockUserId,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("getPinKeyEncryptedUserKeyEphemeral()", () => {
|
||||
it("should get the pinKeyEncrypterUserKeyEphemeral of the specified userId", async () => {
|
||||
await sut.getPinKeyEncryptedUserKeyEphemeral(mockUserId);
|
||||
|
||||
expect(stateProvider.mock.getUserState$).toHaveBeenCalledWith(
|
||||
PIN_KEY_ENCRYPTED_USER_KEY_EPHEMERAL,
|
||||
mockUserId,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("clearPinKeyEncryptedUserKeyEphemeral()", () => {
|
||||
it("should clear the pinKeyEncryptedUserKey of the specified userId", async () => {
|
||||
await sut.clearPinKeyEncryptedUserKeyEphemeral(mockUserId);
|
||||
|
||||
expect(stateProvider.mock.setUserState).toHaveBeenCalledWith(
|
||||
PIN_KEY_ENCRYPTED_USER_KEY_EPHEMERAL,
|
||||
null,
|
||||
mockUserId,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("createPinKeyEncryptedUserKey()", () => {
|
||||
it("should throw an error if a userKey is not provided", async () => {
|
||||
await expect(
|
||||
sut.createPinKeyEncryptedUserKey(mockPin, undefined, mockUserId),
|
||||
).rejects.toThrow("No UserKey provided. Cannot create pinKeyEncryptedUserKey.");
|
||||
});
|
||||
|
||||
it("should create a pinKeyEncryptedUserKey", async () => {
|
||||
// Arrange
|
||||
sut.makePinKey = jest.fn().mockResolvedValue(mockPinKey);
|
||||
|
||||
// Act
|
||||
await sut.createPinKeyEncryptedUserKey(mockPin, mockUserKey, mockUserId);
|
||||
|
||||
// Assert
|
||||
expect(encryptService.wrapSymmetricKey).toHaveBeenCalledWith(mockUserKey, mockPinKey);
|
||||
});
|
||||
});
|
||||
|
||||
describe("storePinKeyEncryptedUserKey", () => {
|
||||
it("should store a pinKeyEncryptedUserKey (persistent version) when 'storeAsEphemeral' is false", async () => {
|
||||
// Arrange
|
||||
const storeAsEphemeral = false;
|
||||
|
||||
// Act
|
||||
await sut.storePinKeyEncryptedUserKey(
|
||||
pinKeyEncryptedUserKeyPersistant,
|
||||
storeAsEphemeral,
|
||||
mockUserId,
|
||||
);
|
||||
|
||||
// Assert
|
||||
expect(stateProvider.mock.setUserState).toHaveBeenCalledWith(
|
||||
PIN_KEY_ENCRYPTED_USER_KEY_PERSISTENT,
|
||||
pinKeyEncryptedUserKeyPersistant.encryptedString,
|
||||
mockUserId,
|
||||
);
|
||||
});
|
||||
|
||||
it("should store a pinKeyEncryptedUserKeyEphemeral when 'storeAsEphemeral' is true", async () => {
|
||||
// Arrange
|
||||
const storeAsEphemeral = true;
|
||||
|
||||
// Act
|
||||
await sut.storePinKeyEncryptedUserKey(
|
||||
pinKeyEncryptedUserKeyEphemeral,
|
||||
storeAsEphemeral,
|
||||
mockUserId,
|
||||
);
|
||||
|
||||
// Assert
|
||||
expect(stateProvider.mock.setUserState).toHaveBeenCalledWith(
|
||||
PIN_KEY_ENCRYPTED_USER_KEY_EPHEMERAL,
|
||||
pinKeyEncryptedUserKeyEphemeral.encryptedString,
|
||||
mockUserId,
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("userKeyEncryptedPin methods", () => {
|
||||
describe("getUserKeyEncryptedPin()", () => {
|
||||
it("should get the userKeyEncryptedPin of the specified userId", async () => {
|
||||
await sut.getUserKeyEncryptedPin(mockUserId);
|
||||
|
||||
expect(stateProvider.mock.getUserState$).toHaveBeenCalledWith(
|
||||
USER_KEY_ENCRYPTED_PIN,
|
||||
mockUserId,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("setUserKeyEncryptedPin()", () => {
|
||||
it("should set the userKeyEncryptedPin of the specified userId", async () => {
|
||||
await sut.setUserKeyEncryptedPin(mockUserKeyEncryptedPin, mockUserId);
|
||||
|
||||
expect(stateProvider.mock.setUserState).toHaveBeenCalledWith(
|
||||
USER_KEY_ENCRYPTED_PIN,
|
||||
mockUserKeyEncryptedPin.encryptedString,
|
||||
mockUserId,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("clearUserKeyEncryptedPin()", () => {
|
||||
it("should clear the pinKeyEncryptedUserKey of the specified userId", async () => {
|
||||
await sut.clearUserKeyEncryptedPin(mockUserId);
|
||||
|
||||
expect(stateProvider.mock.setUserState).toHaveBeenCalledWith(
|
||||
USER_KEY_ENCRYPTED_PIN,
|
||||
null,
|
||||
mockUserId,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("createUserKeyEncryptedPin()", () => {
|
||||
it("should throw an error if a userKey is not provided", async () => {
|
||||
await expect(sut.createUserKeyEncryptedPin(mockPin, undefined)).rejects.toThrow(
|
||||
"No UserKey provided. Cannot create userKeyEncryptedPin.",
|
||||
);
|
||||
});
|
||||
|
||||
it("should create a userKeyEncryptedPin from the provided PIN and userKey", async () => {
|
||||
encryptService.encryptString.mockResolvedValue(mockUserKeyEncryptedPin);
|
||||
|
||||
const result = await sut.createUserKeyEncryptedPin(mockPin, mockUserKey);
|
||||
|
||||
expect(encryptService.encryptString).toHaveBeenCalledWith(mockPin, mockUserKey);
|
||||
expect(result).toEqual(mockUserKeyEncryptedPin);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("makePinKey()", () => {
|
||||
it("should make a PinKey", async () => {
|
||||
// Arrange
|
||||
keyGenerationService.deriveKeyFromPassword.mockResolvedValue(mockPinKey);
|
||||
|
||||
// Act
|
||||
await sut.makePinKey(mockPin, mockUserEmail, DEFAULT_KDF_CONFIG);
|
||||
|
||||
// Assert
|
||||
expect(keyGenerationService.deriveKeyFromPassword).toHaveBeenCalledWith(
|
||||
mockPin,
|
||||
mockUserEmail,
|
||||
DEFAULT_KDF_CONFIG,
|
||||
);
|
||||
expect(keyGenerationService.stretchKey).toHaveBeenCalledWith(mockPinKey);
|
||||
});
|
||||
});
|
||||
|
||||
describe("getPinLockType()", () => {
|
||||
it("should return 'PERSISTENT' if a pinKeyEncryptedUserKey (persistent version) is found", async () => {
|
||||
// Arrange
|
||||
sut.getUserKeyEncryptedPin = jest.fn().mockResolvedValue(null);
|
||||
sut.getPinKeyEncryptedUserKeyPersistent = jest
|
||||
.fn()
|
||||
.mockResolvedValue(pinKeyEncryptedUserKeyPersistant);
|
||||
|
||||
// Act
|
||||
const result = await sut.getPinLockType(mockUserId);
|
||||
|
||||
// Assert
|
||||
expect(result).toBe("PERSISTENT");
|
||||
});
|
||||
|
||||
it("should return 'EPHEMERAL' if a pinKeyEncryptedUserKey (persistent version) is not found but a userKeyEncryptedPin is found", async () => {
|
||||
// Arrange
|
||||
sut.getUserKeyEncryptedPin = jest.fn().mockResolvedValue(mockUserKeyEncryptedPin);
|
||||
sut.getPinKeyEncryptedUserKeyPersistent = jest.fn().mockResolvedValue(null);
|
||||
|
||||
// Act
|
||||
const result = await sut.getPinLockType(mockUserId);
|
||||
|
||||
// Assert
|
||||
expect(result).toBe("EPHEMERAL");
|
||||
});
|
||||
|
||||
it("should return 'DISABLED' if both of these are NOT found: userKeyEncryptedPin, pinKeyEncryptedUserKey (persistent version)", async () => {
|
||||
// Arrange
|
||||
sut.getUserKeyEncryptedPin = jest.fn().mockResolvedValue(null);
|
||||
sut.getPinKeyEncryptedUserKeyPersistent = jest.fn().mockResolvedValue(null);
|
||||
|
||||
// Act
|
||||
const result = await sut.getPinLockType(mockUserId);
|
||||
|
||||
// Assert
|
||||
expect(result).toBe("DISABLED");
|
||||
});
|
||||
});
|
||||
|
||||
describe("isPinSet()", () => {
|
||||
it.each(["PERSISTENT", "EPHEMERAL"])(
|
||||
"should return true if the user PinLockType is '%s'",
|
||||
async () => {
|
||||
// Arrange
|
||||
sut.getPinLockType = jest.fn().mockResolvedValue("PERSISTENT");
|
||||
|
||||
// Act
|
||||
const result = await sut.isPinSet(mockUserId);
|
||||
|
||||
// Assert
|
||||
expect(result).toEqual(true);
|
||||
},
|
||||
);
|
||||
|
||||
it("should return false if the user PinLockType is 'DISABLED'", async () => {
|
||||
// Arrange
|
||||
sut.getPinLockType = jest.fn().mockResolvedValue("DISABLED");
|
||||
|
||||
// Act
|
||||
const result = await sut.isPinSet(mockUserId);
|
||||
|
||||
// Assert
|
||||
expect(result).toEqual(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("isPinDecryptionAvailable()", () => {
|
||||
it("should return false if pinLockType is DISABLED", async () => {
|
||||
// Arrange
|
||||
sut.getPinLockType = jest.fn().mockResolvedValue("DISABLED");
|
||||
|
||||
// Act
|
||||
const result = await sut.isPinDecryptionAvailable(mockUserId);
|
||||
|
||||
// Assert
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it("should return true if pinLockType is PERSISTENT", async () => {
|
||||
// Arrange
|
||||
sut.getPinLockType = jest.fn().mockResolvedValue("PERSISTENT");
|
||||
|
||||
// Act
|
||||
const result = await sut.isPinDecryptionAvailable(mockUserId);
|
||||
|
||||
// Assert
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it("should return true if pinLockType is EPHEMERAL and we have an ephemeral PIN key encrypted user key", async () => {
|
||||
// Arrange
|
||||
sut.getPinLockType = jest.fn().mockResolvedValue("EPHEMERAL");
|
||||
sut.getPinKeyEncryptedUserKeyEphemeral = jest
|
||||
.fn()
|
||||
.mockResolvedValue(pinKeyEncryptedUserKeyEphemeral);
|
||||
|
||||
// Act
|
||||
const result = await sut.isPinDecryptionAvailable(mockUserId);
|
||||
|
||||
// Assert
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it("should return false if pinLockType is EPHEMERAL and we do not have an ephemeral PIN key encrypted user key", async () => {
|
||||
// Arrange
|
||||
sut.getPinLockType = jest.fn().mockResolvedValue("EPHEMERAL");
|
||||
sut.getPinKeyEncryptedUserKeyEphemeral = jest.fn().mockResolvedValue(null);
|
||||
|
||||
// Act
|
||||
const result = await sut.isPinDecryptionAvailable(mockUserId);
|
||||
|
||||
// Assert
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it("should throw an error if an unexpected pinLockType is returned", async () => {
|
||||
// Arrange
|
||||
sut.getPinLockType = jest.fn().mockResolvedValue("UNKNOWN");
|
||||
|
||||
// Act & Assert
|
||||
await expect(sut.isPinDecryptionAvailable(mockUserId)).rejects.toThrow(
|
||||
"Unexpected pinLockType: UNKNOWN",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("decryptUserKeyWithPin()", () => {
|
||||
async function setupDecryptUserKeyWithPinMocks(pinLockType: PinLockType) {
|
||||
sut.getPinLockType = jest.fn().mockResolvedValue(pinLockType);
|
||||
|
||||
mockPinEncryptedKeyDataByPinLockType(pinLockType);
|
||||
|
||||
kdfConfigService.getKdfConfig.mockResolvedValue(DEFAULT_KDF_CONFIG);
|
||||
|
||||
mockDecryptUserKeyFn();
|
||||
|
||||
sut.getUserKeyEncryptedPin = jest.fn().mockResolvedValue(mockUserKeyEncryptedPin);
|
||||
encryptService.decryptString.mockResolvedValue(mockPin);
|
||||
cryptoFunctionService.compareFast.calledWith(mockPin, "1234").mockResolvedValue(true);
|
||||
}
|
||||
|
||||
function mockDecryptUserKeyFn() {
|
||||
sut.getPinKeyEncryptedUserKeyPersistent = jest
|
||||
.fn()
|
||||
.mockResolvedValue(pinKeyEncryptedUserKeyPersistant);
|
||||
sut.makePinKey = jest.fn().mockResolvedValue(mockPinKey);
|
||||
encryptService.unwrapSymmetricKey.mockResolvedValue(mockUserKey);
|
||||
}
|
||||
|
||||
function mockPinEncryptedKeyDataByPinLockType(pinLockType: PinLockType) {
|
||||
switch (pinLockType) {
|
||||
case "PERSISTENT":
|
||||
sut.getPinKeyEncryptedUserKeyPersistent = jest
|
||||
.fn()
|
||||
.mockResolvedValue(pinKeyEncryptedUserKeyPersistant);
|
||||
break;
|
||||
case "EPHEMERAL":
|
||||
sut.getPinKeyEncryptedUserKeyEphemeral = jest
|
||||
.fn()
|
||||
.mockResolvedValue(pinKeyEncryptedUserKeyEphemeral);
|
||||
|
||||
break;
|
||||
case "DISABLED":
|
||||
// no mocking required. Error should be thrown
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const testCases: { pinLockType: PinLockType }[] = [
|
||||
{ pinLockType: "PERSISTENT" },
|
||||
{ pinLockType: "EPHEMERAL" },
|
||||
];
|
||||
|
||||
testCases.forEach(({ pinLockType }) => {
|
||||
describe(`given a ${pinLockType} PIN)`, () => {
|
||||
it(`should successfully decrypt and return user key when using a valid PIN`, async () => {
|
||||
// Arrange
|
||||
await setupDecryptUserKeyWithPinMocks(pinLockType);
|
||||
|
||||
// Act
|
||||
const result = await sut.decryptUserKeyWithPin(mockPin, mockUserId);
|
||||
|
||||
// Assert
|
||||
expect(result).toEqual(mockUserKey);
|
||||
});
|
||||
|
||||
it(`should return null when PIN is incorrect and user key cannot be decrypted`, async () => {
|
||||
// Arrange
|
||||
await setupDecryptUserKeyWithPinMocks(pinLockType);
|
||||
sut.decryptUserKeyWithPin = jest.fn().mockResolvedValue(null);
|
||||
|
||||
// Act
|
||||
const result = await sut.decryptUserKeyWithPin(mockPin, mockUserId);
|
||||
|
||||
// Assert
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
// not sure if this is a realistic scenario but going to test it anyway
|
||||
it(`should return null when PIN doesn't match after successful user key decryption`, async () => {
|
||||
// Arrange
|
||||
await setupDecryptUserKeyWithPinMocks(pinLockType);
|
||||
encryptService.decryptString.mockResolvedValue("9999"); // non matching PIN
|
||||
|
||||
// Act
|
||||
const result = await sut.decryptUserKeyWithPin(mockPin, mockUserId);
|
||||
|
||||
// Assert
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it(`should return null when pin is disabled`, async () => {
|
||||
// Arrange
|
||||
await setupDecryptUserKeyWithPinMocks("DISABLED");
|
||||
|
||||
// Act
|
||||
const result = await sut.decryptUserKeyWithPin(mockPin, mockUserId);
|
||||
|
||||
// Assert
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Test helpers
|
||||
function randomBytes(length: number): Uint8Array {
|
||||
return new Uint8Array(Array.from({ length }, (_, k) => k % 255));
|
||||
}
|
||||
@@ -2,9 +2,6 @@
|
||||
// @ts-strict-ignore
|
||||
import { firstValueFrom, map, timeout } from "rxjs";
|
||||
|
||||
// This import has been flagged as unallowed for this class. It may be involved in a circular dependency loop.
|
||||
// eslint-disable-next-line no-restricted-imports
|
||||
import { PinServiceAbstraction } from "@bitwarden/auth/common";
|
||||
// This import has been flagged as unallowed for this class. It may be involved in a circular dependency loop.
|
||||
// eslint-disable-next-line no-restricted-imports
|
||||
import { BiometricStateService } from "@bitwarden/key-management";
|
||||
@@ -20,6 +17,7 @@ import { LogService } from "../../platform/abstractions/log.service";
|
||||
import { MessagingService } from "../../platform/abstractions/messaging.service";
|
||||
import { UserId } from "../../types/guid";
|
||||
import { ProcessReloadServiceAbstraction } from "../abstractions/process-reload.service";
|
||||
import { PinServiceAbstraction } from "../pin/pin.service.abstraction";
|
||||
|
||||
export class DefaultProcessReloadService implements ProcessReloadServiceAbstraction {
|
||||
private reloadInterval: any = null;
|
||||
|
||||
@@ -6,7 +6,6 @@ import { BehaviorSubject, firstValueFrom, map, of } from "rxjs";
|
||||
// This import has been flagged as unallowed for this class. It may be involved in a circular dependency loop.
|
||||
// eslint-disable-next-line no-restricted-imports
|
||||
import {
|
||||
PinServiceAbstraction,
|
||||
FakeUserDecryptionOptions as UserDecryptionOptions,
|
||||
UserDecryptionOptionsServiceAbstraction,
|
||||
} from "@bitwarden/auth/common";
|
||||
@@ -21,6 +20,7 @@ import { TokenService } from "../../../auth/services/token.service";
|
||||
import { LogService } from "../../../platform/abstractions/log.service";
|
||||
import { Utils } from "../../../platform/misc/utils";
|
||||
import { UserId } from "../../../types/guid";
|
||||
import { PinServiceAbstraction } from "../../pin/pin.service.abstraction";
|
||||
import { VaultTimeoutSettingsService as VaultTimeoutSettingsServiceAbstraction } from "../abstractions/vault-timeout-settings.service";
|
||||
import { VaultTimeoutAction } from "../enums/vault-timeout-action.enum";
|
||||
import { VaultTimeout, VaultTimeoutStringType } from "../types/vault-timeout.type";
|
||||
|
||||
@@ -16,10 +16,7 @@ import {
|
||||
|
||||
// This import has been flagged as unallowed for this class. It may be involved in a circular dependency loop.
|
||||
// eslint-disable-next-line no-restricted-imports
|
||||
import {
|
||||
PinServiceAbstraction,
|
||||
UserDecryptionOptionsServiceAbstraction,
|
||||
} from "@bitwarden/auth/common";
|
||||
import { UserDecryptionOptionsServiceAbstraction } from "@bitwarden/auth/common";
|
||||
// This import has been flagged as unallowed for this class. It may be involved in a circular dependency loop.
|
||||
// eslint-disable-next-line no-restricted-imports
|
||||
import { BiometricStateService, KeyService } from "@bitwarden/key-management";
|
||||
@@ -33,6 +30,7 @@ import { TokenService } from "../../../auth/abstractions/token.service";
|
||||
import { LogService } from "../../../platform/abstractions/log.service";
|
||||
import { StateProvider } from "../../../platform/state";
|
||||
import { UserId } from "../../../types/guid";
|
||||
import { PinServiceAbstraction } from "../../pin/pin.service.abstraction";
|
||||
import { VaultTimeoutSettingsService as VaultTimeoutSettingsServiceAbstraction } from "../abstractions/vault-timeout-settings.service";
|
||||
import { VaultTimeoutAction } from "../enums/vault-timeout-action.enum";
|
||||
import { VaultTimeout, VaultTimeoutStringType } from "../types/vault-timeout.type";
|
||||
|
||||
@@ -143,10 +143,6 @@ export class VaultTimeoutService implements VaultTimeoutServiceAbstraction {
|
||||
),
|
||||
);
|
||||
|
||||
if (userId == null || userId === currentUserId) {
|
||||
await this.collectionService.clearActiveUserCache();
|
||||
}
|
||||
|
||||
await this.searchService.clearIndex(lockingUserId);
|
||||
|
||||
await this.folderService.clearDecryptedFolderState(lockingUserId);
|
||||
|
||||
@@ -13,9 +13,9 @@ export const getById = <TId, T extends { id: TId }>(id: TId) =>
|
||||
* @param id The IDs of the objects to return.
|
||||
* @returns An array containing objects with matching IDs, or an empty array if there are no matching objects.
|
||||
*/
|
||||
export const getByIds = <TId, T extends { id: TId }>(ids: TId[]) => {
|
||||
const idSet = new Set(ids);
|
||||
export const getByIds = <TId, T extends { id: TId | undefined }>(ids: TId[]) => {
|
||||
const idSet = new Set(ids.filter((id) => id != null));
|
||||
return map<T[], T[]>((objects) => {
|
||||
return objects.filter((o) => idSet.has(o.id));
|
||||
return objects.filter((o) => o.id && idSet.has(o.id));
|
||||
});
|
||||
};
|
||||
|
||||
@@ -252,6 +252,7 @@ export class Utils {
|
||||
}
|
||||
|
||||
// ref: http://stackoverflow.com/a/2117523/1090359
|
||||
/** @deprecated Use newGuid from @bitwarden/guid instead */
|
||||
static newGuid(): string {
|
||||
return "xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx".replace(/[xy]/g, (c) => {
|
||||
const r = (Math.random() * 16) | 0;
|
||||
@@ -260,8 +261,10 @@ export class Utils {
|
||||
});
|
||||
}
|
||||
|
||||
/** @deprecated Use guidRegex from @bitwarden/guid instead */
|
||||
static guidRegex = /^[0-9a-f]{8}-(?:[0-9a-f]{4}-){3}[0-9a-f]{12}$/;
|
||||
|
||||
/** @deprecated Use isGuid from @bitwarden/guid instead */
|
||||
static isGuid(id: string) {
|
||||
return RegExp(Utils.guidRegex, "i").test(id);
|
||||
}
|
||||
|
||||
@@ -1,131 +0,0 @@
|
||||
import { mock, MockProxy } from "jest-mock-extended";
|
||||
|
||||
import { makeEncString, makeSymmetricCryptoKey } from "../../../../spec";
|
||||
import { EncryptService } from "../../../key-management/crypto/abstractions/encrypt.service";
|
||||
import { EncString } from "../../../key-management/crypto/models/enc-string";
|
||||
|
||||
import Domain from "./domain-base";
|
||||
|
||||
class TestDomain extends Domain {
|
||||
plainText: string;
|
||||
encToString: EncString;
|
||||
encString2: EncString;
|
||||
}
|
||||
|
||||
describe("DomainBase", () => {
|
||||
let encryptService: MockProxy<EncryptService>;
|
||||
const key = makeSymmetricCryptoKey(64);
|
||||
|
||||
beforeEach(() => {
|
||||
encryptService = mock<EncryptService>();
|
||||
});
|
||||
|
||||
function setUpCryptography() {
|
||||
encryptService.encryptString.mockImplementation((value) =>
|
||||
Promise.resolve(makeEncString(value)),
|
||||
);
|
||||
|
||||
encryptService.decryptString.mockImplementation((value) => {
|
||||
return Promise.resolve(value.data);
|
||||
});
|
||||
}
|
||||
|
||||
describe("decryptWithKey", () => {
|
||||
it("domain property types are decryptable", async () => {
|
||||
const domain = new TestDomain();
|
||||
|
||||
await domain["decryptObjWithKey"](
|
||||
// @ts-expect-error -- clear is not of type EncString
|
||||
["plainText"],
|
||||
makeSymmetricCryptoKey(64),
|
||||
mock<EncryptService>(),
|
||||
);
|
||||
|
||||
await domain["decryptObjWithKey"](
|
||||
// @ts-expect-error -- Clear is not of type EncString
|
||||
["encToString", "encString2", "plainText"],
|
||||
makeSymmetricCryptoKey(64),
|
||||
mock<EncryptService>(),
|
||||
);
|
||||
|
||||
const decrypted = await domain["decryptObjWithKey"](
|
||||
["encToString"],
|
||||
makeSymmetricCryptoKey(64),
|
||||
mock<EncryptService>(),
|
||||
);
|
||||
|
||||
// @ts-expect-error -- encString2 was not decrypted
|
||||
// FIXME: Remove when updating file. Eslint update
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-expressions
|
||||
decrypted as { encToString: string; encString2: string; plainText: string };
|
||||
|
||||
// encString2 was not decrypted, so it's still an EncString
|
||||
// FIXME: Remove when updating file. Eslint update
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-expressions
|
||||
decrypted as { encToString: string; encString2: EncString; plainText: string };
|
||||
});
|
||||
|
||||
it("decrypts the encrypted properties", async () => {
|
||||
setUpCryptography();
|
||||
|
||||
const domain = new TestDomain();
|
||||
|
||||
domain.encToString = await encryptService.encryptString("string", key);
|
||||
|
||||
const decrypted = await domain["decryptObjWithKey"](["encToString"], key, encryptService);
|
||||
|
||||
expect(decrypted).toEqual({
|
||||
encToString: "string",
|
||||
});
|
||||
});
|
||||
|
||||
it("decrypts multiple encrypted properties", async () => {
|
||||
setUpCryptography();
|
||||
|
||||
const domain = new TestDomain();
|
||||
|
||||
domain.encToString = await encryptService.encryptString("string", key);
|
||||
domain.encString2 = await encryptService.encryptString("string2", key);
|
||||
|
||||
const decrypted = await domain["decryptObjWithKey"](
|
||||
["encToString", "encString2"],
|
||||
key,
|
||||
encryptService,
|
||||
);
|
||||
|
||||
expect(decrypted).toEqual({
|
||||
encToString: "string",
|
||||
encString2: "string2",
|
||||
});
|
||||
});
|
||||
|
||||
it("does not decrypt properties that are not encrypted", async () => {
|
||||
const domain = new TestDomain();
|
||||
domain.plainText = "clear";
|
||||
|
||||
const decrypted = await domain["decryptObjWithKey"]([], key, encryptService);
|
||||
|
||||
expect(decrypted).toEqual({
|
||||
plainText: "clear",
|
||||
});
|
||||
});
|
||||
|
||||
it("does not decrypt properties that were not requested to be decrypted", async () => {
|
||||
setUpCryptography();
|
||||
|
||||
const domain = new TestDomain();
|
||||
|
||||
domain.plainText = "clear";
|
||||
domain.encToString = makeEncString("string");
|
||||
domain.encString2 = makeEncString("string2");
|
||||
|
||||
const decrypted = await domain["decryptObjWithKey"]([], key, encryptService);
|
||||
|
||||
expect(decrypted).toEqual({
|
||||
plainText: "clear",
|
||||
encToString: makeEncString("string"),
|
||||
encString2: makeEncString("string2"),
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,6 +1,5 @@
|
||||
import { ConditionalExcept, ConditionalKeys, Constructor } from "type-fest";
|
||||
import { ConditionalExcept, ConditionalKeys } from "type-fest";
|
||||
|
||||
import { EncryptService } from "../../../key-management/crypto/abstractions/encrypt.service";
|
||||
import { EncString } from "../../../key-management/crypto/models/enc-string";
|
||||
import { View } from "../../../models/view/view";
|
||||
|
||||
@@ -14,7 +13,7 @@ export type DecryptedObject<
|
||||
> = Record<TDecryptedKeys, string> & Omit<TEncryptedObject, TDecryptedKeys>;
|
||||
|
||||
// extracts shared keys from the domain and view types
|
||||
type EncryptableKeys<D extends Domain, V extends View> = (keyof D &
|
||||
export type EncryptableKeys<D extends Domain, V extends View> = (keyof D &
|
||||
ConditionalKeys<D, EncString | null>) &
|
||||
(keyof V & ConditionalKeys<V, string | null>);
|
||||
|
||||
@@ -89,66 +88,4 @@ export default class Domain {
|
||||
|
||||
return viewModel as V;
|
||||
}
|
||||
|
||||
/**
|
||||
* Decrypts the requested properties of the domain object with the provided key and encrypt service.
|
||||
*
|
||||
* If a property is null, the result will be null.
|
||||
* @see {@link EncString.decryptWithKey} for more details on decryption behavior.
|
||||
*
|
||||
* @param encryptedProperties The properties to decrypt. Type restricted to EncString properties of the domain object.
|
||||
* @param key The key to use for decryption.
|
||||
* @param encryptService The encryption service to use for decryption.
|
||||
* @param _ The constructor of the domain object. Used for type inference if the domain object is not automatically inferred.
|
||||
* @returns An object with the requested properties decrypted and the rest of the domain object untouched.
|
||||
*/
|
||||
protected async decryptObjWithKey<
|
||||
TThis extends Domain,
|
||||
const TEncryptedKeys extends EncStringKeys<TThis>,
|
||||
>(
|
||||
this: TThis,
|
||||
encryptedProperties: TEncryptedKeys[],
|
||||
key: SymmetricCryptoKey,
|
||||
encryptService: EncryptService,
|
||||
_: Constructor<TThis> = this.constructor as Constructor<TThis>,
|
||||
objectContext: string = "No Domain Context",
|
||||
): Promise<DecryptedObject<TThis, TEncryptedKeys>> {
|
||||
const decryptedObjects = [];
|
||||
|
||||
for (const prop of encryptedProperties) {
|
||||
const value = this[prop] as EncString;
|
||||
const decrypted = await this.decryptProperty(
|
||||
prop,
|
||||
value,
|
||||
key,
|
||||
encryptService,
|
||||
`Property: ${prop.toString()}; ObjectContext: ${objectContext}`,
|
||||
);
|
||||
decryptedObjects.push(decrypted);
|
||||
}
|
||||
|
||||
const decryptedObject = decryptedObjects.reduce(
|
||||
(acc, obj) => {
|
||||
return { ...acc, ...obj };
|
||||
},
|
||||
{ ...this },
|
||||
);
|
||||
return decryptedObject as DecryptedObject<TThis, TEncryptedKeys>;
|
||||
}
|
||||
|
||||
private async decryptProperty<const TEncryptedKeys extends EncStringKeys<this>>(
|
||||
propertyKey: TEncryptedKeys,
|
||||
value: EncString,
|
||||
key: SymmetricCryptoKey,
|
||||
encryptService: EncryptService,
|
||||
decryptTrace: string,
|
||||
) {
|
||||
let decrypted: string | null = null;
|
||||
if (value) {
|
||||
decrypted = await value.decryptWithKey(key, encryptService, decryptTrace);
|
||||
}
|
||||
return {
|
||||
[propertyKey]: decrypted,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { mock } from "jest-mock-extended";
|
||||
|
||||
import { MigrationHelper } from "@bitwarden/state";
|
||||
|
||||
import { FakeStorageService } from "../../../spec/fake-storage.service";
|
||||
import { ClientType } from "../../enums";
|
||||
import { MigrationHelper } from "../../state-migrations/migration-helper";
|
||||
|
||||
import { MigrationBuilderService } from "./migration-builder.service";
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { CURRENT_VERSION, currentVersion, MigrationHelper } from "@bitwarden/state";
|
||||
|
||||
import { ClientType } from "../../enums";
|
||||
import { waitForMigrations } from "../../state-migrations";
|
||||
import { CURRENT_VERSION, currentVersion } from "../../state-migrations/migrate";
|
||||
import { MigrationHelper } from "../../state-migrations/migration-helper";
|
||||
import { LogService } from "../abstractions/log.service";
|
||||
import { AbstractStorageService } from "../abstractions/storage.service";
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ import {
|
||||
ClientSettings,
|
||||
DeviceType as SdkDeviceType,
|
||||
TokenProvider,
|
||||
UnsignedSharedKey,
|
||||
} from "@bitwarden/sdk-internal";
|
||||
|
||||
import { EncryptedOrganizationKeyData } from "../../../admin-console/models/data/encrypted-organization-key.data";
|
||||
@@ -237,7 +238,7 @@ export class DefaultSdkService implements SdkService {
|
||||
organizationKeys: new Map(
|
||||
Object.entries(orgKeys ?? {})
|
||||
.filter(([_, v]) => v.type === "organization")
|
||||
.map(([k, v]) => [k, v.key]),
|
||||
.map(([k, v]) => [k, v.key as UnsignedSharedKey]),
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
11
libs/common/src/platform/state/active-user.accessor.ts
Normal file
11
libs/common/src/platform/state/active-user.accessor.ts
Normal file
@@ -0,0 +1,11 @@
|
||||
import { Observable } from "rxjs";
|
||||
|
||||
import { UserId } from "@bitwarden/user-core";
|
||||
|
||||
export abstract class ActiveUserAccessor {
|
||||
/**
|
||||
* Returns a stream of the current active user for the application. The stream either emits the user id for that account
|
||||
* or returns null if there is no current active user.
|
||||
*/
|
||||
abstract activeUserId$: Observable<UserId | null>;
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
import { DeriveDefinition } from "./derive-definition";
|
||||
import { KeyDefinition } from "./key-definition";
|
||||
import { StateDefinition } from "./state-definition";
|
||||
|
||||
const derive: () => any = () => null;
|
||||
const deserializer: any = (obj: any) => obj;
|
||||
|
||||
const STATE_DEFINITION = new StateDefinition("test", "disk");
|
||||
const TEST_KEY = new KeyDefinition(STATE_DEFINITION, "test", {
|
||||
deserializer,
|
||||
});
|
||||
const TEST_DERIVE = new DeriveDefinition(STATE_DEFINITION, "test", {
|
||||
derive,
|
||||
deserializer,
|
||||
});
|
||||
|
||||
describe("DeriveDefinition", () => {
|
||||
describe("from", () => {
|
||||
it("should create a new DeriveDefinition from a KeyDefinition", () => {
|
||||
const result = DeriveDefinition.from(TEST_KEY, {
|
||||
derive,
|
||||
deserializer,
|
||||
});
|
||||
|
||||
expect(result).toEqual(TEST_DERIVE);
|
||||
});
|
||||
|
||||
it("should create a new DeriveDefinition from a DeriveDefinition", () => {
|
||||
const result = DeriveDefinition.from([TEST_DERIVE, "newDerive"], {
|
||||
derive,
|
||||
deserializer,
|
||||
});
|
||||
|
||||
expect(result).toEqual(
|
||||
new DeriveDefinition(STATE_DEFINITION, "newDerive", {
|
||||
derive,
|
||||
deserializer,
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,196 +1 @@
|
||||
// FIXME: Update this file to be type safe and remove this and next line
|
||||
// @ts-strict-ignore
|
||||
import { Jsonify } from "type-fest";
|
||||
|
||||
import { UserId } from "../../types/guid";
|
||||
import { DerivedStateDependencies, StorageKey } from "../../types/state";
|
||||
|
||||
import { KeyDefinition } from "./key-definition";
|
||||
import { StateDefinition } from "./state-definition";
|
||||
import { UserKeyDefinition } from "./user-key-definition";
|
||||
|
||||
declare const depShapeMarker: unique symbol;
|
||||
/**
|
||||
* A set of options for customizing the behavior of a {@link DeriveDefinition}
|
||||
*/
|
||||
type DeriveDefinitionOptions<TFrom, TTo, TDeps extends DerivedStateDependencies = never> = {
|
||||
/**
|
||||
* A function to use to convert values from TFrom to TTo. This is called on each emit of the parent state observable
|
||||
* and the resulting value will be emitted from the derived state observable.
|
||||
*
|
||||
* @param from Populated with the latest emission from the parent state observable.
|
||||
* @param deps Populated with the dependencies passed into the constructor of the derived state.
|
||||
* These are constant for the lifetime of the derived state.
|
||||
* @returns The derived state value or a Promise that resolves to the derived state value.
|
||||
*/
|
||||
derive: (from: TFrom, deps: TDeps) => TTo | Promise<TTo>;
|
||||
/**
|
||||
* A function to use to safely convert your type from json to your expected type.
|
||||
*
|
||||
* **Important:** Your data may be serialized/deserialized at any time and this
|
||||
* callback needs to be able to faithfully re-initialize from the JSON object representation of your type.
|
||||
*
|
||||
* @param jsonValue The JSON object representation of your state.
|
||||
* @returns The fully typed version of your state.
|
||||
*/
|
||||
deserializer: (serialized: Jsonify<TTo>) => TTo;
|
||||
/**
|
||||
* An object defining the dependencies of the derive function. The keys of the object are the names of the dependencies
|
||||
* and the values are the types of the dependencies.
|
||||
*
|
||||
* for example:
|
||||
* ```
|
||||
* {
|
||||
* myService: MyService,
|
||||
* myOtherService: MyOtherService,
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
[depShapeMarker]?: TDeps;
|
||||
/**
|
||||
* The number of milliseconds to wait before cleaning up the state after the last subscriber has unsubscribed.
|
||||
* Defaults to 1000ms.
|
||||
*/
|
||||
cleanupDelayMs?: number;
|
||||
/**
|
||||
* Whether or not to clear the derived state when cleanup occurs. Defaults to true.
|
||||
*/
|
||||
clearOnCleanup?: boolean;
|
||||
};
|
||||
|
||||
/**
|
||||
* DeriveDefinitions describe state derived from another observable, the value type of which is given by `TFrom`.
|
||||
*
|
||||
* The StateDefinition is used to describe the domain of the state, and the DeriveDefinition
|
||||
* sub-divides that domain into specific keys. These keys are used to cache data in memory and enables derived state to
|
||||
* be calculated once regardless of multiple execution contexts.
|
||||
*/
|
||||
|
||||
export class DeriveDefinition<TFrom, TTo, TDeps extends DerivedStateDependencies> {
|
||||
/**
|
||||
* Creates a new instance of a DeriveDefinition. Derived state is always stored in memory, so the storage location
|
||||
* defined in @link{StateDefinition} is ignored.
|
||||
*
|
||||
* @param stateDefinition The state definition for which this key belongs to.
|
||||
* @param uniqueDerivationName The name of the key, this should be unique per domain.
|
||||
* @param options A set of options to customize the behavior of {@link DeriveDefinition}.
|
||||
* @param options.derive A function to use to convert values from TFrom to TTo. This is called on each emit of the parent state observable
|
||||
* and the resulting value will be emitted from the derived state observable.
|
||||
* @param options.cleanupDelayMs The number of milliseconds to wait before cleaning up the state after the last subscriber has unsubscribed.
|
||||
* Defaults to 1000ms.
|
||||
* @param options.dependencyShape An object defining the dependencies of the derive function. The keys of the object are the names of the dependencies
|
||||
* and the values are the types of the dependencies.
|
||||
* for example:
|
||||
* ```
|
||||
* {
|
||||
* myService: MyService,
|
||||
* myOtherService: MyOtherService,
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* @param options.deserializer A function to use to safely convert your type from json to your expected type.
|
||||
* Your data may be serialized/deserialized at any time and this needs callback needs to be able to faithfully re-initialize
|
||||
* from the JSON object representation of your type.
|
||||
*/
|
||||
constructor(
|
||||
readonly stateDefinition: StateDefinition,
|
||||
readonly uniqueDerivationName: string,
|
||||
readonly options: DeriveDefinitionOptions<TFrom, TTo, TDeps>,
|
||||
) {}
|
||||
|
||||
/**
|
||||
* Factory that produces a {@link DeriveDefinition} from a {@link KeyDefinition} or {@link DeriveDefinition} and new name.
|
||||
*
|
||||
* If a `KeyDefinition` is passed in, the returned definition will have the same key as the given key definition, but
|
||||
* will not collide with it in storage, even if they both reside in memory.
|
||||
*
|
||||
* If a `DeriveDefinition` is passed in, the returned definition will instead use the name given in the second position
|
||||
* of the tuple. It is up to you to ensure this is unique within the domain of derived state.
|
||||
*
|
||||
* @param options A set of options to customize the behavior of {@link DeriveDefinition}.
|
||||
* @param options.derive A function to use to convert values from TFrom to TTo. This is called on each emit of the parent state observable
|
||||
* and the resulting value will be emitted from the derived state observable.
|
||||
* @param options.cleanupDelayMs The number of milliseconds to wait before cleaning up the state after the last subscriber has unsubscribed.
|
||||
* Defaults to 1000ms.
|
||||
* @param options.dependencyShape An object defining the dependencies of the derive function. The keys of the object are the names of the dependencies
|
||||
* and the values are the types of the dependencies.
|
||||
* for example:
|
||||
* ```
|
||||
* {
|
||||
* myService: MyService,
|
||||
* myOtherService: MyOtherService,
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* @param options.deserializer A function to use to safely convert your type from json to your expected type.
|
||||
* Your data may be serialized/deserialized at any time and this needs callback needs to be able to faithfully re-initialize
|
||||
* from the JSON object representation of your type.
|
||||
* @param definition
|
||||
* @param options
|
||||
* @returns
|
||||
*/
|
||||
static from<TFrom, TTo, TDeps extends DerivedStateDependencies = never>(
|
||||
definition:
|
||||
| KeyDefinition<TFrom>
|
||||
| UserKeyDefinition<TFrom>
|
||||
| [DeriveDefinition<unknown, TFrom, DerivedStateDependencies>, string],
|
||||
options: DeriveDefinitionOptions<TFrom, TTo, TDeps>,
|
||||
) {
|
||||
if (isFromDeriveDefinition(definition)) {
|
||||
return new DeriveDefinition(definition[0].stateDefinition, definition[1], options);
|
||||
} else {
|
||||
return new DeriveDefinition(definition.stateDefinition, definition.key, options);
|
||||
}
|
||||
}
|
||||
|
||||
static fromWithUserId<TKeyDef, TTo, TDeps extends DerivedStateDependencies = never>(
|
||||
definition:
|
||||
| KeyDefinition<TKeyDef>
|
||||
| UserKeyDefinition<TKeyDef>
|
||||
| [DeriveDefinition<unknown, TKeyDef, DerivedStateDependencies>, string],
|
||||
options: DeriveDefinitionOptions<[UserId, TKeyDef], TTo, TDeps>,
|
||||
) {
|
||||
if (isFromDeriveDefinition(definition)) {
|
||||
return new DeriveDefinition(definition[0].stateDefinition, definition[1], options);
|
||||
} else {
|
||||
return new DeriveDefinition(definition.stateDefinition, definition.key, options);
|
||||
}
|
||||
}
|
||||
|
||||
get derive() {
|
||||
return this.options.derive;
|
||||
}
|
||||
|
||||
deserialize(serialized: Jsonify<TTo>): TTo {
|
||||
return this.options.deserializer(serialized);
|
||||
}
|
||||
|
||||
get cleanupDelayMs() {
|
||||
return this.options.cleanupDelayMs < 0 ? 0 : (this.options.cleanupDelayMs ?? 1000);
|
||||
}
|
||||
|
||||
get clearOnCleanup() {
|
||||
return this.options.clearOnCleanup ?? true;
|
||||
}
|
||||
|
||||
buildCacheKey(): string {
|
||||
return `derived_${this.stateDefinition.name}_${this.uniqueDerivationName}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link StorageKey} that points to the data for the given derived definition.
|
||||
* @returns A key that is ready to be used in a storage service to get data.
|
||||
*/
|
||||
get storageKey(): StorageKey {
|
||||
return `derived_${this.stateDefinition.name}_${this.uniqueDerivationName}` as StorageKey;
|
||||
}
|
||||
}
|
||||
|
||||
function isFromDeriveDefinition(
|
||||
definition:
|
||||
| KeyDefinition<unknown>
|
||||
| UserKeyDefinition<unknown>
|
||||
| [DeriveDefinition<unknown, unknown, DerivedStateDependencies>, string],
|
||||
): definition is [DeriveDefinition<unknown, unknown, DerivedStateDependencies>, string] {
|
||||
return Array.isArray(definition);
|
||||
}
|
||||
export { DeriveDefinition } from "@bitwarden/state";
|
||||
|
||||
@@ -1,25 +1 @@
|
||||
import { Observable } from "rxjs";
|
||||
|
||||
import { DerivedStateDependencies } from "../../types/state";
|
||||
|
||||
import { DeriveDefinition } from "./derive-definition";
|
||||
import { DerivedState } from "./derived-state";
|
||||
|
||||
/**
|
||||
* State derived from an observable and a derive function
|
||||
*/
|
||||
export abstract class DerivedStateProvider {
|
||||
/**
|
||||
* Creates a derived state observable from a parent state observable, a deriveDefinition, and the dependencies
|
||||
* required by the deriveDefinition
|
||||
* @param parentState$ The parent state observable
|
||||
* @param deriveDefinition The deriveDefinition that defines conversion from the parent state to the derived state as
|
||||
* well as some memory persistent information.
|
||||
* @param dependencies The dependencies of the derive function
|
||||
*/
|
||||
abstract get<TFrom, TTo, TDeps extends DerivedStateDependencies>(
|
||||
parentState$: Observable<TFrom>,
|
||||
deriveDefinition: DeriveDefinition<TFrom, TTo, TDeps>,
|
||||
dependencies: TDeps,
|
||||
): DerivedState<TTo>;
|
||||
}
|
||||
export { DerivedStateProvider } from "@bitwarden/state";
|
||||
|
||||
@@ -1,23 +1 @@
|
||||
import { Observable } from "rxjs";
|
||||
|
||||
export type StateConverter<TFrom extends Array<unknown>, TTo> = (...args: TFrom) => TTo;
|
||||
|
||||
/**
|
||||
* State derived from an observable and a converter function
|
||||
*
|
||||
* Derived state is cached and persisted to memory for sychronization across execution contexts.
|
||||
* For clients with multiple execution contexts, the derived state will be executed only once in the background process.
|
||||
*/
|
||||
export interface DerivedState<T> {
|
||||
/**
|
||||
* The derived state observable
|
||||
*/
|
||||
state$: Observable<T>;
|
||||
/**
|
||||
* Forces the derived state to a given value.
|
||||
*
|
||||
* Useful for setting an in-memory value as a side effect of some event, such as emptying state as a result of a lock.
|
||||
* @param value The value to force the derived state to
|
||||
*/
|
||||
forceValue(value: T): Promise<T>;
|
||||
}
|
||||
export { DerivedState } from "@bitwarden/state";
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
import { record } from "./deserialization-helpers";
|
||||
|
||||
describe("deserialization helpers", () => {
|
||||
describe("record", () => {
|
||||
it("deserializes a record when keys are strings", () => {
|
||||
const deserializer = record((value: number) => value);
|
||||
const input = {
|
||||
a: 1,
|
||||
b: 2,
|
||||
};
|
||||
const output = deserializer(input);
|
||||
expect(output).toEqual(input);
|
||||
});
|
||||
|
||||
it("deserializes a record when keys are numbers", () => {
|
||||
const deserializer = record((value: number) => value);
|
||||
const input = {
|
||||
1: 1,
|
||||
2: 2,
|
||||
};
|
||||
const output = deserializer(input);
|
||||
expect(output).toEqual(input);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,40 +0,0 @@
|
||||
// FIXME: Update this file to be type safe and remove this and next line
|
||||
// @ts-strict-ignore
|
||||
import { Jsonify } from "type-fest";
|
||||
|
||||
/**
|
||||
*
|
||||
* @param elementDeserializer
|
||||
* @returns
|
||||
*/
|
||||
export function array<T>(
|
||||
elementDeserializer: (element: Jsonify<T>) => T,
|
||||
): (array: Jsonify<T[]>) => T[] {
|
||||
return (array) => {
|
||||
if (array == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return array.map((element) => elementDeserializer(element));
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param valueDeserializer
|
||||
*/
|
||||
export function record<T, TKey extends string | number = string>(
|
||||
valueDeserializer: (value: Jsonify<T>) => T,
|
||||
): (record: Jsonify<Record<TKey, T>>) => Record<TKey, T> {
|
||||
return (jsonValue: Jsonify<Record<TKey, T> | null>) => {
|
||||
if (jsonValue == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const output: Record<TKey, T> = {} as any;
|
||||
Object.entries(jsonValue).forEach(([key, value]) => {
|
||||
output[key as TKey] = valueDeserializer(value);
|
||||
});
|
||||
return output;
|
||||
};
|
||||
}
|
||||
@@ -1,13 +1 @@
|
||||
import { GlobalState } from "./global-state";
|
||||
import { KeyDefinition } from "./key-definition";
|
||||
|
||||
/**
|
||||
* A provider for getting an implementation of global state scoped to the given key.
|
||||
*/
|
||||
export abstract class GlobalStateProvider {
|
||||
/**
|
||||
* Gets a {@link GlobalState} scoped to the given {@link KeyDefinition}
|
||||
* @param keyDefinition - The {@link KeyDefinition} for which you want the state for.
|
||||
*/
|
||||
abstract get<T>(keyDefinition: KeyDefinition<T>): GlobalState<T>;
|
||||
}
|
||||
export { GlobalStateProvider } from "@bitwarden/state";
|
||||
|
||||
@@ -1,30 +1 @@
|
||||
import { Observable } from "rxjs";
|
||||
|
||||
import { StateUpdateOptions } from "./state-update-options";
|
||||
|
||||
/**
|
||||
* A helper object for interacting with state that is scoped to a specific domain
|
||||
* but is not scoped to a user. This is application wide storage.
|
||||
*/
|
||||
export interface GlobalState<T> {
|
||||
/**
|
||||
* Method for allowing you to manipulate state in an additive way.
|
||||
* @param configureState callback for how you want to manipulate this section of state
|
||||
* @param options Defaults given by @see {module:state-update-options#DEFAULT_OPTIONS}
|
||||
* @param options.shouldUpdate A callback for determining if you want to update state. Defaults to () => true
|
||||
* @param options.combineLatestWith An observable that you want to combine with the current state for callbacks. Defaults to null
|
||||
* @param options.msTimeout A timeout for how long you are willing to wait for a `combineLatestWith` option to complete. Defaults to 1000ms. Only applies if `combineLatestWith` is set.
|
||||
* @returns A promise that must be awaited before your next action to ensure the update has been written to state.
|
||||
* Resolves to the new state. If `shouldUpdate` returns false, the promise will resolve to the current state.
|
||||
*/
|
||||
update: <TCombine>(
|
||||
configureState: (state: T | null, dependency: TCombine) => T | null,
|
||||
options?: StateUpdateOptions<T, TCombine>,
|
||||
) => Promise<T | null>;
|
||||
|
||||
/**
|
||||
* An observable stream of this state, the first emission of this will be the current state on disk
|
||||
* and subsequent updates will be from an update to that state.
|
||||
*/
|
||||
state$: Observable<T | null>;
|
||||
}
|
||||
export { GlobalState } from "@bitwarden/state";
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
import { mock } from "jest-mock-extended";
|
||||
|
||||
import { mockAccountServiceWith, trackEmissions } from "../../../../spec";
|
||||
import { UserId } from "../../../types/guid";
|
||||
import { SingleUserStateProvider } from "../user-state.provider";
|
||||
|
||||
import { DefaultActiveUserStateProvider } from "./default-active-user-state.provider";
|
||||
|
||||
describe("DefaultActiveUserStateProvider", () => {
|
||||
const singleUserStateProvider = mock<SingleUserStateProvider>();
|
||||
const userId = "userId" as UserId;
|
||||
const accountInfo = {
|
||||
id: userId,
|
||||
name: "name",
|
||||
email: "email",
|
||||
emailVerified: false,
|
||||
};
|
||||
const accountService = mockAccountServiceWith(userId, accountInfo);
|
||||
let sut: DefaultActiveUserStateProvider;
|
||||
|
||||
beforeEach(() => {
|
||||
sut = new DefaultActiveUserStateProvider(accountService, singleUserStateProvider);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.resetAllMocks();
|
||||
});
|
||||
|
||||
it("should track the active User id from account service", () => {
|
||||
const emissions = trackEmissions(sut.activeUserId$);
|
||||
|
||||
accountService.activeAccountSubject.next(undefined);
|
||||
accountService.activeAccountSubject.next(accountInfo);
|
||||
|
||||
expect(emissions).toEqual([userId, undefined, userId]);
|
||||
});
|
||||
});
|
||||
@@ -1,9 +1,9 @@
|
||||
// FIXME: Update this file to be type safe and remove this and next line
|
||||
// @ts-strict-ignore
|
||||
import { Observable, distinctUntilChanged, map } from "rxjs";
|
||||
import { Observable, distinctUntilChanged } from "rxjs";
|
||||
|
||||
import { AccountService } from "../../../auth/abstractions/account.service";
|
||||
import { UserId } from "../../../types/guid";
|
||||
import { ActiveUserAccessor } from "../active-user.accessor";
|
||||
import { UserKeyDefinition } from "../user-key-definition";
|
||||
import { ActiveUserState } from "../user-state";
|
||||
import { ActiveUserStateProvider, SingleUserStateProvider } from "../user-state.provider";
|
||||
@@ -14,11 +14,10 @@ export class DefaultActiveUserStateProvider implements ActiveUserStateProvider {
|
||||
activeUserId$: Observable<UserId | undefined>;
|
||||
|
||||
constructor(
|
||||
private readonly accountService: AccountService,
|
||||
private readonly activeAccountAccessor: ActiveUserAccessor,
|
||||
private readonly singleUserStateProvider: SingleUserStateProvider,
|
||||
) {
|
||||
this.activeUserId$ = this.accountService.activeAccount$.pipe(
|
||||
map((account) => account?.id),
|
||||
this.activeUserId$ = this.activeAccountAccessor.activeUserId$.pipe(
|
||||
// To avoid going to storage when we don't need to, only get updates when there is a true change.
|
||||
distinctUntilChanged((a, b) => (a == null || b == null ? a == b : a === b)), // Treat null and undefined as equal
|
||||
);
|
||||
|
||||
@@ -1,772 +0,0 @@
|
||||
/**
|
||||
* need to update test environment so trackEmissions works appropriately
|
||||
* @jest-environment ../shared/test.environment.ts
|
||||
*/
|
||||
import { any, mock } from "jest-mock-extended";
|
||||
import { BehaviorSubject, firstValueFrom, map, of, timeout } from "rxjs";
|
||||
import { Jsonify } from "type-fest";
|
||||
|
||||
import { StorageServiceProvider } from "@bitwarden/storage-core";
|
||||
|
||||
import { awaitAsync, trackEmissions } from "../../../../spec";
|
||||
import { FakeStorageService } from "../../../../spec/fake-storage.service";
|
||||
import { Account } from "../../../auth/abstractions/account.service";
|
||||
import { UserId } from "../../../types/guid";
|
||||
import { LogService } from "../../abstractions/log.service";
|
||||
import { StateDefinition } from "../state-definition";
|
||||
import { StateEventRegistrarService } from "../state-event-registrar.service";
|
||||
import { UserKeyDefinition } from "../user-key-definition";
|
||||
|
||||
import { DefaultActiveUserState } from "./default-active-user-state";
|
||||
import { DefaultSingleUserStateProvider } from "./default-single-user-state.provider";
|
||||
|
||||
class TestState {
|
||||
date: Date;
|
||||
array: string[];
|
||||
|
||||
static fromJSON(jsonState: Jsonify<TestState>) {
|
||||
if (jsonState == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return Object.assign(new TestState(), jsonState, {
|
||||
date: new Date(jsonState.date),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const testStateDefinition = new StateDefinition("fake", "disk");
|
||||
const cleanupDelayMs = 15;
|
||||
const testKeyDefinition = new UserKeyDefinition<TestState>(testStateDefinition, "fake", {
|
||||
deserializer: TestState.fromJSON,
|
||||
cleanupDelayMs,
|
||||
clearOn: [],
|
||||
});
|
||||
|
||||
describe("DefaultActiveUserState", () => {
|
||||
let diskStorageService: FakeStorageService;
|
||||
const storageServiceProvider = mock<StorageServiceProvider>();
|
||||
const stateEventRegistrarService = mock<StateEventRegistrarService>();
|
||||
const logService = mock<LogService>();
|
||||
let activeAccountSubject: BehaviorSubject<Account | null>;
|
||||
|
||||
let singleUserStateProvider: DefaultSingleUserStateProvider;
|
||||
|
||||
let userState: DefaultActiveUserState<TestState>;
|
||||
|
||||
beforeEach(() => {
|
||||
diskStorageService = new FakeStorageService();
|
||||
storageServiceProvider.get.mockReturnValue(["disk", diskStorageService]);
|
||||
|
||||
singleUserStateProvider = new DefaultSingleUserStateProvider(
|
||||
storageServiceProvider,
|
||||
stateEventRegistrarService,
|
||||
logService,
|
||||
);
|
||||
|
||||
activeAccountSubject = new BehaviorSubject<Account | null>(null);
|
||||
|
||||
userState = new DefaultActiveUserState(
|
||||
testKeyDefinition,
|
||||
activeAccountSubject.asObservable().pipe(map((a) => a?.id)),
|
||||
singleUserStateProvider,
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.resetAllMocks();
|
||||
});
|
||||
|
||||
const makeUserId = (id: string) => {
|
||||
return id != null ? (`00000000-0000-1000-a000-00000000000${id}` as UserId) : undefined;
|
||||
};
|
||||
|
||||
const changeActiveUser = async (id: string) => {
|
||||
const userId = makeUserId(id);
|
||||
activeAccountSubject.next({
|
||||
id: userId,
|
||||
email: `test${id}@example.com`,
|
||||
emailVerified: false,
|
||||
name: `Test User ${id}`,
|
||||
});
|
||||
await awaitAsync();
|
||||
};
|
||||
|
||||
afterEach(() => {
|
||||
jest.resetAllMocks();
|
||||
});
|
||||
|
||||
it("emits updates for each user switch and update", async () => {
|
||||
const user1 = "user_00000000-0000-1000-a000-000000000001_fake_fake";
|
||||
const user2 = "user_00000000-0000-1000-a000-000000000002_fake_fake";
|
||||
const state1 = {
|
||||
date: new Date(2021, 0),
|
||||
array: ["user1"],
|
||||
};
|
||||
const state2 = {
|
||||
date: new Date(2022, 0),
|
||||
array: ["user2"],
|
||||
};
|
||||
const initialState: Record<string, TestState> = {};
|
||||
initialState[user1] = state1;
|
||||
initialState[user2] = state2;
|
||||
diskStorageService.internalUpdateStore(initialState);
|
||||
|
||||
const emissions = trackEmissions(userState.state$);
|
||||
|
||||
// User signs in
|
||||
await changeActiveUser("1");
|
||||
|
||||
// Service does an update
|
||||
const updatedState = {
|
||||
date: new Date(2023, 0),
|
||||
array: ["user1-update"],
|
||||
};
|
||||
await userState.update(() => updatedState);
|
||||
await awaitAsync();
|
||||
|
||||
// Emulate an account switch
|
||||
await changeActiveUser("2");
|
||||
|
||||
// #1 initial state from user1
|
||||
// #2 updated state for user1
|
||||
// #3 switched state to initial state for user2
|
||||
expect(emissions).toEqual([state1, updatedState, state2]);
|
||||
|
||||
// Should be called 4 time to get state, update state for user, emitting update, and switching users
|
||||
expect(diskStorageService.mock.get).toHaveBeenCalledTimes(4);
|
||||
// Initial subscribe to state$
|
||||
expect(diskStorageService.mock.get).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
"user_00000000-0000-1000-a000-000000000001_fake_fake",
|
||||
any(), // options
|
||||
);
|
||||
// The updating of state for user1
|
||||
expect(diskStorageService.mock.get).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
"user_00000000-0000-1000-a000-000000000001_fake_fake",
|
||||
any(), // options
|
||||
);
|
||||
// The emission from being actively subscribed to user1
|
||||
expect(diskStorageService.mock.get).toHaveBeenNthCalledWith(
|
||||
3,
|
||||
"user_00000000-0000-1000-a000-000000000001_fake_fake",
|
||||
any(), // options
|
||||
);
|
||||
// Switch to user2
|
||||
expect(diskStorageService.mock.get).toHaveBeenNthCalledWith(
|
||||
4,
|
||||
"user_00000000-0000-1000-a000-000000000002_fake_fake",
|
||||
any(), // options
|
||||
);
|
||||
// Should only have saved data for the first user
|
||||
expect(diskStorageService.mock.save).toHaveBeenCalledTimes(1);
|
||||
expect(diskStorageService.mock.save).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
"user_00000000-0000-1000-a000-000000000001_fake_fake",
|
||||
updatedState,
|
||||
any(), // options
|
||||
);
|
||||
});
|
||||
|
||||
it("will not emit any value if there isn't an active user", async () => {
|
||||
let resolvedValue: TestState | undefined = undefined;
|
||||
let rejectedError: Error | undefined = undefined;
|
||||
|
||||
const promise = firstValueFrom(userState.state$.pipe(timeout(20)))
|
||||
.then((value) => {
|
||||
resolvedValue = value;
|
||||
})
|
||||
.catch((err) => {
|
||||
rejectedError = err;
|
||||
});
|
||||
await promise;
|
||||
|
||||
expect(diskStorageService.mock.get).not.toHaveBeenCalled();
|
||||
|
||||
expect(resolvedValue).toBe(undefined);
|
||||
expect(rejectedError).toBeTruthy();
|
||||
expect(rejectedError.message).toBe("Timeout has occurred");
|
||||
});
|
||||
|
||||
it("will emit value for a new active user after subscription started", async () => {
|
||||
let resolvedValue: TestState | undefined = undefined;
|
||||
let rejectedError: Error | undefined = undefined;
|
||||
|
||||
diskStorageService.internalUpdateStore({
|
||||
"user_00000000-0000-1000-a000-000000000001_fake_fake": {
|
||||
date: new Date(2020, 0),
|
||||
array: ["testValue"],
|
||||
} as TestState,
|
||||
});
|
||||
|
||||
const promise = firstValueFrom(userState.state$.pipe(timeout(20)))
|
||||
.then((value) => {
|
||||
resolvedValue = value;
|
||||
})
|
||||
.catch((err) => {
|
||||
rejectedError = err;
|
||||
});
|
||||
await changeActiveUser("1");
|
||||
await promise;
|
||||
|
||||
expect(diskStorageService.mock.get).toHaveBeenCalledTimes(1);
|
||||
|
||||
expect(resolvedValue).toBeTruthy();
|
||||
expect(resolvedValue.array).toHaveLength(1);
|
||||
expect(resolvedValue.date.getFullYear()).toBe(2020);
|
||||
expect(rejectedError).toBeFalsy();
|
||||
});
|
||||
|
||||
it("should not emit a previous users value if that user is no longer active", async () => {
|
||||
const user1Data: Jsonify<TestState> = {
|
||||
date: "2020-09-21T13:14:17.648Z",
|
||||
// NOTE: `as any` is here until we migrate to Nx: https://bitwarden.atlassian.net/browse/PM-6493
|
||||
array: ["value"] as any,
|
||||
};
|
||||
const user2Data: Jsonify<TestState> = {
|
||||
date: "2020-09-21T13:14:17.648Z",
|
||||
array: [],
|
||||
};
|
||||
diskStorageService.internalUpdateStore({
|
||||
"user_00000000-0000-1000-a000-000000000001_fake_fake": user1Data,
|
||||
"user_00000000-0000-1000-a000-000000000002_fake_fake": user2Data,
|
||||
});
|
||||
|
||||
// This starts one subscription on the observable for tracking emissions throughout
|
||||
// the whole test.
|
||||
const emissions = trackEmissions(userState.state$);
|
||||
|
||||
// Change to a user with data
|
||||
await changeActiveUser("1");
|
||||
|
||||
// This should always return a value right await
|
||||
const value = await firstValueFrom(
|
||||
userState.state$.pipe(
|
||||
timeout({
|
||||
first: 20,
|
||||
with: () => {
|
||||
throw new Error("Did not emit data from newly active user.");
|
||||
},
|
||||
}),
|
||||
),
|
||||
);
|
||||
expect(value).toEqual(user1Data);
|
||||
|
||||
// Make it such that there is no active user
|
||||
await changeActiveUser(undefined);
|
||||
|
||||
let resolvedValue: TestState | undefined = undefined;
|
||||
let rejectedError: Error | undefined = undefined;
|
||||
|
||||
// Even if the observable has previously emitted a value it shouldn't have
|
||||
// a value for the user subscribing to it because there isn't an active user
|
||||
// to get data for.
|
||||
await firstValueFrom(userState.state$.pipe(timeout(20)))
|
||||
.then((value) => {
|
||||
resolvedValue = value;
|
||||
})
|
||||
.catch((err) => {
|
||||
rejectedError = err;
|
||||
});
|
||||
|
||||
expect(resolvedValue).toBeUndefined();
|
||||
expect(rejectedError).not.toBeUndefined();
|
||||
expect(rejectedError.message).toBe("Timeout has occurred");
|
||||
|
||||
// We need to figure out if something should be emitted
|
||||
// when there becomes no active user, if we don't want that to emit
|
||||
// this value is correct.
|
||||
expect(emissions).toEqual([user1Data]);
|
||||
});
|
||||
|
||||
it("should not emit twice if there are two listeners", async () => {
|
||||
await changeActiveUser("1");
|
||||
const emissions = trackEmissions(userState.state$);
|
||||
const emissions2 = trackEmissions(userState.state$);
|
||||
await awaitAsync();
|
||||
|
||||
expect(emissions).toEqual([
|
||||
null, // Initial value
|
||||
]);
|
||||
expect(emissions2).toEqual([
|
||||
null, // Initial value
|
||||
]);
|
||||
});
|
||||
|
||||
describe("update", () => {
|
||||
const newData = { date: new Date(), array: ["test"] };
|
||||
beforeEach(async () => {
|
||||
await changeActiveUser("1");
|
||||
});
|
||||
|
||||
it("should save on update", async () => {
|
||||
const [setUserId, result] = await userState.update((state, dependencies) => {
|
||||
return newData;
|
||||
});
|
||||
|
||||
expect(diskStorageService.mock.save).toHaveBeenCalledTimes(1);
|
||||
expect(result).toEqual(newData);
|
||||
expect(setUserId).toEqual("00000000-0000-1000-a000-000000000001");
|
||||
});
|
||||
|
||||
it("should emit once per update", async () => {
|
||||
const emissions = trackEmissions(userState.state$);
|
||||
await awaitAsync(); // Need to await for the initial value to be emitted
|
||||
|
||||
await userState.update((state, dependencies) => {
|
||||
return newData;
|
||||
});
|
||||
await awaitAsync();
|
||||
|
||||
expect(emissions).toEqual([
|
||||
null, // initial value
|
||||
newData,
|
||||
]);
|
||||
});
|
||||
|
||||
it("should provide combined dependencies", async () => {
|
||||
const emissions = trackEmissions(userState.state$);
|
||||
await awaitAsync(); // Need to await for the initial value to be emitted
|
||||
|
||||
const combinedDependencies = { date: new Date() };
|
||||
|
||||
await userState.update(
|
||||
(state, dependencies) => {
|
||||
expect(dependencies).toEqual(combinedDependencies);
|
||||
return newData;
|
||||
},
|
||||
{
|
||||
combineLatestWith: of(combinedDependencies),
|
||||
},
|
||||
);
|
||||
await awaitAsync();
|
||||
|
||||
expect(emissions).toEqual([
|
||||
null, // initial value
|
||||
newData,
|
||||
]);
|
||||
});
|
||||
|
||||
it("should not update if shouldUpdate returns false", async () => {
|
||||
const emissions = trackEmissions(userState.state$);
|
||||
await awaitAsync(); // Need to await for the initial value to be emitted
|
||||
|
||||
const [userIdResult, result] = await userState.update(
|
||||
(state, dependencies) => {
|
||||
return newData;
|
||||
},
|
||||
{
|
||||
shouldUpdate: () => false,
|
||||
},
|
||||
);
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
expect(diskStorageService.mock.save).not.toHaveBeenCalled();
|
||||
expect(userIdResult).toEqual("00000000-0000-1000-a000-000000000001");
|
||||
expect(result).toBeNull();
|
||||
expect(emissions).toEqual([null]);
|
||||
});
|
||||
|
||||
it("should provide the current state to the update callback", async () => {
|
||||
const emissions = trackEmissions(userState.state$);
|
||||
await awaitAsync(); // Need to await for the initial value to be emitted
|
||||
|
||||
// Seed with interesting data
|
||||
const initialData = { date: new Date(2020, 0), array: ["value1", "value2"] };
|
||||
await userState.update((state, dependencies) => {
|
||||
return initialData;
|
||||
});
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
await userState.update((state, dependencies) => {
|
||||
expect(state).toEqual(initialData);
|
||||
return newData;
|
||||
});
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
expect(emissions).toEqual([
|
||||
null, // Initial value
|
||||
initialData,
|
||||
newData,
|
||||
]);
|
||||
});
|
||||
|
||||
it("should throw on an attempted update when there is no active user", async () => {
|
||||
await changeActiveUser(undefined);
|
||||
|
||||
await expect(async () => await userState.update(() => null)).rejects.toThrow(
|
||||
"No active user at this time.",
|
||||
);
|
||||
});
|
||||
|
||||
it("should throw on an attempted update where there is no active user even if there used to be one", async () => {
|
||||
// Arrange
|
||||
diskStorageService.internalUpdateStore({
|
||||
"user_00000000-0000-1000-a000-000000000001_fake_fake": {
|
||||
date: new Date(2019, 1),
|
||||
array: [],
|
||||
},
|
||||
});
|
||||
|
||||
const [userId, state] = await firstValueFrom(userState.combinedState$);
|
||||
expect(userId).toBe("00000000-0000-1000-a000-000000000001");
|
||||
expect(state.date.getUTCFullYear()).toBe(2019);
|
||||
|
||||
await changeActiveUser(undefined);
|
||||
// Act
|
||||
|
||||
await expect(async () => await userState.update(() => null)).rejects.toThrow(
|
||||
"No active user at this time.",
|
||||
);
|
||||
});
|
||||
|
||||
it.each([null, undefined])(
|
||||
"should register user key definition when state transitions from null-ish (%s) to non-null",
|
||||
async (startingValue: TestState | null) => {
|
||||
diskStorageService.internalUpdateStore({
|
||||
"user_00000000-0000-1000-a000-000000000001_fake_fake": startingValue,
|
||||
});
|
||||
|
||||
await userState.update(() => ({ array: ["one"], date: new Date() }));
|
||||
|
||||
expect(stateEventRegistrarService.registerEvents).toHaveBeenCalledWith(testKeyDefinition);
|
||||
},
|
||||
);
|
||||
|
||||
it("should not register user key definition when state has preexisting value", async () => {
|
||||
diskStorageService.internalUpdateStore({
|
||||
"user_00000000-0000-1000-a000-000000000001_fake_fake": {
|
||||
date: new Date(2019, 1),
|
||||
array: [],
|
||||
},
|
||||
});
|
||||
|
||||
await userState.update(() => ({ array: ["one"], date: new Date() }));
|
||||
|
||||
expect(stateEventRegistrarService.registerEvents).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it.each([null, undefined])(
|
||||
"should not register user key definition when setting value to null-ish (%s) value",
|
||||
async (updatedValue: TestState | null) => {
|
||||
diskStorageService.internalUpdateStore({
|
||||
"user_00000000-0000-1000-a000-000000000001_fake_fake": {
|
||||
date: new Date(2019, 1),
|
||||
array: [],
|
||||
},
|
||||
});
|
||||
|
||||
await userState.update(() => updatedValue);
|
||||
|
||||
expect(stateEventRegistrarService.registerEvents).not.toHaveBeenCalled();
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
describe("update races", () => {
|
||||
const newData = { date: new Date(), array: ["test"] };
|
||||
const userId = makeUserId("1");
|
||||
|
||||
beforeEach(async () => {
|
||||
await changeActiveUser("1");
|
||||
await awaitAsync();
|
||||
});
|
||||
|
||||
test("subscriptions during an update should receive the current and latest", async () => {
|
||||
const oldData = { date: new Date(2019, 1, 1), array: ["oldValue1"] };
|
||||
await userState.update(() => {
|
||||
return oldData;
|
||||
});
|
||||
const initialData = { date: new Date(2020, 1, 1), array: ["value1", "value2"] };
|
||||
await userState.update(() => {
|
||||
return initialData;
|
||||
});
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
const emissions = trackEmissions(userState.state$);
|
||||
await awaitAsync();
|
||||
expect(emissions).toEqual([initialData]);
|
||||
|
||||
let emissions2: TestState[];
|
||||
const originalSave = diskStorageService.save.bind(diskStorageService);
|
||||
diskStorageService.save = jest.fn().mockImplementation(async (key: string, obj: any) => {
|
||||
emissions2 = trackEmissions(userState.state$);
|
||||
await originalSave(key, obj);
|
||||
});
|
||||
|
||||
const [userIdResult, val] = await userState.update(() => {
|
||||
return newData;
|
||||
});
|
||||
|
||||
await awaitAsync(10);
|
||||
|
||||
expect(userIdResult).toEqual(userId);
|
||||
expect(val).toEqual(newData);
|
||||
expect(emissions).toEqual([initialData, newData]);
|
||||
expect(emissions2).toEqual([initialData, newData]);
|
||||
});
|
||||
|
||||
test("subscription during an aborted update should receive the last value", async () => {
|
||||
// Seed with interesting data
|
||||
const initialData = { date: new Date(2020, 1, 1), array: ["value1", "value2"] };
|
||||
await userState.update(() => {
|
||||
return initialData;
|
||||
});
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
const emissions = trackEmissions(userState.state$);
|
||||
await awaitAsync();
|
||||
expect(emissions).toEqual([initialData]);
|
||||
|
||||
let emissions2: TestState[];
|
||||
const [userIdResult, val] = await userState.update(
|
||||
(state) => {
|
||||
return newData;
|
||||
},
|
||||
{
|
||||
shouldUpdate: () => {
|
||||
emissions2 = trackEmissions(userState.state$);
|
||||
return false;
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
expect(userIdResult).toEqual(userId);
|
||||
expect(val).toEqual(initialData);
|
||||
expect(emissions).toEqual([initialData]);
|
||||
|
||||
expect(emissions2).toEqual([initialData]);
|
||||
});
|
||||
|
||||
test("updates should wait until previous update is complete", async () => {
|
||||
trackEmissions(userState.state$);
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
|
||||
const originalSave = diskStorageService.save.bind(diskStorageService);
|
||||
diskStorageService.save = jest
|
||||
.fn()
|
||||
.mockImplementationOnce(async (key: string, obj: any) => {
|
||||
let resolved = false;
|
||||
await Promise.race([
|
||||
userState.update(() => {
|
||||
// deadlocks
|
||||
resolved = true;
|
||||
return newData;
|
||||
}),
|
||||
awaitAsync(100), // limit test to 100ms
|
||||
]);
|
||||
expect(resolved).toBe(false);
|
||||
})
|
||||
.mockImplementation((...args) => {
|
||||
return originalSave(...args);
|
||||
});
|
||||
|
||||
await userState.update(() => {
|
||||
return newData;
|
||||
});
|
||||
});
|
||||
|
||||
test("updates with FAKE_DEFAULT initial value should resolve correctly", async () => {
|
||||
expect(diskStorageService["updatesSubject"]["observers"]).toHaveLength(0);
|
||||
const [userIdResult, val] = await userState.update((state) => {
|
||||
return newData;
|
||||
});
|
||||
|
||||
expect(userIdResult).toEqual(userId);
|
||||
expect(val).toEqual(newData);
|
||||
const call = diskStorageService.mock.save.mock.calls[0];
|
||||
expect(call[0]).toEqual(`user_${userId}_fake_fake`);
|
||||
expect(call[1]).toEqual(newData);
|
||||
});
|
||||
|
||||
it("does not await updates if the active user changes", async () => {
|
||||
const initialUserId = (await firstValueFrom(activeAccountSubject)).id;
|
||||
expect(initialUserId).toBe(userId);
|
||||
trackEmissions(userState.state$);
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
|
||||
const originalSave = diskStorageService.save.bind(diskStorageService);
|
||||
diskStorageService.save = jest
|
||||
.fn()
|
||||
.mockImplementationOnce(async (key: string, obj: any) => {
|
||||
let resolved = false;
|
||||
await changeActiveUser("2");
|
||||
await Promise.race([
|
||||
userState.update(() => {
|
||||
// should not deadlock because we updated the user
|
||||
resolved = true;
|
||||
return newData;
|
||||
}),
|
||||
awaitAsync(100), // limit test to 100ms
|
||||
]);
|
||||
expect(resolved).toBe(true);
|
||||
})
|
||||
.mockImplementation((...args) => {
|
||||
return originalSave(...args);
|
||||
});
|
||||
|
||||
await userState.update(() => {
|
||||
return newData;
|
||||
});
|
||||
});
|
||||
|
||||
it("stores updates for users in the correct place when active user changes mid-update", async () => {
|
||||
trackEmissions(userState.state$);
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
|
||||
const user2Data = { date: new Date(), array: ["user 2 data"] };
|
||||
|
||||
const originalSave = diskStorageService.save.bind(diskStorageService);
|
||||
diskStorageService.save = jest
|
||||
.fn()
|
||||
.mockImplementationOnce(async (key: string, obj: any) => {
|
||||
let resolved = false;
|
||||
await changeActiveUser("2");
|
||||
await Promise.race([
|
||||
userState.update(() => {
|
||||
// should not deadlock because we updated the user
|
||||
resolved = true;
|
||||
return user2Data;
|
||||
}),
|
||||
awaitAsync(100), // limit test to 100ms
|
||||
]);
|
||||
expect(resolved).toBe(true);
|
||||
await originalSave(key, obj);
|
||||
})
|
||||
.mockImplementation((...args) => {
|
||||
return originalSave(...args);
|
||||
});
|
||||
|
||||
await userState.update(() => {
|
||||
return newData;
|
||||
});
|
||||
await awaitAsync();
|
||||
|
||||
expect(diskStorageService.mock.save).toHaveBeenCalledTimes(2);
|
||||
const innerCall = diskStorageService.mock.save.mock.calls[0];
|
||||
expect(innerCall[0]).toEqual(`user_${makeUserId("2")}_fake_fake`);
|
||||
expect(innerCall[1]).toEqual(user2Data);
|
||||
const outerCall = diskStorageService.mock.save.mock.calls[1];
|
||||
expect(outerCall[0]).toEqual(`user_${makeUserId("1")}_fake_fake`);
|
||||
expect(outerCall[1]).toEqual(newData);
|
||||
});
|
||||
});
|
||||
|
||||
describe("cleanup", () => {
|
||||
const newData = { date: new Date(), array: ["test"] };
|
||||
const userId = makeUserId("1");
|
||||
let userKey: string;
|
||||
|
||||
beforeEach(async () => {
|
||||
await changeActiveUser("1");
|
||||
userKey = testKeyDefinition.buildKey(userId);
|
||||
});
|
||||
|
||||
function assertClean() {
|
||||
expect(activeAccountSubject["observers"]).toHaveLength(0);
|
||||
expect(diskStorageService["updatesSubject"]["observers"]).toHaveLength(0);
|
||||
}
|
||||
|
||||
it("should cleanup after last subscriber", async () => {
|
||||
const subscription = userState.state$.subscribe();
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
|
||||
subscription.unsubscribe();
|
||||
expect(diskStorageService["updatesSubject"]["observers"]).toHaveLength(1);
|
||||
// Wait for cleanup
|
||||
await awaitAsync(cleanupDelayMs * 2);
|
||||
|
||||
assertClean();
|
||||
});
|
||||
|
||||
it("should not cleanup if there are still subscribers", async () => {
|
||||
const subscription1 = userState.state$.subscribe();
|
||||
const sub2Emissions: TestState[] = [];
|
||||
const subscription2 = userState.state$.subscribe((v) => sub2Emissions.push(v));
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
|
||||
subscription1.unsubscribe();
|
||||
|
||||
// Wait for cleanup
|
||||
await awaitAsync(cleanupDelayMs * 2);
|
||||
|
||||
expect(diskStorageService["updatesSubject"]["observers"]).toHaveLength(1);
|
||||
|
||||
// Still be listening to storage updates
|
||||
// FIXME: Verify that this floating promise is intentional. If it is, add an explanatory comment and ensure there is proper error handling.
|
||||
// eslint-disable-next-line @typescript-eslint/no-floating-promises
|
||||
diskStorageService.save(userKey, newData);
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
expect(sub2Emissions).toEqual([null, newData]);
|
||||
|
||||
subscription2.unsubscribe();
|
||||
// Wait for cleanup
|
||||
await awaitAsync(cleanupDelayMs * 2);
|
||||
|
||||
assertClean();
|
||||
});
|
||||
|
||||
it("can re-initialize after cleanup", async () => {
|
||||
const subscription = userState.state$.subscribe();
|
||||
await awaitAsync();
|
||||
|
||||
subscription.unsubscribe();
|
||||
// Wait for cleanup
|
||||
await awaitAsync(cleanupDelayMs * 2);
|
||||
|
||||
const emissions = trackEmissions(userState.state$);
|
||||
await awaitAsync();
|
||||
|
||||
await diskStorageService.save(userKey, newData);
|
||||
await awaitAsync();
|
||||
|
||||
expect(emissions).toEqual([null, newData]);
|
||||
});
|
||||
|
||||
it("should not cleanup if a subscriber joins during the cleanup delay", async () => {
|
||||
const subscription = userState.state$.subscribe();
|
||||
await awaitAsync();
|
||||
|
||||
await diskStorageService.save(userKey, newData);
|
||||
await awaitAsync();
|
||||
|
||||
subscription.unsubscribe();
|
||||
// Do not wait long enough for cleanup
|
||||
await awaitAsync(cleanupDelayMs / 2);
|
||||
|
||||
const state = await firstValueFrom(userState.state$);
|
||||
|
||||
expect(state).toEqual(newData); // digging in to check that it hasn't been cleared
|
||||
|
||||
// Should be called once for the initial subscription and once from the save
|
||||
// but should NOT be called for the second subscription from the `firstValueFrom`
|
||||
expect(diskStorageService.mock.get).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it("state$ observables are durable to cleanup", async () => {
|
||||
const observable = userState.state$;
|
||||
let subscription = observable.subscribe();
|
||||
|
||||
await diskStorageService.save(userKey, newData);
|
||||
await awaitAsync();
|
||||
|
||||
subscription.unsubscribe();
|
||||
// Wait for cleanup
|
||||
await awaitAsync(cleanupDelayMs * 2);
|
||||
|
||||
subscription = observable.subscribe();
|
||||
await diskStorageService.save(userKey, newData);
|
||||
await awaitAsync();
|
||||
|
||||
expect(await firstValueFrom(observable)).toEqual(newData);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,64 +1 @@
|
||||
// FIXME: Update this file to be type safe and remove this and next line
|
||||
// @ts-strict-ignore
|
||||
import { Observable, map, switchMap, firstValueFrom, timeout, throwError, NEVER } from "rxjs";
|
||||
|
||||
import { UserId } from "../../../types/guid";
|
||||
import { StateUpdateOptions } from "../state-update-options";
|
||||
import { UserKeyDefinition } from "../user-key-definition";
|
||||
import { ActiveUserState, CombinedState, activeMarker } from "../user-state";
|
||||
import { SingleUserStateProvider } from "../user-state.provider";
|
||||
|
||||
export class DefaultActiveUserState<T> implements ActiveUserState<T> {
|
||||
[activeMarker]: true;
|
||||
combinedState$: Observable<CombinedState<T>>;
|
||||
state$: Observable<T>;
|
||||
|
||||
constructor(
|
||||
protected keyDefinition: UserKeyDefinition<T>,
|
||||
private activeUserId$: Observable<UserId | null>,
|
||||
private singleUserStateProvider: SingleUserStateProvider,
|
||||
) {
|
||||
this.combinedState$ = this.activeUserId$.pipe(
|
||||
switchMap((userId) =>
|
||||
userId != null
|
||||
? this.singleUserStateProvider.get(userId, this.keyDefinition).combinedState$
|
||||
: NEVER,
|
||||
),
|
||||
);
|
||||
|
||||
// State should just be combined state without the user id
|
||||
this.state$ = this.combinedState$.pipe(map(([_userId, state]) => state));
|
||||
}
|
||||
|
||||
async update<TCombine>(
|
||||
configureState: (state: T, dependency: TCombine) => T,
|
||||
options: StateUpdateOptions<T, TCombine> = {},
|
||||
): Promise<[UserId, T]> {
|
||||
const userId = await firstValueFrom(
|
||||
this.activeUserId$.pipe(
|
||||
timeout({
|
||||
first: 1000,
|
||||
with: () =>
|
||||
throwError(
|
||||
() =>
|
||||
new Error(
|
||||
`Timeout while retrieving active user for key ${this.keyDefinition.fullName}.`,
|
||||
),
|
||||
),
|
||||
}),
|
||||
),
|
||||
);
|
||||
if (userId == null) {
|
||||
throw new Error(
|
||||
`Error storing ${this.keyDefinition.fullName} for the active user: No active user at this time.`,
|
||||
);
|
||||
}
|
||||
|
||||
return [
|
||||
userId,
|
||||
await this.singleUserStateProvider
|
||||
.get(userId, this.keyDefinition)
|
||||
.update(configureState, options),
|
||||
];
|
||||
}
|
||||
}
|
||||
export { DefaultActiveUserState } from "@bitwarden/state";
|
||||
|
||||
@@ -1,53 +1 @@
|
||||
import { Observable } from "rxjs";
|
||||
|
||||
import { DerivedStateDependencies } from "../../../types/state";
|
||||
import { DeriveDefinition } from "../derive-definition";
|
||||
import { DerivedState } from "../derived-state";
|
||||
import { DerivedStateProvider } from "../derived-state.provider";
|
||||
|
||||
import { DefaultDerivedState } from "./default-derived-state";
|
||||
|
||||
export class DefaultDerivedStateProvider implements DerivedStateProvider {
|
||||
/**
|
||||
* The cache uses a WeakMap to maintain separate derived states per user.
|
||||
* Each user's state Observable acts as a unique key, without needing to
|
||||
* pass around `userId`. Also, when a user's state Observable is cleaned up
|
||||
* (like during an account swap) their cache is automatically garbage
|
||||
* collected.
|
||||
*/
|
||||
private cache = new WeakMap<Observable<unknown>, Record<string, DerivedState<unknown>>>();
|
||||
|
||||
constructor() {}
|
||||
|
||||
get<TFrom, TTo, TDeps extends DerivedStateDependencies>(
|
||||
parentState$: Observable<TFrom>,
|
||||
deriveDefinition: DeriveDefinition<TFrom, TTo, TDeps>,
|
||||
dependencies: TDeps,
|
||||
): DerivedState<TTo> {
|
||||
let stateCache = this.cache.get(parentState$);
|
||||
if (!stateCache) {
|
||||
stateCache = {};
|
||||
this.cache.set(parentState$, stateCache);
|
||||
}
|
||||
|
||||
const cacheKey = deriveDefinition.buildCacheKey();
|
||||
const existingDerivedState = stateCache[cacheKey];
|
||||
if (existingDerivedState != null) {
|
||||
// I have to cast out of the unknown generic but this should be safe if rules
|
||||
// around domain token are made
|
||||
return existingDerivedState as DefaultDerivedState<TFrom, TTo, TDeps>;
|
||||
}
|
||||
|
||||
const newDerivedState = this.buildDerivedState(parentState$, deriveDefinition, dependencies);
|
||||
stateCache[cacheKey] = newDerivedState;
|
||||
return newDerivedState;
|
||||
}
|
||||
|
||||
protected buildDerivedState<TFrom, TTo, TDeps extends DerivedStateDependencies>(
|
||||
parentState$: Observable<TFrom>,
|
||||
deriveDefinition: DeriveDefinition<TFrom, TTo, TDeps>,
|
||||
dependencies: TDeps,
|
||||
): DerivedState<TTo> {
|
||||
return new DefaultDerivedState<TFrom, TTo, TDeps>(parentState$, deriveDefinition, dependencies);
|
||||
}
|
||||
}
|
||||
export { DefaultDerivedStateProvider } from "@bitwarden/state";
|
||||
|
||||
@@ -1,211 +0,0 @@
|
||||
/**
|
||||
* need to update test environment so trackEmissions works appropriately
|
||||
* @jest-environment ../shared/test.environment.ts
|
||||
*/
|
||||
import { Subject, firstValueFrom } from "rxjs";
|
||||
|
||||
import { awaitAsync, trackEmissions } from "../../../../spec";
|
||||
import { DeriveDefinition } from "../derive-definition";
|
||||
import { StateDefinition } from "../state-definition";
|
||||
|
||||
import { DefaultDerivedState } from "./default-derived-state";
|
||||
import { DefaultDerivedStateProvider } from "./default-derived-state.provider";
|
||||
|
||||
let callCount = 0;
|
||||
const cleanupDelayMs = 10;
|
||||
const stateDefinition = new StateDefinition("test", "memory");
|
||||
const deriveDefinition = new DeriveDefinition<string, Date, { date: Date }>(
|
||||
stateDefinition,
|
||||
"test",
|
||||
{
|
||||
derive: (dateString: string) => {
|
||||
callCount++;
|
||||
return new Date(dateString);
|
||||
},
|
||||
deserializer: (dateString: string) => new Date(dateString),
|
||||
cleanupDelayMs,
|
||||
},
|
||||
);
|
||||
|
||||
describe("DefaultDerivedState", () => {
|
||||
let parentState$: Subject<string>;
|
||||
let sut: DefaultDerivedState<string, Date, { date: Date }>;
|
||||
const deps = {
|
||||
date: new Date(),
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
callCount = 0;
|
||||
parentState$ = new Subject();
|
||||
sut = new DefaultDerivedState(parentState$, deriveDefinition, deps);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
parentState$.complete();
|
||||
jest.resetAllMocks();
|
||||
});
|
||||
|
||||
it("should derive the state", async () => {
|
||||
const dateString = "2020-01-01";
|
||||
const emissions = trackEmissions(sut.state$);
|
||||
|
||||
parentState$.next(dateString);
|
||||
await awaitAsync();
|
||||
|
||||
expect(emissions).toEqual([new Date(dateString)]);
|
||||
});
|
||||
|
||||
it("should derive the state once", async () => {
|
||||
const dateString = "2020-01-01";
|
||||
trackEmissions(sut.state$);
|
||||
|
||||
parentState$.next(dateString);
|
||||
|
||||
expect(callCount).toBe(1);
|
||||
});
|
||||
|
||||
describe("forceValue", () => {
|
||||
const initialParentValue = "2020-01-01";
|
||||
const forced = new Date("2020-02-02");
|
||||
let emissions: Date[];
|
||||
|
||||
beforeEach(async () => {
|
||||
emissions = trackEmissions(sut.state$);
|
||||
parentState$.next(initialParentValue);
|
||||
await awaitAsync();
|
||||
});
|
||||
|
||||
it("should force the value", async () => {
|
||||
await sut.forceValue(forced);
|
||||
expect(emissions).toEqual([new Date(initialParentValue), forced]);
|
||||
});
|
||||
|
||||
it("should only force the value once", async () => {
|
||||
await sut.forceValue(forced);
|
||||
|
||||
parentState$.next(initialParentValue);
|
||||
await awaitAsync();
|
||||
|
||||
expect(emissions).toEqual([
|
||||
new Date(initialParentValue),
|
||||
forced,
|
||||
new Date(initialParentValue),
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("cleanup", () => {
|
||||
const newDate = "2020-02-02";
|
||||
|
||||
it("should cleanup after last subscriber", async () => {
|
||||
const subscription = sut.state$.subscribe();
|
||||
await awaitAsync();
|
||||
|
||||
subscription.unsubscribe();
|
||||
// Wait for cleanup
|
||||
await awaitAsync(cleanupDelayMs * 2);
|
||||
|
||||
expect(parentState$.observed).toBe(false);
|
||||
});
|
||||
|
||||
it("should not cleanup if there are still subscribers", async () => {
|
||||
const subscription1 = sut.state$.subscribe();
|
||||
const sub2Emissions: Date[] = [];
|
||||
const subscription2 = sut.state$.subscribe((v) => sub2Emissions.push(v));
|
||||
await awaitAsync();
|
||||
|
||||
subscription1.unsubscribe();
|
||||
|
||||
// Wait for cleanup
|
||||
await awaitAsync(cleanupDelayMs * 2);
|
||||
|
||||
// Still be listening to parent updates
|
||||
parentState$.next(newDate);
|
||||
await awaitAsync();
|
||||
expect(sub2Emissions).toEqual([new Date(newDate)]);
|
||||
|
||||
subscription2.unsubscribe();
|
||||
// Wait for cleanup
|
||||
await awaitAsync(cleanupDelayMs * 2);
|
||||
|
||||
expect(parentState$.observed).toBe(false);
|
||||
});
|
||||
|
||||
it("can re-initialize after cleanup", async () => {
|
||||
const subscription = sut.state$.subscribe();
|
||||
await awaitAsync();
|
||||
|
||||
subscription.unsubscribe();
|
||||
// Wait for cleanup
|
||||
await awaitAsync(cleanupDelayMs * 2);
|
||||
|
||||
const emissions = trackEmissions(sut.state$);
|
||||
await awaitAsync();
|
||||
|
||||
parentState$.next(newDate);
|
||||
await awaitAsync();
|
||||
|
||||
expect(emissions).toEqual([new Date(newDate)]);
|
||||
});
|
||||
|
||||
it("should not cleanup if a subscriber joins during the cleanup delay", async () => {
|
||||
const subscription = sut.state$.subscribe();
|
||||
await awaitAsync();
|
||||
|
||||
await parentState$.next(newDate);
|
||||
await awaitAsync();
|
||||
|
||||
subscription.unsubscribe();
|
||||
// Do not wait long enough for cleanup
|
||||
await awaitAsync(cleanupDelayMs / 2);
|
||||
|
||||
expect(parentState$.observed).toBe(true); // still listening to parent
|
||||
|
||||
const emissions = trackEmissions(sut.state$);
|
||||
expect(emissions).toEqual([new Date(newDate)]); // we didn't lose our buffered value
|
||||
});
|
||||
|
||||
it("state$ observables are durable to cleanup", async () => {
|
||||
const observable = sut.state$;
|
||||
let subscription = observable.subscribe();
|
||||
|
||||
await parentState$.next(newDate);
|
||||
await awaitAsync();
|
||||
|
||||
subscription.unsubscribe();
|
||||
// Wait for cleanup
|
||||
await awaitAsync(cleanupDelayMs * 2);
|
||||
|
||||
subscription = observable.subscribe();
|
||||
await parentState$.next(newDate);
|
||||
await awaitAsync();
|
||||
|
||||
expect(await firstValueFrom(observable)).toEqual(new Date(newDate));
|
||||
});
|
||||
});
|
||||
|
||||
describe("account switching", () => {
|
||||
let provider: DefaultDerivedStateProvider;
|
||||
|
||||
beforeEach(() => {
|
||||
provider = new DefaultDerivedStateProvider();
|
||||
});
|
||||
|
||||
it("should provide a dedicated cache for each account", async () => {
|
||||
const user1State$ = new Subject<string>();
|
||||
const user1Derived = provider.get(user1State$, deriveDefinition, deps);
|
||||
const user1Emissions = trackEmissions(user1Derived.state$);
|
||||
|
||||
const user2State$ = new Subject<string>();
|
||||
const user2Derived = provider.get(user2State$, deriveDefinition, deps);
|
||||
const user2Emissions = trackEmissions(user2Derived.state$);
|
||||
|
||||
user1State$.next("2015-12-30");
|
||||
user2State$.next("2020-12-29");
|
||||
await awaitAsync();
|
||||
|
||||
expect(user1Emissions).toEqual([new Date("2015-12-30")]);
|
||||
expect(user2Emissions).toEqual([new Date("2020-12-29")]);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,50 +1 @@
|
||||
import { Observable, ReplaySubject, Subject, concatMap, merge, share, timer } from "rxjs";
|
||||
|
||||
import { DerivedStateDependencies } from "../../../types/state";
|
||||
import { DeriveDefinition } from "../derive-definition";
|
||||
import { DerivedState } from "../derived-state";
|
||||
|
||||
/**
|
||||
* Default derived state
|
||||
*/
|
||||
export class DefaultDerivedState<TFrom, TTo, TDeps extends DerivedStateDependencies>
|
||||
implements DerivedState<TTo>
|
||||
{
|
||||
private readonly storageKey: string;
|
||||
private forcedValueSubject = new Subject<TTo>();
|
||||
|
||||
state$: Observable<TTo>;
|
||||
|
||||
constructor(
|
||||
private parentState$: Observable<TFrom>,
|
||||
protected deriveDefinition: DeriveDefinition<TFrom, TTo, TDeps>,
|
||||
private dependencies: TDeps,
|
||||
) {
|
||||
this.storageKey = deriveDefinition.storageKey;
|
||||
|
||||
const derivedState$ = this.parentState$.pipe(
|
||||
concatMap(async (state) => {
|
||||
let derivedStateOrPromise = this.deriveDefinition.derive(state, this.dependencies);
|
||||
if (derivedStateOrPromise instanceof Promise) {
|
||||
derivedStateOrPromise = await derivedStateOrPromise;
|
||||
}
|
||||
const derivedState = derivedStateOrPromise;
|
||||
return derivedState;
|
||||
}),
|
||||
);
|
||||
|
||||
this.state$ = merge(this.forcedValueSubject, derivedState$).pipe(
|
||||
share({
|
||||
connector: () => {
|
||||
return new ReplaySubject<TTo>(1);
|
||||
},
|
||||
resetOnRefCountZero: () => timer(this.deriveDefinition.cleanupDelayMs),
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
async forceValue(value: TTo) {
|
||||
this.forcedValueSubject.next(value);
|
||||
return value;
|
||||
}
|
||||
}
|
||||
export { DefaultDerivedState } from "@bitwarden/state";
|
||||
|
||||
@@ -1,46 +1 @@
|
||||
// FIXME: Update this file to be type safe and remove this and next line
|
||||
// @ts-strict-ignore
|
||||
import { StorageServiceProvider } from "@bitwarden/storage-core";
|
||||
|
||||
import { LogService } from "../../abstractions/log.service";
|
||||
import { GlobalState } from "../global-state";
|
||||
import { GlobalStateProvider } from "../global-state.provider";
|
||||
import { KeyDefinition } from "../key-definition";
|
||||
|
||||
import { DefaultGlobalState } from "./default-global-state";
|
||||
|
||||
export class DefaultGlobalStateProvider implements GlobalStateProvider {
|
||||
private globalStateCache: Record<string, GlobalState<unknown>> = {};
|
||||
|
||||
constructor(
|
||||
private storageServiceProvider: StorageServiceProvider,
|
||||
private readonly logService: LogService,
|
||||
) {}
|
||||
|
||||
get<T>(keyDefinition: KeyDefinition<T>): GlobalState<T> {
|
||||
const [location, storageService] = this.storageServiceProvider.get(
|
||||
keyDefinition.stateDefinition.defaultStorageLocation,
|
||||
keyDefinition.stateDefinition.storageLocationOverrides,
|
||||
);
|
||||
const cacheKey = this.buildCacheKey(location, keyDefinition);
|
||||
const existingGlobalState = this.globalStateCache[cacheKey];
|
||||
if (existingGlobalState != null) {
|
||||
// The cast into the actual generic is safe because of rules around key definitions
|
||||
// being unique.
|
||||
return existingGlobalState as DefaultGlobalState<T>;
|
||||
}
|
||||
|
||||
const newGlobalState = new DefaultGlobalState<T>(
|
||||
keyDefinition,
|
||||
storageService,
|
||||
this.logService,
|
||||
);
|
||||
|
||||
this.globalStateCache[cacheKey] = newGlobalState;
|
||||
return newGlobalState;
|
||||
}
|
||||
|
||||
private buildCacheKey(location: string, keyDefinition: KeyDefinition<unknown>) {
|
||||
return `${location}_${keyDefinition.fullName}`;
|
||||
}
|
||||
}
|
||||
export { DefaultGlobalStateProvider } from "@bitwarden/state";
|
||||
|
||||
@@ -1,411 +0,0 @@
|
||||
/**
|
||||
* need to update test environment so trackEmissions works appropriately
|
||||
* @jest-environment ../shared/test.environment.ts
|
||||
*/
|
||||
|
||||
import { mock } from "jest-mock-extended";
|
||||
import { firstValueFrom, of } from "rxjs";
|
||||
import { Jsonify } from "type-fest";
|
||||
|
||||
import { trackEmissions, awaitAsync } from "../../../../spec";
|
||||
import { FakeStorageService } from "../../../../spec/fake-storage.service";
|
||||
import { LogService } from "../../abstractions/log.service";
|
||||
import { KeyDefinition, globalKeyBuilder } from "../key-definition";
|
||||
import { StateDefinition } from "../state-definition";
|
||||
|
||||
import { DefaultGlobalState } from "./default-global-state";
|
||||
|
||||
class TestState {
|
||||
date: Date;
|
||||
|
||||
static fromJSON(jsonState: Jsonify<TestState>) {
|
||||
if (jsonState == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return Object.assign(new TestState(), jsonState, {
|
||||
date: new Date(jsonState.date),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const testStateDefinition = new StateDefinition("fake", "disk");
|
||||
const cleanupDelayMs = 10;
|
||||
const testKeyDefinition = new KeyDefinition<TestState>(testStateDefinition, "fake", {
|
||||
deserializer: TestState.fromJSON,
|
||||
cleanupDelayMs,
|
||||
});
|
||||
const globalKey = globalKeyBuilder(testKeyDefinition);
|
||||
|
||||
describe("DefaultGlobalState", () => {
|
||||
let diskStorageService: FakeStorageService;
|
||||
let globalState: DefaultGlobalState<TestState>;
|
||||
const logService = mock<LogService>();
|
||||
const newData = { date: new Date() };
|
||||
|
||||
beforeEach(() => {
|
||||
diskStorageService = new FakeStorageService();
|
||||
globalState = new DefaultGlobalState(testKeyDefinition, diskStorageService, logService);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.resetAllMocks();
|
||||
});
|
||||
|
||||
describe("state$", () => {
|
||||
it("should emit when storage updates", async () => {
|
||||
const emissions = trackEmissions(globalState.state$);
|
||||
await diskStorageService.save(globalKey, newData);
|
||||
await awaitAsync();
|
||||
|
||||
expect(emissions).toEqual([
|
||||
null, // Initial value
|
||||
newData,
|
||||
]);
|
||||
});
|
||||
|
||||
it("should not emit when update key does not match", async () => {
|
||||
const emissions = trackEmissions(globalState.state$);
|
||||
await diskStorageService.save("wrong_key", newData);
|
||||
|
||||
expect(emissions).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("should emit initial storage value on first subscribe", async () => {
|
||||
const initialStorage: Record<string, TestState> = {};
|
||||
initialStorage[globalKey] = TestState.fromJSON({
|
||||
date: "2022-09-21T13:14:17.648Z",
|
||||
});
|
||||
diskStorageService.internalUpdateStore(initialStorage);
|
||||
|
||||
const state = await firstValueFrom(globalState.state$);
|
||||
expect(diskStorageService.mock.get).toHaveBeenCalledTimes(1);
|
||||
expect(diskStorageService.mock.get).toHaveBeenCalledWith("global_fake_fake", undefined);
|
||||
expect(state).toBeTruthy();
|
||||
});
|
||||
|
||||
it("should not emit twice if there are two listeners", async () => {
|
||||
const emissions = trackEmissions(globalState.state$);
|
||||
const emissions2 = trackEmissions(globalState.state$);
|
||||
await awaitAsync();
|
||||
|
||||
expect(emissions).toEqual([
|
||||
null, // Initial value
|
||||
]);
|
||||
expect(emissions2).toEqual([
|
||||
null, // Initial value
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("update", () => {
|
||||
it("should save on update", async () => {
|
||||
const result = await globalState.update((state) => {
|
||||
return newData;
|
||||
});
|
||||
|
||||
expect(diskStorageService.mock.save).toHaveBeenCalledTimes(1);
|
||||
expect(result).toEqual(newData);
|
||||
});
|
||||
|
||||
it("should emit once per update", async () => {
|
||||
const emissions = trackEmissions(globalState.state$);
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
|
||||
await globalState.update((state) => {
|
||||
return newData;
|
||||
});
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
expect(emissions).toEqual([
|
||||
null, // Initial value
|
||||
newData,
|
||||
]);
|
||||
});
|
||||
|
||||
it("should provided combined dependencies", async () => {
|
||||
const emissions = trackEmissions(globalState.state$);
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
|
||||
const combinedDependencies = { date: new Date() };
|
||||
|
||||
await globalState.update(
|
||||
(state, dependencies) => {
|
||||
expect(dependencies).toEqual(combinedDependencies);
|
||||
return newData;
|
||||
},
|
||||
{
|
||||
combineLatestWith: of(combinedDependencies),
|
||||
},
|
||||
);
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
expect(emissions).toEqual([
|
||||
null, // Initial value
|
||||
newData,
|
||||
]);
|
||||
});
|
||||
|
||||
it("should not update if shouldUpdate returns false", async () => {
|
||||
const emissions = trackEmissions(globalState.state$);
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
|
||||
const result = await globalState.update(
|
||||
(state) => {
|
||||
return newData;
|
||||
},
|
||||
{
|
||||
shouldUpdate: () => false,
|
||||
},
|
||||
);
|
||||
|
||||
expect(diskStorageService.mock.save).not.toHaveBeenCalled();
|
||||
expect(emissions).toEqual([null]); // Initial value
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("should provide the update callback with the current State", async () => {
|
||||
const emissions = trackEmissions(globalState.state$);
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
|
||||
// Seed with interesting data
|
||||
const initialData = { date: new Date(2020, 1, 1) };
|
||||
await globalState.update((state, dependencies) => {
|
||||
return initialData;
|
||||
});
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
await globalState.update((state) => {
|
||||
expect(state).toEqual(initialData);
|
||||
return newData;
|
||||
});
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
expect(emissions).toEqual([
|
||||
null, // Initial value
|
||||
initialData,
|
||||
newData,
|
||||
]);
|
||||
});
|
||||
|
||||
it("should give initial state for update call", async () => {
|
||||
const initialStorage: Record<string, TestState> = {};
|
||||
const initialState = TestState.fromJSON({
|
||||
date: "2022-09-21T13:14:17.648Z",
|
||||
});
|
||||
initialStorage[globalKey] = initialState;
|
||||
diskStorageService.internalUpdateStore(initialStorage);
|
||||
|
||||
const emissions = trackEmissions(globalState.state$);
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
|
||||
const newState = {
|
||||
...initialState,
|
||||
date: new Date(initialState.date.getFullYear(), initialState.date.getMonth() + 1),
|
||||
};
|
||||
const actual = await globalState.update((existingState) => newState);
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
expect(actual).toEqual(newState);
|
||||
expect(emissions).toHaveLength(2);
|
||||
expect(emissions).toEqual(expect.arrayContaining([initialState, newState]));
|
||||
});
|
||||
});
|
||||
|
||||
describe("update races", () => {
|
||||
test("subscriptions during an update should receive the current and latest data", async () => {
|
||||
const oldData = { date: new Date(2019, 1, 1) };
|
||||
await globalState.update(() => {
|
||||
return oldData;
|
||||
});
|
||||
const initialData = { date: new Date(2020, 1, 1) };
|
||||
await globalState.update(() => {
|
||||
return initialData;
|
||||
});
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
const emissions = trackEmissions(globalState.state$);
|
||||
await awaitAsync();
|
||||
expect(emissions).toEqual([initialData]);
|
||||
|
||||
let emissions2: TestState[];
|
||||
const originalSave = diskStorageService.save.bind(diskStorageService);
|
||||
diskStorageService.save = jest.fn().mockImplementation(async (key: string, obj: any) => {
|
||||
emissions2 = trackEmissions(globalState.state$);
|
||||
await originalSave(key, obj);
|
||||
});
|
||||
|
||||
const val = await globalState.update(() => {
|
||||
return newData;
|
||||
});
|
||||
|
||||
await awaitAsync(10);
|
||||
|
||||
expect(val).toEqual(newData);
|
||||
expect(emissions).toEqual([initialData, newData]);
|
||||
expect(emissions2).toEqual([initialData, newData]);
|
||||
});
|
||||
|
||||
test("subscription during an aborted update should receive the last value", async () => {
|
||||
// Seed with interesting data
|
||||
const initialData = { date: new Date(2020, 1, 1) };
|
||||
await globalState.update(() => {
|
||||
return initialData;
|
||||
});
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
const emissions = trackEmissions(globalState.state$);
|
||||
await awaitAsync();
|
||||
expect(emissions).toEqual([initialData]);
|
||||
|
||||
let emissions2: TestState[];
|
||||
const val = await globalState.update(
|
||||
() => {
|
||||
return newData;
|
||||
},
|
||||
{
|
||||
shouldUpdate: () => {
|
||||
emissions2 = trackEmissions(globalState.state$);
|
||||
return false;
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
expect(val).toEqual(initialData);
|
||||
expect(emissions).toEqual([initialData]);
|
||||
|
||||
expect(emissions2).toEqual([initialData]);
|
||||
});
|
||||
|
||||
test("updates should wait until previous update is complete", async () => {
|
||||
trackEmissions(globalState.state$);
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
|
||||
const originalSave = diskStorageService.save.bind(diskStorageService);
|
||||
diskStorageService.save = jest
|
||||
.fn()
|
||||
.mockImplementationOnce(async () => {
|
||||
let resolved = false;
|
||||
await Promise.race([
|
||||
globalState.update(() => {
|
||||
// deadlocks
|
||||
resolved = true;
|
||||
return newData;
|
||||
}),
|
||||
awaitAsync(100), // limit test to 100ms
|
||||
]);
|
||||
expect(resolved).toBe(false);
|
||||
})
|
||||
.mockImplementation(originalSave);
|
||||
|
||||
await globalState.update((state) => {
|
||||
return newData;
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("cleanup", () => {
|
||||
function assertClean() {
|
||||
expect(diskStorageService["updatesSubject"]["observers"]).toHaveLength(0);
|
||||
}
|
||||
|
||||
it("should cleanup after last subscriber", async () => {
|
||||
const subscription = globalState.state$.subscribe();
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
|
||||
subscription.unsubscribe();
|
||||
// Wait for cleanup
|
||||
await awaitAsync(cleanupDelayMs * 2);
|
||||
|
||||
assertClean();
|
||||
});
|
||||
|
||||
it("should not cleanup if there are still subscribers", async () => {
|
||||
const subscription1 = globalState.state$.subscribe();
|
||||
const sub2Emissions: TestState[] = [];
|
||||
const subscription2 = globalState.state$.subscribe((v) => sub2Emissions.push(v));
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
|
||||
subscription1.unsubscribe();
|
||||
|
||||
// Wait for cleanup
|
||||
await awaitAsync(cleanupDelayMs * 2);
|
||||
|
||||
expect(diskStorageService["updatesSubject"]["observers"]).toHaveLength(1);
|
||||
|
||||
// Still be listening to storage updates
|
||||
// FIXME: Verify that this floating promise is intentional. If it is, add an explanatory comment and ensure there is proper error handling.
|
||||
// eslint-disable-next-line @typescript-eslint/no-floating-promises
|
||||
diskStorageService.save(globalKey, newData);
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
expect(sub2Emissions).toEqual([null, newData]);
|
||||
|
||||
subscription2.unsubscribe();
|
||||
// Wait for cleanup
|
||||
await awaitAsync(cleanupDelayMs * 2);
|
||||
|
||||
assertClean();
|
||||
});
|
||||
|
||||
it("can re-initialize after cleanup", async () => {
|
||||
const subscription = globalState.state$.subscribe();
|
||||
await awaitAsync();
|
||||
|
||||
subscription.unsubscribe();
|
||||
// Wait for cleanup
|
||||
await awaitAsync(cleanupDelayMs * 2);
|
||||
|
||||
const emissions = trackEmissions(globalState.state$);
|
||||
await awaitAsync();
|
||||
|
||||
// FIXME: Verify that this floating promise is intentional. If it is, add an explanatory comment and ensure there is proper error handling.
|
||||
// eslint-disable-next-line @typescript-eslint/no-floating-promises
|
||||
diskStorageService.save(globalKey, newData);
|
||||
await awaitAsync();
|
||||
|
||||
expect(emissions).toEqual([null, newData]);
|
||||
});
|
||||
|
||||
it("should not cleanup if a subscriber joins during the cleanup delay", async () => {
|
||||
const subscription = globalState.state$.subscribe();
|
||||
await awaitAsync();
|
||||
|
||||
await diskStorageService.save(globalKey, newData);
|
||||
await awaitAsync();
|
||||
|
||||
subscription.unsubscribe();
|
||||
expect(diskStorageService["updatesSubject"]["observers"]).toHaveLength(1);
|
||||
// Do not wait long enough for cleanup
|
||||
await awaitAsync(cleanupDelayMs / 2);
|
||||
|
||||
expect(diskStorageService["updatesSubject"]["observers"]).toHaveLength(1);
|
||||
});
|
||||
|
||||
it("state$ observables are durable to cleanup", async () => {
|
||||
const observable = globalState.state$;
|
||||
let subscription = observable.subscribe();
|
||||
|
||||
await diskStorageService.save(globalKey, newData);
|
||||
await awaitAsync();
|
||||
|
||||
subscription.unsubscribe();
|
||||
// Wait for cleanup
|
||||
await awaitAsync(cleanupDelayMs * 2);
|
||||
|
||||
subscription = observable.subscribe();
|
||||
await diskStorageService.save(globalKey, newData);
|
||||
await awaitAsync();
|
||||
|
||||
expect(await firstValueFrom(observable)).toEqual(newData);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,20 +1 @@
|
||||
import { AbstractStorageService, ObservableStorageService } from "@bitwarden/storage-core";
|
||||
|
||||
import { LogService } from "../../abstractions/log.service";
|
||||
import { GlobalState } from "../global-state";
|
||||
import { KeyDefinition, globalKeyBuilder } from "../key-definition";
|
||||
|
||||
import { StateBase } from "./state-base";
|
||||
|
||||
export class DefaultGlobalState<T>
|
||||
extends StateBase<T, KeyDefinition<T>>
|
||||
implements GlobalState<T>
|
||||
{
|
||||
constructor(
|
||||
keyDefinition: KeyDefinition<T>,
|
||||
chosenLocation: AbstractStorageService & ObservableStorageService,
|
||||
logService: LogService,
|
||||
) {
|
||||
super(globalKeyBuilder(keyDefinition), chosenLocation, keyDefinition, logService);
|
||||
}
|
||||
}
|
||||
export { DefaultGlobalState } from "@bitwarden/state";
|
||||
|
||||
@@ -1,54 +1 @@
|
||||
// FIXME: Update this file to be type safe and remove this and next line
|
||||
// @ts-strict-ignore
|
||||
import { StorageServiceProvider } from "@bitwarden/storage-core";
|
||||
|
||||
import { UserId } from "../../../types/guid";
|
||||
import { LogService } from "../../abstractions/log.service";
|
||||
import { StateEventRegistrarService } from "../state-event-registrar.service";
|
||||
import { UserKeyDefinition } from "../user-key-definition";
|
||||
import { SingleUserState } from "../user-state";
|
||||
import { SingleUserStateProvider } from "../user-state.provider";
|
||||
|
||||
import { DefaultSingleUserState } from "./default-single-user-state";
|
||||
|
||||
export class DefaultSingleUserStateProvider implements SingleUserStateProvider {
|
||||
private cache: Record<string, SingleUserState<unknown>> = {};
|
||||
|
||||
constructor(
|
||||
private readonly storageServiceProvider: StorageServiceProvider,
|
||||
private readonly stateEventRegistrarService: StateEventRegistrarService,
|
||||
private readonly logService: LogService,
|
||||
) {}
|
||||
|
||||
get<T>(userId: UserId, keyDefinition: UserKeyDefinition<T>): SingleUserState<T> {
|
||||
const [location, storageService] = this.storageServiceProvider.get(
|
||||
keyDefinition.stateDefinition.defaultStorageLocation,
|
||||
keyDefinition.stateDefinition.storageLocationOverrides,
|
||||
);
|
||||
const cacheKey = this.buildCacheKey(location, userId, keyDefinition);
|
||||
const existingUserState = this.cache[cacheKey];
|
||||
if (existingUserState != null) {
|
||||
// I have to cast out of the unknown generic but this should be safe if rules
|
||||
// around domain token are made
|
||||
return existingUserState as SingleUserState<T>;
|
||||
}
|
||||
|
||||
const newUserState = new DefaultSingleUserState<T>(
|
||||
userId,
|
||||
keyDefinition,
|
||||
storageService,
|
||||
this.stateEventRegistrarService,
|
||||
this.logService,
|
||||
);
|
||||
this.cache[cacheKey] = newUserState;
|
||||
return newUserState;
|
||||
}
|
||||
|
||||
private buildCacheKey(
|
||||
location: string,
|
||||
userId: UserId,
|
||||
keyDefinition: UserKeyDefinition<unknown>,
|
||||
) {
|
||||
return `${location}_${keyDefinition.fullName}_${userId}`;
|
||||
}
|
||||
}
|
||||
export { DefaultSingleUserStateProvider } from "@bitwarden/state";
|
||||
|
||||
@@ -1,596 +0,0 @@
|
||||
/**
|
||||
* need to update test environment so trackEmissions works appropriately
|
||||
* @jest-environment ../shared/test.environment.ts
|
||||
*/
|
||||
|
||||
import { mock } from "jest-mock-extended";
|
||||
import { firstValueFrom, of } from "rxjs";
|
||||
import { Jsonify } from "type-fest";
|
||||
|
||||
import { trackEmissions, awaitAsync } from "../../../../spec";
|
||||
import { FakeStorageService } from "../../../../spec/fake-storage.service";
|
||||
import { UserId } from "../../../types/guid";
|
||||
import { LogService } from "../../abstractions/log.service";
|
||||
import { Utils } from "../../misc/utils";
|
||||
import { StateDefinition } from "../state-definition";
|
||||
import { StateEventRegistrarService } from "../state-event-registrar.service";
|
||||
import { UserKeyDefinition } from "../user-key-definition";
|
||||
|
||||
import { DefaultSingleUserState } from "./default-single-user-state";
|
||||
|
||||
class TestState {
|
||||
date: Date;
|
||||
|
||||
static fromJSON(jsonState: Jsonify<TestState>) {
|
||||
if (jsonState == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return Object.assign(new TestState(), jsonState, {
|
||||
date: new Date(jsonState.date),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const testStateDefinition = new StateDefinition("fake", "disk");
|
||||
const cleanupDelayMs = 10;
|
||||
const testKeyDefinition = new UserKeyDefinition<TestState>(testStateDefinition, "fake", {
|
||||
deserializer: TestState.fromJSON,
|
||||
cleanupDelayMs,
|
||||
clearOn: [],
|
||||
});
|
||||
const userId = Utils.newGuid() as UserId;
|
||||
const userKey = testKeyDefinition.buildKey(userId);
|
||||
|
||||
describe("DefaultSingleUserState", () => {
|
||||
let diskStorageService: FakeStorageService;
|
||||
let userState: DefaultSingleUserState<TestState>;
|
||||
const stateEventRegistrarService = mock<StateEventRegistrarService>();
|
||||
const logService = mock<LogService>();
|
||||
const newData = { date: new Date() };
|
||||
|
||||
beforeEach(() => {
|
||||
diskStorageService = new FakeStorageService();
|
||||
userState = new DefaultSingleUserState(
|
||||
userId,
|
||||
testKeyDefinition,
|
||||
diskStorageService,
|
||||
stateEventRegistrarService,
|
||||
logService,
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.resetAllMocks();
|
||||
});
|
||||
|
||||
describe("state$", () => {
|
||||
it("should emit when storage updates", async () => {
|
||||
const emissions = trackEmissions(userState.state$);
|
||||
await diskStorageService.save(userKey, newData);
|
||||
await awaitAsync();
|
||||
|
||||
expect(emissions).toEqual([
|
||||
null, // Initial value
|
||||
newData,
|
||||
]);
|
||||
});
|
||||
|
||||
it("should not emit when update key does not match", async () => {
|
||||
const emissions = trackEmissions(userState.state$);
|
||||
await diskStorageService.save("wrong_key", newData);
|
||||
|
||||
// Give userState a chance to emit it's initial value
|
||||
// as well as wrongly emit the different key.
|
||||
await awaitAsync();
|
||||
|
||||
// Just the initial value
|
||||
expect(emissions).toEqual([null]);
|
||||
});
|
||||
|
||||
it("should emit initial storage value on first subscribe", async () => {
|
||||
const initialStorage: Record<string, TestState> = {};
|
||||
initialStorage[userKey] = TestState.fromJSON({
|
||||
date: "2022-09-21T13:14:17.648Z",
|
||||
});
|
||||
diskStorageService.internalUpdateStore(initialStorage);
|
||||
|
||||
const state = await firstValueFrom(userState.state$);
|
||||
expect(diskStorageService.mock.get).toHaveBeenCalledTimes(1);
|
||||
expect(diskStorageService.mock.get).toHaveBeenCalledWith(
|
||||
`user_${userId}_fake_fake`,
|
||||
undefined,
|
||||
);
|
||||
expect(state).toBeTruthy();
|
||||
});
|
||||
|
||||
it("should go to disk each subscription if a cleanupDelayMs of 0 is given", async () => {
|
||||
const state = new DefaultSingleUserState(
|
||||
userId,
|
||||
new UserKeyDefinition(testStateDefinition, "test", {
|
||||
cleanupDelayMs: 0,
|
||||
deserializer: TestState.fromJSON,
|
||||
clearOn: [],
|
||||
debug: {
|
||||
enableRetrievalLogging: true,
|
||||
},
|
||||
}),
|
||||
diskStorageService,
|
||||
stateEventRegistrarService,
|
||||
logService,
|
||||
);
|
||||
|
||||
await firstValueFrom(state.state$);
|
||||
await firstValueFrom(state.state$);
|
||||
|
||||
expect(diskStorageService.mock.get).toHaveBeenCalledTimes(2);
|
||||
expect(logService.info).toHaveBeenCalledTimes(2);
|
||||
expect(logService.info).toHaveBeenCalledWith(
|
||||
`Retrieving 'user_${userId}_fake_test' from storage, value is null`,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("combinedState$", () => {
|
||||
it("should emit when storage updates", async () => {
|
||||
const emissions = trackEmissions(userState.combinedState$);
|
||||
await diskStorageService.save(userKey, newData);
|
||||
await awaitAsync();
|
||||
|
||||
expect(emissions).toEqual([
|
||||
[userId, null], // Initial value
|
||||
[userId, newData],
|
||||
]);
|
||||
});
|
||||
|
||||
it("should not emit when update key does not match", async () => {
|
||||
const emissions = trackEmissions(userState.combinedState$);
|
||||
await diskStorageService.save("wrong_key", newData);
|
||||
|
||||
// Give userState a chance to emit it's initial value
|
||||
// as well as wrongly emit the different key.
|
||||
await awaitAsync();
|
||||
|
||||
// Just the initial value
|
||||
expect(emissions).toHaveLength(1);
|
||||
});
|
||||
|
||||
it("should emit initial storage value on first subscribe", async () => {
|
||||
const initialStorage: Record<string, TestState> = {};
|
||||
initialStorage[userKey] = TestState.fromJSON({
|
||||
date: "2022-09-21T13:14:17.648Z",
|
||||
});
|
||||
diskStorageService.internalUpdateStore(initialStorage);
|
||||
|
||||
const combinedState = await firstValueFrom(userState.combinedState$);
|
||||
expect(diskStorageService.mock.get).toHaveBeenCalledTimes(1);
|
||||
expect(diskStorageService.mock.get).toHaveBeenCalledWith(
|
||||
`user_${userId}_fake_fake`,
|
||||
undefined,
|
||||
);
|
||||
expect(combinedState).toBeTruthy();
|
||||
const [stateUserId, state] = combinedState;
|
||||
expect(stateUserId).toBe(userId);
|
||||
expect(state).toBe(initialStorage[userKey]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("update", () => {
|
||||
it("should save on update", async () => {
|
||||
const result = await userState.update((state) => {
|
||||
return newData;
|
||||
});
|
||||
|
||||
expect(diskStorageService.mock.save).toHaveBeenCalledTimes(1);
|
||||
expect(result).toEqual(newData);
|
||||
});
|
||||
|
||||
it("should emit once per update", async () => {
|
||||
const emissions = trackEmissions(userState.state$);
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
|
||||
await userState.update((state) => {
|
||||
return newData;
|
||||
});
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
expect(emissions).toEqual([
|
||||
null, // Initial value
|
||||
newData,
|
||||
]);
|
||||
});
|
||||
|
||||
it("should provided combined dependencies", async () => {
|
||||
const emissions = trackEmissions(userState.state$);
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
|
||||
const combinedDependencies = { date: new Date() };
|
||||
|
||||
await userState.update(
|
||||
(state, dependencies) => {
|
||||
expect(dependencies).toEqual(combinedDependencies);
|
||||
return newData;
|
||||
},
|
||||
{
|
||||
combineLatestWith: of(combinedDependencies),
|
||||
},
|
||||
);
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
expect(emissions).toEqual([
|
||||
null, // Initial value
|
||||
newData,
|
||||
]);
|
||||
});
|
||||
|
||||
it("should not update if shouldUpdate returns false", async () => {
|
||||
const emissions = trackEmissions(userState.state$);
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
|
||||
const result = await userState.update(
|
||||
(state) => {
|
||||
return newData;
|
||||
},
|
||||
{
|
||||
shouldUpdate: () => false,
|
||||
},
|
||||
);
|
||||
|
||||
expect(diskStorageService.mock.save).not.toHaveBeenCalled();
|
||||
expect(emissions).toEqual([null]); // Initial value
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("should provide the update callback with the current State", async () => {
|
||||
const emissions = trackEmissions(userState.state$);
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
|
||||
// Seed with interesting data
|
||||
const initialData = { date: new Date(2020, 1, 1) };
|
||||
await userState.update((state, dependencies) => {
|
||||
return initialData;
|
||||
});
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
await userState.update((state) => {
|
||||
expect(state).toEqual(initialData);
|
||||
return newData;
|
||||
});
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
expect(emissions).toEqual([
|
||||
null, // Initial value
|
||||
initialData,
|
||||
newData,
|
||||
]);
|
||||
});
|
||||
|
||||
it("should give initial state for update call", async () => {
|
||||
const initialStorage: Record<string, TestState> = {};
|
||||
const initialState = TestState.fromJSON({
|
||||
date: "2022-09-21T13:14:17.648Z",
|
||||
});
|
||||
initialStorage[userKey] = initialState;
|
||||
diskStorageService.internalUpdateStore(initialStorage);
|
||||
|
||||
const emissions = trackEmissions(userState.state$);
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
|
||||
const newState = {
|
||||
...initialState,
|
||||
date: new Date(initialState.date.getFullYear(), initialState.date.getMonth() + 1),
|
||||
};
|
||||
const actual = await userState.update((existingState) => newState);
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
expect(actual).toEqual(newState);
|
||||
expect(emissions).toHaveLength(2);
|
||||
expect(emissions).toEqual(expect.arrayContaining([initialState, newState]));
|
||||
});
|
||||
|
||||
it.each([null, undefined])(
|
||||
"should register user key definition when state transitions from null-ish (%s) to non-null",
|
||||
async (startingValue: TestState | null) => {
|
||||
const initialState: Record<string, TestState> = {};
|
||||
initialState[userKey] = startingValue;
|
||||
|
||||
diskStorageService.internalUpdateStore(initialState);
|
||||
|
||||
await userState.update(() => ({ array: ["one"], date: new Date() }));
|
||||
|
||||
expect(stateEventRegistrarService.registerEvents).toHaveBeenCalledWith(testKeyDefinition);
|
||||
},
|
||||
);
|
||||
|
||||
it("should not register user key definition when state has preexisting value", async () => {
|
||||
const initialState: Record<string, TestState> = {};
|
||||
initialState[userKey] = {
|
||||
date: new Date(2019, 1),
|
||||
};
|
||||
|
||||
diskStorageService.internalUpdateStore(initialState);
|
||||
|
||||
await userState.update(() => ({ array: ["one"], date: new Date() }));
|
||||
|
||||
expect(stateEventRegistrarService.registerEvents).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it.each([null, undefined])(
|
||||
"should not register user key definition when setting value to null-ish (%s) value",
|
||||
async (updatedValue: TestState | null) => {
|
||||
const initialState: Record<string, TestState> = {};
|
||||
initialState[userKey] = {
|
||||
date: new Date(2019, 1),
|
||||
};
|
||||
|
||||
diskStorageService.internalUpdateStore(initialState);
|
||||
|
||||
await userState.update(() => updatedValue);
|
||||
|
||||
expect(stateEventRegistrarService.registerEvents).not.toHaveBeenCalled();
|
||||
},
|
||||
);
|
||||
|
||||
const logCases: { startingValue: TestState; updateValue: TestState; phrase: string }[] = [
|
||||
{
|
||||
startingValue: null,
|
||||
updateValue: null,
|
||||
phrase: "null to null",
|
||||
},
|
||||
{
|
||||
startingValue: null,
|
||||
updateValue: new TestState(),
|
||||
phrase: "null to non-null",
|
||||
},
|
||||
{
|
||||
startingValue: new TestState(),
|
||||
updateValue: null,
|
||||
phrase: "non-null to null",
|
||||
},
|
||||
{
|
||||
startingValue: new TestState(),
|
||||
updateValue: new TestState(),
|
||||
phrase: "non-null to non-null",
|
||||
},
|
||||
];
|
||||
|
||||
it.each(logCases)(
|
||||
"should log meta info about the update",
|
||||
async ({ startingValue, updateValue, phrase }) => {
|
||||
diskStorageService.internalUpdateStore({
|
||||
[`user_${userId}_fake_fake`]: startingValue,
|
||||
});
|
||||
const state = new DefaultSingleUserState(
|
||||
userId,
|
||||
new UserKeyDefinition<TestState>(testStateDefinition, "fake", {
|
||||
deserializer: TestState.fromJSON,
|
||||
clearOn: [],
|
||||
debug: {
|
||||
enableUpdateLogging: true,
|
||||
},
|
||||
}),
|
||||
diskStorageService,
|
||||
stateEventRegistrarService,
|
||||
logService,
|
||||
);
|
||||
|
||||
await state.update(() => updateValue);
|
||||
|
||||
expect(logService.info).toHaveBeenCalledWith(
|
||||
`Updating 'user_${userId}_fake_fake' from ${phrase}`,
|
||||
);
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
describe("update races", () => {
|
||||
test("subscriptions during an update should receive the current and latest data", async () => {
|
||||
const oldData = { date: new Date(2019, 1, 1) };
|
||||
await userState.update(() => {
|
||||
return oldData;
|
||||
});
|
||||
const initialData = { date: new Date(2020, 1, 1) };
|
||||
await userState.update(() => {
|
||||
return initialData;
|
||||
});
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
const emissions = trackEmissions(userState.state$);
|
||||
await awaitAsync();
|
||||
expect(emissions).toEqual([initialData]);
|
||||
|
||||
let emissions2: TestState[];
|
||||
const originalSave = diskStorageService.save.bind(diskStorageService);
|
||||
diskStorageService.save = jest.fn().mockImplementation(async (key: string, obj: any) => {
|
||||
emissions2 = trackEmissions(userState.state$);
|
||||
await originalSave(key, obj);
|
||||
});
|
||||
|
||||
const val = await userState.update(() => {
|
||||
return newData;
|
||||
});
|
||||
|
||||
await awaitAsync(10);
|
||||
|
||||
expect(val).toEqual(newData);
|
||||
expect(emissions).toEqual([initialData, newData]);
|
||||
expect(emissions2).toEqual([initialData, newData]);
|
||||
});
|
||||
|
||||
test("subscription during an aborted update should receive the last value", async () => {
|
||||
// Seed with interesting data
|
||||
const initialData = { date: new Date(2020, 1, 1) };
|
||||
await userState.update(() => {
|
||||
return initialData;
|
||||
});
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
const emissions = trackEmissions(userState.state$);
|
||||
await awaitAsync();
|
||||
expect(emissions).toEqual([initialData]);
|
||||
|
||||
let emissions2: TestState[];
|
||||
const val = await userState.update(
|
||||
(state) => {
|
||||
return newData;
|
||||
},
|
||||
{
|
||||
shouldUpdate: () => {
|
||||
emissions2 = trackEmissions(userState.state$);
|
||||
return false;
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
expect(val).toEqual(initialData);
|
||||
expect(emissions).toEqual([initialData]);
|
||||
|
||||
expect(emissions2).toEqual([initialData]);
|
||||
});
|
||||
|
||||
test("updates should wait until previous update is complete", async () => {
|
||||
trackEmissions(userState.state$);
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
|
||||
const originalSave = diskStorageService.save.bind(diskStorageService);
|
||||
diskStorageService.save = jest
|
||||
.fn()
|
||||
.mockImplementationOnce(async () => {
|
||||
let resolved = false;
|
||||
await Promise.race([
|
||||
userState.update(() => {
|
||||
// deadlocks
|
||||
resolved = true;
|
||||
return newData;
|
||||
}),
|
||||
awaitAsync(100), // limit test to 100ms
|
||||
]);
|
||||
expect(resolved).toBe(false);
|
||||
})
|
||||
.mockImplementation(originalSave);
|
||||
|
||||
await userState.update((state) => {
|
||||
return newData;
|
||||
});
|
||||
});
|
||||
|
||||
test("updates with FAKE_DEFAULT initial value should resolve correctly", async () => {
|
||||
const val = await userState.update((state) => {
|
||||
return newData;
|
||||
});
|
||||
|
||||
expect(val).toEqual(newData);
|
||||
const call = diskStorageService.mock.save.mock.calls[0];
|
||||
expect(call[0]).toEqual(`user_${userId}_fake_fake`);
|
||||
expect(call[1]).toEqual(newData);
|
||||
});
|
||||
});
|
||||
|
||||
describe("cleanup", () => {
|
||||
function assertClean() {
|
||||
expect(diskStorageService["updatesSubject"]["observers"]).toHaveLength(0);
|
||||
}
|
||||
|
||||
it("should cleanup after last subscriber", async () => {
|
||||
const subscription = userState.state$.subscribe();
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
|
||||
subscription.unsubscribe();
|
||||
// Wait for cleanup
|
||||
await awaitAsync(cleanupDelayMs * 2);
|
||||
assertClean();
|
||||
});
|
||||
|
||||
it("should not cleanup if there are still subscribers", async () => {
|
||||
const subscription1 = userState.state$.subscribe();
|
||||
const sub2Emissions: TestState[] = [];
|
||||
const subscription2 = userState.state$.subscribe((v) => sub2Emissions.push(v));
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
|
||||
subscription1.unsubscribe();
|
||||
|
||||
// Wait for cleanup
|
||||
await awaitAsync(cleanupDelayMs * 2);
|
||||
|
||||
expect(diskStorageService["updatesSubject"]["observers"]).toHaveLength(1);
|
||||
|
||||
// Still be listening to storage updates
|
||||
// FIXME: Verify that this floating promise is intentional. If it is, add an explanatory comment and ensure there is proper error handling.
|
||||
// eslint-disable-next-line @typescript-eslint/no-floating-promises
|
||||
diskStorageService.save(userKey, newData);
|
||||
await awaitAsync(); // storage updates are behind a promise
|
||||
expect(sub2Emissions).toEqual([null, newData]);
|
||||
|
||||
subscription2.unsubscribe();
|
||||
// Wait for cleanup
|
||||
await awaitAsync(cleanupDelayMs * 2);
|
||||
|
||||
assertClean();
|
||||
});
|
||||
|
||||
it("can re-initialize after cleanup", async () => {
|
||||
const subscription = userState.state$.subscribe();
|
||||
await awaitAsync();
|
||||
|
||||
subscription.unsubscribe();
|
||||
// Wait for cleanup
|
||||
await awaitAsync(cleanupDelayMs * 2);
|
||||
|
||||
const emissions = trackEmissions(userState.state$);
|
||||
await awaitAsync();
|
||||
|
||||
// FIXME: Verify that this floating promise is intentional. If it is, add an explanatory comment and ensure there is proper error handling.
|
||||
// eslint-disable-next-line @typescript-eslint/no-floating-promises
|
||||
diskStorageService.save(userKey, newData);
|
||||
await awaitAsync();
|
||||
|
||||
expect(emissions).toEqual([null, newData]);
|
||||
});
|
||||
|
||||
it("should not cleanup if a subscriber joins during the cleanup delay", async () => {
|
||||
const subscription = userState.state$.subscribe();
|
||||
await awaitAsync();
|
||||
|
||||
await diskStorageService.save(userKey, newData);
|
||||
await awaitAsync();
|
||||
|
||||
subscription.unsubscribe();
|
||||
// Do not wait long enough for cleanup
|
||||
await awaitAsync(cleanupDelayMs / 2);
|
||||
|
||||
const value = await firstValueFrom(userState.state$);
|
||||
expect(value).toEqual(newData);
|
||||
|
||||
// Should be called once for the initial subscription and a second time during the save
|
||||
// but should not be called for a second subscription if the cleanup hasn't happened yet.
|
||||
expect(diskStorageService.mock.get).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it("state$ observables are durable to cleanup", async () => {
|
||||
const observable = userState.state$;
|
||||
let subscription = observable.subscribe();
|
||||
|
||||
await diskStorageService.save(userKey, newData);
|
||||
await awaitAsync();
|
||||
|
||||
subscription.unsubscribe();
|
||||
// Wait for cleanup
|
||||
await awaitAsync(cleanupDelayMs * 2);
|
||||
|
||||
subscription = observable.subscribe();
|
||||
await diskStorageService.save(userKey, newData);
|
||||
await awaitAsync();
|
||||
|
||||
expect(await firstValueFrom(observable)).toEqual(newData);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,36 +1 @@
|
||||
import { Observable, combineLatest, of } from "rxjs";
|
||||
|
||||
import { AbstractStorageService, ObservableStorageService } from "@bitwarden/storage-core";
|
||||
|
||||
import { UserId } from "../../../types/guid";
|
||||
import { LogService } from "../../abstractions/log.service";
|
||||
import { StateEventRegistrarService } from "../state-event-registrar.service";
|
||||
import { UserKeyDefinition } from "../user-key-definition";
|
||||
import { CombinedState, SingleUserState } from "../user-state";
|
||||
|
||||
import { StateBase } from "./state-base";
|
||||
|
||||
export class DefaultSingleUserState<T>
|
||||
extends StateBase<T, UserKeyDefinition<T>>
|
||||
implements SingleUserState<T>
|
||||
{
|
||||
readonly combinedState$: Observable<CombinedState<T | null>>;
|
||||
|
||||
constructor(
|
||||
readonly userId: UserId,
|
||||
keyDefinition: UserKeyDefinition<T>,
|
||||
chosenLocation: AbstractStorageService & ObservableStorageService,
|
||||
private stateEventRegistrarService: StateEventRegistrarService,
|
||||
logService: LogService,
|
||||
) {
|
||||
super(keyDefinition.buildKey(userId), chosenLocation, keyDefinition, logService);
|
||||
this.combinedState$ = combineLatest([of(userId), this.state$]);
|
||||
}
|
||||
|
||||
protected override async doStorageSave(newState: T, oldState: T): Promise<void> {
|
||||
await super.doStorageSave(newState, oldState);
|
||||
if (newState != null && oldState == null) {
|
||||
await this.stateEventRegistrarService.registerEvents(this.keyDefinition);
|
||||
}
|
||||
}
|
||||
}
|
||||
export { DefaultSingleUserState } from "@bitwarden/state";
|
||||
|
||||
@@ -1,265 +0,0 @@
|
||||
/**
|
||||
* need to update test environment so structuredClone works appropriately
|
||||
* @jest-environment ../shared/test.environment.ts
|
||||
*/
|
||||
import { Observable, of } from "rxjs";
|
||||
|
||||
import { awaitAsync, trackEmissions } from "../../../../spec";
|
||||
import { FakeAccountService, mockAccountServiceWith } from "../../../../spec/fake-account-service";
|
||||
import {
|
||||
FakeActiveUserStateProvider,
|
||||
FakeDerivedStateProvider,
|
||||
FakeGlobalStateProvider,
|
||||
FakeSingleUserStateProvider,
|
||||
} from "../../../../spec/fake-state-provider";
|
||||
import { AuthenticationStatus } from "../../../auth/enums/authentication-status";
|
||||
import { UserId } from "../../../types/guid";
|
||||
import { DeriveDefinition } from "../derive-definition";
|
||||
import { KeyDefinition } from "../key-definition";
|
||||
import { StateDefinition } from "../state-definition";
|
||||
import { UserKeyDefinition } from "../user-key-definition";
|
||||
|
||||
import { DefaultStateProvider } from "./default-state.provider";
|
||||
|
||||
describe("DefaultStateProvider", () => {
|
||||
let sut: DefaultStateProvider;
|
||||
let activeUserStateProvider: FakeActiveUserStateProvider;
|
||||
let singleUserStateProvider: FakeSingleUserStateProvider;
|
||||
let globalStateProvider: FakeGlobalStateProvider;
|
||||
let derivedStateProvider: FakeDerivedStateProvider;
|
||||
let accountService: FakeAccountService;
|
||||
const userId = "fakeUserId" as UserId;
|
||||
|
||||
beforeEach(() => {
|
||||
accountService = mockAccountServiceWith(userId);
|
||||
activeUserStateProvider = new FakeActiveUserStateProvider(accountService);
|
||||
singleUserStateProvider = new FakeSingleUserStateProvider();
|
||||
globalStateProvider = new FakeGlobalStateProvider();
|
||||
derivedStateProvider = new FakeDerivedStateProvider();
|
||||
sut = new DefaultStateProvider(
|
||||
activeUserStateProvider,
|
||||
singleUserStateProvider,
|
||||
globalStateProvider,
|
||||
derivedStateProvider,
|
||||
);
|
||||
});
|
||||
|
||||
describe("activeUserId$", () => {
|
||||
it("should track the active User id from active user state provider", () => {
|
||||
expect(sut.activeUserId$).toBe(activeUserStateProvider.activeUserId$);
|
||||
});
|
||||
});
|
||||
|
||||
describe.each([
|
||||
[
|
||||
"getUserState$",
|
||||
(keyDefinition: UserKeyDefinition<string>, userId?: UserId) =>
|
||||
sut.getUserState$(keyDefinition, userId),
|
||||
],
|
||||
[
|
||||
"getUserStateOrDefault$",
|
||||
(keyDefinition: UserKeyDefinition<string>, userId?: UserId) =>
|
||||
sut.getUserStateOrDefault$(keyDefinition, { userId: userId }),
|
||||
],
|
||||
])(
|
||||
"Shared behavior for %s",
|
||||
(
|
||||
_testName: string,
|
||||
methodUnderTest: (
|
||||
keyDefinition: UserKeyDefinition<string>,
|
||||
userId?: UserId,
|
||||
) => Observable<string>,
|
||||
) => {
|
||||
const accountInfo = {
|
||||
email: "email",
|
||||
emailVerified: false,
|
||||
name: "name",
|
||||
status: AuthenticationStatus.LoggedOut,
|
||||
};
|
||||
const keyDefinition = new UserKeyDefinition<string>(
|
||||
new StateDefinition("test", "disk"),
|
||||
"test",
|
||||
{
|
||||
deserializer: (s) => s,
|
||||
clearOn: [],
|
||||
},
|
||||
);
|
||||
|
||||
it("should follow the specified user if userId is provided", async () => {
|
||||
const state = singleUserStateProvider.getFake(userId, keyDefinition);
|
||||
state.nextState("value");
|
||||
const emissions = trackEmissions(methodUnderTest(keyDefinition, userId));
|
||||
|
||||
state.nextState("value2");
|
||||
state.nextState("value3");
|
||||
|
||||
expect(emissions).toEqual(["value", "value2", "value3"]);
|
||||
});
|
||||
|
||||
it("should follow the current active user if no userId is provided", async () => {
|
||||
accountService.activeAccountSubject.next({ id: userId, ...accountInfo });
|
||||
const state = singleUserStateProvider.getFake(userId, keyDefinition);
|
||||
state.nextState("value");
|
||||
const emissions = trackEmissions(methodUnderTest(keyDefinition));
|
||||
|
||||
state.nextState("value2");
|
||||
state.nextState("value3");
|
||||
|
||||
expect(emissions).toEqual(["value", "value2", "value3"]);
|
||||
});
|
||||
|
||||
it("should continue to follow the state of the user that was active when called, even if active user changes", async () => {
|
||||
const state = singleUserStateProvider.getFake(userId, keyDefinition);
|
||||
state.nextState("value");
|
||||
const emissions = trackEmissions(methodUnderTest(keyDefinition));
|
||||
|
||||
accountService.activeAccountSubject.next({ id: "newUserId" as UserId, ...accountInfo });
|
||||
const newUserEmissions = trackEmissions(sut.getUserState$(keyDefinition));
|
||||
state.nextState("value2");
|
||||
state.nextState("value3");
|
||||
|
||||
expect(emissions).toEqual(["value", "value2", "value3"]);
|
||||
expect(newUserEmissions).toEqual([null]);
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
describe("getUserState$", () => {
|
||||
const accountInfo = {
|
||||
email: "email",
|
||||
emailVerified: false,
|
||||
name: "name",
|
||||
status: AuthenticationStatus.LoggedOut,
|
||||
};
|
||||
const keyDefinition = new UserKeyDefinition<string>(
|
||||
new StateDefinition("test", "disk"),
|
||||
"test",
|
||||
{
|
||||
deserializer: (s) => s,
|
||||
clearOn: [],
|
||||
},
|
||||
);
|
||||
|
||||
it("should not emit any values until a truthy user id is supplied", async () => {
|
||||
accountService.activeAccountSubject.next(null);
|
||||
const state = singleUserStateProvider.getFake(userId, keyDefinition);
|
||||
state.nextState("value");
|
||||
|
||||
const emissions = trackEmissions(sut.getUserState$(keyDefinition));
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
expect(emissions).toHaveLength(0);
|
||||
|
||||
accountService.activeAccountSubject.next({ id: userId, ...accountInfo });
|
||||
|
||||
await awaitAsync();
|
||||
|
||||
expect(emissions).toEqual(["value"]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("getUserStateOrDefault$", () => {
|
||||
const keyDefinition = new UserKeyDefinition<string>(
|
||||
new StateDefinition("test", "disk"),
|
||||
"test",
|
||||
{
|
||||
deserializer: (s) => s,
|
||||
clearOn: [],
|
||||
},
|
||||
);
|
||||
|
||||
it("should emit default value if no userId supplied and first active user id emission in falsy", async () => {
|
||||
accountService.activeAccountSubject.next(null);
|
||||
|
||||
const emissions = trackEmissions(
|
||||
sut.getUserStateOrDefault$(keyDefinition, {
|
||||
userId: undefined,
|
||||
defaultValue: "I'm default!",
|
||||
}),
|
||||
);
|
||||
|
||||
expect(emissions).toEqual(["I'm default!"]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("setUserState", () => {
|
||||
const keyDefinition = new UserKeyDefinition<string>(
|
||||
new StateDefinition("test", "disk"),
|
||||
"test",
|
||||
{
|
||||
deserializer: (s) => s,
|
||||
clearOn: [],
|
||||
},
|
||||
);
|
||||
|
||||
it("should set the state for the active user if no userId is provided", async () => {
|
||||
const value = "value";
|
||||
await sut.setUserState(keyDefinition, value);
|
||||
const state = activeUserStateProvider.getFake(keyDefinition);
|
||||
expect(state.nextMock).toHaveBeenCalledWith([expect.any(String), value]);
|
||||
});
|
||||
|
||||
it("should not set state for a single user if no userId is provided", async () => {
|
||||
const value = "value";
|
||||
await sut.setUserState(keyDefinition, value);
|
||||
const state = singleUserStateProvider.getFake(userId, keyDefinition);
|
||||
expect(state.nextMock).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should set the state for the provided userId", async () => {
|
||||
const value = "value";
|
||||
await sut.setUserState(keyDefinition, value, userId);
|
||||
const state = singleUserStateProvider.getFake(userId, keyDefinition);
|
||||
expect(state.nextMock).toHaveBeenCalledWith(value);
|
||||
});
|
||||
|
||||
it("should not set the active user state if userId is provided", async () => {
|
||||
const value = "value";
|
||||
await sut.setUserState(keyDefinition, value, userId);
|
||||
const state = activeUserStateProvider.getFake(keyDefinition);
|
||||
expect(state.nextMock).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
it("should bind the activeUserStateProvider", () => {
|
||||
const keyDefinition = new UserKeyDefinition(new StateDefinition("test", "disk"), "test", {
|
||||
deserializer: () => null,
|
||||
clearOn: [],
|
||||
});
|
||||
const existing = activeUserStateProvider.get(keyDefinition);
|
||||
const actual = sut.getActive(keyDefinition);
|
||||
expect(actual).toBe(existing);
|
||||
});
|
||||
|
||||
it("should bind the singleUserStateProvider", () => {
|
||||
const userId = "user" as UserId;
|
||||
const keyDefinition = new UserKeyDefinition(new StateDefinition("test", "disk"), "test", {
|
||||
deserializer: () => null,
|
||||
clearOn: [],
|
||||
});
|
||||
const existing = singleUserStateProvider.get(userId, keyDefinition);
|
||||
const actual = sut.getUser(userId, keyDefinition);
|
||||
expect(actual).toBe(existing);
|
||||
});
|
||||
|
||||
it("should bind the globalStateProvider", () => {
|
||||
const keyDefinition = new KeyDefinition(new StateDefinition("test", "disk"), "test", {
|
||||
deserializer: () => null,
|
||||
});
|
||||
const existing = globalStateProvider.get(keyDefinition);
|
||||
const actual = sut.getGlobal(keyDefinition);
|
||||
expect(actual).toBe(existing);
|
||||
});
|
||||
|
||||
it("should bind the derivedStateProvider", () => {
|
||||
const derivedDefinition = new DeriveDefinition(new StateDefinition("test", "disk"), "test", {
|
||||
derive: () => null,
|
||||
deserializer: () => null,
|
||||
});
|
||||
const parentState$ = of(null);
|
||||
const existing = derivedStateProvider.get(parentState$, derivedDefinition, {});
|
||||
const actual = sut.getDerived(parentState$, derivedDefinition, {});
|
||||
expect(actual).toBe(existing);
|
||||
});
|
||||
});
|
||||
@@ -1,79 +1 @@
|
||||
// FIXME: Update this file to be type safe and remove this and next line
|
||||
// @ts-strict-ignore
|
||||
import { Observable, filter, of, switchMap, take } from "rxjs";
|
||||
|
||||
import { UserId } from "../../../types/guid";
|
||||
import { DerivedStateDependencies } from "../../../types/state";
|
||||
import { DeriveDefinition } from "../derive-definition";
|
||||
import { DerivedState } from "../derived-state";
|
||||
import { DerivedStateProvider } from "../derived-state.provider";
|
||||
import { GlobalStateProvider } from "../global-state.provider";
|
||||
import { StateProvider } from "../state.provider";
|
||||
import { UserKeyDefinition } from "../user-key-definition";
|
||||
import { ActiveUserStateProvider, SingleUserStateProvider } from "../user-state.provider";
|
||||
|
||||
export class DefaultStateProvider implements StateProvider {
|
||||
activeUserId$: Observable<UserId>;
|
||||
constructor(
|
||||
private readonly activeUserStateProvider: ActiveUserStateProvider,
|
||||
private readonly singleUserStateProvider: SingleUserStateProvider,
|
||||
private readonly globalStateProvider: GlobalStateProvider,
|
||||
private readonly derivedStateProvider: DerivedStateProvider,
|
||||
) {
|
||||
this.activeUserId$ = this.activeUserStateProvider.activeUserId$;
|
||||
}
|
||||
|
||||
getUserState$<T>(userKeyDefinition: UserKeyDefinition<T>, userId?: UserId): Observable<T> {
|
||||
if (userId) {
|
||||
return this.getUser<T>(userId, userKeyDefinition).state$;
|
||||
} else {
|
||||
return this.activeUserId$.pipe(
|
||||
filter((userId) => userId != null), // Filter out null-ish user ids since we can't get state for a null user id
|
||||
take(1),
|
||||
switchMap((userId) => this.getUser<T>(userId, userKeyDefinition).state$),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
getUserStateOrDefault$<T>(
|
||||
userKeyDefinition: UserKeyDefinition<T>,
|
||||
config: { userId: UserId | undefined; defaultValue?: T },
|
||||
): Observable<T> {
|
||||
const { userId, defaultValue = null } = config;
|
||||
if (userId) {
|
||||
return this.getUser<T>(userId, userKeyDefinition).state$;
|
||||
} else {
|
||||
return this.activeUserId$.pipe(
|
||||
take(1),
|
||||
switchMap((userId) =>
|
||||
userId != null ? this.getUser<T>(userId, userKeyDefinition).state$ : of(defaultValue),
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async setUserState<T>(
|
||||
userKeyDefinition: UserKeyDefinition<T>,
|
||||
value: T | null,
|
||||
userId?: UserId,
|
||||
): Promise<[UserId, T | null]> {
|
||||
if (userId) {
|
||||
return [userId, await this.getUser<T>(userId, userKeyDefinition).update(() => value)];
|
||||
} else {
|
||||
return await this.getActive<T>(userKeyDefinition).update(() => value);
|
||||
}
|
||||
}
|
||||
|
||||
getActive: InstanceType<typeof ActiveUserStateProvider>["get"] =
|
||||
this.activeUserStateProvider.get.bind(this.activeUserStateProvider);
|
||||
getUser: InstanceType<typeof SingleUserStateProvider>["get"] =
|
||||
this.singleUserStateProvider.get.bind(this.singleUserStateProvider);
|
||||
getGlobal: InstanceType<typeof GlobalStateProvider>["get"] = this.globalStateProvider.get.bind(
|
||||
this.globalStateProvider,
|
||||
);
|
||||
getDerived: <TFrom, TTo, TDeps extends DerivedStateDependencies>(
|
||||
parentState$: Observable<TFrom>,
|
||||
deriveDefinition: DeriveDefinition<unknown, TTo, TDeps>,
|
||||
dependencies: TDeps,
|
||||
) => DerivedState<TTo> = this.derivedStateProvider.get.bind(this.derivedStateProvider);
|
||||
}
|
||||
export { DefaultStateProvider } from "@bitwarden/state";
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
import { Subject, firstValueFrom } from "rxjs";
|
||||
|
||||
import { DeriveDefinition } from "../derive-definition";
|
||||
import { StateDefinition } from "../state-definition";
|
||||
|
||||
import { InlineDerivedState } from "./inline-derived-state";
|
||||
|
||||
describe("InlineDerivedState", () => {
|
||||
const syncDeriveDefinition = new DeriveDefinition<boolean, boolean, Record<string, unknown>>(
|
||||
new StateDefinition("test", "disk"),
|
||||
"test",
|
||||
{
|
||||
derive: (value, deps) => !value,
|
||||
deserializer: (value) => value,
|
||||
},
|
||||
);
|
||||
|
||||
const asyncDeriveDefinition = new DeriveDefinition<boolean, boolean, Record<string, unknown>>(
|
||||
new StateDefinition("test", "disk"),
|
||||
"test",
|
||||
{
|
||||
derive: async (value, deps) => Promise.resolve(!value),
|
||||
deserializer: (value) => value,
|
||||
},
|
||||
);
|
||||
|
||||
const parentState = new Subject<boolean>();
|
||||
|
||||
describe("state", () => {
|
||||
const cases = [
|
||||
{
|
||||
it: "works when derive function is sync",
|
||||
definition: syncDeriveDefinition,
|
||||
},
|
||||
{
|
||||
it: "works when derive function is async",
|
||||
definition: asyncDeriveDefinition,
|
||||
},
|
||||
];
|
||||
|
||||
it.each(cases)("$it", async ({ definition }) => {
|
||||
const sut = new InlineDerivedState(parentState.asObservable(), definition, {});
|
||||
|
||||
const valuePromise = firstValueFrom(sut.state$);
|
||||
parentState.next(true);
|
||||
|
||||
const value = await valuePromise;
|
||||
|
||||
expect(value).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("forceValue", () => {
|
||||
it("returns the force value back to the caller", async () => {
|
||||
const sut = new InlineDerivedState(parentState.asObservable(), syncDeriveDefinition, {});
|
||||
|
||||
const value = await sut.forceValue(true);
|
||||
|
||||
expect(value).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,37 +1 @@
|
||||
import { Observable, concatMap } from "rxjs";
|
||||
|
||||
import { DerivedStateDependencies } from "../../../types/state";
|
||||
import { DeriveDefinition } from "../derive-definition";
|
||||
import { DerivedState } from "../derived-state";
|
||||
import { DerivedStateProvider } from "../derived-state.provider";
|
||||
|
||||
export class InlineDerivedStateProvider implements DerivedStateProvider {
|
||||
get<TFrom, TTo, TDeps extends DerivedStateDependencies>(
|
||||
parentState$: Observable<TFrom>,
|
||||
deriveDefinition: DeriveDefinition<TFrom, TTo, TDeps>,
|
||||
dependencies: TDeps,
|
||||
): DerivedState<TTo> {
|
||||
return new InlineDerivedState(parentState$, deriveDefinition, dependencies);
|
||||
}
|
||||
}
|
||||
|
||||
export class InlineDerivedState<TFrom, TTo, TDeps extends DerivedStateDependencies>
|
||||
implements DerivedState<TTo>
|
||||
{
|
||||
constructor(
|
||||
parentState$: Observable<TFrom>,
|
||||
deriveDefinition: DeriveDefinition<TFrom, TTo, TDeps>,
|
||||
dependencies: TDeps,
|
||||
) {
|
||||
this.state$ = parentState$.pipe(
|
||||
concatMap(async (value) => await deriveDefinition.derive(value, dependencies)),
|
||||
);
|
||||
}
|
||||
|
||||
state$: Observable<TTo>;
|
||||
|
||||
forceValue(value: TTo): Promise<TTo> {
|
||||
// No need to force anything, we don't keep a cache
|
||||
return Promise.resolve(value);
|
||||
}
|
||||
}
|
||||
export { InlineDerivedState, InlineDerivedStateProvider } from "@bitwarden/state";
|
||||
|
||||
@@ -1,177 +0,0 @@
|
||||
import { mock } from "jest-mock-extended";
|
||||
|
||||
import { StorageServiceProvider } from "@bitwarden/storage-core";
|
||||
|
||||
import { mockAccountServiceWith } from "../../../../spec/fake-account-service";
|
||||
import { FakeStorageService } from "../../../../spec/fake-storage.service";
|
||||
import { UserId } from "../../../types/guid";
|
||||
import { LogService } from "../../abstractions/log.service";
|
||||
import { KeyDefinition } from "../key-definition";
|
||||
import { StateDefinition } from "../state-definition";
|
||||
import { StateEventRegistrarService } from "../state-event-registrar.service";
|
||||
import { UserKeyDefinition } from "../user-key-definition";
|
||||
|
||||
import { DefaultActiveUserState } from "./default-active-user-state";
|
||||
import { DefaultActiveUserStateProvider } from "./default-active-user-state.provider";
|
||||
import { DefaultGlobalState } from "./default-global-state";
|
||||
import { DefaultGlobalStateProvider } from "./default-global-state.provider";
|
||||
import { DefaultSingleUserState } from "./default-single-user-state";
|
||||
import { DefaultSingleUserStateProvider } from "./default-single-user-state.provider";
|
||||
|
||||
describe("Specific State Providers", () => {
|
||||
const storageServiceProvider = mock<StorageServiceProvider>();
|
||||
const stateEventRegistrarService = mock<StateEventRegistrarService>();
|
||||
const logService = mock<LogService>();
|
||||
|
||||
let singleSut: DefaultSingleUserStateProvider;
|
||||
let activeSut: DefaultActiveUserStateProvider;
|
||||
let globalSut: DefaultGlobalStateProvider;
|
||||
|
||||
const fakeUser1 = "00000000-0000-1000-a000-000000000001" as UserId;
|
||||
|
||||
beforeEach(() => {
|
||||
storageServiceProvider.get.mockImplementation((location) => {
|
||||
return [location, new FakeStorageService()];
|
||||
});
|
||||
|
||||
singleSut = new DefaultSingleUserStateProvider(
|
||||
storageServiceProvider,
|
||||
stateEventRegistrarService,
|
||||
logService,
|
||||
);
|
||||
activeSut = new DefaultActiveUserStateProvider(mockAccountServiceWith(null), singleSut);
|
||||
globalSut = new DefaultGlobalStateProvider(storageServiceProvider, logService);
|
||||
});
|
||||
|
||||
const fakeDiskStateDefinition = new StateDefinition("fake", "disk");
|
||||
const fakeAlternateDiskStateDefinition = new StateDefinition("fakeAlternate", "disk");
|
||||
const fakeMemoryStateDefinition = new StateDefinition("fake", "memory");
|
||||
const makeKeyDefinition = (stateDefinition: StateDefinition, key: string) =>
|
||||
new KeyDefinition<boolean>(stateDefinition, key, {
|
||||
deserializer: (b) => b,
|
||||
});
|
||||
const makeUserKeyDefinition = (stateDefinition: StateDefinition, key: string) =>
|
||||
new UserKeyDefinition<boolean>(stateDefinition, key, {
|
||||
deserializer: (b) => b,
|
||||
clearOn: [],
|
||||
});
|
||||
const keyDefinitions = {
|
||||
disk: {
|
||||
keyDefinition: makeKeyDefinition(fakeDiskStateDefinition, "fake"),
|
||||
userKeyDefinition: makeUserKeyDefinition(fakeDiskStateDefinition, "fake"),
|
||||
altKeyDefinition: makeKeyDefinition(fakeDiskStateDefinition, "fakeAlternate"),
|
||||
altUserKeyDefinition: makeUserKeyDefinition(fakeDiskStateDefinition, "fakeAlternate"),
|
||||
},
|
||||
memory: {
|
||||
keyDefinition: makeKeyDefinition(fakeMemoryStateDefinition, "fake"),
|
||||
userKeyDefinition: makeUserKeyDefinition(fakeMemoryStateDefinition, "fake"),
|
||||
},
|
||||
alternateDisk: {
|
||||
keyDefinition: makeKeyDefinition(fakeAlternateDiskStateDefinition, "fake"),
|
||||
userKeyDefinition: makeUserKeyDefinition(fakeAlternateDiskStateDefinition, "fake"),
|
||||
},
|
||||
};
|
||||
|
||||
describe("active provider", () => {
|
||||
it("returns a DefaultActiveUserState", () => {
|
||||
const state = activeSut.get(keyDefinitions.disk.userKeyDefinition);
|
||||
|
||||
expect(state).toBeInstanceOf(DefaultActiveUserState);
|
||||
});
|
||||
|
||||
it("returns different instances when the storage location differs", () => {
|
||||
const stateDisk = activeSut.get(keyDefinitions.disk.userKeyDefinition);
|
||||
const stateMemory = activeSut.get(keyDefinitions.memory.userKeyDefinition);
|
||||
expect(stateDisk).not.toStrictEqual(stateMemory);
|
||||
});
|
||||
|
||||
it("returns different instances when the state name differs", () => {
|
||||
const state = activeSut.get(keyDefinitions.disk.userKeyDefinition);
|
||||
const stateAlt = activeSut.get(keyDefinitions.alternateDisk.userKeyDefinition);
|
||||
expect(state).not.toStrictEqual(stateAlt);
|
||||
});
|
||||
|
||||
it("returns different instances when the key differs", () => {
|
||||
const state = activeSut.get(keyDefinitions.disk.userKeyDefinition);
|
||||
const stateAlt = activeSut.get(keyDefinitions.disk.altUserKeyDefinition);
|
||||
expect(state).not.toStrictEqual(stateAlt);
|
||||
});
|
||||
});
|
||||
|
||||
describe("single provider", () => {
|
||||
it("returns a DefaultSingleUserState", () => {
|
||||
const state = singleSut.get(fakeUser1, keyDefinitions.disk.userKeyDefinition);
|
||||
|
||||
expect(state).toBeInstanceOf(DefaultSingleUserState);
|
||||
});
|
||||
|
||||
it("returns different instances when the storage location differs", () => {
|
||||
const stateDisk = singleSut.get(fakeUser1, keyDefinitions.disk.userKeyDefinition);
|
||||
const stateMemory = singleSut.get(fakeUser1, keyDefinitions.memory.userKeyDefinition);
|
||||
expect(stateDisk).not.toStrictEqual(stateMemory);
|
||||
});
|
||||
|
||||
it("returns different instances when the state name differs", () => {
|
||||
const state = singleSut.get(fakeUser1, keyDefinitions.disk.userKeyDefinition);
|
||||
const stateAlt = singleSut.get(fakeUser1, keyDefinitions.alternateDisk.userKeyDefinition);
|
||||
expect(state).not.toStrictEqual(stateAlt);
|
||||
});
|
||||
|
||||
it("returns different instances when the key differs", () => {
|
||||
const state = singleSut.get(fakeUser1, keyDefinitions.disk.userKeyDefinition);
|
||||
const stateAlt = singleSut.get(fakeUser1, keyDefinitions.disk.altUserKeyDefinition);
|
||||
expect(state).not.toStrictEqual(stateAlt);
|
||||
});
|
||||
|
||||
const fakeUser2 = "00000000-0000-1000-a000-000000000002" as UserId;
|
||||
|
||||
it("returns different instances when the user id differs", () => {
|
||||
const user1State = singleSut.get(fakeUser1, keyDefinitions.disk.userKeyDefinition);
|
||||
const user2State = singleSut.get(fakeUser2, keyDefinitions.disk.userKeyDefinition);
|
||||
expect(user1State).not.toStrictEqual(user2State);
|
||||
});
|
||||
|
||||
it("returns an instance with the userId property corresponding to the user id passed in", () => {
|
||||
const userState = singleSut.get(fakeUser1, keyDefinitions.disk.userKeyDefinition);
|
||||
expect(userState.userId).toBe(fakeUser1);
|
||||
});
|
||||
|
||||
it("returns cached instance on repeated request", () => {
|
||||
const stateFirst = singleSut.get(fakeUser1, keyDefinitions.disk.userKeyDefinition);
|
||||
const stateCached = singleSut.get(fakeUser1, keyDefinitions.disk.userKeyDefinition);
|
||||
expect(stateFirst).toStrictEqual(stateCached);
|
||||
});
|
||||
});
|
||||
|
||||
describe("global provider", () => {
|
||||
it("returns a DefaultGlobalState", () => {
|
||||
const state = globalSut.get(keyDefinitions.disk.keyDefinition);
|
||||
|
||||
expect(state).toBeInstanceOf(DefaultGlobalState);
|
||||
});
|
||||
|
||||
it("returns different instances when the storage location differs", () => {
|
||||
const stateDisk = globalSut.get(keyDefinitions.disk.keyDefinition);
|
||||
const stateMemory = globalSut.get(keyDefinitions.memory.keyDefinition);
|
||||
expect(stateDisk).not.toStrictEqual(stateMemory);
|
||||
});
|
||||
|
||||
it("returns different instances when the state name differs", () => {
|
||||
const state = globalSut.get(keyDefinitions.disk.keyDefinition);
|
||||
const stateAlt = globalSut.get(keyDefinitions.alternateDisk.keyDefinition);
|
||||
expect(state).not.toStrictEqual(stateAlt);
|
||||
});
|
||||
|
||||
it("returns different instances when the key differs", () => {
|
||||
const state = globalSut.get(keyDefinitions.disk.keyDefinition);
|
||||
const stateAlt = globalSut.get(keyDefinitions.disk.altKeyDefinition);
|
||||
expect(state).not.toStrictEqual(stateAlt);
|
||||
});
|
||||
|
||||
it("returns cached instance on repeated request", () => {
|
||||
const stateFirst = globalSut.get(keyDefinitions.disk.keyDefinition);
|
||||
const stateCached = globalSut.get(keyDefinitions.disk.keyDefinition);
|
||||
expect(stateFirst).toStrictEqual(stateCached);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,137 +1 @@
|
||||
// FIXME: Update this file to be type safe and remove this and next line
|
||||
// @ts-strict-ignore
|
||||
import {
|
||||
defer,
|
||||
filter,
|
||||
firstValueFrom,
|
||||
merge,
|
||||
Observable,
|
||||
ReplaySubject,
|
||||
share,
|
||||
switchMap,
|
||||
tap,
|
||||
timeout,
|
||||
timer,
|
||||
} from "rxjs";
|
||||
import { Jsonify } from "type-fest";
|
||||
|
||||
import { AbstractStorageService, ObservableStorageService } from "@bitwarden/storage-core";
|
||||
|
||||
import { StorageKey } from "../../../types/state";
|
||||
import { LogService } from "../../abstractions/log.service";
|
||||
import { DebugOptions } from "../key-definition";
|
||||
import { populateOptionsWithDefault, StateUpdateOptions } from "../state-update-options";
|
||||
|
||||
import { getStoredValue } from "./util";
|
||||
|
||||
// The parts of a KeyDefinition this class cares about to make it work
|
||||
type KeyDefinitionRequirements<T> = {
|
||||
deserializer: (jsonState: Jsonify<T>) => T | null;
|
||||
cleanupDelayMs: number;
|
||||
debug: Required<DebugOptions>;
|
||||
};
|
||||
|
||||
export abstract class StateBase<T, KeyDef extends KeyDefinitionRequirements<T>> {
|
||||
private updatePromise: Promise<T>;
|
||||
|
||||
readonly state$: Observable<T | null>;
|
||||
|
||||
constructor(
|
||||
protected readonly key: StorageKey,
|
||||
protected readonly storageService: AbstractStorageService & ObservableStorageService,
|
||||
protected readonly keyDefinition: KeyDef,
|
||||
protected readonly logService: LogService,
|
||||
) {
|
||||
const storageUpdate$ = storageService.updates$.pipe(
|
||||
filter((storageUpdate) => storageUpdate.key === key),
|
||||
switchMap(async (storageUpdate) => {
|
||||
if (storageUpdate.updateType === "remove") {
|
||||
return null;
|
||||
}
|
||||
|
||||
return await getStoredValue(key, storageService, keyDefinition.deserializer);
|
||||
}),
|
||||
);
|
||||
|
||||
let state$ = merge(
|
||||
defer(() => getStoredValue(key, storageService, keyDefinition.deserializer)),
|
||||
storageUpdate$,
|
||||
);
|
||||
|
||||
if (keyDefinition.debug.enableRetrievalLogging) {
|
||||
state$ = state$.pipe(
|
||||
tap({
|
||||
next: (v) => {
|
||||
this.logService.info(
|
||||
`Retrieving '${key}' from storage, value is ${v == null ? "null" : "non-null"}`,
|
||||
);
|
||||
},
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
// If 0 cleanup is chosen, treat this as absolutely no cache
|
||||
if (keyDefinition.cleanupDelayMs !== 0) {
|
||||
state$ = state$.pipe(
|
||||
share({
|
||||
connector: () => new ReplaySubject(1),
|
||||
resetOnRefCountZero: () => timer(keyDefinition.cleanupDelayMs),
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
this.state$ = state$;
|
||||
}
|
||||
|
||||
async update<TCombine>(
|
||||
configureState: (state: T | null, dependency: TCombine) => T | null,
|
||||
options: StateUpdateOptions<T, TCombine> = {},
|
||||
): Promise<T | null> {
|
||||
options = populateOptionsWithDefault(options);
|
||||
if (this.updatePromise != null) {
|
||||
await this.updatePromise;
|
||||
}
|
||||
|
||||
try {
|
||||
this.updatePromise = this.internalUpdate(configureState, options);
|
||||
return await this.updatePromise;
|
||||
} finally {
|
||||
this.updatePromise = null;
|
||||
}
|
||||
}
|
||||
|
||||
private async internalUpdate<TCombine>(
|
||||
configureState: (state: T | null, dependency: TCombine) => T | null,
|
||||
options: StateUpdateOptions<T, TCombine>,
|
||||
): Promise<T | null> {
|
||||
const currentState = await this.getStateForUpdate();
|
||||
const combinedDependencies =
|
||||
options.combineLatestWith != null
|
||||
? await firstValueFrom(options.combineLatestWith.pipe(timeout(options.msTimeout)))
|
||||
: null;
|
||||
|
||||
if (!options.shouldUpdate(currentState, combinedDependencies)) {
|
||||
return currentState;
|
||||
}
|
||||
|
||||
const newState = configureState(currentState, combinedDependencies);
|
||||
await this.doStorageSave(newState, currentState);
|
||||
return newState;
|
||||
}
|
||||
|
||||
protected async doStorageSave(newState: T | null, oldState: T) {
|
||||
if (this.keyDefinition.debug.enableUpdateLogging) {
|
||||
this.logService.info(
|
||||
`Updating '${this.key}' from ${oldState == null ? "null" : "non-null"} to ${newState == null ? "null" : "non-null"}`,
|
||||
);
|
||||
}
|
||||
await this.storageService.save(this.key, newState);
|
||||
}
|
||||
|
||||
/** For use in update methods, does not wait for update to complete before yielding state.
|
||||
* The expectation is that that await is already done
|
||||
*/
|
||||
private async getStateForUpdate() {
|
||||
return await getStoredValue(this.key, this.storageService, this.keyDefinition.deserializer);
|
||||
}
|
||||
}
|
||||
export { StateBase } from "@bitwarden/state";
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
import { FakeStorageService } from "../../../../spec/fake-storage.service";
|
||||
|
||||
import { getStoredValue } from "./util";
|
||||
|
||||
describe("getStoredValue", () => {
|
||||
const key = "key";
|
||||
const deserializedValue = { value: 1 };
|
||||
const value = JSON.stringify(deserializedValue);
|
||||
const deserializer = (v: string) => JSON.parse(v);
|
||||
let storageService: FakeStorageService;
|
||||
|
||||
beforeEach(() => {
|
||||
storageService = new FakeStorageService();
|
||||
});
|
||||
|
||||
describe("when the storage service requires deserialization", () => {
|
||||
beforeEach(() => {
|
||||
storageService.internalUpdateValuesRequireDeserialization(true);
|
||||
});
|
||||
|
||||
it("should deserialize", async () => {
|
||||
// FIXME: Verify that this floating promise is intentional. If it is, add an explanatory comment and ensure there is proper error handling.
|
||||
// eslint-disable-next-line @typescript-eslint/no-floating-promises
|
||||
storageService.save(key, value);
|
||||
|
||||
const result = await getStoredValue(key, storageService, deserializer);
|
||||
|
||||
expect(result).toEqual(deserializedValue);
|
||||
});
|
||||
});
|
||||
describe("when the storage service does not require deserialization", () => {
|
||||
beforeEach(() => {
|
||||
storageService.internalUpdateValuesRequireDeserialization(false);
|
||||
});
|
||||
|
||||
it("should not deserialize", async () => {
|
||||
// FIXME: Verify that this floating promise is intentional. If it is, add an explanatory comment and ensure there is proper error handling.
|
||||
// eslint-disable-next-line @typescript-eslint/no-floating-promises
|
||||
storageService.save(key, value);
|
||||
|
||||
const result = await getStoredValue(key, storageService, deserializer);
|
||||
|
||||
expect(result).toEqual(value);
|
||||
});
|
||||
|
||||
it("should convert undefined to null", async () => {
|
||||
// FIXME: Verify that this floating promise is intentional. If it is, add an explanatory comment and ensure there is proper error handling.
|
||||
// eslint-disable-next-line @typescript-eslint/no-floating-promises
|
||||
storageService.save(key, undefined);
|
||||
|
||||
const result = await getStoredValue(key, storageService, deserializer);
|
||||
|
||||
expect(result).toEqual(null);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,17 +0,0 @@
|
||||
import { Jsonify } from "type-fest";
|
||||
|
||||
import { AbstractStorageService } from "@bitwarden/storage-core";
|
||||
|
||||
export async function getStoredValue<T>(
|
||||
key: string,
|
||||
storage: AbstractStorageService,
|
||||
deserializer: (jsonValue: Jsonify<T>) => T | null,
|
||||
) {
|
||||
if (storage.valuesRequireDeserialization) {
|
||||
const jsonValue = await storage.get<Jsonify<T>>(key);
|
||||
return deserializer(jsonValue);
|
||||
} else {
|
||||
const value = await storage.get<T>(key);
|
||||
return value ?? null;
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1 @@
|
||||
export { DeriveDefinition } from "./derive-definition";
|
||||
export { DerivedStateProvider } from "./derived-state.provider";
|
||||
export { DerivedState } from "./derived-state";
|
||||
export { GlobalState } from "./global-state";
|
||||
export { StateProvider } from "./state.provider";
|
||||
export { GlobalStateProvider } from "./global-state.provider";
|
||||
export { ActiveUserState, SingleUserState, CombinedState } from "./user-state";
|
||||
export { ActiveUserStateProvider, SingleUserStateProvider } from "./user-state.provider";
|
||||
export { KeyDefinition, KeyDefinitionOptions } from "./key-definition";
|
||||
export { StateUpdateOptions } from "./state-update-options";
|
||||
export { UserKeyDefinitionOptions, UserKeyDefinition } from "./user-key-definition";
|
||||
export { StateEventRunnerService } from "./state-event-runner.service";
|
||||
|
||||
export * from "./state-definitions";
|
||||
export * from "@bitwarden/state";
|
||||
|
||||
@@ -1,204 +0,0 @@
|
||||
import { Opaque } from "type-fest";
|
||||
|
||||
import { DebugOptions, KeyDefinition } from "./key-definition";
|
||||
import { StateDefinition } from "./state-definition";
|
||||
|
||||
const fakeStateDefinition = new StateDefinition("fake", "disk");
|
||||
|
||||
type FancyString = Opaque<string, "FancyString">;
|
||||
|
||||
describe("KeyDefinition", () => {
|
||||
describe("constructor", () => {
|
||||
it("throws on undefined deserializer", () => {
|
||||
expect(() => {
|
||||
new KeyDefinition<boolean>(fakeStateDefinition, "fake", {
|
||||
deserializer: undefined,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it("normalizes debug options set to undefined", () => {
|
||||
const keyDefinition = new KeyDefinition(fakeStateDefinition, "fake", {
|
||||
deserializer: (v) => v,
|
||||
debug: undefined,
|
||||
});
|
||||
|
||||
expect(keyDefinition.debug.enableUpdateLogging).toBe(false);
|
||||
});
|
||||
|
||||
it("normalizes no debug options", () => {
|
||||
const keyDefinition = new KeyDefinition(fakeStateDefinition, "fake", {
|
||||
deserializer: (v) => v,
|
||||
});
|
||||
|
||||
expect(keyDefinition.debug.enableUpdateLogging).toBe(false);
|
||||
});
|
||||
|
||||
const cases: {
|
||||
debug: DebugOptions | undefined;
|
||||
expectedEnableUpdateLogging: boolean;
|
||||
expectedEnableRetrievalLogging: boolean;
|
||||
}[] = [
|
||||
{
|
||||
debug: undefined,
|
||||
expectedEnableUpdateLogging: false,
|
||||
expectedEnableRetrievalLogging: false,
|
||||
},
|
||||
{
|
||||
debug: {},
|
||||
expectedEnableUpdateLogging: false,
|
||||
expectedEnableRetrievalLogging: false,
|
||||
},
|
||||
{
|
||||
debug: {
|
||||
enableUpdateLogging: false,
|
||||
},
|
||||
expectedEnableUpdateLogging: false,
|
||||
expectedEnableRetrievalLogging: false,
|
||||
},
|
||||
{
|
||||
debug: {
|
||||
enableRetrievalLogging: false,
|
||||
},
|
||||
expectedEnableUpdateLogging: false,
|
||||
expectedEnableRetrievalLogging: false,
|
||||
},
|
||||
{
|
||||
debug: {
|
||||
enableUpdateLogging: true,
|
||||
},
|
||||
expectedEnableUpdateLogging: true,
|
||||
expectedEnableRetrievalLogging: false,
|
||||
},
|
||||
{
|
||||
debug: {
|
||||
enableRetrievalLogging: true,
|
||||
},
|
||||
expectedEnableUpdateLogging: false,
|
||||
expectedEnableRetrievalLogging: true,
|
||||
},
|
||||
{
|
||||
debug: {
|
||||
enableRetrievalLogging: false,
|
||||
enableUpdateLogging: false,
|
||||
},
|
||||
expectedEnableUpdateLogging: false,
|
||||
expectedEnableRetrievalLogging: false,
|
||||
},
|
||||
{
|
||||
debug: {
|
||||
enableRetrievalLogging: true,
|
||||
enableUpdateLogging: true,
|
||||
},
|
||||
expectedEnableUpdateLogging: true,
|
||||
expectedEnableRetrievalLogging: true,
|
||||
},
|
||||
];
|
||||
|
||||
it.each(cases)(
|
||||
"normalizes debug options to correct values when given $debug",
|
||||
({ debug, expectedEnableUpdateLogging, expectedEnableRetrievalLogging }) => {
|
||||
const keyDefinition = new KeyDefinition(fakeStateDefinition, "fake", {
|
||||
deserializer: (v) => v,
|
||||
debug: debug,
|
||||
});
|
||||
|
||||
expect(keyDefinition.debug.enableUpdateLogging).toBe(expectedEnableUpdateLogging);
|
||||
expect(keyDefinition.debug.enableRetrievalLogging).toBe(expectedEnableRetrievalLogging);
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
describe("cleanupDelayMs", () => {
|
||||
it("defaults to 1000ms", () => {
|
||||
const keyDefinition = new KeyDefinition<boolean>(fakeStateDefinition, "fake", {
|
||||
deserializer: (value) => value,
|
||||
});
|
||||
|
||||
expect(keyDefinition).toBeTruthy();
|
||||
expect(keyDefinition.cleanupDelayMs).toBe(1000);
|
||||
});
|
||||
|
||||
it("can be overridden", () => {
|
||||
const keyDefinition = new KeyDefinition<boolean>(fakeStateDefinition, "fake", {
|
||||
deserializer: (value) => value,
|
||||
cleanupDelayMs: 500,
|
||||
});
|
||||
|
||||
expect(keyDefinition).toBeTruthy();
|
||||
expect(keyDefinition.cleanupDelayMs).toBe(500);
|
||||
});
|
||||
|
||||
it("throws on negative", () => {
|
||||
expect(
|
||||
() =>
|
||||
new KeyDefinition<boolean>(fakeStateDefinition, "fake", {
|
||||
deserializer: (value) => value,
|
||||
cleanupDelayMs: -1,
|
||||
}),
|
||||
).toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe("record", () => {
|
||||
it("runs custom deserializer for each record value", () => {
|
||||
const recordDefinition = KeyDefinition.record<boolean>(fakeStateDefinition, "fake", {
|
||||
// Intentionally negate the value for testing
|
||||
deserializer: (value) => !value,
|
||||
});
|
||||
|
||||
expect(recordDefinition).toBeTruthy();
|
||||
expect(recordDefinition.deserializer).toBeTruthy();
|
||||
|
||||
const deserializedValue = recordDefinition.deserializer({
|
||||
test1: false,
|
||||
test2: true,
|
||||
});
|
||||
|
||||
expect(Object.keys(deserializedValue)).toHaveLength(2);
|
||||
|
||||
// Values should have swapped from their initial value
|
||||
expect(deserializedValue["test1"]).toBeTruthy();
|
||||
expect(deserializedValue["test2"]).toBeFalsy();
|
||||
});
|
||||
|
||||
it("can handle fancy string type", () => {
|
||||
// This test is more of a test that I got the typescript typing correctly than actually testing any business logic
|
||||
const recordDefinition = KeyDefinition.record<boolean, FancyString>(
|
||||
fakeStateDefinition,
|
||||
"fake",
|
||||
{
|
||||
deserializer: (value) => !value,
|
||||
},
|
||||
);
|
||||
|
||||
const fancyRecord = recordDefinition.deserializer(
|
||||
JSON.parse(`{ "myKey": false, "mySecondKey": true }`),
|
||||
);
|
||||
|
||||
expect(fancyRecord).toBeTruthy();
|
||||
expect(Object.keys(fancyRecord)).toHaveLength(2);
|
||||
expect(fancyRecord["myKey" as FancyString]).toBeTruthy();
|
||||
expect(fancyRecord["mySecondKey" as FancyString]).toBeFalsy();
|
||||
});
|
||||
});
|
||||
|
||||
describe("array", () => {
|
||||
it("run custom deserializer for each array element", () => {
|
||||
const arrayDefinition = KeyDefinition.array<boolean>(fakeStateDefinition, "fake", {
|
||||
deserializer: (value) => !value,
|
||||
});
|
||||
|
||||
expect(arrayDefinition).toBeTruthy();
|
||||
expect(arrayDefinition.deserializer).toBeTruthy();
|
||||
|
||||
// NOTE: `as any` is here until we migrate to Nx: https://bitwarden.atlassian.net/browse/PM-6493
|
||||
const deserializedValue = arrayDefinition.deserializer([false, true] as any);
|
||||
|
||||
expect(deserializedValue).toBeTruthy();
|
||||
expect(deserializedValue).toHaveLength(2);
|
||||
expect(deserializedValue[0]).toBeTruthy();
|
||||
expect(deserializedValue[1]).toBeFalsy();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,182 +1 @@
|
||||
// FIXME: Update this file to be type safe and remove this and next line
|
||||
// @ts-strict-ignore
|
||||
import { Jsonify } from "type-fest";
|
||||
|
||||
import { StorageKey } from "../../types/state";
|
||||
|
||||
import { array, record } from "./deserialization-helpers";
|
||||
import { StateDefinition } from "./state-definition";
|
||||
|
||||
export type DebugOptions = {
|
||||
/**
|
||||
* When true, logs will be written that look like the following:
|
||||
*
|
||||
* ```
|
||||
* "Updating 'global_myState_myKey' from null to non-null"
|
||||
* "Updating 'user_32265eda-62ff-4797-9ead-22214772f888_myState_myKey' from non-null to null."
|
||||
* ```
|
||||
*
|
||||
* It does not include the value of the data, only whether it is null or non-null.
|
||||
*/
|
||||
enableUpdateLogging?: boolean;
|
||||
|
||||
/**
|
||||
* When true, logs will be written that look like the following everytime a value is retrieved from storage.
|
||||
*
|
||||
* "Retrieving 'global_myState_myKey' from storage, value is null."
|
||||
* "Retrieving 'user_32265eda-62ff-4797-9ead-22214772f888_myState_myKey' from storage, value is non-null."
|
||||
*/
|
||||
enableRetrievalLogging?: boolean;
|
||||
};
|
||||
|
||||
/**
|
||||
* A set of options for customizing the behavior of a {@link KeyDefinition}
|
||||
*/
|
||||
export type KeyDefinitionOptions<T> = {
|
||||
/**
|
||||
* A function to use to safely convert your type from json to your expected type.
|
||||
*
|
||||
* **Important:** Your data may be serialized/deserialized at any time and this
|
||||
* callback needs to be able to faithfully re-initialize from the JSON object representation of your type.
|
||||
*
|
||||
* @param jsonValue The JSON object representation of your state.
|
||||
* @returns The fully typed version of your state.
|
||||
*/
|
||||
readonly deserializer: (jsonValue: Jsonify<T>) => T | null;
|
||||
/**
|
||||
* The number of milliseconds to wait before cleaning up the state after the last subscriber has unsubscribed.
|
||||
* Defaults to 1000ms.
|
||||
*/
|
||||
readonly cleanupDelayMs?: number;
|
||||
|
||||
/**
|
||||
* Options for configuring the debugging behavior, see individual options for more info.
|
||||
*/
|
||||
readonly debug?: DebugOptions;
|
||||
};
|
||||
|
||||
/**
|
||||
* KeyDefinitions describe the precise location to store data for a given piece of state.
|
||||
* The StateDefinition is used to describe the domain of the state, and the KeyDefinition
|
||||
* sub-divides that domain into specific keys.
|
||||
*/
|
||||
export class KeyDefinition<T> {
|
||||
readonly debug: Required<DebugOptions>;
|
||||
|
||||
/**
|
||||
* Creates a new instance of a KeyDefinition
|
||||
* @param stateDefinition The state definition for which this key belongs to.
|
||||
* @param key The name of the key, this should be unique per domain.
|
||||
* @param options A set of options to customize the behavior of {@link KeyDefinition}. All options are required.
|
||||
* @param options.deserializer A function to use to safely convert your type from json to your expected type.
|
||||
* Your data may be serialized/deserialized at any time and this needs callback needs to be able to faithfully re-initialize
|
||||
* from the JSON object representation of your type.
|
||||
*/
|
||||
constructor(
|
||||
readonly stateDefinition: StateDefinition,
|
||||
readonly key: string,
|
||||
private readonly options: KeyDefinitionOptions<T>,
|
||||
) {
|
||||
if (options.deserializer == null) {
|
||||
throw new Error(`'deserializer' is a required property on key ${this.errorKeyName}`);
|
||||
}
|
||||
|
||||
if (options.cleanupDelayMs < 0) {
|
||||
throw new Error(
|
||||
`'cleanupDelayMs' must be greater than or equal to 0. Value of ${options.cleanupDelayMs} passed to key ${this.errorKeyName} `,
|
||||
);
|
||||
}
|
||||
|
||||
// Normalize optional debug options
|
||||
const { enableUpdateLogging = false, enableRetrievalLogging = false } = options.debug ?? {};
|
||||
this.debug = {
|
||||
enableUpdateLogging,
|
||||
enableRetrievalLogging,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the deserializer configured for this {@link KeyDefinition}
|
||||
*/
|
||||
get deserializer() {
|
||||
return this.options.deserializer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the number of milliseconds to wait before cleaning up the state after the last subscriber has unsubscribed.
|
||||
*/
|
||||
get cleanupDelayMs() {
|
||||
return this.options.cleanupDelayMs < 0 ? 0 : (this.options.cleanupDelayMs ?? 1000);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link KeyDefinition} for state that is an array.
|
||||
* @param stateDefinition The state definition to be added to the KeyDefinition
|
||||
* @param key The key to be added to the KeyDefinition
|
||||
* @param options The options to customize the final {@link KeyDefinition}.
|
||||
* @returns A {@link KeyDefinition} initialized for arrays, the options run
|
||||
* the deserializer on the provided options for each element of an array.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const MY_KEY = KeyDefinition.array<MyArrayElement>(MY_STATE, "key", {
|
||||
* deserializer: (myJsonElement) => convertToElement(myJsonElement),
|
||||
* });
|
||||
* ```
|
||||
*/
|
||||
static array<T>(
|
||||
stateDefinition: StateDefinition,
|
||||
key: string,
|
||||
// We have them provide options for the element of the array, depending on future options we add, this could get a little weird.
|
||||
options: KeyDefinitionOptions<T>, // The array helper forces an initialValue of an empty array
|
||||
) {
|
||||
return new KeyDefinition<T[]>(stateDefinition, key, {
|
||||
...options,
|
||||
deserializer: array((e) => options.deserializer(e)),
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link KeyDefinition} for state that is a record.
|
||||
* @param stateDefinition The state definition to be added to the KeyDefinition
|
||||
* @param key The key to be added to the KeyDefinition
|
||||
* @param options The options to customize the final {@link KeyDefinition}.
|
||||
* @returns A {@link KeyDefinition} that contains a serializer that will run the provided deserializer for each
|
||||
* value in a record and returns every key as a string.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const MY_KEY = KeyDefinition.record<MyRecordValue>(MY_STATE, "key", {
|
||||
* deserializer: (myJsonValue) => convertToValue(myJsonValue),
|
||||
* });
|
||||
* ```
|
||||
*/
|
||||
static record<T, TKey extends string | number = string>(
|
||||
stateDefinition: StateDefinition,
|
||||
key: string,
|
||||
// We have them provide options for the value of the record, depending on future options we add, this could get a little weird.
|
||||
options: KeyDefinitionOptions<T>, // The array helper forces an initialValue of an empty record
|
||||
) {
|
||||
return new KeyDefinition<Record<TKey, T>>(stateDefinition, key, {
|
||||
...options,
|
||||
deserializer: record((v) => options.deserializer(v)),
|
||||
});
|
||||
}
|
||||
|
||||
get fullName() {
|
||||
return `${this.stateDefinition.name}_${this.key}`;
|
||||
}
|
||||
|
||||
protected get errorKeyName() {
|
||||
return `${this.stateDefinition.name} > ${this.key}`;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link StorageKey}
|
||||
* @param keyDefinition The key definition of which data the key should point to.
|
||||
* @returns A key that is ready to be used in a storage service to get data.
|
||||
*/
|
||||
export function globalKeyBuilder(keyDefinition: KeyDefinition<unknown>): StorageKey {
|
||||
return `global_${keyDefinition.stateDefinition.name}_${keyDefinition.key}` as StorageKey;
|
||||
}
|
||||
export { KeyDefinition, KeyDefinitionOptions } from "@bitwarden/state";
|
||||
|
||||
@@ -1,24 +1,4 @@
|
||||
import { StorageLocation, ClientLocations } from "@bitwarden/storage-core";
|
||||
export { StateDefinition } from "@bitwarden/state";
|
||||
|
||||
// To be removed once references are updated to point to @bitwarden/storage-core
|
||||
export { StorageLocation, ClientLocations };
|
||||
|
||||
/**
|
||||
* Defines the base location and instruction of where this state is expected to be located.
|
||||
*/
|
||||
export class StateDefinition {
|
||||
readonly storageLocationOverrides: Partial<ClientLocations>;
|
||||
|
||||
/**
|
||||
* Creates a new instance of {@link StateDefinition}, the creation of which is owned by the platform team.
|
||||
* @param name The name of the state, this needs to be unique from all other {@link StateDefinition}'s.
|
||||
* @param defaultStorageLocation The location of where this state should be stored.
|
||||
*/
|
||||
constructor(
|
||||
readonly name: string,
|
||||
readonly defaultStorageLocation: StorageLocation,
|
||||
storageLocationOverrides?: Partial<ClientLocations>,
|
||||
) {
|
||||
this.storageLocationOverrides = storageLocationOverrides ?? {};
|
||||
}
|
||||
}
|
||||
export { StorageLocation, ClientLocations } from "@bitwarden/storage-core";
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
import { ClientLocations, StateDefinition } from "./state-definition";
|
||||
import * as stateDefinitionsRecord from "./state-definitions";
|
||||
|
||||
describe.each(["web", "cli", "desktop", "browser"])(
|
||||
"state definitions follow rules for client %s",
|
||||
(clientType: keyof ClientLocations) => {
|
||||
const trackedNames: [string, string][] = [];
|
||||
|
||||
test.each(Object.entries(stateDefinitionsRecord))(
|
||||
"that export %s follows all rules",
|
||||
(exportName, stateDefinition) => {
|
||||
// All exports from state-definitions are expected to be StateDefinition's
|
||||
if (!(stateDefinition instanceof StateDefinition)) {
|
||||
throw new Error(`export ${exportName} is expected to be a StateDefinition`);
|
||||
}
|
||||
|
||||
const storageLocation =
|
||||
stateDefinition.storageLocationOverrides[clientType] ??
|
||||
stateDefinition.defaultStorageLocation;
|
||||
|
||||
const fullName = `${stateDefinition.name}_${storageLocation}`;
|
||||
|
||||
const exactConflictingExport = trackedNames.find(
|
||||
([_, trackedName]) => trackedName === fullName,
|
||||
);
|
||||
if (exactConflictingExport !== undefined) {
|
||||
const [conflictingExportName] = exactConflictingExport;
|
||||
throw new Error(
|
||||
`The export '${exportName}' has a conflicting state name and storage location with export ` +
|
||||
`'${conflictingExportName}' please ensure that you choose a unique name and location for all clients.`,
|
||||
);
|
||||
}
|
||||
|
||||
const roughConflictingExport = trackedNames.find(
|
||||
([_, trackedName]) => trackedName.toLowerCase() === fullName.toLowerCase(),
|
||||
);
|
||||
if (roughConflictingExport !== undefined) {
|
||||
const [conflictingExportName] = roughConflictingExport;
|
||||
throw new Error(
|
||||
`The export '${exportName}' differs its state name and storage location ` +
|
||||
`only by casing with export '${conflictingExportName}' please ensure it differs by more than casing.`,
|
||||
);
|
||||
}
|
||||
|
||||
const name = stateDefinition.name;
|
||||
|
||||
expect(name).not.toBeUndefined(); // undefined in an invalid name
|
||||
expect(name).not.toBeNull(); // null is in invalid name
|
||||
expect(name.length).toBeGreaterThan(3); // A 3 characters or less name is not descriptive enough
|
||||
expect(name[0]).toEqual(name[0].toLowerCase()); // First character should be lower case since camelCase is required
|
||||
expect(name).not.toContain(" "); // There should be no spaces in a state name
|
||||
expect(name).not.toContain("_"); // We should not be doing snake_case for state name
|
||||
|
||||
// NOTE: We could expect some details about the export name as well
|
||||
|
||||
trackedNames.push([exportName, fullName]);
|
||||
},
|
||||
);
|
||||
},
|
||||
);
|
||||
@@ -1,215 +0,0 @@
|
||||
import { StateDefinition } from "./state-definition";
|
||||
|
||||
/**
|
||||
* `StateDefinition`s comes with some rules, to facilitate a quick review from
|
||||
* platform of this file, ensure you follow these rules, the ones marked with (tested)
|
||||
* have unit tests that you can run locally.
|
||||
*
|
||||
* 1. (tested) Names should not be null or undefined
|
||||
* 2. (tested) Name and storage location should be unique
|
||||
* 3. (tested) Name and storage location can't differ from another export by only casing
|
||||
* 4. (tested) Name should be longer than 3 characters. It should be descriptive, but brief.
|
||||
* 5. (tested) Name should not contain spaces or underscores
|
||||
* 6. Name should be human readable
|
||||
* 7. Name should be in camelCase format (unit tests ensure the first character is lowercase)
|
||||
* 8. Teams should only use state definitions they have created
|
||||
* 9. StateDefinitions should only be used for keys relating to the state name they chose
|
||||
*
|
||||
*/
|
||||
|
||||
// Admin Console
|
||||
|
||||
export const ORGANIZATIONS_DISK = new StateDefinition("organizations", "disk");
|
||||
export const POLICIES_DISK = new StateDefinition("policies", "disk");
|
||||
export const PROVIDERS_DISK = new StateDefinition("providers", "disk");
|
||||
export const ORGANIZATION_MANAGEMENT_PREFERENCES_DISK = new StateDefinition(
|
||||
"organizationManagementPreferences",
|
||||
"disk",
|
||||
{
|
||||
web: "disk-local",
|
||||
},
|
||||
);
|
||||
export const DELETE_MANAGED_USER_WARNING = new StateDefinition(
|
||||
"showDeleteManagedUserWarning",
|
||||
"disk",
|
||||
{
|
||||
web: "disk-local",
|
||||
},
|
||||
);
|
||||
|
||||
// Billing
|
||||
export const BILLING_DISK = new StateDefinition("billing", "disk");
|
||||
|
||||
// Auth
|
||||
|
||||
export const ACCOUNT_DISK = new StateDefinition("account", "disk");
|
||||
export const ACCOUNT_MEMORY = new StateDefinition("account", "memory");
|
||||
export const AUTH_REQUEST_DISK_LOCAL = new StateDefinition("authRequestLocal", "disk", {
|
||||
web: "disk-local",
|
||||
});
|
||||
export const AVATAR_DISK = new StateDefinition("avatar", "disk", { web: "disk-local" });
|
||||
export const DEVICE_TRUST_DISK_LOCAL = new StateDefinition("deviceTrust", "disk", {
|
||||
web: "disk-local",
|
||||
browser: "disk-backup-local-storage",
|
||||
});
|
||||
export const KDF_CONFIG_DISK = new StateDefinition("kdfConfig", "disk");
|
||||
export const KEY_CONNECTOR_DISK = new StateDefinition("keyConnector", "disk");
|
||||
export const LOGIN_EMAIL_DISK = new StateDefinition("loginEmail", "disk", {
|
||||
web: "disk-local",
|
||||
});
|
||||
export const LOGIN_EMAIL_MEMORY = new StateDefinition("loginEmail", "memory");
|
||||
export const LOGIN_STRATEGY_MEMORY = new StateDefinition("loginStrategy", "memory");
|
||||
export const MASTER_PASSWORD_DISK = new StateDefinition("masterPassword", "disk");
|
||||
export const MASTER_PASSWORD_MEMORY = new StateDefinition("masterPassword", "memory");
|
||||
export const PIN_DISK = new StateDefinition("pinUnlock", "disk");
|
||||
export const PIN_MEMORY = new StateDefinition("pinUnlock", "memory");
|
||||
export const ROUTER_DISK = new StateDefinition("router", "disk");
|
||||
export const SSO_DISK = new StateDefinition("ssoLogin", "disk");
|
||||
export const TOKEN_DISK = new StateDefinition("token", "disk");
|
||||
export const TOKEN_DISK_LOCAL = new StateDefinition("tokenDiskLocal", "disk", {
|
||||
web: "disk-local",
|
||||
});
|
||||
export const TOKEN_MEMORY = new StateDefinition("token", "memory");
|
||||
export const TWO_FACTOR_MEMORY = new StateDefinition("twoFactor", "memory");
|
||||
export const USER_DECRYPTION_OPTIONS_DISK = new StateDefinition("userDecryptionOptions", "disk");
|
||||
export const ORGANIZATION_INVITE_DISK = new StateDefinition("organizationInvite", "disk");
|
||||
export const VAULT_TIMEOUT_SETTINGS_DISK_LOCAL = new StateDefinition(
|
||||
"vaultTimeoutSettings",
|
||||
"disk",
|
||||
{
|
||||
web: "disk-local",
|
||||
},
|
||||
);
|
||||
|
||||
// Autofill
|
||||
|
||||
export const BADGE_SETTINGS_DISK = new StateDefinition("badgeSettings", "disk");
|
||||
export const USER_NOTIFICATION_SETTINGS_DISK = new StateDefinition(
|
||||
"userNotificationSettings",
|
||||
"disk",
|
||||
);
|
||||
|
||||
export const DOMAIN_SETTINGS_DISK = new StateDefinition("domainSettings", "disk");
|
||||
export const AUTOFILL_SETTINGS_DISK = new StateDefinition("autofillSettings", "disk");
|
||||
export const AUTOFILL_SETTINGS_DISK_LOCAL = new StateDefinition("autofillSettingsLocal", "disk", {
|
||||
web: "disk-local",
|
||||
});
|
||||
|
||||
export const AUTOTYPE_SETTINGS_DISK = new StateDefinition("autotypeSettings", "disk");
|
||||
|
||||
// Components
|
||||
|
||||
export const NEW_WEB_LAYOUT_BANNER_DISK = new StateDefinition("newWebLayoutBanner", "disk", {
|
||||
web: "disk-local",
|
||||
});
|
||||
|
||||
// Platform
|
||||
|
||||
export const APPLICATION_ID_DISK = new StateDefinition("applicationId", "disk", {
|
||||
web: "disk-local",
|
||||
});
|
||||
export const BADGE_MEMORY = new StateDefinition("badge", "memory", {
|
||||
browser: "memory-large-object",
|
||||
});
|
||||
export const BIOMETRIC_SETTINGS_DISK = new StateDefinition("biometricSettings", "disk");
|
||||
export const CLEAR_EVENT_DISK = new StateDefinition("clearEvent", "disk");
|
||||
export const CONFIG_DISK = new StateDefinition("config", "disk", {
|
||||
web: "disk-local",
|
||||
});
|
||||
export const CRYPTO_DISK = new StateDefinition("crypto", "disk");
|
||||
export const CRYPTO_MEMORY = new StateDefinition("crypto", "memory");
|
||||
export const DESKTOP_SETTINGS_DISK = new StateDefinition("desktopSettings", "disk");
|
||||
export const ENVIRONMENT_DISK = new StateDefinition("environment", "disk");
|
||||
export const ENVIRONMENT_MEMORY = new StateDefinition("environment", "memory");
|
||||
export const POPUP_VIEW_MEMORY = new StateDefinition("popupView", "memory", {
|
||||
browser: "memory-large-object",
|
||||
});
|
||||
export const SYNC_DISK = new StateDefinition("sync", "disk", { web: "memory" });
|
||||
export const THEMING_DISK = new StateDefinition("theming", "disk", { web: "disk-local" });
|
||||
export const TRANSLATION_DISK = new StateDefinition("translation", "disk", { web: "disk-local" });
|
||||
export const ANIMATION_DISK = new StateDefinition("animation", "disk");
|
||||
export const TASK_SCHEDULER_DISK = new StateDefinition("taskScheduler", "disk");
|
||||
export const EXTENSION_INITIAL_INSTALL_DISK = new StateDefinition(
|
||||
"extensionInitialInstall",
|
||||
"disk",
|
||||
);
|
||||
export const WEB_PUSH_SUBSCRIPTION = new StateDefinition("webPushSubscription", "disk", {
|
||||
web: "disk-local",
|
||||
});
|
||||
|
||||
// Design System
|
||||
|
||||
export const POPUP_STYLE_DISK = new StateDefinition("popupStyle", "disk");
|
||||
|
||||
// Secrets Manager
|
||||
|
||||
export const SM_ONBOARDING_DISK = new StateDefinition("smOnboarding", "disk", {
|
||||
web: "disk-local",
|
||||
});
|
||||
|
||||
// Tools
|
||||
|
||||
export const EXTENSION_DISK = new StateDefinition("extension", "disk");
|
||||
export const GENERATOR_DISK = new StateDefinition("generator", "disk");
|
||||
export const GENERATOR_MEMORY = new StateDefinition("generator", "memory");
|
||||
export const BROWSER_SEND_MEMORY = new StateDefinition("sendBrowser", "memory");
|
||||
export const EVENT_COLLECTION_DISK = new StateDefinition("eventCollection", "disk");
|
||||
export const SEND_DISK = new StateDefinition("encryptedSend", "disk", {
|
||||
web: "memory",
|
||||
});
|
||||
export const SEND_MEMORY = new StateDefinition("decryptedSend", "memory", {
|
||||
browser: "memory-large-object",
|
||||
});
|
||||
export const SEND_ACCESS_AUTH_MEMORY = new StateDefinition("sendAccessAuth", "memory");
|
||||
|
||||
// Vault
|
||||
|
||||
export const COLLECTION_DATA = new StateDefinition("collection", "disk", {
|
||||
web: "memory",
|
||||
});
|
||||
export const FOLDER_DISK = new StateDefinition("folder", "disk", { web: "memory" });
|
||||
export const FOLDER_MEMORY = new StateDefinition("decryptedFolders", "memory", {
|
||||
browser: "memory-large-object",
|
||||
});
|
||||
export const VAULT_FILTER_DISK = new StateDefinition("vaultFilter", "disk", {
|
||||
web: "disk-local",
|
||||
});
|
||||
export const VAULT_ONBOARDING = new StateDefinition("vaultOnboarding", "disk", {
|
||||
web: "disk-local",
|
||||
});
|
||||
export const VAULT_SETTINGS_DISK = new StateDefinition("vaultSettings", "disk", {
|
||||
web: "disk-local",
|
||||
});
|
||||
export const VAULT_BROWSER_MEMORY = new StateDefinition("vaultBrowser", "memory", {
|
||||
browser: "memory-large-object",
|
||||
});
|
||||
export const VAULT_SEARCH_MEMORY = new StateDefinition("vaultSearch", "memory", {
|
||||
browser: "memory-large-object",
|
||||
});
|
||||
export const CIPHERS_DISK = new StateDefinition("ciphers", "disk", { web: "memory" });
|
||||
export const CIPHERS_DISK_LOCAL = new StateDefinition("ciphersLocal", "disk", {
|
||||
web: "disk-local",
|
||||
});
|
||||
export const CIPHERS_MEMORY = new StateDefinition("ciphersMemory", "memory", {
|
||||
browser: "memory-large-object",
|
||||
});
|
||||
export const PREMIUM_BANNER_DISK_LOCAL = new StateDefinition("premiumBannerReprompt", "disk", {
|
||||
web: "disk-local",
|
||||
});
|
||||
export const BANNERS_DISMISSED_DISK = new StateDefinition("bannersDismissed", "disk");
|
||||
export const VAULT_APPEARANCE = new StateDefinition("vaultAppearance", "disk");
|
||||
export const SECURITY_TASKS_DISK = new StateDefinition("securityTasks", "disk");
|
||||
export const AT_RISK_PASSWORDS_PAGE_DISK = new StateDefinition("atRiskPasswordsPage", "disk");
|
||||
export const NOTIFICATION_DISK = new StateDefinition("notifications", "disk");
|
||||
export const NUDGES_DISK = new StateDefinition("nudges", "disk", { web: "disk-local" });
|
||||
export const SETUP_EXTENSION_DISMISSED_DISK = new StateDefinition(
|
||||
"setupExtensionDismissed",
|
||||
"disk",
|
||||
{
|
||||
web: "disk-local",
|
||||
},
|
||||
);
|
||||
export const VAULT_BROWSER_INTRO_CAROUSEL = new StateDefinition(
|
||||
"vaultBrowserIntroCarousel",
|
||||
"disk",
|
||||
);
|
||||
@@ -1,89 +0,0 @@
|
||||
import { mock } from "jest-mock-extended";
|
||||
|
||||
import {
|
||||
AbstractStorageService,
|
||||
ObservableStorageService,
|
||||
StorageServiceProvider,
|
||||
} from "@bitwarden/storage-core";
|
||||
|
||||
import { FakeGlobalStateProvider } from "../../../spec";
|
||||
|
||||
import { StateDefinition } from "./state-definition";
|
||||
import { STATE_LOCK_EVENT, StateEventRegistrarService } from "./state-event-registrar.service";
|
||||
import { UserKeyDefinition } from "./user-key-definition";
|
||||
|
||||
describe("StateEventRegistrarService", () => {
|
||||
const globalStateProvider = new FakeGlobalStateProvider();
|
||||
const lockState = globalStateProvider.getFake(STATE_LOCK_EVENT);
|
||||
const storageServiceProvider = mock<StorageServiceProvider>();
|
||||
|
||||
const sut = new StateEventRegistrarService(globalStateProvider, storageServiceProvider);
|
||||
|
||||
describe("registerEvents", () => {
|
||||
const fakeKeyDefinition = new UserKeyDefinition<boolean>(
|
||||
new StateDefinition("fakeState", "disk"),
|
||||
"fakeKey",
|
||||
{
|
||||
deserializer: (s) => s,
|
||||
clearOn: ["lock"],
|
||||
},
|
||||
);
|
||||
|
||||
beforeEach(() => {
|
||||
jest.resetAllMocks();
|
||||
});
|
||||
|
||||
it("adds event on null storage", async () => {
|
||||
storageServiceProvider.get.mockReturnValue([
|
||||
"disk",
|
||||
mock<AbstractStorageService & ObservableStorageService>(),
|
||||
]);
|
||||
|
||||
await sut.registerEvents(fakeKeyDefinition);
|
||||
|
||||
expect(lockState.nextMock).toHaveBeenCalledWith([
|
||||
{
|
||||
key: "fakeKey",
|
||||
location: "disk",
|
||||
state: "fakeState",
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it("adds event on empty array in storage", async () => {
|
||||
lockState.stateSubject.next([]);
|
||||
storageServiceProvider.get.mockReturnValue([
|
||||
"disk",
|
||||
mock<AbstractStorageService & ObservableStorageService>(),
|
||||
]);
|
||||
|
||||
await sut.registerEvents(fakeKeyDefinition);
|
||||
|
||||
expect(lockState.nextMock).toHaveBeenCalledWith([
|
||||
{
|
||||
key: "fakeKey",
|
||||
location: "disk",
|
||||
state: "fakeState",
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it("doesn't add a duplicate", async () => {
|
||||
lockState.stateSubject.next([
|
||||
{
|
||||
key: "fakeKey",
|
||||
location: "disk",
|
||||
state: "fakeState",
|
||||
},
|
||||
]);
|
||||
storageServiceProvider.get.mockReturnValue([
|
||||
"disk",
|
||||
mock<AbstractStorageService & ObservableStorageService>(),
|
||||
]);
|
||||
|
||||
await sut.registerEvents(fakeKeyDefinition);
|
||||
|
||||
expect(lockState.nextMock).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,76 +1,6 @@
|
||||
import { PossibleLocation, StorageServiceProvider } from "../services/storage-service.provider";
|
||||
|
||||
import { GlobalState } from "./global-state";
|
||||
import { GlobalStateProvider } from "./global-state.provider";
|
||||
import { KeyDefinition } from "./key-definition";
|
||||
import { CLEAR_EVENT_DISK } from "./state-definitions";
|
||||
import { ClearEvent, UserKeyDefinition } from "./user-key-definition";
|
||||
|
||||
export type StateEventInfo = {
|
||||
state: string;
|
||||
key: string;
|
||||
location: PossibleLocation;
|
||||
};
|
||||
|
||||
export const STATE_LOCK_EVENT = KeyDefinition.array<StateEventInfo>(CLEAR_EVENT_DISK, "lock", {
|
||||
deserializer: (e) => e,
|
||||
});
|
||||
|
||||
export const STATE_LOGOUT_EVENT = KeyDefinition.array<StateEventInfo>(CLEAR_EVENT_DISK, "logout", {
|
||||
deserializer: (e) => e,
|
||||
});
|
||||
|
||||
export class StateEventRegistrarService {
|
||||
private readonly stateEventStateMap: { [Prop in ClearEvent]: GlobalState<StateEventInfo[]> };
|
||||
|
||||
constructor(
|
||||
globalStateProvider: GlobalStateProvider,
|
||||
private storageServiceProvider: StorageServiceProvider,
|
||||
) {
|
||||
this.stateEventStateMap = {
|
||||
lock: globalStateProvider.get(STATE_LOCK_EVENT),
|
||||
logout: globalStateProvider.get(STATE_LOGOUT_EVENT),
|
||||
};
|
||||
}
|
||||
|
||||
async registerEvents(keyDefinition: UserKeyDefinition<unknown>) {
|
||||
for (const clearEvent of keyDefinition.clearOn) {
|
||||
const eventState = this.stateEventStateMap[clearEvent];
|
||||
// Determine the storage location for this
|
||||
const [storageLocation] = this.storageServiceProvider.get(
|
||||
keyDefinition.stateDefinition.defaultStorageLocation,
|
||||
keyDefinition.stateDefinition.storageLocationOverrides,
|
||||
);
|
||||
|
||||
const newEvent: StateEventInfo = {
|
||||
state: keyDefinition.stateDefinition.name,
|
||||
key: keyDefinition.key,
|
||||
location: storageLocation,
|
||||
};
|
||||
|
||||
// Only update the event state if the existing list doesn't have a matching entry
|
||||
await eventState.update(
|
||||
(existingTickets) => {
|
||||
existingTickets ??= [];
|
||||
existingTickets.push(newEvent);
|
||||
return existingTickets;
|
||||
},
|
||||
{
|
||||
shouldUpdate: (currentTickets) => {
|
||||
return (
|
||||
// If the current tickets are null, then it will for sure be added
|
||||
currentTickets == null ||
|
||||
// If an existing match couldn't be found, we also need to add one
|
||||
currentTickets.findIndex(
|
||||
(e) =>
|
||||
e.state === newEvent.state &&
|
||||
e.key === newEvent.key &&
|
||||
e.location === newEvent.location,
|
||||
) === -1
|
||||
);
|
||||
},
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
export {
|
||||
StateEventRegistrarService,
|
||||
StateEventInfo,
|
||||
STATE_LOCK_EVENT,
|
||||
STATE_LOGOUT_EVENT,
|
||||
} from "@bitwarden/state";
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
import { mock } from "jest-mock-extended";
|
||||
|
||||
import {
|
||||
AbstractStorageService,
|
||||
ObservableStorageService,
|
||||
StorageServiceProvider,
|
||||
} from "@bitwarden/storage-core";
|
||||
|
||||
import { FakeGlobalStateProvider } from "../../../spec";
|
||||
import { UserId } from "../../types/guid";
|
||||
|
||||
import { STATE_LOCK_EVENT } from "./state-event-registrar.service";
|
||||
import { StateEventRunnerService } from "./state-event-runner.service";
|
||||
|
||||
describe("EventRunnerService", () => {
|
||||
const fakeGlobalStateProvider = new FakeGlobalStateProvider();
|
||||
const lockState = fakeGlobalStateProvider.getFake(STATE_LOCK_EVENT);
|
||||
|
||||
const storageServiceProvider = mock<StorageServiceProvider>();
|
||||
|
||||
const sut = new StateEventRunnerService(fakeGlobalStateProvider, storageServiceProvider);
|
||||
|
||||
describe("handleEvent", () => {
|
||||
it("does nothing if there are no events in state", async () => {
|
||||
const mockStorageService = mock<AbstractStorageService & ObservableStorageService>();
|
||||
storageServiceProvider.get.mockReturnValue(["disk", mockStorageService]);
|
||||
|
||||
await sut.handleEvent("lock", "bff09d3c-762a-4551-9275-45b137b2f073" as UserId);
|
||||
|
||||
expect(lockState.nextMock).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("loops through and acts on all events", async () => {
|
||||
const mockDiskStorageService = mock<AbstractStorageService & ObservableStorageService>();
|
||||
const mockMemoryStorageService = mock<AbstractStorageService & ObservableStorageService>();
|
||||
|
||||
lockState.stateSubject.next([
|
||||
{
|
||||
state: "fakeState1",
|
||||
key: "fakeKey1",
|
||||
location: "disk",
|
||||
},
|
||||
{
|
||||
state: "fakeState2",
|
||||
key: "fakeKey2",
|
||||
location: "memory",
|
||||
},
|
||||
]);
|
||||
|
||||
storageServiceProvider.get.mockImplementation((defaultLocation, overrides) => {
|
||||
if (defaultLocation === "disk") {
|
||||
return [defaultLocation, mockDiskStorageService];
|
||||
} else if (defaultLocation === "memory") {
|
||||
return [defaultLocation, mockMemoryStorageService];
|
||||
}
|
||||
});
|
||||
|
||||
mockMemoryStorageService.get.mockResolvedValue("something");
|
||||
|
||||
await sut.handleEvent("lock", "bff09d3c-762a-4551-9275-45b137b2f073" as UserId);
|
||||
|
||||
expect(mockDiskStorageService.get).toHaveBeenCalledTimes(1);
|
||||
expect(mockDiskStorageService.get).toHaveBeenCalledWith(
|
||||
"user_bff09d3c-762a-4551-9275-45b137b2f073_fakeState1_fakeKey1",
|
||||
);
|
||||
expect(mockMemoryStorageService.get).toHaveBeenCalledTimes(1);
|
||||
expect(mockMemoryStorageService.get).toHaveBeenCalledWith(
|
||||
"user_bff09d3c-762a-4551-9275-45b137b2f073_fakeState2_fakeKey2",
|
||||
);
|
||||
expect(mockMemoryStorageService.remove).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,83 +1 @@
|
||||
// FIXME: Update this file to be type safe and remove this and next line
|
||||
// @ts-strict-ignore
|
||||
import { firstValueFrom } from "rxjs";
|
||||
|
||||
import { StorageServiceProvider } from "@bitwarden/storage-core";
|
||||
|
||||
import { UserId } from "../../types/guid";
|
||||
|
||||
import { GlobalState } from "./global-state";
|
||||
import { GlobalStateProvider } from "./global-state.provider";
|
||||
import { StateDefinition, StorageLocation } from "./state-definition";
|
||||
import {
|
||||
STATE_LOCK_EVENT,
|
||||
STATE_LOGOUT_EVENT,
|
||||
StateEventInfo,
|
||||
} from "./state-event-registrar.service";
|
||||
import { ClearEvent, UserKeyDefinition } from "./user-key-definition";
|
||||
|
||||
export class StateEventRunnerService {
|
||||
private readonly stateEventMap: { [Prop in ClearEvent]: GlobalState<StateEventInfo[]> };
|
||||
|
||||
constructor(
|
||||
globalStateProvider: GlobalStateProvider,
|
||||
private storageServiceProvider: StorageServiceProvider,
|
||||
) {
|
||||
this.stateEventMap = {
|
||||
lock: globalStateProvider.get(STATE_LOCK_EVENT),
|
||||
logout: globalStateProvider.get(STATE_LOGOUT_EVENT),
|
||||
};
|
||||
}
|
||||
|
||||
async handleEvent(event: ClearEvent, userId: UserId) {
|
||||
let tickets = await firstValueFrom(this.stateEventMap[event].state$);
|
||||
tickets ??= [];
|
||||
|
||||
const failures: string[] = [];
|
||||
|
||||
for (const ticket of tickets) {
|
||||
try {
|
||||
const [, service] = this.storageServiceProvider.get(
|
||||
ticket.location,
|
||||
{}, // The storage location is already the computed storage location for this client
|
||||
);
|
||||
|
||||
const ticketStorageKey = this.storageKeyFor(userId, ticket);
|
||||
|
||||
// Evaluate current value so we can avoid writing to state if we don't need to
|
||||
const currentValue = await service.get(ticketStorageKey);
|
||||
if (currentValue != null) {
|
||||
await service.remove(ticketStorageKey);
|
||||
}
|
||||
} catch (err: unknown) {
|
||||
let errorMessage = "Unknown Error";
|
||||
if (typeof err === "object" && "message" in err && typeof err.message === "string") {
|
||||
errorMessage = err.message;
|
||||
}
|
||||
|
||||
failures.push(
|
||||
`${errorMessage} in ${ticket.state} > ${ticket.key} located ${ticket.location}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (failures.length > 0) {
|
||||
// Throw aggregated error
|
||||
throw new Error(
|
||||
`One or more errors occurred while handling event '${event}' for user ${userId}.\n${failures.join("\n")}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private storageKeyFor(userId: UserId, ticket: StateEventInfo) {
|
||||
const userKey = new UserKeyDefinition<unknown>(
|
||||
new StateDefinition(ticket.state, ticket.location as unknown as StorageLocation),
|
||||
ticket.key,
|
||||
{
|
||||
deserializer: (v) => v,
|
||||
clearOn: [],
|
||||
},
|
||||
);
|
||||
return userKey.buildKey(userId);
|
||||
}
|
||||
}
|
||||
export { StateEventRunnerService } from "@bitwarden/state";
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
// FIXME: Update this file to be type safe and remove this and next line
|
||||
// @ts-strict-ignore
|
||||
import { Observable } from "rxjs";
|
||||
|
||||
export const DEFAULT_OPTIONS = {
|
||||
shouldUpdate: () => true,
|
||||
combineLatestWith: null as Observable<unknown>,
|
||||
msTimeout: 1000,
|
||||
};
|
||||
|
||||
type DefinitelyTypedDefault<T, TCombine> = Omit<
|
||||
typeof DEFAULT_OPTIONS,
|
||||
"shouldUpdate" | "combineLatestWith"
|
||||
> & {
|
||||
shouldUpdate: (state: T, dependency: TCombine) => boolean;
|
||||
combineLatestWith?: Observable<TCombine>;
|
||||
};
|
||||
|
||||
export type StateUpdateOptions<T, TCombine> = Partial<DefinitelyTypedDefault<T, TCombine>>;
|
||||
|
||||
export function populateOptionsWithDefault<T, TCombine>(
|
||||
options: StateUpdateOptions<T, TCombine>,
|
||||
): StateUpdateOptions<T, TCombine> {
|
||||
return {
|
||||
...(DEFAULT_OPTIONS as StateUpdateOptions<T, TCombine>),
|
||||
...options,
|
||||
};
|
||||
}
|
||||
@@ -1,80 +1 @@
|
||||
import { Observable } from "rxjs";
|
||||
|
||||
import { UserId } from "../../types/guid";
|
||||
import { DerivedStateDependencies } from "../../types/state";
|
||||
|
||||
import { DeriveDefinition } from "./derive-definition";
|
||||
import { DerivedState } from "./derived-state";
|
||||
import { GlobalState } from "./global-state";
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars -- used in docs
|
||||
import { GlobalStateProvider } from "./global-state.provider";
|
||||
import { KeyDefinition } from "./key-definition";
|
||||
import { UserKeyDefinition } from "./user-key-definition";
|
||||
import { ActiveUserState, SingleUserState } from "./user-state";
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars -- used in docs
|
||||
import { ActiveUserStateProvider, SingleUserStateProvider } from "./user-state.provider";
|
||||
|
||||
/** Convenience wrapper class for {@link ActiveUserStateProvider}, {@link SingleUserStateProvider},
|
||||
* and {@link GlobalStateProvider}.
|
||||
*/
|
||||
export abstract class StateProvider {
|
||||
/** @see{@link ActiveUserStateProvider.activeUserId$} */
|
||||
abstract activeUserId$: Observable<UserId | undefined>;
|
||||
|
||||
/**
|
||||
* Gets a state observable for a given key and userId.
|
||||
*
|
||||
* @remarks If userId is falsy the observable returned will attempt to point to the currently active user _and not update if the active user changes_.
|
||||
* This is different to how `getActive` works and more similar to `getUser` for whatever user happens to be active at the time of the call.
|
||||
* If no user happens to be active at the time this method is called with a falsy userId then this observable will not emit a value until
|
||||
* a user becomes active. If you are not confident a user is active at the time this method is called, you may want to pipe a call to `timeout`
|
||||
* or instead call {@link getUserStateOrDefault$} and supply a value you would rather have given in the case of no passed in userId and no active user.
|
||||
*
|
||||
* @param keyDefinition - The key definition for the state you want to get.
|
||||
* @param userId - The userId for which you want the state for. If not provided, the state for the currently active user will be returned.
|
||||
*/
|
||||
abstract getUserState$<T>(keyDefinition: UserKeyDefinition<T>, userId?: UserId): Observable<T>;
|
||||
|
||||
/**
|
||||
* Gets a state observable for a given key and userId
|
||||
*
|
||||
* @remarks If userId is falsy the observable return will first attempt to point to the currently active user but will not follow subsequent active user changes,
|
||||
* if there is no immediately available active user, then it will fallback to returning a default value in an observable that immediately completes.
|
||||
*
|
||||
* @param keyDefinition - The key definition for the state you want to get.
|
||||
* @param config.userId - The userId for which you want the state for. If not provided, the state for the currently active user will be returned.
|
||||
* @param config.defaultValue - The default value that should be wrapped in an observable if no active user is immediately available and no truthy userId is passed in.
|
||||
*/
|
||||
abstract getUserStateOrDefault$<T>(
|
||||
keyDefinition: UserKeyDefinition<T>,
|
||||
config: { userId: UserId | undefined; defaultValue?: T },
|
||||
): Observable<T>;
|
||||
|
||||
/**
|
||||
* Sets the state for a given key and userId.
|
||||
*
|
||||
* @overload
|
||||
* @param keyDefinition - The key definition for the state you want to set.
|
||||
* @param value - The value to set the state to.
|
||||
* @param userId - The userId for which you want to set the state for. If not provided, the state for the currently active user will be set.
|
||||
*/
|
||||
abstract setUserState<T>(
|
||||
keyDefinition: UserKeyDefinition<T>,
|
||||
value: T | null,
|
||||
userId?: UserId,
|
||||
): Promise<[UserId, T | null]>;
|
||||
|
||||
/** @see{@link ActiveUserStateProvider.get} */
|
||||
abstract getActive<T>(userKeyDefinition: UserKeyDefinition<T>): ActiveUserState<T>;
|
||||
|
||||
/** @see{@link SingleUserStateProvider.get} */
|
||||
abstract getUser<T>(userId: UserId, userKeyDefinition: UserKeyDefinition<T>): SingleUserState<T>;
|
||||
|
||||
/** @see{@link GlobalStateProvider.get} */
|
||||
abstract getGlobal<T>(keyDefinition: KeyDefinition<T>): GlobalState<T>;
|
||||
abstract getDerived<TFrom, TTo, TDeps extends DerivedStateDependencies>(
|
||||
parentState$: Observable<TFrom>,
|
||||
deriveDefinition: DeriveDefinition<TFrom, TTo, TDeps>,
|
||||
dependencies: TDeps,
|
||||
): DerivedState<TTo>;
|
||||
}
|
||||
export { StateProvider } from "@bitwarden/state";
|
||||
|
||||
@@ -1,142 +1 @@
|
||||
// FIXME: Update this file to be type safe and remove this and next line
|
||||
// @ts-strict-ignore
|
||||
import { UserId } from "../../types/guid";
|
||||
import { StorageKey } from "../../types/state";
|
||||
import { Utils } from "../misc/utils";
|
||||
|
||||
import { array, record } from "./deserialization-helpers";
|
||||
import { DebugOptions, KeyDefinitionOptions } from "./key-definition";
|
||||
import { StateDefinition } from "./state-definition";
|
||||
|
||||
export type ClearEvent = "lock" | "logout";
|
||||
|
||||
export type UserKeyDefinitionOptions<T> = KeyDefinitionOptions<T> & {
|
||||
clearOn: ClearEvent[];
|
||||
};
|
||||
|
||||
const USER_KEY_DEFINITION_MARKER: unique symbol = Symbol("UserKeyDefinition");
|
||||
|
||||
export class UserKeyDefinition<T> {
|
||||
readonly [USER_KEY_DEFINITION_MARKER] = true;
|
||||
/**
|
||||
* A unique array of events that the state stored at this key should be cleared on.
|
||||
*/
|
||||
readonly clearOn: ClearEvent[];
|
||||
|
||||
/**
|
||||
* Normalized options used for debugging purposes.
|
||||
*/
|
||||
readonly debug: Required<DebugOptions>;
|
||||
|
||||
constructor(
|
||||
readonly stateDefinition: StateDefinition,
|
||||
readonly key: string,
|
||||
private readonly options: UserKeyDefinitionOptions<T>,
|
||||
) {
|
||||
if (options.deserializer == null) {
|
||||
throw new Error(`'deserializer' is a required property on key ${this.errorKeyName}`);
|
||||
}
|
||||
|
||||
if (options.cleanupDelayMs < 0) {
|
||||
throw new Error(
|
||||
`'cleanupDelayMs' must be greater than or equal to 0. Value of ${options.cleanupDelayMs} passed to key ${this.errorKeyName} `,
|
||||
);
|
||||
}
|
||||
|
||||
// Filter out repeat values
|
||||
this.clearOn = Array.from(new Set(options.clearOn));
|
||||
|
||||
// Normalize optional debug options
|
||||
const { enableUpdateLogging = false, enableRetrievalLogging = false } = options.debug ?? {};
|
||||
this.debug = {
|
||||
enableUpdateLogging,
|
||||
enableRetrievalLogging,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the deserializer configured for this {@link KeyDefinition}
|
||||
*/
|
||||
get deserializer() {
|
||||
return this.options.deserializer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the number of milliseconds to wait before cleaning up the state after the last subscriber has unsubscribed.
|
||||
*/
|
||||
get cleanupDelayMs() {
|
||||
return this.options.cleanupDelayMs < 0 ? 0 : (this.options.cleanupDelayMs ?? 1000);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link UserKeyDefinition} for state that is an array.
|
||||
* @param stateDefinition The state definition to be added to the UserKeyDefinition
|
||||
* @param key The key to be added to the KeyDefinition
|
||||
* @param options The options to customize the final {@link UserKeyDefinition}.
|
||||
* @returns A {@link UserKeyDefinition} initialized for arrays, the options run
|
||||
* the deserializer on the provided options for each element of an array
|
||||
* **unless that array is null, in which case it will return an empty list.**
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const MY_KEY = UserKeyDefinition.array<MyArrayElement>(MY_STATE, "key", {
|
||||
* deserializer: (myJsonElement) => convertToElement(myJsonElement),
|
||||
* });
|
||||
* ```
|
||||
*/
|
||||
static array<T>(
|
||||
stateDefinition: StateDefinition,
|
||||
key: string,
|
||||
// We have them provide options for the element of the array, depending on future options we add, this could get a little weird.
|
||||
options: UserKeyDefinitionOptions<T>,
|
||||
) {
|
||||
return new UserKeyDefinition<T[]>(stateDefinition, key, {
|
||||
...options,
|
||||
deserializer: array((e) => options.deserializer(e)),
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link UserKeyDefinition} for state that is a record.
|
||||
* @param stateDefinition The state definition to be added to the UserKeyDefinition
|
||||
* @param key The key to be added to the KeyDefinition
|
||||
* @param options The options to customize the final {@link UserKeyDefinition}.
|
||||
* @returns A {@link UserKeyDefinition} that contains a serializer that will run the provided deserializer for each
|
||||
* value in a record and returns every key as a string **unless that record is null, in which case it will return an record.**
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const MY_KEY = UserKeyDefinition.record<MyRecordValue>(MY_STATE, "key", {
|
||||
* deserializer: (myJsonValue) => convertToValue(myJsonValue),
|
||||
* });
|
||||
* ```
|
||||
*/
|
||||
static record<T, TKey extends string | number = string>(
|
||||
stateDefinition: StateDefinition,
|
||||
key: string,
|
||||
// We have them provide options for the value of the record, depending on future options we add, this could get a little weird.
|
||||
options: UserKeyDefinitionOptions<T>, // The array helper forces an initialValue of an empty record
|
||||
) {
|
||||
return new UserKeyDefinition<Record<TKey, T>>(stateDefinition, key, {
|
||||
...options,
|
||||
deserializer: record((v) => options.deserializer(v)),
|
||||
});
|
||||
}
|
||||
|
||||
get fullName() {
|
||||
return `${this.stateDefinition.name}_${this.key}`;
|
||||
}
|
||||
|
||||
buildKey(userId: UserId) {
|
||||
if (!Utils.isGuid(userId)) {
|
||||
throw new Error(
|
||||
`You cannot build a user key without a valid UserId, building for key ${this.fullName}`,
|
||||
);
|
||||
}
|
||||
return `user_${userId}_${this.stateDefinition.name}_${this.key}` as StorageKey;
|
||||
}
|
||||
|
||||
private get errorKeyName() {
|
||||
return `${this.stateDefinition.name} > ${this.key}`;
|
||||
}
|
||||
}
|
||||
export { UserKeyDefinition, UserKeyDefinitionOptions } from "@bitwarden/state";
|
||||
|
||||
@@ -1,35 +1 @@
|
||||
import { Observable } from "rxjs";
|
||||
|
||||
import { UserId } from "../../types/guid";
|
||||
|
||||
import { UserKeyDefinition } from "./user-key-definition";
|
||||
import { ActiveUserState, SingleUserState } from "./user-state";
|
||||
|
||||
/** A provider for getting an implementation of state scoped to a given key and userId */
|
||||
export abstract class SingleUserStateProvider {
|
||||
/**
|
||||
* Gets a {@link SingleUserState} scoped to the given {@link UserKeyDefinition} and {@link UserId}
|
||||
*
|
||||
* @param userId - The {@link UserId} for which you want the user state for.
|
||||
* @param userKeyDefinition - The {@link UserKeyDefinition} for which you want the user state for.
|
||||
*/
|
||||
abstract get<T>(userId: UserId, userKeyDefinition: UserKeyDefinition<T>): SingleUserState<T>;
|
||||
}
|
||||
|
||||
/** A provider for getting an implementation of state scoped to a given key, but always pointing
|
||||
* to the currently active user
|
||||
*/
|
||||
export abstract class ActiveUserStateProvider {
|
||||
/**
|
||||
* Convenience re-emission of active user ID from {@link AccountService.activeAccount$}
|
||||
*/
|
||||
abstract activeUserId$: Observable<UserId | undefined>;
|
||||
|
||||
/**
|
||||
* Gets a {@link ActiveUserState} scoped to the given {@link KeyDefinition}, but updates when active user changes such
|
||||
* that the emitted values always represents the state for the currently active user.
|
||||
*
|
||||
* @param keyDefinition - The {@link UserKeyDefinition} for which you want the user state for.
|
||||
*/
|
||||
abstract get<T>(userKeyDefinition: UserKeyDefinition<T>): ActiveUserState<T>;
|
||||
}
|
||||
export { ActiveUserStateProvider, SingleUserStateProvider } from "@bitwarden/state";
|
||||
|
||||
@@ -1,64 +1 @@
|
||||
import { Observable } from "rxjs";
|
||||
|
||||
import { UserId } from "../../types/guid";
|
||||
|
||||
import { StateUpdateOptions } from "./state-update-options";
|
||||
|
||||
export type CombinedState<T> = readonly [userId: UserId, state: T];
|
||||
|
||||
/** A helper object for interacting with state that is scoped to a specific user. */
|
||||
export interface UserState<T> {
|
||||
/** Emits a stream of data. Emits null if the user does not have specified state. */
|
||||
readonly state$: Observable<T | null>;
|
||||
|
||||
/** Emits a stream of tuples, with the first element being a user id and the second element being the data for that user. */
|
||||
readonly combinedState$: Observable<CombinedState<T | null>>;
|
||||
}
|
||||
|
||||
export const activeMarker: unique symbol = Symbol("active");
|
||||
|
||||
export interface ActiveUserState<T> extends UserState<T> {
|
||||
readonly [activeMarker]: true;
|
||||
|
||||
/**
|
||||
* Emits a stream of data. Emits null if the user does not have specified state.
|
||||
* Note: Will not emit if there is no active user.
|
||||
*/
|
||||
readonly state$: Observable<T | null>;
|
||||
|
||||
/**
|
||||
* Updates backing stores for the active user.
|
||||
* @param configureState function that takes the current state and returns the new state
|
||||
* @param options Defaults to @see {module:state-update-options#DEFAULT_OPTIONS}
|
||||
* @param options.shouldUpdate A callback for determining if you want to update state. Defaults to () => true
|
||||
* @param options.combineLatestWith An observable that you want to combine with the current state for callbacks. Defaults to null
|
||||
* @param options.msTimeout A timeout for how long you are willing to wait for a `combineLatestWith` option to complete. Defaults to 1000ms. Only applies if `combineLatestWith` is set.
|
||||
*
|
||||
* @returns A promise that must be awaited before your next action to ensure the update has been written to state.
|
||||
* Resolves to the new state. If `shouldUpdate` returns false, the promise will resolve to the current state.
|
||||
*/
|
||||
readonly update: <TCombine>(
|
||||
configureState: (state: T | null, dependencies: TCombine) => T | null,
|
||||
options?: StateUpdateOptions<T, TCombine>,
|
||||
) => Promise<[UserId, T | null]>;
|
||||
}
|
||||
|
||||
export interface SingleUserState<T> extends UserState<T> {
|
||||
readonly userId: UserId;
|
||||
|
||||
/**
|
||||
* Updates backing stores for the active user.
|
||||
* @param configureState function that takes the current state and returns the new state
|
||||
* @param options Defaults to @see {module:state-update-options#DEFAULT_OPTIONS}
|
||||
* @param options.shouldUpdate A callback for determining if you want to update state. Defaults to () => true
|
||||
* @param options.combineLatestWith An observable that you want to combine with the current state for callbacks. Defaults to null
|
||||
* @param options.msTimeout A timeout for how long you are willing to wait for a `combineLatestWith` option to complete. Defaults to 1000ms. Only applies if `combineLatestWith` is set.
|
||||
*
|
||||
* @returns A promise that must be awaited before your next action to ensure the update has been written to state.
|
||||
* Resolves to the new state. If `shouldUpdate` returns false, the promise will resolve to the current state.
|
||||
*/
|
||||
readonly update: <TCombine>(
|
||||
configureState: (state: T | null, dependencies: TCombine) => T | null,
|
||||
options?: StateUpdateOptions<T, TCombine>,
|
||||
) => Promise<T | null>;
|
||||
}
|
||||
export { ActiveUserState, SingleUserState, CombinedState } from "@bitwarden/state";
|
||||
|
||||
@@ -172,7 +172,11 @@ export abstract class CoreSyncService implements SyncService {
|
||||
notification.collectionIds != null &&
|
||||
notification.collectionIds.length > 0
|
||||
) {
|
||||
const collections = await this.collectionService.getAll();
|
||||
const collections = await firstValueFrom(
|
||||
this.collectionService
|
||||
.encryptedCollections$(userId)
|
||||
.pipe(map((collections) => collections ?? [])),
|
||||
);
|
||||
if (collections != null) {
|
||||
for (let i = 0; i < collections.length; i++) {
|
||||
if (notification.collectionIds.indexOf(collections[i].id) > -1) {
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
export { createMigrationBuilder, waitForMigrations, CURRENT_VERSION } from "./migrate";
|
||||
// Compatibility re-export for @bitwarden/common/state-migrations
|
||||
export * from "@bitwarden/state";
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
import { mock, MockProxy } from "jest-mock-extended";
|
||||
|
||||
// eslint-disable-next-line import/no-restricted-paths -- Needed to print log messages
|
||||
import { LogService } from "../platform/abstractions/log.service";
|
||||
// eslint-disable-next-line import/no-restricted-paths -- Needed to interface with storage locations
|
||||
import { AbstractStorageService } from "../platform/abstractions/storage.service";
|
||||
|
||||
import { currentVersion } from "./migrate";
|
||||
|
||||
describe("currentVersion", () => {
|
||||
let storage: MockProxy<AbstractStorageService>;
|
||||
let logService: MockProxy<LogService>;
|
||||
|
||||
beforeEach(() => {
|
||||
storage = mock();
|
||||
logService = mock();
|
||||
});
|
||||
|
||||
it("should return -1 if no version", async () => {
|
||||
storage.get.mockReturnValueOnce(null);
|
||||
expect(await currentVersion(storage, logService)).toEqual(-1);
|
||||
});
|
||||
|
||||
it("should return version", async () => {
|
||||
storage.get.calledWith("stateVersion").mockReturnValueOnce(1 as any);
|
||||
expect(await currentVersion(storage, logService)).toEqual(1);
|
||||
});
|
||||
|
||||
it("should return version from global", async () => {
|
||||
storage.get.calledWith("stateVersion").mockReturnValueOnce(null);
|
||||
storage.get.calledWith("global").mockReturnValueOnce({ stateVersion: 1 } as any);
|
||||
expect(await currentVersion(storage, logService)).toEqual(1);
|
||||
});
|
||||
|
||||
it("should prefer root version to global", async () => {
|
||||
storage.get.calledWith("stateVersion").mockReturnValue(1 as any);
|
||||
storage.get.calledWith("global").mockReturnValue({ stateVersion: 2 } as any);
|
||||
expect(await currentVersion(storage, logService)).toEqual(1);
|
||||
});
|
||||
});
|
||||
@@ -1,218 +0,0 @@
|
||||
// eslint-disable-next-line import/no-restricted-paths -- Needed to print log messages
|
||||
import { LogService } from "../platform/abstractions/log.service";
|
||||
// eslint-disable-next-line import/no-restricted-paths -- Needed to interface with storage locations
|
||||
import { AbstractStorageService } from "../platform/abstractions/storage.service";
|
||||
|
||||
import { MigrationBuilder } from "./migration-builder";
|
||||
import { EverHadUserKeyMigrator } from "./migrations/10-move-ever-had-user-key-to-state-providers";
|
||||
import { OrganizationKeyMigrator } from "./migrations/11-move-org-keys-to-state-providers";
|
||||
import { MoveEnvironmentStateToProviders } from "./migrations/12-move-environment-state-to-providers";
|
||||
import { ProviderKeyMigrator } from "./migrations/13-move-provider-keys-to-state-providers";
|
||||
import { MoveBiometricClientKeyHalfToStateProviders } from "./migrations/14-move-biometric-client-key-half-state-to-providers";
|
||||
import { FolderMigrator } from "./migrations/15-move-folder-state-to-state-provider";
|
||||
import { LastSyncMigrator } from "./migrations/16-move-last-sync-to-state-provider";
|
||||
import { EnablePasskeysMigrator } from "./migrations/17-move-enable-passkeys-to-state-providers";
|
||||
import { AutofillSettingsKeyMigrator } from "./migrations/18-move-autofill-settings-to-state-providers";
|
||||
import { RequirePasswordOnStartMigrator } from "./migrations/19-migrate-require-password-on-start";
|
||||
import { PrivateKeyMigrator } from "./migrations/20-move-private-key-to-state-providers";
|
||||
import { CollectionMigrator } from "./migrations/21-move-collections-state-to-state-provider";
|
||||
import { CollapsedGroupingsMigrator } from "./migrations/22-move-collapsed-groupings-to-state-provider";
|
||||
import { MoveBiometricPromptsToStateProviders } from "./migrations/23-move-biometric-prompts-to-state-providers";
|
||||
import { SmOnboardingTasksMigrator } from "./migrations/24-move-sm-onboarding-key-to-state-providers";
|
||||
import { ClearClipboardDelayMigrator } from "./migrations/25-move-clear-clipboard-to-autofill-settings-state-provider";
|
||||
import { RevertLastSyncMigrator } from "./migrations/26-revert-move-last-sync-to-state-provider";
|
||||
import { BadgeSettingsMigrator } from "./migrations/27-move-badge-settings-to-state-providers";
|
||||
import { MoveBiometricUnlockToStateProviders } from "./migrations/28-move-biometric-unlock-to-state-providers";
|
||||
import { UserNotificationSettingsKeyMigrator } from "./migrations/29-move-user-notification-settings-to-state-provider";
|
||||
import { PolicyMigrator } from "./migrations/30-move-policy-state-to-state-provider";
|
||||
import { EnableContextMenuMigrator } from "./migrations/31-move-enable-context-menu-to-autofill-settings-state-provider";
|
||||
import { PreferredLanguageMigrator } from "./migrations/32-move-preferred-language";
|
||||
import { AppIdMigrator } from "./migrations/33-move-app-id-to-state-providers";
|
||||
import { DomainSettingsMigrator } from "./migrations/34-move-domain-settings-to-state-providers";
|
||||
import { MoveThemeToStateProviderMigrator } from "./migrations/35-move-theme-to-state-providers";
|
||||
import { VaultSettingsKeyMigrator } from "./migrations/36-move-show-card-and-identity-to-state-provider";
|
||||
import { AvatarColorMigrator } from "./migrations/37-move-avatar-color-to-state-providers";
|
||||
import { TokenServiceStateProviderMigrator } from "./migrations/38-migrate-token-svc-to-state-provider";
|
||||
import { MoveBillingAccountProfileMigrator } from "./migrations/39-move-billing-account-profile-to-state-providers";
|
||||
import { RemoveEverBeenUnlockedMigrator } from "./migrations/4-remove-ever-been-unlocked";
|
||||
import { OrganizationMigrator } from "./migrations/40-move-organization-state-to-state-provider";
|
||||
import { EventCollectionMigrator } from "./migrations/41-move-event-collection-to-state-provider";
|
||||
import { EnableFaviconMigrator } from "./migrations/42-move-enable-favicon-to-domain-settings-state-provider";
|
||||
import { AutoConfirmFingerPrintsMigrator } from "./migrations/43-move-auto-confirm-finger-prints-to-state-provider";
|
||||
import { UserDecryptionOptionsMigrator } from "./migrations/44-move-user-decryption-options-to-state-provider";
|
||||
import { MergeEnvironmentState } from "./migrations/45-merge-environment-state";
|
||||
import { DeleteBiometricPromptCancelledData } from "./migrations/46-delete-orphaned-biometric-prompt-data";
|
||||
import { MoveDesktopSettingsMigrator } from "./migrations/47-move-desktop-settings";
|
||||
import { MoveDdgToStateProviderMigrator } from "./migrations/48-move-ddg-to-state-provider";
|
||||
import { AccountServerConfigMigrator } from "./migrations/49-move-account-server-configs";
|
||||
import { AddKeyTypeToOrgKeysMigrator } from "./migrations/5-add-key-type-to-org-keys";
|
||||
import { KeyConnectorMigrator } from "./migrations/50-move-key-connector-to-state-provider";
|
||||
import { RememberedEmailMigrator } from "./migrations/51-move-remembered-email-to-state-providers";
|
||||
import { DeleteInstalledVersion } from "./migrations/52-delete-installed-version";
|
||||
import { DeviceTrustServiceStateProviderMigrator } from "./migrations/53-migrate-device-trust-svc-to-state-providers";
|
||||
import { SendMigrator } from "./migrations/54-move-encrypted-sends";
|
||||
import { MoveMasterKeyStateToProviderMigrator } from "./migrations/55-move-master-key-state-to-provider";
|
||||
import { AuthRequestMigrator } from "./migrations/56-move-auth-requests";
|
||||
import { CipherServiceMigrator } from "./migrations/57-move-cipher-service-to-state-provider";
|
||||
import { RemoveRefreshTokenMigratedFlagMigrator } from "./migrations/58-remove-refresh-token-migrated-state-provider-flag";
|
||||
import { KdfConfigMigrator } from "./migrations/59-move-kdf-config-to-state-provider";
|
||||
import { RemoveLegacyEtmKeyMigrator } from "./migrations/6-remove-legacy-etm-key";
|
||||
import { KnownAccountsMigrator } from "./migrations/60-known-accounts";
|
||||
import { PinStateMigrator } from "./migrations/61-move-pin-state-to-providers";
|
||||
import { VaultTimeoutSettingsServiceStateProviderMigrator } from "./migrations/62-migrate-vault-timeout-settings-svc-to-state-provider";
|
||||
import { PasswordOptionsMigrator } from "./migrations/63-migrate-password-settings";
|
||||
import { GeneratorHistoryMigrator } from "./migrations/64-migrate-generator-history";
|
||||
import { ForwarderOptionsMigrator } from "./migrations/65-migrate-forwarder-settings";
|
||||
import { MoveFinalDesktopSettingsMigrator } from "./migrations/66-move-final-desktop-settings";
|
||||
import { RemoveUnassignedItemsBannerDismissed } from "./migrations/67-remove-unassigned-items-banner-dismissed";
|
||||
import { MoveLastSyncDate } from "./migrations/68-move-last-sync-date";
|
||||
import { MigrateIncorrectFolderKey } from "./migrations/69-migrate-incorrect-folder-key";
|
||||
import { MoveBiometricAutoPromptToAccount } from "./migrations/7-move-biometric-auto-prompt-to-account";
|
||||
import { RemoveAcBannersDismissed } from "./migrations/70-remove-ac-banner-dismissed";
|
||||
import { RemoveNewCustomizationOptionsCalloutDismissed } from "./migrations/71-remove-new-customization-options-callout-dismissed";
|
||||
import { RemoveAccountDeprovisioningBannerDismissed } from "./migrations/72-remove-account-deprovisioning-banner-dismissed";
|
||||
import { MoveStateVersionMigrator } from "./migrations/8-move-state-version";
|
||||
import { MoveBrowserSettingsToGlobal } from "./migrations/9-move-browser-settings-to-global";
|
||||
import { MinVersionMigrator } from "./migrations/min-version";
|
||||
|
||||
export const MIN_VERSION = 3;
|
||||
export const CURRENT_VERSION = 72;
|
||||
export type MinVersion = typeof MIN_VERSION;
|
||||
|
||||
export function createMigrationBuilder() {
|
||||
return MigrationBuilder.create()
|
||||
.with(MinVersionMigrator)
|
||||
.with(RemoveEverBeenUnlockedMigrator, 3, 4)
|
||||
.with(AddKeyTypeToOrgKeysMigrator, 4, 5)
|
||||
.with(RemoveLegacyEtmKeyMigrator, 5, 6)
|
||||
.with(MoveBiometricAutoPromptToAccount, 6, 7)
|
||||
.with(MoveStateVersionMigrator, 7, 8)
|
||||
.with(MoveBrowserSettingsToGlobal, 8, 9)
|
||||
.with(EverHadUserKeyMigrator, 9, 10)
|
||||
.with(OrganizationKeyMigrator, 10, 11)
|
||||
.with(MoveEnvironmentStateToProviders, 11, 12)
|
||||
.with(ProviderKeyMigrator, 12, 13)
|
||||
.with(MoveBiometricClientKeyHalfToStateProviders, 13, 14)
|
||||
.with(FolderMigrator, 14, 15)
|
||||
.with(LastSyncMigrator, 15, 16)
|
||||
.with(EnablePasskeysMigrator, 16, 17)
|
||||
.with(AutofillSettingsKeyMigrator, 17, 18)
|
||||
.with(RequirePasswordOnStartMigrator, 18, 19)
|
||||
.with(PrivateKeyMigrator, 19, 20)
|
||||
.with(CollectionMigrator, 20, 21)
|
||||
.with(CollapsedGroupingsMigrator, 21, 22)
|
||||
.with(MoveBiometricPromptsToStateProviders, 22, 23)
|
||||
.with(SmOnboardingTasksMigrator, 23, 24)
|
||||
.with(ClearClipboardDelayMigrator, 24, 25)
|
||||
.with(RevertLastSyncMigrator, 25, 26)
|
||||
.with(BadgeSettingsMigrator, 26, 27)
|
||||
.with(MoveBiometricUnlockToStateProviders, 27, 28)
|
||||
.with(UserNotificationSettingsKeyMigrator, 28, 29)
|
||||
.with(PolicyMigrator, 29, 30)
|
||||
.with(EnableContextMenuMigrator, 30, 31)
|
||||
.with(PreferredLanguageMigrator, 31, 32)
|
||||
.with(AppIdMigrator, 32, 33)
|
||||
.with(DomainSettingsMigrator, 33, 34)
|
||||
.with(MoveThemeToStateProviderMigrator, 34, 35)
|
||||
.with(VaultSettingsKeyMigrator, 35, 36)
|
||||
.with(AvatarColorMigrator, 36, 37)
|
||||
.with(TokenServiceStateProviderMigrator, 37, 38)
|
||||
.with(MoveBillingAccountProfileMigrator, 38, 39)
|
||||
.with(OrganizationMigrator, 39, 40)
|
||||
.with(EventCollectionMigrator, 40, 41)
|
||||
.with(EnableFaviconMigrator, 41, 42)
|
||||
.with(AutoConfirmFingerPrintsMigrator, 42, 43)
|
||||
.with(UserDecryptionOptionsMigrator, 43, 44)
|
||||
.with(MergeEnvironmentState, 44, 45)
|
||||
.with(DeleteBiometricPromptCancelledData, 45, 46)
|
||||
.with(MoveDesktopSettingsMigrator, 46, 47)
|
||||
.with(MoveDdgToStateProviderMigrator, 47, 48)
|
||||
.with(AccountServerConfigMigrator, 48, 49)
|
||||
.with(KeyConnectorMigrator, 49, 50)
|
||||
.with(RememberedEmailMigrator, 50, 51)
|
||||
.with(DeleteInstalledVersion, 51, 52)
|
||||
.with(DeviceTrustServiceStateProviderMigrator, 52, 53)
|
||||
.with(SendMigrator, 53, 54)
|
||||
.with(MoveMasterKeyStateToProviderMigrator, 54, 55)
|
||||
.with(AuthRequestMigrator, 55, 56)
|
||||
.with(CipherServiceMigrator, 56, 57)
|
||||
.with(RemoveRefreshTokenMigratedFlagMigrator, 57, 58)
|
||||
.with(KdfConfigMigrator, 58, 59)
|
||||
.with(KnownAccountsMigrator, 59, 60)
|
||||
.with(PinStateMigrator, 60, 61)
|
||||
.with(VaultTimeoutSettingsServiceStateProviderMigrator, 61, 62)
|
||||
.with(PasswordOptionsMigrator, 62, 63)
|
||||
.with(GeneratorHistoryMigrator, 63, 64)
|
||||
.with(ForwarderOptionsMigrator, 64, 65)
|
||||
.with(MoveFinalDesktopSettingsMigrator, 65, 66)
|
||||
.with(RemoveUnassignedItemsBannerDismissed, 66, 67)
|
||||
.with(MoveLastSyncDate, 67, 68)
|
||||
.with(MigrateIncorrectFolderKey, 68, 69)
|
||||
.with(RemoveAcBannersDismissed, 69, 70)
|
||||
.with(RemoveNewCustomizationOptionsCalloutDismissed, 70, 71)
|
||||
.with(RemoveAccountDeprovisioningBannerDismissed, 71, CURRENT_VERSION);
|
||||
}
|
||||
|
||||
export async function currentVersion(
|
||||
storageService: AbstractStorageService,
|
||||
logService: LogService,
|
||||
) {
|
||||
let state = await storageService.get<number>("stateVersion");
|
||||
if (state == null) {
|
||||
// Pre v8
|
||||
state = (await storageService.get<{ stateVersion: number }>("global"))?.stateVersion;
|
||||
}
|
||||
if (state == null) {
|
||||
logService.info("No state version found, assuming empty state.");
|
||||
return -1;
|
||||
}
|
||||
logService.info(`State version: ${state}`);
|
||||
return state;
|
||||
}
|
||||
|
||||
/**
|
||||
* Waits for migrations to have a chance to run and will resolve the promise once they are.
|
||||
*
|
||||
* @param storageService Disk storage where the `stateVersion` will or is already saved in.
|
||||
* @param logService Log service
|
||||
*/
|
||||
export async function waitForMigrations(
|
||||
storageService: AbstractStorageService,
|
||||
logService: LogService,
|
||||
) {
|
||||
const isReady = async () => {
|
||||
const version = await currentVersion(storageService, logService);
|
||||
// The saved version is what we consider the latest
|
||||
// migrations should be complete, the state version
|
||||
// shouldn't become larger than `CURRENT_VERSION` in
|
||||
// any normal usage of the application but it is common
|
||||
// enough in dev scenarios where we want to consider that
|
||||
// ready as well and return true in that scenario.
|
||||
return version >= CURRENT_VERSION;
|
||||
};
|
||||
|
||||
const wait = async (time: number) => {
|
||||
// Wait exponentially
|
||||
const nextTime = time * 2;
|
||||
if (nextTime > 8192) {
|
||||
// Don't wait longer than ~8 seconds in a single wait,
|
||||
// if the migrations still haven't happened. They aren't
|
||||
// likely to.
|
||||
return;
|
||||
}
|
||||
return new Promise<void>((resolve) => {
|
||||
setTimeout(async () => {
|
||||
if (!(await isReady())) {
|
||||
logService.info(`Waiting for migrations to finish, waiting for ${nextTime}ms`);
|
||||
await wait(nextTime);
|
||||
}
|
||||
resolve();
|
||||
}, time);
|
||||
});
|
||||
};
|
||||
|
||||
if (!(await isReady())) {
|
||||
// Wait for 2ms to start with
|
||||
await wait(2);
|
||||
}
|
||||
}
|
||||
@@ -1,143 +0,0 @@
|
||||
import { mock } from "jest-mock-extended";
|
||||
|
||||
// eslint-disable-next-line import/no-restricted-paths
|
||||
import { ClientType } from "../enums";
|
||||
|
||||
import { MigrationBuilder } from "./migration-builder";
|
||||
import { MigrationHelper } from "./migration-helper";
|
||||
import { Migrator } from "./migrator";
|
||||
|
||||
describe("MigrationBuilder", () => {
|
||||
class TestMigrator extends Migrator<0, 1> {
|
||||
async migrate(helper: MigrationHelper): Promise<void> {
|
||||
await helper.set("test", "test");
|
||||
return;
|
||||
}
|
||||
|
||||
async rollback(helper: MigrationHelper): Promise<void> {
|
||||
await helper.set("test", "rollback");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
class TestMigratorWithInstanceMethod extends Migrator<0, 1> {
|
||||
private async instanceMethod(helper: MigrationHelper, value: string) {
|
||||
await helper.set("test", value);
|
||||
}
|
||||
|
||||
async migrate(helper: MigrationHelper): Promise<void> {
|
||||
await this.instanceMethod(helper, "migrate");
|
||||
}
|
||||
|
||||
async rollback(helper: MigrationHelper): Promise<void> {
|
||||
await this.instanceMethod(helper, "rollback");
|
||||
}
|
||||
}
|
||||
|
||||
let sut: MigrationBuilder<number>;
|
||||
|
||||
beforeEach(() => {
|
||||
sut = MigrationBuilder.create();
|
||||
});
|
||||
|
||||
class TestBadMigrator extends Migrator<1, 0> {
|
||||
async migrate(helper: MigrationHelper): Promise<void> {
|
||||
await helper.set("test", "test");
|
||||
}
|
||||
|
||||
async rollback(helper: MigrationHelper): Promise<void> {
|
||||
await helper.set("test", "rollback");
|
||||
}
|
||||
}
|
||||
|
||||
it("should throw if instantiated incorrectly", () => {
|
||||
expect(() => MigrationBuilder.create().with(TestMigrator, null, null)).toThrow();
|
||||
expect(() =>
|
||||
MigrationBuilder.create().with(TestMigrator, 0, 1).with(TestBadMigrator, 1, 0),
|
||||
).toThrow();
|
||||
});
|
||||
|
||||
it("should be able to create a new MigrationBuilder", () => {
|
||||
expect(sut).toBeInstanceOf(MigrationBuilder);
|
||||
});
|
||||
|
||||
it("should be able to add a migrator", () => {
|
||||
const newBuilder = sut.with(TestMigrator, 0, 1);
|
||||
const migrations = newBuilder["migrations"];
|
||||
expect(migrations.length).toBe(1);
|
||||
expect(migrations[0]).toMatchObject({ migrator: expect.any(TestMigrator), direction: "up" });
|
||||
});
|
||||
|
||||
it("should be able to add a rollback", () => {
|
||||
const newBuilder = sut.with(TestMigrator, 0, 1).rollback(TestMigrator, 1, 0);
|
||||
const migrations = newBuilder["migrations"];
|
||||
expect(migrations.length).toBe(2);
|
||||
expect(migrations[1]).toMatchObject({ migrator: expect.any(TestMigrator), direction: "down" });
|
||||
});
|
||||
|
||||
const clientTypes = Object.values(ClientType);
|
||||
|
||||
describe.each(clientTypes)("for client %s", (clientType) => {
|
||||
describe("migrate", () => {
|
||||
let migrator: TestMigrator;
|
||||
let rollback_migrator: TestMigrator;
|
||||
|
||||
beforeEach(() => {
|
||||
sut = sut.with(TestMigrator, 0, 1).rollback(TestMigrator, 1, 0);
|
||||
migrator = (sut as any).migrations[0].migrator;
|
||||
rollback_migrator = (sut as any).migrations[1].migrator;
|
||||
});
|
||||
|
||||
it("should migrate", async () => {
|
||||
const helper = new MigrationHelper(0, mock(), mock(), "general", clientType);
|
||||
const spy = jest.spyOn(migrator, "migrate");
|
||||
await sut.migrate(helper);
|
||||
expect(spy).toBeCalledWith(helper);
|
||||
});
|
||||
|
||||
it("should rollback", async () => {
|
||||
const helper = new MigrationHelper(1, mock(), mock(), "general", clientType);
|
||||
const spy = jest.spyOn(rollback_migrator, "rollback");
|
||||
await sut.migrate(helper);
|
||||
expect(spy).toBeCalledWith(helper);
|
||||
});
|
||||
|
||||
it("should update version on migrate", async () => {
|
||||
const helper = new MigrationHelper(0, mock(), mock(), "general", clientType);
|
||||
const spy = jest.spyOn(migrator, "updateVersion");
|
||||
await sut.migrate(helper);
|
||||
expect(spy).toBeCalledWith(helper, "up");
|
||||
});
|
||||
|
||||
it("should update version on rollback", async () => {
|
||||
const helper = new MigrationHelper(1, mock(), mock(), "general", clientType);
|
||||
const spy = jest.spyOn(rollback_migrator, "updateVersion");
|
||||
await sut.migrate(helper);
|
||||
expect(spy).toBeCalledWith(helper, "down");
|
||||
});
|
||||
|
||||
it("should not run the migrator if the current version does not match the from version", async () => {
|
||||
const helper = new MigrationHelper(3, mock(), mock(), "general", clientType);
|
||||
const migrate = jest.spyOn(migrator, "migrate");
|
||||
const rollback = jest.spyOn(rollback_migrator, "rollback");
|
||||
await sut.migrate(helper);
|
||||
expect(migrate).not.toBeCalled();
|
||||
expect(rollback).not.toBeCalled();
|
||||
});
|
||||
|
||||
it("should not update version if the current version does not match the from version", async () => {
|
||||
const helper = new MigrationHelper(3, mock(), mock(), "general", clientType);
|
||||
const migrate = jest.spyOn(migrator, "updateVersion");
|
||||
const rollback = jest.spyOn(rollback_migrator, "updateVersion");
|
||||
await sut.migrate(helper);
|
||||
expect(migrate).not.toBeCalled();
|
||||
expect(rollback).not.toBeCalled();
|
||||
});
|
||||
});
|
||||
|
||||
it("should be able to call instance methods", async () => {
|
||||
const helper = new MigrationHelper(0, mock(), mock(), "general", clientType);
|
||||
await sut.with(TestMigratorWithInstanceMethod, 0, 1).migrate(helper);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,106 +1 @@
|
||||
import { MigrationHelper } from "./migration-helper";
|
||||
import { Direction, Migrator, VersionFrom, VersionTo } from "./migrator";
|
||||
|
||||
export class MigrationBuilder<TCurrent extends number = 0> {
|
||||
/** Create a new MigrationBuilder with an empty buffer of migrations to perform.
|
||||
*
|
||||
* Add migrations to the buffer with {@link with} and {@link rollback}.
|
||||
* @returns A new MigrationBuilder.
|
||||
*/
|
||||
static create(): MigrationBuilder<0> {
|
||||
return new MigrationBuilder([]);
|
||||
}
|
||||
|
||||
private constructor(
|
||||
private migrations: readonly { migrator: Migrator<number, number>; direction: Direction }[],
|
||||
) {}
|
||||
|
||||
/** Add a migrator to the MigrationBuilder. Types are updated such that the chained MigrationBuilder must currently be
|
||||
* at state version equal to the from version of the migrator. Return as MigrationBuilder<TTo> where TTo is the to
|
||||
* version of the migrator, so that the next migrator can be chained.
|
||||
*
|
||||
* @param migrate A migrator class or a tuple of a migrator class, the from version, and the to version. A tuple is
|
||||
* required to instantiate version numbers unless a default constructor is defined.
|
||||
* @returns A new MigrationBuilder with the to version of the migrator as the current version.
|
||||
*/
|
||||
with<
|
||||
TMigrator extends Migrator<number, number>,
|
||||
TFrom extends VersionFrom<TMigrator> & TCurrent,
|
||||
TTo extends VersionTo<TMigrator>,
|
||||
>(
|
||||
...migrate: [new () => TMigrator] | [new (from: TFrom, to: TTo) => TMigrator, TFrom, TTo]
|
||||
): MigrationBuilder<TTo> {
|
||||
return this.addMigrator(migrate, "up");
|
||||
}
|
||||
|
||||
/** Add a migrator to rollback on the MigrationBuilder's list of migrations. As with {@link with}, types of
|
||||
* MigrationBuilder and Migrator must align. However, this time the migration is reversed so TCurrent of the
|
||||
* MigrationBuilder must be equal to the to version of the migrator. Return as MigrationBuilder<TFrom> where TFrom
|
||||
* is the from version of the migrator, so that the next migrator can be chained.
|
||||
*
|
||||
* @param migrate A migrator class or a tuple of a migrator class, the from version, and the to version. A tuple is
|
||||
* required to instantiate version numbers unless a default constructor is defined.
|
||||
* @returns A new MigrationBuilder with the from version of the migrator as the current version.
|
||||
*/
|
||||
rollback<
|
||||
TMigrator extends Migrator<number, number>,
|
||||
TFrom extends VersionFrom<TMigrator>,
|
||||
TTo extends VersionTo<TMigrator> & TCurrent,
|
||||
>(
|
||||
...migrate: [new () => TMigrator] | [new (from: TFrom, to: TTo) => TMigrator, TTo, TFrom]
|
||||
): MigrationBuilder<TFrom> {
|
||||
if (migrate.length === 3) {
|
||||
migrate = [migrate[0], migrate[2], migrate[1]];
|
||||
}
|
||||
return this.addMigrator(migrate, "down");
|
||||
}
|
||||
|
||||
/** Execute the migrations as defined in the MigrationBuilder's migrator buffer */
|
||||
migrate(helper: MigrationHelper): Promise<void> {
|
||||
return this.migrations.reduce(
|
||||
(promise, migrator) =>
|
||||
promise.then(async () => {
|
||||
await this.runMigrator(migrator.migrator, helper, migrator.direction);
|
||||
}),
|
||||
Promise.resolve(),
|
||||
);
|
||||
}
|
||||
|
||||
private addMigrator<
|
||||
TMigrator extends Migrator<number, number>,
|
||||
TFrom extends VersionFrom<TMigrator> & TCurrent,
|
||||
TTo extends VersionTo<TMigrator>,
|
||||
>(
|
||||
migrate: [new () => TMigrator] | [new (from: TFrom, to: TTo) => TMigrator, TFrom, TTo],
|
||||
direction: Direction = "up",
|
||||
) {
|
||||
const newMigration =
|
||||
migrate.length === 1
|
||||
? { migrator: new migrate[0](), direction }
|
||||
: { migrator: new migrate[0](migrate[1], migrate[2]), direction };
|
||||
|
||||
return new MigrationBuilder<TTo>([...this.migrations, newMigration]);
|
||||
}
|
||||
|
||||
private async runMigrator(
|
||||
migrator: Migrator<number, number>,
|
||||
helper: MigrationHelper,
|
||||
direction: Direction,
|
||||
): Promise<void> {
|
||||
const shouldMigrate = await migrator.shouldMigrate(helper, direction);
|
||||
helper.info(
|
||||
`Migrator ${migrator.constructor.name} (to version ${migrator.toVersion}) should migrate: ${shouldMigrate} - ${direction}`,
|
||||
);
|
||||
if (shouldMigrate) {
|
||||
const method = direction === "up" ? migrator.migrate : migrator.rollback;
|
||||
await method.bind(migrator)(helper);
|
||||
helper.info(
|
||||
`Migrator ${migrator.constructor.name} (to version ${migrator.toVersion}) migrated - ${direction}`,
|
||||
);
|
||||
await migrator.updateVersion(helper, direction);
|
||||
helper.info(
|
||||
`Migrator ${migrator.constructor.name} (to version ${migrator.toVersion}) updated version - ${direction}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
export { MigrationBuilder } from "@bitwarden/state";
|
||||
|
||||
@@ -1,385 +0,0 @@
|
||||
import { MockProxy, mock } from "jest-mock-extended";
|
||||
|
||||
import { FakeStorageService } from "../../spec/fake-storage.service";
|
||||
// eslint-disable-next-line import/no-restricted-paths -- Needed client type enum
|
||||
import { ClientType } from "../enums";
|
||||
// eslint-disable-next-line import/no-restricted-paths -- Needed to print log messages
|
||||
import { LogService } from "../platform/abstractions/log.service";
|
||||
// eslint-disable-next-line import/no-restricted-paths -- Needed to interface with storage locations
|
||||
import { AbstractStorageService } from "../platform/abstractions/storage.service";
|
||||
// eslint-disable-next-line import/no-restricted-paths -- Needed to generate unique strings for injection
|
||||
import { Utils } from "../platform/misc/utils";
|
||||
|
||||
import { MigrationHelper, MigrationHelperType } from "./migration-helper";
|
||||
import { Migrator } from "./migrator";
|
||||
|
||||
const exampleJSON = {
|
||||
authenticatedAccounts: [
|
||||
"c493ed01-4e08-4e88-abc7-332f380ca760",
|
||||
"23e61a5f-2ece-4f5e-b499-f0bc489482a9",
|
||||
],
|
||||
"c493ed01-4e08-4e88-abc7-332f380ca760": {
|
||||
otherStuff: "otherStuff1",
|
||||
},
|
||||
"23e61a5f-2ece-4f5e-b499-f0bc489482a9": {
|
||||
otherStuff: "otherStuff2",
|
||||
},
|
||||
global_serviceName_key: "global_serviceName_key",
|
||||
user_userId_serviceName_key: "user_userId_serviceName_key",
|
||||
global_account_accounts: {
|
||||
"c493ed01-4e08-4e88-abc7-332f380ca760": {
|
||||
otherStuff: "otherStuff3",
|
||||
},
|
||||
"23e61a5f-2ece-4f5e-b499-f0bc489482a9": {
|
||||
otherStuff: "otherStuff4",
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
describe("RemoveLegacyEtmKeyMigrator", () => {
|
||||
let storage: MockProxy<AbstractStorageService>;
|
||||
let logService: MockProxy<LogService>;
|
||||
let sut: MigrationHelper;
|
||||
|
||||
const clientTypes = Object.values(ClientType);
|
||||
|
||||
describe.each(clientTypes)("for client %s", (clientType) => {
|
||||
beforeEach(() => {
|
||||
logService = mock();
|
||||
storage = mock();
|
||||
storage.get.mockImplementation((key) => (exampleJSON as any)[key]);
|
||||
|
||||
sut = new MigrationHelper(0, storage, logService, "general", clientType);
|
||||
});
|
||||
|
||||
describe("get", () => {
|
||||
it("should delegate to storage.get", async () => {
|
||||
await sut.get("key");
|
||||
expect(storage.get).toHaveBeenCalledWith("key");
|
||||
});
|
||||
});
|
||||
|
||||
describe("set", () => {
|
||||
it("should delegate to storage.save", async () => {
|
||||
await sut.set("key", "value");
|
||||
expect(storage.save).toHaveBeenCalledWith("key", "value");
|
||||
});
|
||||
});
|
||||
|
||||
describe("getAccounts", () => {
|
||||
it("should return all accounts", async () => {
|
||||
const accounts = await sut.getAccounts();
|
||||
expect(accounts).toEqual([
|
||||
{
|
||||
userId: "c493ed01-4e08-4e88-abc7-332f380ca760",
|
||||
account: { otherStuff: "otherStuff1" },
|
||||
},
|
||||
{
|
||||
userId: "23e61a5f-2ece-4f5e-b499-f0bc489482a9",
|
||||
account: { otherStuff: "otherStuff2" },
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it("should handle missing authenticatedAccounts", async () => {
|
||||
storage.get.mockImplementation((key) =>
|
||||
key === "authenticatedAccounts" ? undefined : (exampleJSON as any)[key],
|
||||
);
|
||||
const accounts = await sut.getAccounts();
|
||||
expect(accounts).toEqual([]);
|
||||
});
|
||||
|
||||
it("handles global scoped known accounts for version 60 and after", async () => {
|
||||
sut.currentVersion = 60;
|
||||
const accounts = await sut.getAccounts();
|
||||
expect(accounts).toEqual([
|
||||
// Note, still gets values stored in state service objects, just grabs user ids from global
|
||||
{
|
||||
userId: "c493ed01-4e08-4e88-abc7-332f380ca760",
|
||||
account: { otherStuff: "otherStuff1" },
|
||||
},
|
||||
{
|
||||
userId: "23e61a5f-2ece-4f5e-b499-f0bc489482a9",
|
||||
account: { otherStuff: "otherStuff2" },
|
||||
},
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("getKnownUserIds", () => {
|
||||
it("returns all user ids", async () => {
|
||||
const userIds = await sut.getKnownUserIds();
|
||||
expect(userIds).toEqual([
|
||||
"c493ed01-4e08-4e88-abc7-332f380ca760",
|
||||
"23e61a5f-2ece-4f5e-b499-f0bc489482a9",
|
||||
]);
|
||||
});
|
||||
|
||||
it("returns all user ids when version is 60 or greater", async () => {
|
||||
sut.currentVersion = 60;
|
||||
const userIds = await sut.getKnownUserIds();
|
||||
expect(userIds).toEqual([
|
||||
"c493ed01-4e08-4e88-abc7-332f380ca760",
|
||||
"23e61a5f-2ece-4f5e-b499-f0bc489482a9",
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("getFromGlobal", () => {
|
||||
it("should return the correct value", async () => {
|
||||
sut.currentVersion = 9;
|
||||
const value = await sut.getFromGlobal({
|
||||
stateDefinition: { name: "serviceName" },
|
||||
key: "key",
|
||||
});
|
||||
expect(value).toEqual("global_serviceName_key");
|
||||
});
|
||||
|
||||
it("should throw if the current version is less than 9", () => {
|
||||
expect(() =>
|
||||
sut.getFromGlobal({ stateDefinition: { name: "serviceName" }, key: "key" }),
|
||||
).toThrowError("No key builder should be used for versions prior to 9.");
|
||||
});
|
||||
});
|
||||
|
||||
describe("setToGlobal", () => {
|
||||
it("should set the correct value", async () => {
|
||||
sut.currentVersion = 9;
|
||||
await sut.setToGlobal(
|
||||
{ stateDefinition: { name: "serviceName" }, key: "key" },
|
||||
"new_value",
|
||||
);
|
||||
expect(storage.save).toHaveBeenCalledWith("global_serviceName_key", "new_value");
|
||||
});
|
||||
|
||||
it("should throw if the current version is less than 9", () => {
|
||||
expect(() =>
|
||||
sut.setToGlobal(
|
||||
{ stateDefinition: { name: "serviceName" }, key: "key" },
|
||||
"global_serviceName_key",
|
||||
),
|
||||
).toThrowError("No key builder should be used for versions prior to 9.");
|
||||
});
|
||||
});
|
||||
|
||||
describe("getFromUser", () => {
|
||||
it("should return the correct value", async () => {
|
||||
sut.currentVersion = 9;
|
||||
const value = await sut.getFromUser("userId", {
|
||||
stateDefinition: { name: "serviceName" },
|
||||
key: "key",
|
||||
});
|
||||
expect(value).toEqual("user_userId_serviceName_key");
|
||||
});
|
||||
|
||||
it("should throw if the current version is less than 9", () => {
|
||||
expect(() =>
|
||||
sut.getFromUser("userId", { stateDefinition: { name: "serviceName" }, key: "key" }),
|
||||
).toThrowError("No key builder should be used for versions prior to 9.");
|
||||
});
|
||||
});
|
||||
|
||||
describe("setToUser", () => {
|
||||
it("should set the correct value", async () => {
|
||||
sut.currentVersion = 9;
|
||||
await sut.setToUser(
|
||||
"userId",
|
||||
{ stateDefinition: { name: "serviceName" }, key: "key" },
|
||||
"new_value",
|
||||
);
|
||||
expect(storage.save).toHaveBeenCalledWith("user_userId_serviceName_key", "new_value");
|
||||
});
|
||||
|
||||
it("should throw if the current version is less than 9", () => {
|
||||
expect(() =>
|
||||
sut.setToUser(
|
||||
"userId",
|
||||
{ stateDefinition: { name: "serviceName" }, key: "key" },
|
||||
"new_value",
|
||||
),
|
||||
).toThrowError("No key builder should be used for versions prior to 9.");
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
/** Helper to create well-mocked migration helpers in migration tests */
|
||||
export function mockMigrationHelper(
|
||||
storageJson: any,
|
||||
stateVersion = 0,
|
||||
type: MigrationHelperType = "general",
|
||||
clientType: ClientType = ClientType.Web,
|
||||
): MockProxy<MigrationHelper> {
|
||||
const logService: MockProxy<LogService> = mock();
|
||||
const storage: MockProxy<AbstractStorageService> = mock();
|
||||
storage.get.mockImplementation((key) => (storageJson as any)[key]);
|
||||
storage.save.mockImplementation(async (key, value) => {
|
||||
(storageJson as any)[key] = value;
|
||||
});
|
||||
const helper = new MigrationHelper(stateVersion, storage, logService, type, clientType);
|
||||
|
||||
const mockHelper = mock<MigrationHelper>();
|
||||
mockHelper.get.mockImplementation((key) => helper.get(key));
|
||||
mockHelper.set.mockImplementation((key, value) => helper.set(key, value));
|
||||
mockHelper.getFromGlobal.mockImplementation((keyDefinition) =>
|
||||
helper.getFromGlobal(keyDefinition),
|
||||
);
|
||||
mockHelper.setToGlobal.mockImplementation((keyDefinition, value) =>
|
||||
helper.setToGlobal(keyDefinition, value),
|
||||
);
|
||||
mockHelper.getFromUser.mockImplementation((userId, keyDefinition) =>
|
||||
helper.getFromUser(userId, keyDefinition),
|
||||
);
|
||||
mockHelper.setToUser.mockImplementation((userId, keyDefinition, value) =>
|
||||
helper.setToUser(userId, keyDefinition, value),
|
||||
);
|
||||
mockHelper.getAccounts.mockImplementation(() => helper.getAccounts());
|
||||
mockHelper.getKnownUserIds.mockImplementation(() => helper.getKnownUserIds());
|
||||
mockHelper.removeFromGlobal.mockImplementation((keyDefinition) =>
|
||||
helper.removeFromGlobal(keyDefinition),
|
||||
);
|
||||
mockHelper.remove.mockImplementation((key) => helper.remove(key));
|
||||
|
||||
mockHelper.type = helper.type;
|
||||
mockHelper.clientType = helper.clientType;
|
||||
|
||||
return mockHelper;
|
||||
}
|
||||
|
||||
export type InitialDataHint<TUsers extends readonly string[]> = {
|
||||
/**
|
||||
* A string array of the users id who are authenticated
|
||||
*/
|
||||
authenticatedAccounts?: TUsers;
|
||||
/**
|
||||
* Global data
|
||||
*/
|
||||
global?: unknown;
|
||||
/**
|
||||
* Other top level data
|
||||
*/
|
||||
[key: string]: unknown;
|
||||
} & {
|
||||
/**
|
||||
* A users data
|
||||
*/
|
||||
[userData in TUsers[number]]?: unknown;
|
||||
};
|
||||
|
||||
type InjectedData = {
|
||||
propertyName: string;
|
||||
propertyValue: string;
|
||||
originalPath: string[];
|
||||
};
|
||||
|
||||
// This is a slight lie, technically the type is `Record<string | symbol, unknown>
|
||||
// but for the purposes of things in the migrations this is enough.
|
||||
function isStringRecord(object: unknown | undefined): object is Record<string, unknown> {
|
||||
return object && typeof object === "object" && !Array.isArray(object);
|
||||
}
|
||||
|
||||
function injectData(data: Record<string, unknown>, path: string[]): InjectedData[] {
|
||||
if (!data) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const injectedData: InjectedData[] = [];
|
||||
|
||||
// Traverse keys for other objects
|
||||
const keys = Object.keys(data);
|
||||
for (const key of keys) {
|
||||
const currentProperty = data[key];
|
||||
if (isStringRecord(currentProperty)) {
|
||||
injectedData.push(...injectData(currentProperty, [...path, key]));
|
||||
}
|
||||
}
|
||||
|
||||
const propertyName = `__injectedProperty__${Utils.newGuid()}`;
|
||||
const propertyValue = `__injectedValue__${Utils.newGuid()}`;
|
||||
|
||||
injectedData.push({
|
||||
propertyName: propertyName,
|
||||
propertyValue: propertyValue,
|
||||
// Track the path it was originally injected in just for a better error
|
||||
originalPath: path,
|
||||
});
|
||||
data[propertyName] = propertyValue;
|
||||
return injectedData;
|
||||
}
|
||||
|
||||
function expectInjectedData(
|
||||
data: Record<string, unknown>,
|
||||
injectedData: InjectedData[],
|
||||
): [data: Record<string, unknown>, leftoverInjectedData: InjectedData[]] {
|
||||
const keys = Object.keys(data);
|
||||
for (const key of keys) {
|
||||
const propertyValue = data[key];
|
||||
// Injected data does not have to be found exactly where it was injected,
|
||||
// just that it exists at all.
|
||||
const injectedIndex = injectedData.findIndex(
|
||||
(d) =>
|
||||
d.propertyName === key &&
|
||||
typeof propertyValue === "string" &&
|
||||
propertyValue === d.propertyValue,
|
||||
);
|
||||
|
||||
if (injectedIndex !== -1) {
|
||||
// We found something we injected, remove it
|
||||
injectedData.splice(injectedIndex, 1);
|
||||
delete data[key];
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isStringRecord(propertyValue)) {
|
||||
const [updatedData, leftoverInjectedData] = expectInjectedData(propertyValue, injectedData);
|
||||
data[key] = updatedData;
|
||||
injectedData = leftoverInjectedData;
|
||||
}
|
||||
}
|
||||
|
||||
return [data, injectedData];
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the {@link Migrator.migrate} method of your migrator. You may pass in your test data and get back the data after the migration.
|
||||
* This also injects extra properties at every level of your state and makes sure that it can be found.
|
||||
* @param migrator Your migrator to use to do the migration
|
||||
* @param initalData The data to start with
|
||||
* @returns State after your migration has ran.
|
||||
*/
|
||||
export async function runMigrator<
|
||||
TMigrator extends Migrator<number, number>,
|
||||
const TUsers extends readonly string[],
|
||||
>(
|
||||
migrator: TMigrator,
|
||||
initalData?: InitialDataHint<TUsers>,
|
||||
direction: "migrate" | "rollback" = "migrate",
|
||||
): Promise<Record<string, unknown>> {
|
||||
const clonedData = JSON.parse(JSON.stringify(initalData ?? {}));
|
||||
|
||||
// Inject fake data at every level of the object
|
||||
const allInjectedData = injectData(clonedData, []);
|
||||
|
||||
const fakeStorageService = new FakeStorageService(clonedData);
|
||||
const helper = new MigrationHelper(
|
||||
migrator.fromVersion,
|
||||
fakeStorageService,
|
||||
mock(),
|
||||
"general",
|
||||
ClientType.Web,
|
||||
);
|
||||
|
||||
// Run their migrations
|
||||
if (direction === "rollback") {
|
||||
await migrator.rollback(helper);
|
||||
} else {
|
||||
await migrator.migrate(helper);
|
||||
}
|
||||
const [data, leftoverInjectedData] = expectInjectedData(
|
||||
fakeStorageService.internalStore,
|
||||
allInjectedData,
|
||||
);
|
||||
expect(leftoverInjectedData).toHaveLength(0);
|
||||
|
||||
return data;
|
||||
}
|
||||
@@ -1,261 +1 @@
|
||||
// eslint-disable-next-line import/no-restricted-paths -- Needed to provide client type to migrations
|
||||
import { ClientType } from "../enums";
|
||||
// eslint-disable-next-line import/no-restricted-paths -- Needed to print log messages
|
||||
import { LogService } from "../platform/abstractions/log.service";
|
||||
// eslint-disable-next-line import/no-restricted-paths -- Needed to interface with storage locations
|
||||
import { AbstractStorageService } from "../platform/abstractions/storage.service";
|
||||
|
||||
export type StateDefinitionLike = { name: string };
|
||||
export type KeyDefinitionLike = {
|
||||
stateDefinition: StateDefinitionLike;
|
||||
key: string;
|
||||
};
|
||||
|
||||
export type MigrationHelperType = "general" | "web-disk-local";
|
||||
|
||||
export class MigrationHelper {
|
||||
constructor(
|
||||
public currentVersion: number,
|
||||
private storageService: AbstractStorageService,
|
||||
public logService: LogService,
|
||||
type: MigrationHelperType,
|
||||
public clientType: ClientType,
|
||||
) {
|
||||
this.type = type;
|
||||
}
|
||||
|
||||
/**
|
||||
* On some clients, migrations are ran multiple times without direct action from the migration writer.
|
||||
*
|
||||
* All clients will run through migrations at least once, this run is referred to as `"general"`. If a migration is
|
||||
* ran more than that single time, they will get a unique name if that the write can make conditional logic based on which
|
||||
* migration run this is.
|
||||
*
|
||||
* @remarks The preferrable way of writing migrations is ALWAYS to be defensive and reflect on the data you are given back. This
|
||||
* should really only be used when reflecting on the data given isn't enough.
|
||||
*/
|
||||
type: MigrationHelperType;
|
||||
|
||||
/**
|
||||
* Gets a value from the storage service at the given key.
|
||||
*
|
||||
* This is a brute force method to just get a value from the storage service. If you can use {@link getFromGlobal} or {@link getFromUser}, you should.
|
||||
* @param key location
|
||||
* @returns the value at the location
|
||||
*/
|
||||
get<T>(key: string): Promise<T> {
|
||||
return this.storageService.get<T>(key);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets a value in the storage service at the given key.
|
||||
*
|
||||
* This is a brute force method to just set a value in the storage service. If you can use {@link setToGlobal} or {@link setToUser}, you should.
|
||||
* @param key location
|
||||
* @param value the value to set
|
||||
* @returns
|
||||
*/
|
||||
set<T>(key: string, value: T): Promise<void> {
|
||||
this.logService.info(`Setting ${key}`);
|
||||
return this.storageService.save(key, value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a value in the storage service at the given key.
|
||||
*
|
||||
* This is a brute force method to just remove a value in the storage service. If you can use {@link removeFromGlobal} or {@link removeFromUser}, you should.
|
||||
* @param key location
|
||||
* @returns void
|
||||
*/
|
||||
remove(key: string): Promise<void> {
|
||||
this.logService.info(`Removing ${key}`);
|
||||
return this.storageService.remove(key);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a globally scoped value from a location derived through the key definition
|
||||
*
|
||||
* This is for use with the state providers framework, DO NOT use for values stored with {@link StateService},
|
||||
* use {@link get} for those.
|
||||
* @param keyDefinition unique key definition
|
||||
* @returns value from store
|
||||
*/
|
||||
getFromGlobal<T>(keyDefinition: KeyDefinitionLike): Promise<T> {
|
||||
return this.get<T>(this.getGlobalKey(keyDefinition));
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets a globally scoped value to a location derived through the key definition
|
||||
*
|
||||
* This is for use with the state providers framework, DO NOT use for values stored with {@link StateService},
|
||||
* use {@link set} for those.
|
||||
* @param keyDefinition unique key definition
|
||||
* @param value value to store
|
||||
* @returns void
|
||||
*/
|
||||
setToGlobal<T>(keyDefinition: KeyDefinitionLike, value: T): Promise<void> {
|
||||
return this.set(this.getGlobalKey(keyDefinition), value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a globally scoped location derived through the key definition
|
||||
*
|
||||
* This is for use with the state providers framework, DO NOT use for values stored with {@link StateService},
|
||||
* use {@link remove} for those.
|
||||
* @param keyDefinition unique key definition
|
||||
* @returns void
|
||||
*/
|
||||
removeFromGlobal(keyDefinition: KeyDefinitionLike): Promise<void> {
|
||||
return this.remove(this.getGlobalKey(keyDefinition));
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a user scoped value from a location derived through the user id and key definition
|
||||
*
|
||||
* This is for use with the state providers framework, DO NOT use for values stored with {@link StateService},
|
||||
* use {@link get} for those.
|
||||
* @param userId userId to use in the key
|
||||
* @param keyDefinition unique key definition
|
||||
* @returns value from store
|
||||
*/
|
||||
getFromUser<T>(userId: string, keyDefinition: KeyDefinitionLike): Promise<T> {
|
||||
return this.get<T>(this.getUserKey(userId, keyDefinition));
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets a user scoped value to a location derived through the user id and key definition
|
||||
*
|
||||
* This is for use with the state providers framework, DO NOT use for values stored with {@link StateService},
|
||||
* use {@link set} for those.
|
||||
* @param userId userId to use in the key
|
||||
* @param keyDefinition unique key definition
|
||||
* @param value value to store
|
||||
* @returns void
|
||||
*/
|
||||
setToUser<T>(userId: string, keyDefinition: KeyDefinitionLike, value: T): Promise<void> {
|
||||
return this.set(this.getUserKey(userId, keyDefinition), value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a user scoped location derived through the key definition
|
||||
*
|
||||
* This is for use with the state providers framework, DO NOT use for values stored with {@link StateService},
|
||||
* use {@link remove} for those.
|
||||
* @param keyDefinition unique key definition
|
||||
* @returns void
|
||||
*/
|
||||
removeFromUser(userId: string, keyDefinition: KeyDefinitionLike): Promise<void> {
|
||||
return this.remove(this.getUserKey(userId, keyDefinition));
|
||||
}
|
||||
|
||||
info(message: string): void {
|
||||
this.logService.info(message);
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper method to read all Account objects stored by the State Service.
|
||||
*
|
||||
* This is useful from creating migrations off of this paradigm, but should not be used once a value is migrated to a state provider.
|
||||
*
|
||||
* @returns a list of all accounts that have been authenticated with state service, cast the expected type.
|
||||
*/
|
||||
async getAccounts<ExpectedAccountType>(): Promise<
|
||||
{ userId: string; account: ExpectedAccountType }[]
|
||||
> {
|
||||
const userIds = await this.getKnownUserIds();
|
||||
return Promise.all(
|
||||
userIds.map(async (userId) => ({
|
||||
userId,
|
||||
account: await this.get<ExpectedAccountType>(userId),
|
||||
})),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper method to read known users ids.
|
||||
*/
|
||||
async getKnownUserIds(): Promise<string[]> {
|
||||
if (this.currentVersion < 60) {
|
||||
return knownAccountUserIdsBuilderPre60(this.storageService);
|
||||
} else {
|
||||
return knownAccountUserIdsBuilder(this.storageService);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a user storage key appropriate for the current version.
|
||||
*
|
||||
* @param userId userId to use in the key
|
||||
* @param keyDefinition state and key to use in the key
|
||||
* @returns
|
||||
*/
|
||||
private getUserKey(userId: string, keyDefinition: KeyDefinitionLike): string {
|
||||
if (this.currentVersion < 9) {
|
||||
return userKeyBuilderPre9();
|
||||
} else {
|
||||
return userKeyBuilder(userId, keyDefinition);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a global storage key appropriate for the current version.
|
||||
*
|
||||
* @param keyDefinition state and key to use in the key
|
||||
* @returns
|
||||
*/
|
||||
private getGlobalKey(keyDefinition: KeyDefinitionLike): string {
|
||||
if (this.currentVersion < 9) {
|
||||
return globalKeyBuilderPre9();
|
||||
} else {
|
||||
return globalKeyBuilder(keyDefinition);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* When this is updated, rename this function to `userKeyBuilderXToY` where `X` is the version number it
|
||||
* became relevant, and `Y` prior to the version it was updated.
|
||||
*
|
||||
* Be sure to update the map in `MigrationHelper` to point to the appropriate function for the current version.
|
||||
* @param userId The userId of the user you want the key to be for.
|
||||
* @param keyDefinition the key definition of which data the key should point to.
|
||||
* @returns
|
||||
*/
|
||||
function userKeyBuilder(userId: string, keyDefinition: KeyDefinitionLike): string {
|
||||
return `user_${userId}_${keyDefinition.stateDefinition.name}_${keyDefinition.key}`;
|
||||
}
|
||||
|
||||
function userKeyBuilderPre9(): string {
|
||||
throw Error("No key builder should be used for versions prior to 9.");
|
||||
}
|
||||
|
||||
/**
|
||||
* When this is updated, rename this function to `globalKeyBuilderXToY` where `X` is the version number
|
||||
* it became relevant, and `Y` prior to the version it was updated.
|
||||
*
|
||||
* Be sure to update the map in `MigrationHelper` to point to the appropriate function for the current version.
|
||||
* @param keyDefinition the key definition of which data the key should point to.
|
||||
* @returns
|
||||
*/
|
||||
function globalKeyBuilder(keyDefinition: KeyDefinitionLike): string {
|
||||
return `global_${keyDefinition.stateDefinition.name}_${keyDefinition.key}`;
|
||||
}
|
||||
|
||||
function globalKeyBuilderPre9(): string {
|
||||
throw Error("No key builder should be used for versions prior to 9.");
|
||||
}
|
||||
|
||||
async function knownAccountUserIdsBuilderPre60(
|
||||
storageService: AbstractStorageService,
|
||||
): Promise<string[]> {
|
||||
return (await storageService.get<string[]>("authenticatedAccounts")) ?? [];
|
||||
}
|
||||
|
||||
async function knownAccountUserIdsBuilder(
|
||||
storageService: AbstractStorageService,
|
||||
): Promise<string[]> {
|
||||
const accounts = await storageService.get<Record<string, unknown>>(
|
||||
globalKeyBuilder({ stateDefinition: { name: "account" }, key: "accounts" }),
|
||||
);
|
||||
return Object.keys(accounts ?? {});
|
||||
}
|
||||
export { MigrationHelper } from "@bitwarden/state";
|
||||
|
||||
@@ -1,161 +0,0 @@
|
||||
import { MockProxy, any } from "jest-mock-extended";
|
||||
|
||||
import { MigrationHelper } from "../migration-helper";
|
||||
import { mockMigrationHelper } from "../migration-helper.spec";
|
||||
|
||||
import { EverHadUserKeyMigrator } from "./10-move-ever-had-user-key-to-state-providers";
|
||||
|
||||
function exampleJSON() {
|
||||
return {
|
||||
global: {
|
||||
otherStuff: "otherStuff1",
|
||||
},
|
||||
authenticatedAccounts: [
|
||||
"c493ed01-4e08-4e88-abc7-332f380ca760",
|
||||
"23e61a5f-2ece-4f5e-b499-f0bc489482a9",
|
||||
"fd005ea6-a16a-45ef-ba4a-a194269bfd73",
|
||||
],
|
||||
"c493ed01-4e08-4e88-abc7-332f380ca760": {
|
||||
profile: {
|
||||
everHadUserKey: false,
|
||||
otherStuff: "overStuff2",
|
||||
},
|
||||
otherStuff: "otherStuff3",
|
||||
},
|
||||
"23e61a5f-2ece-4f5e-b499-f0bc489482a9": {
|
||||
profile: {
|
||||
everHadUserKey: true,
|
||||
otherStuff: "otherStuff4",
|
||||
},
|
||||
otherStuff: "otherStuff5",
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function rollbackJSON() {
|
||||
return {
|
||||
"user_c493ed01-4e08-4e88-abc7-332f380ca760_crypto_everHadUserKey": false,
|
||||
"user_23e61a5f-2ece-4f5e-b499-f0bc489482a9_crypto_everHadUserKey": true,
|
||||
"user_fd005ea6-a16a-45ef-ba4a-a194269bfd73_crypto_everHadUserKey": false,
|
||||
global: {
|
||||
otherStuff: "otherStuff1",
|
||||
},
|
||||
authenticatedAccounts: [
|
||||
"c493ed01-4e08-4e88-abc7-332f380ca760",
|
||||
"23e61a5f-2ece-4f5e-b499-f0bc489482a9",
|
||||
"fd005ea6-a16a-45ef-ba4a-a194269bfd73",
|
||||
],
|
||||
"c493ed01-4e08-4e88-abc7-332f380ca760": {
|
||||
profile: {
|
||||
everHadUserKey: false,
|
||||
otherStuff: "overStuff2",
|
||||
},
|
||||
otherStuff: "otherStuff3",
|
||||
},
|
||||
"23e61a5f-2ece-4f5e-b499-f0bc489482a9": {
|
||||
profile: {
|
||||
everHadUserKey: true,
|
||||
otherStuff: "otherStuff4",
|
||||
},
|
||||
otherStuff: "otherStuff5",
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
describe("EverHadUserKeyMigrator", () => {
|
||||
let helper: MockProxy<MigrationHelper>;
|
||||
let sut: EverHadUserKeyMigrator;
|
||||
const keyDefinitionLike = {
|
||||
key: "everHadUserKey",
|
||||
stateDefinition: {
|
||||
name: "crypto",
|
||||
},
|
||||
};
|
||||
|
||||
describe("migrate", () => {
|
||||
beforeEach(() => {
|
||||
helper = mockMigrationHelper(exampleJSON(), 9);
|
||||
sut = new EverHadUserKeyMigrator(9, 10);
|
||||
});
|
||||
|
||||
it("should remove everHadUserKey from all accounts", async () => {
|
||||
await sut.migrate(helper);
|
||||
expect(helper.set).toHaveBeenCalledWith("c493ed01-4e08-4e88-abc7-332f380ca760", {
|
||||
profile: {
|
||||
otherStuff: "overStuff2",
|
||||
},
|
||||
otherStuff: "otherStuff3",
|
||||
});
|
||||
expect(helper.set).toHaveBeenCalledWith("23e61a5f-2ece-4f5e-b499-f0bc489482a9", {
|
||||
profile: {
|
||||
otherStuff: "otherStuff4",
|
||||
},
|
||||
otherStuff: "otherStuff5",
|
||||
});
|
||||
});
|
||||
|
||||
it("should set everHadUserKey provider value for each account", async () => {
|
||||
await sut.migrate(helper);
|
||||
|
||||
expect(helper.setToUser).toHaveBeenCalledWith(
|
||||
"c493ed01-4e08-4e88-abc7-332f380ca760",
|
||||
keyDefinitionLike,
|
||||
false,
|
||||
);
|
||||
|
||||
expect(helper.setToUser).toHaveBeenCalledWith(
|
||||
"23e61a5f-2ece-4f5e-b499-f0bc489482a9",
|
||||
keyDefinitionLike,
|
||||
true,
|
||||
);
|
||||
|
||||
expect(helper.setToUser).toHaveBeenCalledWith(
|
||||
"fd005ea6-a16a-45ef-ba4a-a194269bfd73",
|
||||
keyDefinitionLike,
|
||||
false,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("rollback", () => {
|
||||
beforeEach(() => {
|
||||
helper = mockMigrationHelper(rollbackJSON(), 10);
|
||||
sut = new EverHadUserKeyMigrator(9, 10);
|
||||
});
|
||||
|
||||
it.each([
|
||||
"c493ed01-4e08-4e88-abc7-332f380ca760",
|
||||
"23e61a5f-2ece-4f5e-b499-f0bc489482a9",
|
||||
"fd005ea6-a16a-45ef-ba4a-a194269bfd73",
|
||||
])("should null out new values", async (userId) => {
|
||||
await sut.rollback(helper);
|
||||
|
||||
expect(helper.setToUser).toHaveBeenCalledWith(userId, keyDefinitionLike, null);
|
||||
});
|
||||
|
||||
it("should add explicit value back to accounts", async () => {
|
||||
await sut.rollback(helper);
|
||||
|
||||
expect(helper.set).toHaveBeenCalledWith("c493ed01-4e08-4e88-abc7-332f380ca760", {
|
||||
profile: {
|
||||
everHadUserKey: false,
|
||||
otherStuff: "overStuff2",
|
||||
},
|
||||
otherStuff: "otherStuff3",
|
||||
});
|
||||
expect(helper.set).toHaveBeenCalledWith("23e61a5f-2ece-4f5e-b499-f0bc489482a9", {
|
||||
profile: {
|
||||
everHadUserKey: true,
|
||||
otherStuff: "otherStuff4",
|
||||
},
|
||||
otherStuff: "otherStuff5",
|
||||
});
|
||||
});
|
||||
|
||||
it("should not try to restore values to missing accounts", async () => {
|
||||
await sut.rollback(helper);
|
||||
|
||||
expect(helper.set).not.toHaveBeenCalledWith("fd005ea6-a16a-45ef-ba4a-a194269bfd73", any());
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,48 +0,0 @@
|
||||
// FIXME: Update this file to be type safe and remove this and next line
|
||||
// @ts-strict-ignore
|
||||
import { KeyDefinitionLike, MigrationHelper } from "../migration-helper";
|
||||
import { Migrator } from "../migrator";
|
||||
|
||||
type ExpectedAccountType = {
|
||||
profile?: {
|
||||
everHadUserKey?: boolean;
|
||||
};
|
||||
};
|
||||
|
||||
const USER_EVER_HAD_USER_KEY: KeyDefinitionLike = {
|
||||
key: "everHadUserKey",
|
||||
stateDefinition: {
|
||||
name: "crypto",
|
||||
},
|
||||
};
|
||||
|
||||
export class EverHadUserKeyMigrator extends Migrator<9, 10> {
|
||||
async migrate(helper: MigrationHelper): Promise<void> {
|
||||
const accounts = await helper.getAccounts<ExpectedAccountType>();
|
||||
async function migrateAccount(userId: string, account: ExpectedAccountType): Promise<void> {
|
||||
const value = account?.profile?.everHadUserKey;
|
||||
await helper.setToUser(userId, USER_EVER_HAD_USER_KEY, value ?? false);
|
||||
if (value != null) {
|
||||
delete account.profile.everHadUserKey;
|
||||
}
|
||||
await helper.set(userId, account);
|
||||
}
|
||||
|
||||
await Promise.all([...accounts.map(({ userId, account }) => migrateAccount(userId, account))]);
|
||||
}
|
||||
async rollback(helper: MigrationHelper): Promise<void> {
|
||||
const accounts = await helper.getAccounts<ExpectedAccountType>();
|
||||
async function rollbackAccount(userId: string, account: ExpectedAccountType): Promise<void> {
|
||||
const value = await helper.getFromUser(userId, USER_EVER_HAD_USER_KEY);
|
||||
if (account) {
|
||||
account.profile = Object.assign(account.profile ?? {}, {
|
||||
everHadUserKey: value,
|
||||
});
|
||||
await helper.set(userId, account);
|
||||
}
|
||||
await helper.setToUser(userId, USER_EVER_HAD_USER_KEY, null);
|
||||
}
|
||||
|
||||
await Promise.all([...accounts.map(({ userId, account }) => rollbackAccount(userId, account))]);
|
||||
}
|
||||
}
|
||||
@@ -1,163 +0,0 @@
|
||||
import { MockProxy, any } from "jest-mock-extended";
|
||||
|
||||
import { MigrationHelper } from "../migration-helper";
|
||||
import { mockMigrationHelper } from "../migration-helper.spec";
|
||||
|
||||
import { OrganizationKeyMigrator } from "./11-move-org-keys-to-state-providers";
|
||||
|
||||
function exampleJSON() {
|
||||
return {
|
||||
global: {
|
||||
otherStuff: "otherStuff1",
|
||||
},
|
||||
authenticatedAccounts: ["user-1", "user-2", "user-3"],
|
||||
"user-1": {
|
||||
keys: {
|
||||
organizationKeys: {
|
||||
encrypted: {
|
||||
"org-id-1": {
|
||||
type: "organization",
|
||||
key: "org-key-1",
|
||||
},
|
||||
"org-id-2": {
|
||||
type: "provider",
|
||||
key: "org-key-2",
|
||||
providerId: "provider-id-2",
|
||||
},
|
||||
},
|
||||
},
|
||||
otherStuff: "overStuff2",
|
||||
},
|
||||
otherStuff: "otherStuff3",
|
||||
},
|
||||
"user-2": {
|
||||
keys: {
|
||||
otherStuff: "otherStuff4",
|
||||
},
|
||||
otherStuff: "otherStuff5",
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function rollbackJSON() {
|
||||
return {
|
||||
"user_user-1_crypto_organizationKeys": {
|
||||
"org-id-1": {
|
||||
type: "organization",
|
||||
key: "org-key-1",
|
||||
},
|
||||
"org-id-2": {
|
||||
type: "provider",
|
||||
key: "org-key-2",
|
||||
providerId: "provider-id-2",
|
||||
},
|
||||
},
|
||||
"user_user-2_crypto_organizationKeys": null as any,
|
||||
global: {
|
||||
otherStuff: "otherStuff1",
|
||||
},
|
||||
authenticatedAccounts: ["user-1", "user-2", "user-3"],
|
||||
"user-1": {
|
||||
keys: {
|
||||
otherStuff: "overStuff2",
|
||||
},
|
||||
otherStuff: "otherStuff3",
|
||||
},
|
||||
"user-2": {
|
||||
keys: {
|
||||
otherStuff: "otherStuff4",
|
||||
},
|
||||
otherStuff: "otherStuff5",
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
describe("OrganizationKeysMigrator", () => {
|
||||
let helper: MockProxy<MigrationHelper>;
|
||||
let sut: OrganizationKeyMigrator;
|
||||
const keyDefinitionLike = {
|
||||
key: "organizationKeys",
|
||||
stateDefinition: {
|
||||
name: "crypto",
|
||||
},
|
||||
};
|
||||
|
||||
describe("migrate", () => {
|
||||
beforeEach(() => {
|
||||
helper = mockMigrationHelper(exampleJSON(), 10);
|
||||
sut = new OrganizationKeyMigrator(10, 11);
|
||||
});
|
||||
|
||||
it("should remove organizationKeys from all accounts", async () => {
|
||||
await sut.migrate(helper);
|
||||
expect(helper.set).toHaveBeenCalledTimes(1);
|
||||
expect(helper.set).toHaveBeenCalledWith("user-1", {
|
||||
keys: {
|
||||
otherStuff: "overStuff2",
|
||||
},
|
||||
otherStuff: "otherStuff3",
|
||||
});
|
||||
});
|
||||
|
||||
it("should set organizationKeys value for each account", async () => {
|
||||
await sut.migrate(helper);
|
||||
|
||||
expect(helper.setToUser).toHaveBeenCalledTimes(1);
|
||||
expect(helper.setToUser).toHaveBeenCalledWith("user-1", keyDefinitionLike, {
|
||||
"org-id-1": {
|
||||
type: "organization",
|
||||
key: "org-key-1",
|
||||
},
|
||||
"org-id-2": {
|
||||
type: "provider",
|
||||
key: "org-key-2",
|
||||
providerId: "provider-id-2",
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("rollback", () => {
|
||||
beforeEach(() => {
|
||||
helper = mockMigrationHelper(rollbackJSON(), 11);
|
||||
sut = new OrganizationKeyMigrator(10, 11);
|
||||
});
|
||||
|
||||
it.each(["user-1", "user-2", "user-3"])("should null out new values %s", async (userId) => {
|
||||
await sut.rollback(helper);
|
||||
|
||||
expect(helper.setToUser).toHaveBeenCalledWith(userId, keyDefinitionLike, null);
|
||||
});
|
||||
|
||||
it("should add explicit value back to accounts", async () => {
|
||||
await sut.rollback(helper);
|
||||
|
||||
expect(helper.set).toHaveBeenCalledTimes(1);
|
||||
expect(helper.set).toHaveBeenCalledWith("user-1", {
|
||||
keys: {
|
||||
organizationKeys: {
|
||||
encrypted: {
|
||||
"org-id-1": {
|
||||
type: "organization",
|
||||
key: "org-key-1",
|
||||
},
|
||||
"org-id-2": {
|
||||
type: "provider",
|
||||
key: "org-key-2",
|
||||
providerId: "provider-id-2",
|
||||
},
|
||||
},
|
||||
},
|
||||
otherStuff: "overStuff2",
|
||||
},
|
||||
otherStuff: "otherStuff3",
|
||||
});
|
||||
});
|
||||
|
||||
it("should not try to restore values to missing accounts", async () => {
|
||||
await sut.rollback(helper);
|
||||
|
||||
expect(helper.set).not.toHaveBeenCalledWith("user-3", any());
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,61 +0,0 @@
|
||||
// FIXME: Update this file to be type safe and remove this and next line
|
||||
// @ts-strict-ignore
|
||||
import { KeyDefinitionLike, MigrationHelper } from "../migration-helper";
|
||||
import { Migrator } from "../migrator";
|
||||
|
||||
type OrgKeyDataType = {
|
||||
type: "organization" | "provider";
|
||||
key: string;
|
||||
providerId?: string;
|
||||
};
|
||||
|
||||
type ExpectedAccountType = {
|
||||
keys?: {
|
||||
organizationKeys?: {
|
||||
encrypted?: Record<string, OrgKeyDataType>;
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
const USER_ENCRYPTED_ORGANIZATION_KEYS: KeyDefinitionLike = {
|
||||
key: "organizationKeys",
|
||||
stateDefinition: {
|
||||
name: "crypto",
|
||||
},
|
||||
};
|
||||
|
||||
export class OrganizationKeyMigrator extends Migrator<10, 11> {
|
||||
async migrate(helper: MigrationHelper): Promise<void> {
|
||||
const accounts = await helper.getAccounts<ExpectedAccountType>();
|
||||
async function migrateAccount(userId: string, account: ExpectedAccountType): Promise<void> {
|
||||
const value = account?.keys?.organizationKeys?.encrypted;
|
||||
if (value != null) {
|
||||
await helper.setToUser(userId, USER_ENCRYPTED_ORGANIZATION_KEYS, value);
|
||||
delete account.keys.organizationKeys;
|
||||
await helper.set(userId, account);
|
||||
}
|
||||
}
|
||||
|
||||
await Promise.all([...accounts.map(({ userId, account }) => migrateAccount(userId, account))]);
|
||||
}
|
||||
async rollback(helper: MigrationHelper): Promise<void> {
|
||||
const accounts = await helper.getAccounts<ExpectedAccountType>();
|
||||
async function rollbackAccount(userId: string, account: ExpectedAccountType): Promise<void> {
|
||||
const value = await helper.getFromUser<Record<string, OrgKeyDataType>>(
|
||||
userId,
|
||||
USER_ENCRYPTED_ORGANIZATION_KEYS,
|
||||
);
|
||||
if (account && value) {
|
||||
account.keys = Object.assign(account.keys ?? {}, {
|
||||
organizationKeys: {
|
||||
encrypted: value,
|
||||
},
|
||||
});
|
||||
await helper.set(userId, account);
|
||||
}
|
||||
await helper.setToUser(userId, USER_ENCRYPTED_ORGANIZATION_KEYS, null);
|
||||
}
|
||||
|
||||
await Promise.all([...accounts.map(({ userId, account }) => rollbackAccount(userId, account))]);
|
||||
}
|
||||
}
|
||||
@@ -1,157 +0,0 @@
|
||||
import { runMigrator } from "../migration-helper.spec";
|
||||
|
||||
import { MoveEnvironmentStateToProviders } from "./12-move-environment-state-to-providers";
|
||||
|
||||
describe("MoveEnvironmentStateToProviders", () => {
|
||||
const migrator = new MoveEnvironmentStateToProviders(11, 12);
|
||||
|
||||
it("can migrate all data", async () => {
|
||||
const output = await runMigrator(migrator, {
|
||||
authenticatedAccounts: ["user1", "user2"] as const,
|
||||
global: {
|
||||
region: "US",
|
||||
environmentUrls: {
|
||||
base: "example.com",
|
||||
},
|
||||
extra: "data",
|
||||
},
|
||||
user1: {
|
||||
extra: "data",
|
||||
settings: {
|
||||
extra: "data",
|
||||
region: "US",
|
||||
environmentUrls: {
|
||||
base: "example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
user2: {
|
||||
extra: "data",
|
||||
settings: {
|
||||
region: "EU",
|
||||
environmentUrls: {
|
||||
base: "other.example.com",
|
||||
},
|
||||
extra: "data",
|
||||
},
|
||||
},
|
||||
extra: "data",
|
||||
});
|
||||
|
||||
expect(output).toEqual({
|
||||
authenticatedAccounts: ["user1", "user2"],
|
||||
global: {
|
||||
extra: "data",
|
||||
},
|
||||
global_environment_region: "US",
|
||||
global_environment_urls: {
|
||||
base: "example.com",
|
||||
},
|
||||
user1: {
|
||||
extra: "data",
|
||||
settings: {
|
||||
extra: "data",
|
||||
},
|
||||
},
|
||||
user2: {
|
||||
extra: "data",
|
||||
settings: {
|
||||
extra: "data",
|
||||
},
|
||||
},
|
||||
extra: "data",
|
||||
user_user1_environment_region: "US",
|
||||
user_user2_environment_region: "EU",
|
||||
user_user1_environment_urls: {
|
||||
base: "example.com",
|
||||
},
|
||||
user_user2_environment_urls: {
|
||||
base: "other.example.com",
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("handles missing parts", async () => {
|
||||
const output = await runMigrator(migrator, {
|
||||
authenticatedAccounts: ["user1", "user2"],
|
||||
global: {
|
||||
extra: "data",
|
||||
},
|
||||
user1: {
|
||||
extra: "data",
|
||||
settings: {
|
||||
extra: "data",
|
||||
},
|
||||
},
|
||||
user2: null,
|
||||
});
|
||||
|
||||
expect(output).toEqual({
|
||||
authenticatedAccounts: ["user1", "user2"],
|
||||
global: {
|
||||
extra: "data",
|
||||
},
|
||||
user1: {
|
||||
extra: "data",
|
||||
settings: {
|
||||
extra: "data",
|
||||
},
|
||||
},
|
||||
user2: null,
|
||||
});
|
||||
});
|
||||
|
||||
it("can migrate only global data", async () => {
|
||||
const output = await runMigrator(migrator, {
|
||||
authenticatedAccounts: [] as const,
|
||||
global: {
|
||||
region: "Self-Hosted",
|
||||
},
|
||||
});
|
||||
|
||||
expect(output).toEqual({
|
||||
authenticatedAccounts: [],
|
||||
global_environment_region: "Self-Hosted",
|
||||
global: {},
|
||||
});
|
||||
});
|
||||
|
||||
it("can migrate only user state", async () => {
|
||||
const output = await runMigrator(migrator, {
|
||||
authenticatedAccounts: ["user1"] as const,
|
||||
global: null,
|
||||
user1: {
|
||||
settings: {
|
||||
region: "Self-Hosted",
|
||||
environmentUrls: {
|
||||
base: "some-base-url",
|
||||
api: "some-api-url",
|
||||
identity: "some-identity-url",
|
||||
icons: "some-icons-url",
|
||||
notifications: "some-notifications-url",
|
||||
events: "some-events-url",
|
||||
webVault: "some-webVault-url",
|
||||
keyConnector: "some-keyConnector-url",
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
expect(output).toEqual({
|
||||
authenticatedAccounts: ["user1"] as const,
|
||||
global: null,
|
||||
user1: { settings: {} },
|
||||
user_user1_environment_region: "Self-Hosted",
|
||||
user_user1_environment_urls: {
|
||||
base: "some-base-url",
|
||||
api: "some-api-url",
|
||||
identity: "some-identity-url",
|
||||
icons: "some-icons-url",
|
||||
notifications: "some-notifications-url",
|
||||
events: "some-events-url",
|
||||
webVault: "some-webVault-url",
|
||||
keyConnector: "some-keyConnector-url",
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user