1
0
mirror of https://github.com/bitwarden/browser synced 2025-12-15 07:43:35 +00:00

Merge branch 'main' into task/DEVOPS-1683

This commit is contained in:
Alex Urbina
2023-12-13 14:07:46 -06:00
61 changed files with 2924 additions and 390 deletions

View File

@@ -71,7 +71,7 @@ jobs:
uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 # v4.1.0
with:
repository: bitwarden/clients
ref: master
ref: main
token: ${{ steps.retrieve-secrets.outputs.github-pat-bitwarden-devops-bot-repo-scope }}
- name: Import GPG key

View File

@@ -673,7 +673,7 @@ class OverlayBackground implements OverlayBackgroundInterface {
*/
private setupExtensionMessageListeners() {
BrowserApi.messageListener("overlay.background", this.handleExtensionMessage);
chrome.runtime.onConnect.addListener(this.handlePortOnConnect);
BrowserApi.addListener(chrome.runtime.onConnect, this.handlePortOnConnect);
}
/**

View File

@@ -10,6 +10,10 @@ import {
settingsServiceFactory,
SettingsServiceInitOptions,
} from "../../../background/service-factories/settings-service.factory";
import {
configServiceFactory,
ConfigServiceInitOptions,
} from "../../../platform/background/service-factories/config-service.factory";
import {
CachedServices,
factory,
@@ -43,7 +47,8 @@ export type AutoFillServiceInitOptions = AutoFillServiceOptions &
EventCollectionServiceInitOptions &
LogServiceInitOptions &
SettingsServiceInitOptions &
UserVerificationServiceInitOptions;
UserVerificationServiceInitOptions &
ConfigServiceInitOptions;
export function autofillServiceFactory(
cache: { autofillService?: AbstractAutoFillService } & CachedServices,
@@ -62,6 +67,7 @@ export function autofillServiceFactory(
await logServiceFactory(cache, opts),
await settingsServiceFactory(cache, opts),
await userVerificationServiceFactory(cache, opts),
await configServiceFactory(cache, opts),
),
);
}

View File

@@ -15,7 +15,7 @@ import OverlayBackground from "./overlay.background";
import TabsBackground from "./tabs.background";
describe("TabsBackground", () => {
let tabsBackgorund: TabsBackground;
let tabsBackground: TabsBackground;
const mainBackground = mock<MainBackground>({
messagingService: {
send: jest.fn(),
@@ -25,7 +25,7 @@ describe("TabsBackground", () => {
const overlayBackground = mock<OverlayBackground>();
beforeEach(() => {
tabsBackgorund = new TabsBackground(mainBackground, notificationBackground, overlayBackground);
tabsBackground = new TabsBackground(mainBackground, notificationBackground, overlayBackground);
});
afterEach(() => {
@@ -35,11 +35,11 @@ describe("TabsBackground", () => {
describe("init", () => {
it("sets up a window on focusChanged listener", () => {
const handleWindowOnFocusChangedSpy = jest.spyOn(
tabsBackgorund as any,
tabsBackground as any,
"handleWindowOnFocusChanged",
);
tabsBackgorund.init();
tabsBackground.init();
expect(chrome.windows.onFocusChanged.addListener).toHaveBeenCalledWith(
handleWindowOnFocusChangedSpy,
@@ -49,7 +49,7 @@ describe("TabsBackground", () => {
describe("tab event listeners", () => {
beforeEach(() => {
tabsBackgorund.init();
tabsBackground["setupTabEventListeners"]();
});
describe("window onFocusChanged event", () => {
@@ -64,7 +64,7 @@ describe("TabsBackground", () => {
triggerWindowOnFocusedChangedEvent(10);
await flushPromises();
expect(tabsBackgorund["focusedWindowId"]).toBe(10);
expect(tabsBackground["focusedWindowId"]).toBe(10);
});
it("updates the current tab data", async () => {
@@ -144,7 +144,7 @@ describe("TabsBackground", () => {
beforeEach(() => {
mainBackground.onUpdatedRan = false;
tabsBackgorund["focusedWindowId"] = focusedWindowId;
tabsBackground["focusedWindowId"] = focusedWindowId;
tab = mock<chrome.tabs.Tab>({
windowId: focusedWindowId,
active: true,

View File

@@ -20,6 +20,14 @@ export default class TabsBackground {
return;
}
this.updateCurrentTabData();
this.setupTabEventListeners();
}
/**
* Sets up the tab and window event listeners.
*/
private setupTabEventListeners() {
chrome.windows.onFocusChanged.addListener(this.handleWindowOnFocusChanged);
chrome.tabs.onActivated.addListener(this.handleTabOnActivated);
chrome.tabs.onReplaced.addListener(this.handleTabOnReplaced);
@@ -33,7 +41,7 @@ export default class TabsBackground {
* @param windowId - The ID of the window that was focused.
*/
private handleWindowOnFocusChanged = async (windowId: number) => {
if (!windowId) {
if (windowId == null || windowId < 0) {
return;
}
@@ -116,8 +124,10 @@ export default class TabsBackground {
* for the current tab. Also updates the overlay ciphers.
*/
private updateCurrentTabData = async () => {
await this.main.refreshBadge();
await this.main.refreshMenu();
await this.overlayBackground.updateOverlayCiphers();
await Promise.all([
this.main.refreshBadge(),
this.main.refreshMenu(),
this.overlayBackground.updateOverlayCiphers(),
]);
};
}

View File

@@ -17,6 +17,7 @@ type AutofillExtensionMessage = {
direction?: "previous" | "next";
isOpeningFullOverlay?: boolean;
forceCloseOverlay?: boolean;
autofillOverlayVisibility?: number;
};
};
@@ -34,10 +35,12 @@ type AutofillExtensionMessageHandlers = {
updateIsOverlayCiphersPopulated: ({ message }: AutofillExtensionMessageParam) => void;
bgUnlockPopoutOpened: () => void;
bgVaultItemRepromptPopoutOpened: () => void;
updateAutofillOverlayVisibility: ({ message }: AutofillExtensionMessageParam) => void;
};
interface AutofillInit {
init(): void;
destroy(): void;
}
export { AutofillExtensionMessage, AutofillExtensionMessageHandlers, AutofillInit };

View File

@@ -6,7 +6,7 @@ import { flushPromises, sendExtensionRuntimeMessage } from "../jest/testing-util
import AutofillPageDetails from "../models/autofill-page-details";
import AutofillScript from "../models/autofill-script";
import AutofillOverlayContentService from "../services/autofill-overlay-content.service";
import { RedirectFocusDirection } from "../utils/autofill-overlay.enum";
import { AutofillOverlayVisibility, RedirectFocusDirection } from "../utils/autofill-overlay.enum";
import { AutofillExtensionMessage } from "./abstractions/autofill-init";
import AutofillInit from "./autofill-init";
@@ -16,6 +16,11 @@ describe("AutofillInit", () => {
const autofillOverlayContentService = mock<AutofillOverlayContentService>();
beforeEach(() => {
chrome.runtime.connect = jest.fn().mockReturnValue({
onDisconnect: {
addListener: jest.fn(),
},
});
autofillInit = new AutofillInit(autofillOverlayContentService);
});
@@ -477,6 +482,57 @@ describe("AutofillInit", () => {
expect(autofillInit["removeAutofillOverlay"]).toHaveBeenCalled();
});
});
describe("updateAutofillOverlayVisibility", () => {
beforeEach(() => {
autofillInit["autofillOverlayContentService"].autofillOverlayVisibility =
AutofillOverlayVisibility.OnButtonClick;
});
it("skips attempting to update the overlay visibility if the autofillOverlayVisibility data value is not present", () => {
sendExtensionRuntimeMessage({
command: "updateAutofillOverlayVisibility",
data: {},
});
expect(autofillInit["autofillOverlayContentService"].autofillOverlayVisibility).toEqual(
AutofillOverlayVisibility.OnButtonClick,
);
});
it("updates the overlay visibility value", () => {
const message = {
command: "updateAutofillOverlayVisibility",
data: {
autofillOverlayVisibility: AutofillOverlayVisibility.Off,
},
};
sendExtensionRuntimeMessage(message);
expect(autofillInit["autofillOverlayContentService"].autofillOverlayVisibility).toEqual(
message.data.autofillOverlayVisibility,
);
});
});
});
});
describe("destroy", () => {
it("removes the extension message listeners", () => {
autofillInit.destroy();
expect(chrome.runtime.onMessage.removeListener).toHaveBeenCalledWith(
autofillInit["handleExtensionMessage"],
);
});
it("destroys the collectAutofillContentService", () => {
jest.spyOn(autofillInit["collectAutofillContentService"], "destroy");
autofillInit.destroy();
expect(autofillInit["collectAutofillContentService"].destroy).toHaveBeenCalled();
});
});
});

View File

@@ -26,6 +26,7 @@ class AutofillInit implements AutofillInitInterface {
updateIsOverlayCiphersPopulated: ({ message }) => this.updateIsOverlayCiphersPopulated(message),
bgUnlockPopoutOpened: () => this.blurAndRemoveOverlay(),
bgVaultItemRepromptPopoutOpened: () => this.blurAndRemoveOverlay(),
updateAutofillOverlayVisibility: ({ message }) => this.updateAutofillOverlayVisibility(message),
};
/**
@@ -214,6 +215,19 @@ class AutofillInit implements AutofillInitInterface {
);
}
/**
* Updates the autofill overlay visibility.
*
* @param data - Contains the autoFillOverlayVisibility value
*/
private updateAutofillOverlayVisibility({ data }: AutofillExtensionMessage) {
if (!this.autofillOverlayContentService || isNaN(data?.autofillOverlayVisibility)) {
return;
}
this.autofillOverlayContentService.autofillOverlayVisibility = data?.autofillOverlayVisibility;
}
/**
* Sets up the extension message listeners for the content script.
*/
@@ -247,6 +261,16 @@ class AutofillInit implements AutofillInitInterface {
Promise.resolve(messageResponse).then((response) => sendResponse(response));
return true;
};
/**
* Handles destroying the autofill init content script. Removes all
* listeners, timeouts, and object instances to prevent memory leaks.
*/
destroy() {
chrome.runtime.onMessage.removeListener(this.handleExtensionMessage);
this.collectAutofillContentService.destroy();
this.autofillOverlayContentService?.destroy();
}
}
export default AutofillInit;

View File

@@ -1,3 +1,5 @@
import { getFromLocalStorage, setupExtensionDisconnectAction } from "../utils";
if (document.readyState === "loading") {
document.addEventListener("DOMContentLoaded", loadAutofiller);
} else {
@@ -8,27 +10,30 @@ function loadAutofiller() {
let pageHref: string = null;
let filledThisHref = false;
let delayFillTimeout: number;
const activeUserIdKey = "activeUserId";
let activeUserId: string;
chrome.storage.local.get(activeUserIdKey, (obj: any) => {
if (obj == null || obj[activeUserIdKey] == null) {
return;
}
activeUserId = obj[activeUserIdKey];
});
chrome.storage.local.get(activeUserId, (obj: any) => {
if (obj?.[activeUserId]?.settings?.enableAutoFillOnPageLoad === true) {
setInterval(() => doFillIfNeeded(), 500);
}
});
chrome.runtime.onMessage.addListener((msg, sender, sendResponse) => {
if (msg.command === "fillForm" && pageHref === msg.url) {
let doFillInterval: NodeJS.Timeout;
const handleExtensionDisconnect = () => {
clearDoFillInterval();
clearDelayFillTimeout();
};
const handleExtensionMessage = (message: any) => {
if (message.command === "fillForm" && pageHref === message.url) {
filledThisHref = true;
}
});
};
setupExtensionEventListeners();
triggerUserFillOnLoad();
async function triggerUserFillOnLoad() {
const activeUserIdKey = "activeUserId";
const userKeyStorage = await getFromLocalStorage(activeUserIdKey);
const activeUserId = userKeyStorage[activeUserIdKey];
const activeUserStorage = await getFromLocalStorage(activeUserId);
if (activeUserStorage?.[activeUserId]?.settings?.enableAutoFillOnPageLoad === true) {
clearDoFillInterval();
doFillInterval = setInterval(() => doFillIfNeeded(), 500);
}
}
function doFillIfNeeded(force = false) {
if (force || pageHref !== window.location.href) {
@@ -36,9 +41,7 @@ function loadAutofiller() {
// Some websites are slow and rendering all page content. Try to fill again later
// if we haven't already.
filledThisHref = false;
if (delayFillTimeout != null) {
window.clearTimeout(delayFillTimeout);
}
clearDelayFillTimeout();
delayFillTimeout = window.setTimeout(() => {
if (!filledThisHref) {
doFillIfNeeded(true);
@@ -55,4 +58,21 @@ function loadAutofiller() {
chrome.runtime.sendMessage(msg);
}
}
function clearDoFillInterval() {
if (doFillInterval) {
window.clearInterval(doFillInterval);
}
}
function clearDelayFillTimeout() {
if (delayFillTimeout) {
window.clearTimeout(delayFillTimeout);
}
}
function setupExtensionEventListeners() {
setupExtensionDisconnectAction(handleExtensionDisconnect);
chrome.runtime.onMessage.addListener(handleExtensionMessage);
}
}

View File

@@ -1,4 +1,5 @@
import AutofillOverlayContentService from "../services/autofill-overlay-content.service";
import { setupAutofillInitDisconnectAction } from "../utils";
import AutofillInit from "./autofill-init";
@@ -6,6 +7,8 @@ import AutofillInit from "./autofill-init";
if (!windowContext.bitwardenAutofillInit) {
const autofillOverlayContentService = new AutofillOverlayContentService();
windowContext.bitwardenAutofillInit = new AutofillInit(autofillOverlayContentService);
setupAutofillInitDisconnectAction(windowContext);
windowContext.bitwardenAutofillInit.init();
}
})(window);

View File

@@ -1,8 +1,12 @@
import { setupAutofillInitDisconnectAction } from "../utils";
import AutofillInit from "./autofill-init";
(function (windowContext) {
if (!windowContext.bitwardenAutofillInit) {
windowContext.bitwardenAutofillInit = new AutofillInit();
setupAutofillInitDisconnectAction(windowContext);
windowContext.bitwardenAutofillInit.init();
}
})(window);

View File

@@ -1,6 +1,19 @@
window.addEventListener(
"message",
(event) => {
import { setupExtensionDisconnectAction } from "../utils";
const forwardCommands = [
"bgUnlockPopoutOpened",
"addToLockedVaultPendingNotifications",
"unlockCompleted",
"addedCipher",
];
/**
* Handles sending extension messages to the background
* script based on window messages from the page.
*
* @param event - Window message event
*/
const handleWindowMessage = (event: MessageEvent) => {
if (event.source !== window) {
return;
}
@@ -23,19 +36,30 @@ window.addEventListener(
referrer: event.source.location.hostname,
});
}
},
false,
);
};
const forwardCommands = [
"bgUnlockPopoutOpened",
"addToLockedVaultPendingNotifications",
"unlockCompleted",
"addedCipher",
];
chrome.runtime.onMessage.addListener((event) => {
if (forwardCommands.includes(event.command)) {
chrome.runtime.sendMessage(event);
/**
* Handles forwarding any commands that need to trigger
* an action from one service of the extension background
* to another.
*
* @param message - Message from the extension
*/
const handleExtensionMessage = (message: any) => {
if (forwardCommands.includes(message.command)) {
chrome.runtime.sendMessage(message);
}
});
};
/**
* Handles cleaning up any event listeners that were
* added to the window or extension.
*/
const handleExtensionDisconnect = () => {
window.removeEventListener("message", handleWindowMessage);
chrome.runtime.onMessage.removeListener(handleExtensionMessage);
};
window.addEventListener("message", handleWindowMessage, false);
chrome.runtime.onMessage.addListener(handleExtensionMessage);
setupExtensionDisconnectAction(handleExtensionDisconnect);

View File

@@ -4,6 +4,7 @@ import AddLoginRuntimeMessage from "../notification/models/add-login-runtime-mes
import ChangePasswordRuntimeMessage from "../notification/models/change-password-runtime-message";
import { FormData } from "../services/abstractions/autofill.service";
import { GlobalSettings, UserSettings } from "../types";
import { getFromLocalStorage, setupExtensionDisconnectAction } from "../utils";
interface HTMLElementWithFormOpId extends HTMLElement {
formOpId: string;
@@ -122,6 +123,8 @@ async function loadNotificationBar() {
}
}
setupExtensionDisconnectAction(handleExtensionDisconnection);
if (!showNotificationBar) {
return;
}
@@ -999,11 +1002,23 @@ async function loadNotificationBar() {
return theEl === document;
}
function handleExtensionDisconnection(port: chrome.runtime.Port) {
closeBar(false);
clearTimeout(domObservationCollectTimeoutId);
clearTimeout(collectPageDetailsTimeoutId);
clearTimeout(handlePageChangeTimeoutId);
observer?.disconnect();
observer = null;
watchedForms.forEach((wf: WatchedForm) => {
const form = wf.formEl;
form.removeEventListener("submit", formSubmitted, false);
const submitButton = getSubmitButton(
form,
unionSets(logInButtonNames, changePasswordButtonNames),
);
submitButton?.removeEventListener("click", formSubmitted, false);
});
}
// End Helper Functions
}
async function getFromLocalStorage(keys: string | string[]): Promise<Record<string, any>> {
return new Promise((resolve) => {
chrome.storage.local.get(keys, (storage: Record<string, any>) => resolve(storage));
});
}

View File

@@ -0,0 +1,5 @@
const AutofillPort = {
InjectedScript: "autofill-injected-script-port",
} as const;
export { AutofillPort };

View File

@@ -1,5 +1,5 @@
import { EVENTS } from "../../constants";
import { setElementStyles } from "../../utils/utils";
import { setElementStyles } from "../../utils";
import {
BackgroundPortMessageHandlers,
AutofillOverlayIframeService as AutofillOverlayIframeServiceInterface,
@@ -166,9 +166,10 @@ class AutofillOverlayIframeService implements AutofillOverlayIframeServiceInterf
this.updateElementStyles(this.iframe, { opacity: "0", height: "0px", display: "block" });
globalThis.removeEventListener("message", this.handleWindowMessage);
this.port.onMessage.removeListener(this.handlePortMessage);
this.port.onDisconnect.removeListener(this.handlePortDisconnect);
this.port.disconnect();
this.unobserveIframe();
this.port?.onMessage.removeListener(this.handlePortMessage);
this.port?.onDisconnect.removeListener(this.handlePortDisconnect);
this.port?.disconnect();
this.port = null;
};
@@ -369,7 +370,7 @@ class AutofillOverlayIframeService implements AutofillOverlayIframeServiceInterf
* Unobserves the iframe element for mutations to its style attribute.
*/
private unobserveIframe() {
this.iframeMutationObserver.disconnect();
this.iframeMutationObserver?.disconnect();
}
/**

View File

@@ -3,8 +3,8 @@ import "lit/polyfill-support.js";
import { AuthenticationStatus } from "@bitwarden/common/auth/enums/authentication-status";
import { EVENTS } from "../../../constants";
import { buildSvgDomElement } from "../../../utils";
import { logoIcon, logoLockedIcon } from "../../../utils/svg-icons";
import { buildSvgDomElement } from "../../../utils/utils";
import {
InitAutofillOverlayButtonMessage,
OverlayButtonWindowMessageHandlers,

View File

@@ -4,8 +4,8 @@ import { AuthenticationStatus } from "@bitwarden/common/auth/enums/authenticatio
import { OverlayCipherData } from "../../../background/abstractions/overlay.background";
import { EVENTS } from "../../../constants";
import { buildSvgDomElement } from "../../../utils";
import { globeIcon, lockIcon, plusIcon, viewCipherIcon } from "../../../utils/svg-icons";
import { buildSvgDomElement } from "../../../utils/utils";
import {
InitAutofillOverlayListMessage,
OverlayListWindowMessageHandlers,

View File

@@ -10,6 +10,7 @@ import { UriMatchType } from "@bitwarden/common/vault/enums";
import { BrowserApi } from "../../../platform/browser/browser-api";
import { flagEnabled } from "../../../platform/flags";
import { AutofillService } from "../../services/abstractions/autofill.service";
import { AutofillOverlayVisibility } from "../../utils/autofill-overlay.enum";
@Component({
@@ -35,6 +36,7 @@ export class AutofillComponent implements OnInit {
private platformUtilsService: PlatformUtilsService,
private configService: ConfigServiceAbstraction,
private settingsService: SettingsService,
private autofillService: AutofillService,
) {
this.autoFillOverlayVisibilityOptions = [
{
@@ -86,7 +88,10 @@ export class AutofillComponent implements OnInit {
}
async updateAutoFillOverlayVisibility() {
const previousAutoFillOverlayVisibility =
await this.settingsService.getAutoFillOverlayVisibility();
await this.settingsService.setAutoFillOverlayVisibility(this.autoFillOverlayVisibility);
await this.handleUpdatingAutofillOverlayContentScripts(previousAutoFillOverlayVisibility);
}
async updateAutoFillOnPageLoad() {
@@ -144,4 +149,25 @@ export class AutofillComponent implements OnInit {
event.preventDefault();
BrowserApi.createNewTab(this.disablePasswordManagerLink);
}
private async handleUpdatingAutofillOverlayContentScripts(
previousAutoFillOverlayVisibility: number,
) {
const autofillOverlayPreviouslyDisabled =
previousAutoFillOverlayVisibility === AutofillOverlayVisibility.Off;
const autofillOverlayCurrentlyDisabled =
this.autoFillOverlayVisibility === AutofillOverlayVisibility.Off;
if (!autofillOverlayPreviouslyDisabled && !autofillOverlayCurrentlyDisabled) {
const tabs = await BrowserApi.tabsQuery({});
tabs.forEach((tab) =>
BrowserApi.tabSendMessageData(tab, "updateAutofillOverlayVisibility", {
autofillOverlayVisibility: this.autoFillOverlayVisibility,
}),
);
return;
}
await this.autofillService.reloadAutofillScripts();
}
}

View File

@@ -14,6 +14,7 @@ interface AutofillOverlayContentService {
isCurrentlyFilling: boolean;
isOverlayCiphersPopulated: boolean;
pageDetailsUpdateRequired: boolean;
autofillOverlayVisibility: number;
init(): void;
setupAutofillOverlayListenerOnField(
autofillFieldElement: ElementWithOpId<FormFieldElement>,
@@ -27,6 +28,7 @@ interface AutofillOverlayContentService {
redirectOverlayFocusOut(direction: "previous" | "next"): void;
focusMostRecentOverlayField(): void;
blurMostRecentOverlayField(): void;
destroy(): void;
}
export { OpenAutofillOverlayOptions, AutofillOverlayContentService };

View File

@@ -44,10 +44,12 @@ export interface GenerateFillScriptOptions {
}
export abstract class AutofillService {
loadAutofillScriptsOnInstall: () => Promise<void>;
reloadAutofillScripts: () => Promise<void>;
injectAutofillScripts: (
sender: chrome.runtime.MessageSender,
autofillV2?: boolean,
autofillOverlay?: boolean,
tab: chrome.tabs.Tab,
frameId?: number,
triggeringOnPageLoad?: boolean,
) => Promise<void>;
getFormsWithPasswordFields: (pageDetails: AutofillPageDetails) => FormData[];
doAutoFill: (options: AutoFillOptions) => Promise<string | null>;

View File

@@ -22,6 +22,7 @@ interface CollectAutofillContentService {
filterCallback: CallableFunction,
isObservingShadowRoot?: boolean,
): Node[];
destroy(): void;
}
export {

View File

@@ -1609,4 +1609,100 @@ describe("AutofillOverlayContentService", () => {
expect(autofillOverlayContentService["removeAutofillOverlay"]).toHaveBeenCalled();
});
});
describe("destroy", () => {
let autofillFieldElement: ElementWithOpId<FormFieldElement>;
let autofillFieldData: AutofillField;
beforeEach(() => {
document.body.innerHTML = `
<form id="validFormId">
<input type="text" id="username-field" placeholder="username" />
<input type="password" id="password-field" placeholder="password" />
</form>
`;
autofillFieldElement = document.getElementById(
"username-field",
) as ElementWithOpId<FormFieldElement>;
autofillFieldElement.opid = "op-1";
autofillFieldData = createAutofillFieldMock({
opid: "username-field",
form: "validFormId",
placeholder: "username",
elementNumber: 1,
});
autofillOverlayContentService.setupAutofillOverlayListenerOnField(
autofillFieldElement,
autofillFieldData,
);
autofillOverlayContentService["mostRecentlyFocusedField"] = autofillFieldElement;
});
it("disconnects all mutation observers", () => {
autofillOverlayContentService["setupMutationObserver"]();
jest.spyOn(autofillOverlayContentService["bodyElementMutationObserver"], "disconnect");
jest.spyOn(autofillOverlayContentService["documentElementMutationObserver"], "disconnect");
autofillOverlayContentService.destroy();
expect(
autofillOverlayContentService["documentElementMutationObserver"].disconnect,
).toHaveBeenCalled();
expect(
autofillOverlayContentService["bodyElementMutationObserver"].disconnect,
).toHaveBeenCalled();
});
it("clears the user interaction event timeout", () => {
jest.spyOn(autofillOverlayContentService as any, "clearUserInteractionEventTimeout");
autofillOverlayContentService.destroy();
expect(autofillOverlayContentService["clearUserInteractionEventTimeout"]).toHaveBeenCalled();
});
it("de-registers all global event listeners", () => {
jest.spyOn(globalThis.document, "removeEventListener");
jest.spyOn(globalThis, "removeEventListener");
jest.spyOn(autofillOverlayContentService as any, "removeOverlayRepositionEventListeners");
autofillOverlayContentService.destroy();
expect(globalThis.document.removeEventListener).toHaveBeenCalledWith(
EVENTS.VISIBILITYCHANGE,
autofillOverlayContentService["handleVisibilityChangeEvent"],
);
expect(globalThis.removeEventListener).toHaveBeenCalledWith(
EVENTS.FOCUSOUT,
autofillOverlayContentService["handleFormFieldBlurEvent"],
);
expect(
autofillOverlayContentService["removeOverlayRepositionEventListeners"],
).toHaveBeenCalled();
});
it("de-registers any event listeners that are attached to the form field elements", () => {
jest.spyOn(autofillOverlayContentService as any, "removeCachedFormFieldEventListeners");
jest.spyOn(autofillFieldElement, "removeEventListener");
jest.spyOn(autofillOverlayContentService["formFieldElements"], "delete");
autofillOverlayContentService.destroy();
expect(
autofillOverlayContentService["removeCachedFormFieldEventListeners"],
).toHaveBeenCalledWith(autofillFieldElement);
expect(autofillFieldElement.removeEventListener).toHaveBeenCalledWith(
EVENTS.BLUR,
autofillOverlayContentService["handleFormFieldBlurEvent"],
);
expect(autofillFieldElement.removeEventListener).toHaveBeenCalledWith(
EVENTS.KEYUP,
autofillOverlayContentService["handleFormFieldKeyupEvent"],
);
expect(autofillOverlayContentService["formFieldElements"].delete).toHaveBeenCalledWith(
autofillFieldElement,
);
});
});
});

View File

@@ -10,16 +10,12 @@ import AutofillField from "../models/autofill-field";
import AutofillOverlayButtonIframe from "../overlay/iframe-content/autofill-overlay-button-iframe";
import AutofillOverlayListIframe from "../overlay/iframe-content/autofill-overlay-list-iframe";
import { ElementWithOpId, FillableFormFieldElement, FormFieldElement } from "../types";
import { generateRandomCustomElementName, sendExtensionMessage, setElementStyles } from "../utils";
import {
AutofillOverlayElement,
RedirectFocusDirection,
AutofillOverlayVisibility,
} from "../utils/autofill-overlay.enum";
import {
generateRandomCustomElementName,
sendExtensionMessage,
setElementStyles,
} from "../utils/utils";
import {
AutofillOverlayContentService as AutofillOverlayContentServiceInterface,
@@ -32,9 +28,10 @@ class AutofillOverlayContentService implements AutofillOverlayContentServiceInte
isCurrentlyFilling = false;
isOverlayCiphersPopulated = false;
pageDetailsUpdateRequired = false;
autofillOverlayVisibility: number;
private readonly findTabs = tabbable;
private readonly sendExtensionMessage = sendExtensionMessage;
private autofillOverlayVisibility: number;
private formFieldElements: Set<ElementWithOpId<FormFieldElement>> = new Set([]);
private userFilledFields: Record<string, FillableFormFieldElement> = {};
private authStatus: AuthenticationStatus;
private focusableElements: FocusableElement[] = [];
@@ -47,6 +44,7 @@ class AutofillOverlayContentService implements AutofillOverlayContentServiceInte
private userInteractionEventTimeout: NodeJS.Timeout;
private overlayElementsMutationObserver: MutationObserver;
private bodyElementMutationObserver: MutationObserver;
private documentElementMutationObserver: MutationObserver;
private mutationObserverIterations = 0;
private mutationObserverIterationsResetTimeout: NodeJS.Timeout;
private autofillFieldKeywordsMap: WeakMap<AutofillField, string> = new WeakMap();
@@ -86,6 +84,8 @@ class AutofillOverlayContentService implements AutofillOverlayContentServiceInte
return;
}
this.formFieldElements.add(formFieldElement);
if (!this.autofillOverlayVisibility) {
await this.getAutofillOverlayVisibility();
}
@@ -901,10 +901,10 @@ class AutofillOverlayContentService implements AutofillOverlayContentServiceInte
this.handleBodyElementMutationObserverUpdate,
);
const documentElementMutationObserver = new MutationObserver(
this.documentElementMutationObserver = new MutationObserver(
this.handleDocumentElementMutationObserverUpdate,
);
documentElementMutationObserver.observe(globalThis.document.documentElement, {
this.documentElementMutationObserver.observe(globalThis.document.documentElement, {
childList: true,
});
};
@@ -1117,6 +1117,28 @@ class AutofillOverlayContentService implements AutofillOverlayContentServiceInte
const documentRoot = element.getRootNode() as ShadowRoot | Document;
return documentRoot?.activeElement;
}
/**
* Destroys the autofill overlay content service. This method will
* disconnect the mutation observers and remove all event listeners.
*/
destroy() {
this.documentElementMutationObserver?.disconnect();
this.clearUserInteractionEventTimeout();
this.formFieldElements.forEach((formFieldElement) => {
this.removeCachedFormFieldEventListeners(formFieldElement);
formFieldElement.removeEventListener(EVENTS.BLUR, this.handleFormFieldBlurEvent);
formFieldElement.removeEventListener(EVENTS.KEYUP, this.handleFormFieldKeyupEvent);
this.formFieldElements.delete(formFieldElement);
});
globalThis.document.removeEventListener(
EVENTS.VISIBILITYCHANGE,
this.handleVisibilityChangeEvent,
);
globalThis.removeEventListener(EVENTS.FOCUSOUT, this.handleFormFieldBlurEvent);
this.removeAutofillOverlay();
this.removeOverlayRepositionEventListeners();
}
}
export default AutofillOverlayContentService;

View File

@@ -2,7 +2,9 @@ import { mock, mockReset } from "jest-mock-extended";
import { UserVerificationService } from "@bitwarden/common/auth/services/user-verification/user-verification.service";
import { EventType } from "@bitwarden/common/enums";
import { FeatureFlag } from "@bitwarden/common/enums/feature-flag.enum";
import { LogService } from "@bitwarden/common/platform/abstractions/log.service";
import { ConfigService } from "@bitwarden/common/platform/services/config/config.service";
import { EventCollectionService } from "@bitwarden/common/services/event/event-collection.service";
import { SettingsService } from "@bitwarden/common/services/settings.service";
import {
@@ -24,6 +26,7 @@ import { TotpService } from "@bitwarden/common/vault/services/totp.service";
import { BrowserApi } from "../../platform/browser/browser-api";
import { BrowserStateService } from "../../platform/services/browser-state.service";
import { AutofillPort } from "../enums/autofill-port.enums";
import {
createAutofillFieldMock,
createAutofillPageDetailsMock,
@@ -54,6 +57,7 @@ describe("AutofillService", () => {
const logService = mock<LogService>();
const settingsService = mock<SettingsService>();
const userVerificationService = mock<UserVerificationService>();
const configService = mock<ConfigService>();
beforeEach(() => {
autofillService = new AutofillService(
@@ -64,6 +68,7 @@ describe("AutofillService", () => {
logService,
settingsService,
userVerificationService,
configService,
);
});
@@ -72,6 +77,72 @@ describe("AutofillService", () => {
mockReset(cipherService);
});
describe("loadAutofillScriptsOnInstall", () => {
let tab1: chrome.tabs.Tab;
let tab2: chrome.tabs.Tab;
let tab3: chrome.tabs.Tab;
beforeEach(() => {
tab1 = createChromeTabMock({ id: 1, url: "https://some-url.com" });
tab2 = createChromeTabMock({ id: 2, url: "http://some-url.com" });
tab3 = createChromeTabMock({ id: 3, url: "chrome-extension://some-extension-route" });
jest.spyOn(BrowserApi, "tabsQuery").mockResolvedValueOnce([tab1, tab2]);
});
it("queries all browser tabs and injects the autofill scripts into them", async () => {
jest.spyOn(autofillService, "injectAutofillScripts");
await autofillService.loadAutofillScriptsOnInstall();
expect(BrowserApi.tabsQuery).toHaveBeenCalledWith({});
expect(autofillService.injectAutofillScripts).toHaveBeenCalledWith(tab1, 0, false);
expect(autofillService.injectAutofillScripts).toHaveBeenCalledWith(tab2, 0, false);
});
it("skips injecting scripts into tabs that do not have an http(s) protocol", async () => {
jest.spyOn(autofillService, "injectAutofillScripts");
await autofillService.loadAutofillScriptsOnInstall();
expect(BrowserApi.tabsQuery).toHaveBeenCalledWith({});
expect(autofillService.injectAutofillScripts).not.toHaveBeenCalledWith(tab3);
});
it("sets up an extension runtime onConnect listener", async () => {
await autofillService.loadAutofillScriptsOnInstall();
// eslint-disable-next-line no-restricted-syntax
expect(chrome.runtime.onConnect.addListener).toHaveBeenCalledWith(expect.any(Function));
});
});
describe("reloadAutofillScripts", () => {
it("disconnects and removes all autofill script ports", () => {
const port1 = mock<chrome.runtime.Port>({
disconnect: jest.fn(),
});
const port2 = mock<chrome.runtime.Port>({
disconnect: jest.fn(),
});
autofillService["autofillScriptPortsSet"] = new Set([port1, port2]);
autofillService.reloadAutofillScripts();
expect(port1.disconnect).toHaveBeenCalled();
expect(port2.disconnect).toHaveBeenCalled();
expect(autofillService["autofillScriptPortsSet"].size).toBe(0);
});
it("re-injects the autofill scripts in all tabs", () => {
autofillService["autofillScriptPortsSet"] = new Set([mock<chrome.runtime.Port>()]);
jest.spyOn(autofillService as any, "injectAutofillScriptsInAllTabs");
autofillService.reloadAutofillScripts();
expect(autofillService["injectAutofillScriptsInAllTabs"]).toHaveBeenCalled();
});
});
describe("injectAutofillScripts", () => {
const autofillV1Script = "autofill.js";
const autofillV2BootstrapScript = "bootstrap-autofill.js";
@@ -83,12 +154,12 @@ describe("AutofillService", () => {
beforeEach(() => {
tabMock = createChromeTabMock();
sender = { tab: tabMock };
sender = { tab: tabMock, frameId: 1 };
jest.spyOn(BrowserApi, "executeScriptInTab").mockImplementation();
});
it("accepts an extension message sender and injects the autofill scripts into the tab of the sender", async () => {
await autofillService.injectAutofillScripts(sender);
await autofillService.injectAutofillScripts(sender.tab, sender.frameId, true);
[autofillV1Script, ...defaultAutofillScripts].forEach((scriptName) => {
expect(BrowserApi.executeScriptInTab).toHaveBeenCalledWith(tabMock.id, {
@@ -105,7 +176,11 @@ describe("AutofillService", () => {
});
it("will inject the bootstrap-autofill script if the enableAutofillV2 flag is set", async () => {
await autofillService.injectAutofillScripts(sender, true);
jest
.spyOn(configService, "getFeatureFlag")
.mockImplementation((flag) => Promise.resolve(flag === FeatureFlag.AutofillV2));
await autofillService.injectAutofillScripts(sender.tab, sender.frameId);
expect(BrowserApi.executeScriptInTab).toHaveBeenCalledWith(tabMock.id, {
file: `content/${autofillV2BootstrapScript}`,
@@ -120,11 +195,16 @@ describe("AutofillService", () => {
});
it("will inject the bootstrap-autofill-overlay script if the enableAutofillOverlay flag is set and the user has the autofill overlay enabled", async () => {
jest
.spyOn(configService, "getFeatureFlag")
.mockImplementation((flag) =>
Promise.resolve(flag === FeatureFlag.AutofillOverlay || flag === FeatureFlag.AutofillV2),
);
jest
.spyOn(autofillService["settingsService"], "getAutoFillOverlayVisibility")
.mockResolvedValue(AutofillOverlayVisibility.OnFieldFocus);
await autofillService.injectAutofillScripts(sender, true, true);
await autofillService.injectAutofillScripts(sender.tab, sender.frameId);
expect(BrowserApi.executeScriptInTab).toHaveBeenCalledWith(tabMock.id, {
file: `content/${autofillOverlayBootstrapScript}`,
@@ -144,18 +224,25 @@ describe("AutofillService", () => {
});
it("will inject the bootstrap-autofill script if the enableAutofillOverlay flag is set but the user does not have the autofill overlay enabled", async () => {
jest
.spyOn(configService, "getFeatureFlag")
.mockImplementation((flag) =>
Promise.resolve(flag === FeatureFlag.AutofillOverlay || flag === FeatureFlag.AutofillV2),
);
jest
.spyOn(autofillService["settingsService"], "getAutoFillOverlayVisibility")
.mockResolvedValue(AutofillOverlayVisibility.Off);
await autofillService.injectAutofillScripts(sender, true, true);
await autofillService.injectAutofillScripts(sender.tab, sender.frameId);
expect(BrowserApi.executeScriptInTab).toHaveBeenCalledWith(tabMock.id, {
file: `content/${autofillV2BootstrapScript}`,
frameId: sender.frameId,
...defaultExecuteScriptOptions,
});
expect(BrowserApi.executeScriptInTab).not.toHaveBeenCalledWith(tabMock.id, {
file: `content/${autofillV1Script}`,
frameId: sender.frameId,
...defaultExecuteScriptOptions,
});
});
@@ -4436,4 +4523,58 @@ describe("AutofillService", () => {
expect(autofillService["currentlyOpeningPasswordRepromptPopout"]).toBe(false);
});
});
describe("handleInjectedScriptPortConnection", () => {
it("ignores port connections that do not have the correct port name", () => {
const port = mock<chrome.runtime.Port>({
name: "some-invalid-port-name",
onDisconnect: { addListener: jest.fn() },
}) as any;
autofillService["handleInjectedScriptPortConnection"](port);
expect(port.onDisconnect.addListener).not.toHaveBeenCalled();
expect(autofillService["autofillScriptPortsSet"].size).toBe(0);
});
it("adds the connect port to the set of injected script ports and sets up an onDisconnect listener", () => {
const port = mock<chrome.runtime.Port>({
name: AutofillPort.InjectedScript,
onDisconnect: { addListener: jest.fn() },
}) as any;
jest.spyOn(autofillService as any, "handleInjectScriptPortOnDisconnect");
autofillService["handleInjectedScriptPortConnection"](port);
expect(port.onDisconnect.addListener).toHaveBeenCalledWith(
autofillService["handleInjectScriptPortOnDisconnect"],
);
expect(autofillService["autofillScriptPortsSet"].size).toBe(1);
});
});
describe("handleInjectScriptPortOnDisconnect", () => {
it("ignores port disconnections that do not have the correct port name", () => {
autofillService["autofillScriptPortsSet"].add(mock<chrome.runtime.Port>());
autofillService["handleInjectScriptPortOnDisconnect"](
mock<chrome.runtime.Port>({
name: "some-invalid-port-name",
}),
);
expect(autofillService["autofillScriptPortsSet"].size).toBe(1);
});
it("removes the port from the set of injected script ports", () => {
const port = mock<chrome.runtime.Port>({
name: AutofillPort.InjectedScript,
}) as any;
autofillService["autofillScriptPortsSet"].add(port);
autofillService["handleInjectScriptPortOnDisconnect"](port);
expect(autofillService["autofillScriptPortsSet"].size).toBe(0);
});
});
});

View File

@@ -2,6 +2,8 @@ import { EventCollectionService } from "@bitwarden/common/abstractions/event/eve
import { SettingsService } from "@bitwarden/common/abstractions/settings.service";
import { UserVerificationService } from "@bitwarden/common/auth/abstractions/user-verification/user-verification.service.abstraction";
import { EventType } from "@bitwarden/common/enums";
import { FeatureFlag } from "@bitwarden/common/enums/feature-flag.enum";
import { ConfigServiceAbstraction } from "@bitwarden/common/platform/abstractions/config/config.service.abstraction";
import { LogService } from "@bitwarden/common/platform/abstractions/log.service";
import { CipherService } from "@bitwarden/common/vault/abstractions/cipher.service";
import { TotpService } from "@bitwarden/common/vault/abstractions/totp.service";
@@ -13,6 +15,7 @@ import { FieldView } from "@bitwarden/common/vault/models/view/field.view";
import { BrowserApi } from "../../platform/browser/browser-api";
import { BrowserStateService } from "../../platform/services/abstractions/browser-state.service";
import { openVaultItemPasswordRepromptPopout } from "../../vault/popup/utils/vault-popout-window";
import { AutofillPort } from "../enums/autofill-port.enums";
import AutofillField from "../models/autofill-field";
import AutofillPageDetails from "../models/autofill-page-details";
import AutofillScript from "../models/autofill-script";
@@ -35,6 +38,7 @@ export default class AutofillService implements AutofillServiceInterface {
private openVaultItemPasswordRepromptPopout = openVaultItemPasswordRepromptPopout;
private openPasswordRepromptPopoutDebounce: NodeJS.Timeout;
private currentlyOpeningPasswordRepromptPopout = false;
private autofillScriptPortsSet = new Set<chrome.runtime.Port>();
constructor(
private cipherService: CipherService,
@@ -44,23 +48,54 @@ export default class AutofillService implements AutofillServiceInterface {
private logService: LogService,
private settingsService: SettingsService,
private userVerificationService: UserVerificationService,
private configService: ConfigServiceAbstraction,
) {}
/**
* Triggers on installation of the extension Handles injecting
* content scripts into all tabs that are currently open, and
* sets up a listener to ensure content scripts can identify
* if the extension context has been disconnected.
*/
async loadAutofillScriptsOnInstall() {
BrowserApi.addListener(chrome.runtime.onConnect, this.handleInjectedScriptPortConnection);
this.injectAutofillScriptsInAllTabs();
}
/**
* Triggers a complete reload of all autofill scripts on tabs open within
* the user's browsing session. This is done by first disconnecting all
* existing autofill content script ports, which cleans up existing object
* instances, and then re-injecting the autofill scripts into all tabs.
*/
async reloadAutofillScripts() {
this.autofillScriptPortsSet.forEach((port) => {
port.disconnect();
this.autofillScriptPortsSet.delete(port);
});
this.injectAutofillScriptsInAllTabs();
}
/**
* Injects the autofill scripts into the current tab and all frames
* found within the tab. Temporarily, will conditionally inject
* the refactor of the core autofill script if the feature flag
* is enabled.
* @param {chrome.runtime.MessageSender} sender
* @param {boolean} autofillV2
* @param {boolean} autofillOverlay
* @returns {Promise<void>}
* @param {chrome.tabs.Tab} tab
* @param {number} frameId
* @param {boolean} triggeringOnPageLoad
*/
async injectAutofillScripts(
sender: chrome.runtime.MessageSender,
autofillV2 = false,
autofillOverlay = false,
) {
tab: chrome.tabs.Tab,
frameId = 0,
triggeringOnPageLoad = true,
): Promise<void> {
const autofillV2 = await this.configService.getFeatureFlag<boolean>(FeatureFlag.AutofillV2);
const autofillOverlay = await this.configService.getFeatureFlag<boolean>(
FeatureFlag.AutofillOverlay,
);
let mainAutofillScript = "autofill.js";
const isUsingAutofillOverlay =
@@ -73,20 +108,24 @@ export default class AutofillService implements AutofillServiceInterface {
: "bootstrap-autofill.js";
}
const injectedScripts = [
mainAutofillScript,
"autofiller.js",
"notificationBar.js",
"contextMenuHandler.js",
];
const injectedScripts = [mainAutofillScript];
if (triggeringOnPageLoad) {
injectedScripts.push("autofiller.js");
}
injectedScripts.push("notificationBar.js", "contextMenuHandler.js");
for (const injectedScript of injectedScripts) {
await BrowserApi.executeScriptInTab(sender.tab.id, {
await BrowserApi.executeScriptInTab(tab.id, {
file: `content/${injectedScript}`,
frameId: sender.frameId,
frameId,
runAt: "document_start",
});
}
await BrowserApi.executeScriptInTab(tab.id, {
file: "content/message_handler.js",
runAt: "document_start",
});
}
/**
@@ -1877,4 +1916,47 @@ export default class AutofillService implements AutofillServiceInterface {
return false;
}
/**
* Handles incoming long-lived connections from injected autofill scripts.
* Stores the port in a set to facilitate disconnecting ports if the extension
* needs to re-inject the autofill scripts.
*
* @param port - The port that was connected
*/
private handleInjectedScriptPortConnection = (port: chrome.runtime.Port) => {
if (port.name !== AutofillPort.InjectedScript) {
return;
}
this.autofillScriptPortsSet.add(port);
port.onDisconnect.addListener(this.handleInjectScriptPortOnDisconnect);
};
/**
* Handles disconnecting ports that relate to injected autofill scripts.
* @param port - The port that was disconnected
*/
private handleInjectScriptPortOnDisconnect = (port: chrome.runtime.Port) => {
if (port.name !== AutofillPort.InjectedScript) {
return;
}
this.autofillScriptPortsSet.delete(port);
};
/**
* Queries all open tabs in the user's browsing session
* and injects the autofill scripts into the page.
*/
private async injectAutofillScriptsInAllTabs() {
const tabs = await BrowserApi.tabsQuery({});
for (let index = 0; index < tabs.length; index++) {
const tab = tabs[index];
if (tab.url?.startsWith("http")) {
this.injectAutofillScripts(tab, 0, false);
}
}
}
}

View File

@@ -1249,6 +1249,17 @@ class CollectAutofillContentService implements CollectAutofillContentServiceInte
return attributeValue;
}
/**
* Destroys the CollectAutofillContentService. Clears all
* timeouts and disconnects the mutation observer.
*/
destroy() {
if (this.updateAutofillElementsAfterMutationTimeout) {
clearTimeout(this.updateAutofillElementsAfterMutationTimeout);
}
this.mutationObserver?.disconnect();
}
}
export default CollectAutofillContentService;

View File

@@ -1,10 +1,17 @@
import { AutofillPort } from "../enums/autofill-port.enums";
import { triggerPortOnDisconnectEvent } from "../jest/testing-utils";
import { logoIcon, logoLockedIcon } from "./svg-icons";
import {
buildSvgDomElement,
generateRandomCustomElementName,
sendExtensionMessage,
setElementStyles,
} from "./utils";
getFromLocalStorage,
setupExtensionDisconnectAction,
setupAutofillInitDisconnectAction,
} from "./index";
describe("buildSvgDomElement", () => {
it("returns an SVG DOM element", () => {
@@ -116,3 +123,107 @@ describe("setElementStyles", () => {
expect(testDiv.style.cssText).toEqual(expectedCSSRuleString);
});
});
describe("getFromLocalStorage", () => {
it("returns a promise with the storage object pulled from the extension storage api", async () => {
const localStorage: Record<string, any> = {
testValue: "test",
another: "another",
};
jest.spyOn(chrome.storage.local, "get").mockImplementation((keys, callback) => {
const localStorageObject: Record<string, string> = {};
if (typeof keys === "string") {
localStorageObject[keys] = localStorage[keys];
} else if (Array.isArray(keys)) {
for (const key of keys) {
localStorageObject[key] = localStorage[key];
}
}
callback(localStorageObject);
});
const returnValue = await getFromLocalStorage("testValue");
expect(chrome.storage.local.get).toHaveBeenCalled();
expect(returnValue).toEqual({ testValue: "test" });
});
});
describe("setupExtensionDisconnectAction", () => {
afterEach(() => {
jest.clearAllMocks();
});
it("connects a port to the extension background and sets up an onDisconnect listener", () => {
const onDisconnectCallback = jest.fn();
let port: chrome.runtime.Port;
jest.spyOn(chrome.runtime, "connect").mockImplementation(() => {
port = {
onDisconnect: {
addListener: onDisconnectCallback,
removeListener: jest.fn(),
},
} as unknown as chrome.runtime.Port;
return port;
});
setupExtensionDisconnectAction(onDisconnectCallback);
expect(chrome.runtime.connect).toHaveBeenCalledWith({
name: AutofillPort.InjectedScript,
});
expect(port.onDisconnect.addListener).toHaveBeenCalledWith(expect.any(Function));
});
});
describe("setupAutofillInitDisconnectAction", () => {
afterEach(() => {
jest.clearAllMocks();
});
it("skips setting up the extension disconnect action if the bitwardenAutofillInit object is not populated", () => {
const onDisconnectCallback = jest.fn();
window.bitwardenAutofillInit = undefined;
const portConnectSpy = jest.spyOn(chrome.runtime, "connect").mockImplementation(() => {
return {
onDisconnect: {
addListener: onDisconnectCallback,
removeListener: jest.fn(),
},
} as unknown as chrome.runtime.Port;
});
setupAutofillInitDisconnectAction(window);
expect(portConnectSpy).not.toHaveBeenCalled();
});
it("destroys the autofill init instance when the port is disconnected", () => {
let port: chrome.runtime.Port;
const autofillInitDestroy: CallableFunction = jest.fn();
window.bitwardenAutofillInit = {
destroy: autofillInitDestroy,
} as any;
jest.spyOn(chrome.runtime, "connect").mockImplementation(() => {
port = {
onDisconnect: {
addListener: jest.fn(),
removeListener: jest.fn(),
},
} as unknown as chrome.runtime.Port;
return port;
});
setupAutofillInitDisconnectAction(window);
triggerPortOnDisconnectEvent(port as chrome.runtime.Port);
expect(chrome.runtime.connect).toHaveBeenCalled();
expect(port.onDisconnect.addListener).toHaveBeenCalled();
expect(autofillInitDestroy).toHaveBeenCalled();
expect(window.bitwardenAutofillInit).toBeUndefined();
});
});

View File

@@ -1,3 +1,5 @@
import { AutofillPort } from "../enums/autofill-port.enums";
/**
* Generates a random string of characters that formatted as a custom element name.
*/
@@ -103,9 +105,57 @@ function setElementStyles(
}
}
/**
* Get data from local storage based on the keys provided.
*
* @param keys - String or array of strings of keys to get from local storage
*/
async function getFromLocalStorage(keys: string | string[]): Promise<Record<string, any>> {
return new Promise((resolve) => {
chrome.storage.local.get(keys, (storage: Record<string, any>) => resolve(storage));
});
}
/**
* Sets up a long-lived connection with the extension background
* and triggers an onDisconnect event if the extension context
* is invalidated.
*
* @param callback - Callback function to run when the extension disconnects
*/
function setupExtensionDisconnectAction(callback: (port: chrome.runtime.Port) => void) {
const port = chrome.runtime.connect({ name: AutofillPort.InjectedScript });
const onDisconnectCallback = (disconnectedPort: chrome.runtime.Port) => {
callback(disconnectedPort);
port.onDisconnect.removeListener(onDisconnectCallback);
};
port.onDisconnect.addListener(onDisconnectCallback);
}
/**
* Handles setup of the extension disconnect action for the autofill init class
* in both instances where the overlay might or might not be initialized.
*
* @param windowContext - The global window context
*/
function setupAutofillInitDisconnectAction(windowContext: Window) {
if (!windowContext.bitwardenAutofillInit) {
return;
}
const onDisconnectCallback = () => {
windowContext.bitwardenAutofillInit.destroy();
delete windowContext.bitwardenAutofillInit;
};
setupExtensionDisconnectAction(onDisconnectCallback);
}
export {
generateRandomCustomElementName,
buildSvgDomElement,
sendExtensionMessage,
setElementStyles,
getFromLocalStorage,
setupExtensionDisconnectAction,
setupAutofillInitDisconnectAction,
};

View File

@@ -1,5 +1,8 @@
import { firstValueFrom } from "rxjs";
import { NotificationsService } from "@bitwarden/common/abstractions/notifications.service";
import { VaultTimeoutService } from "@bitwarden/common/abstractions/vault-timeout/vault-timeout.service";
import { AccountService } from "@bitwarden/common/auth/abstractions/account.service";
import { VaultTimeoutAction } from "@bitwarden/common/enums/vault-timeout-action.enum";
import { BrowserStateService } from "../platform/services/abstractions/browser-state.service";
@@ -7,7 +10,7 @@ import { BrowserStateService } from "../platform/services/abstractions/browser-s
const IdleInterval = 60 * 5; // 5 minutes
export default class IdleBackground {
private idle: any;
private idle: typeof chrome.idle | typeof browser.idle | null;
private idleTimer: number = null;
private idleState = "active";
@@ -15,6 +18,7 @@ export default class IdleBackground {
private vaultTimeoutService: VaultTimeoutService,
private stateService: BrowserStateService,
private notificationsService: NotificationsService,
private accountService: AccountService,
) {
this.idle = chrome.idle || (browser != null ? browser.idle : null);
}
@@ -39,21 +43,27 @@ export default class IdleBackground {
}
if (this.idle.onStateChanged) {
this.idle.onStateChanged.addListener(async (newState: string) => {
this.idle.onStateChanged.addListener(
async (newState: chrome.idle.IdleState | browser.idle.IdleState) => {
if (newState === "locked") {
// Need to check if any of the current users have their timeout set to `onLocked`
const allUsers = await firstValueFrom(this.accountService.accounts$);
for (const userId in allUsers) {
// If the screen is locked or the screensaver activates
const timeout = await this.stateService.getVaultTimeout();
const timeout = await this.stateService.getVaultTimeout({ userId: userId });
if (timeout === -2) {
// On System Lock vault timeout option
const action = await this.stateService.getVaultTimeoutAction();
const action = await this.stateService.getVaultTimeoutAction({ userId: userId });
if (action === VaultTimeoutAction.LogOut) {
await this.vaultTimeoutService.logOut();
await this.vaultTimeoutService.logOut(userId);
} else {
await this.vaultTimeoutService.lock();
await this.vaultTimeoutService.lock(userId);
}
}
}
});
}
},
);
}
}

View File

@@ -566,6 +566,7 @@ export default class MainBackground {
this.logService,
this.settingsService,
this.userVerificationService,
this.configService,
);
this.auditService = new AuditService(this.cryptoFunctionService, this.apiService);
@@ -729,6 +730,7 @@ export default class MainBackground {
this.vaultTimeoutService,
this.stateService,
this.notificationsService,
this.accountService,
);
this.webRequestBackground = new WebRequestBackground(
this.platformUtilsService,

View File

@@ -1,5 +1,4 @@
import { NotificationsService } from "@bitwarden/common/abstractions/notifications.service";
import { FeatureFlag } from "@bitwarden/common/enums/feature-flag.enum";
import { ConfigServiceAbstraction } from "@bitwarden/common/platform/abstractions/config/config.service.abstraction";
import { I18nService } from "@bitwarden/common/platform/abstractions/i18n.service";
import { LogService } from "@bitwarden/common/platform/abstractions/log.service";
@@ -97,9 +96,9 @@ export default class RuntimeBackground {
await closeUnlockPopout();
}
await this.notificationsService.updateConnection(msg.command === "loggedIn");
await this.main.refreshBadge();
await this.main.refreshMenu(false);
this.notificationsService.updateConnection(msg.command === "unlocked");
this.systemService.cancelProcessReload();
if (item) {
@@ -133,11 +132,7 @@ export default class RuntimeBackground {
await this.main.openPopup();
break;
case "triggerAutofillScriptInjection":
await this.autofillService.injectAutofillScripts(
sender,
await this.configService.getFeatureFlag<boolean>(FeatureFlag.AutofillV2),
await this.configService.getFeatureFlag<boolean>(FeatureFlag.AutofillOverlay),
);
await this.autofillService.injectAutofillScripts(sender.tab, sender.frameId);
break;
case "bgCollectPageDetails":
await this.main.collectPageDetailsForContentScript(sender.tab, msg.sender, sender.frameId);
@@ -325,6 +320,8 @@ export default class RuntimeBackground {
private async checkOnInstalled() {
setTimeout(async () => {
this.autofillService.loadAutofillScriptsOnInstall();
if (this.onInstalledReason != null) {
if (this.onInstalledReason === "install") {
BrowserApi.createNewTab("https://bitwarden.com/browser-start/");

View File

@@ -24,12 +24,6 @@
"matches": ["http://*/*", "https://*/*", "file:///*"],
"run_at": "document_start"
},
{
"all_frames": false,
"js": ["content/message_handler.js"],
"matches": ["http://*/*", "https://*/*", "file:///*"],
"run_at": "document_start"
},
{
"all_frames": true,
"css": ["content/autofill.css"],

View File

@@ -23,11 +23,12 @@ const runtime = {
removeListener: jest.fn(),
},
sendMessage: jest.fn(),
getManifest: jest.fn(),
getManifest: jest.fn(() => ({ version: 2 })),
getURL: jest.fn((path) => `chrome-extension://id/${path}`),
connect: jest.fn(),
onConnect: {
addListener: jest.fn(),
removeListener: jest.fn(),
},
};

View File

@@ -0,0 +1,394 @@
import { NO_ERRORS_SCHEMA } from "@angular/core";
import { ComponentFixture, TestBed, fakeAsync, tick } from "@angular/core/testing";
import { ActivatedRoute } from "@angular/router";
import { MockProxy, mock } from "jest-mock-extended";
import { of } from "rxjs";
import { LockComponent as BaseLockComponent } from "@bitwarden/angular/auth/components/lock.component";
import { I18nPipe } from "@bitwarden/angular/platform/pipes/i18n.pipe";
import { ApiService } from "@bitwarden/common/abstractions/api.service";
import { VaultTimeoutSettingsService } from "@bitwarden/common/abstractions/vault-timeout/vault-timeout-settings.service";
import { VaultTimeoutService } from "@bitwarden/common/abstractions/vault-timeout/vault-timeout.service";
import { PolicyApiServiceAbstraction } from "@bitwarden/common/admin-console/abstractions/policy/policy-api.service.abstraction";
import { InternalPolicyService } from "@bitwarden/common/admin-console/abstractions/policy/policy.service.abstraction";
import { DeviceTrustCryptoServiceAbstraction } from "@bitwarden/common/auth/abstractions/device-trust-crypto.service.abstraction";
import { UserVerificationService } from "@bitwarden/common/auth/abstractions/user-verification/user-verification.service.abstraction";
import { BroadcasterService } from "@bitwarden/common/platform/abstractions/broadcaster.service";
import { CryptoService } from "@bitwarden/common/platform/abstractions/crypto.service";
import { EnvironmentService } from "@bitwarden/common/platform/abstractions/environment.service";
import { I18nService } from "@bitwarden/common/platform/abstractions/i18n.service";
import { LogService } from "@bitwarden/common/platform/abstractions/log.service";
import { MessagingService } from "@bitwarden/common/platform/abstractions/messaging.service";
import { PlatformUtilsService } from "@bitwarden/common/platform/abstractions/platform-utils.service";
import { PasswordStrengthServiceAbstraction } from "@bitwarden/common/tools/password-strength";
import { DialogService } from "@bitwarden/components";
import { ElectronStateService } from "../platform/services/electron-state.service.abstraction";
import { LockComponent } from "./lock.component";
// ipc mock global
const isWindowVisibleMock = jest.fn();
(global as any).ipc = {
platform: {
biometric: {
enabled: jest.fn(),
},
isWindowVisible: isWindowVisibleMock,
},
};
describe("LockComponent", () => {
let component: LockComponent;
let fixture: ComponentFixture<LockComponent>;
let stateServiceMock: MockProxy<ElectronStateService>;
let messagingServiceMock: MockProxy<MessagingService>;
let broadcasterServiceMock: MockProxy<BroadcasterService>;
let platformUtilsServiceMock: MockProxy<PlatformUtilsService>;
let activatedRouteMock: MockProxy<ActivatedRoute>;
beforeEach(() => {
stateServiceMock = mock<ElectronStateService>();
stateServiceMock.activeAccount$ = of(null);
messagingServiceMock = mock<MessagingService>();
broadcasterServiceMock = mock<BroadcasterService>();
platformUtilsServiceMock = mock<PlatformUtilsService>();
activatedRouteMock = mock<ActivatedRoute>();
activatedRouteMock.queryParams = mock<ActivatedRoute["queryParams"]>();
TestBed.configureTestingModule({
declarations: [LockComponent, I18nPipe],
providers: [
{
provide: I18nService,
useValue: mock<I18nService>(),
},
{
provide: PlatformUtilsService,
useValue: platformUtilsServiceMock,
},
{
provide: MessagingService,
useValue: messagingServiceMock,
},
{
provide: CryptoService,
useValue: mock<CryptoService>(),
},
{
provide: VaultTimeoutService,
useValue: mock<VaultTimeoutService>(),
},
{
provide: VaultTimeoutSettingsService,
useValue: mock<VaultTimeoutSettingsService>(),
},
{
provide: EnvironmentService,
useValue: mock<EnvironmentService>(),
},
{
provide: ElectronStateService,
useValue: stateServiceMock,
},
{
provide: ApiService,
useValue: mock<ApiService>(),
},
{
provide: ActivatedRoute,
useValue: activatedRouteMock,
},
{
provide: BroadcasterService,
useValue: broadcasterServiceMock,
},
{
provide: PolicyApiServiceAbstraction,
useValue: mock<PolicyApiServiceAbstraction>(),
},
{
provide: InternalPolicyService,
useValue: mock<InternalPolicyService>(),
},
{
provide: PasswordStrengthServiceAbstraction,
useValue: mock<PasswordStrengthServiceAbstraction>(),
},
{
provide: LogService,
useValue: mock<LogService>(),
},
{
provide: DialogService,
useValue: mock<DialogService>(),
},
{
provide: DeviceTrustCryptoServiceAbstraction,
useValue: mock<DeviceTrustCryptoServiceAbstraction>(),
},
{
provide: UserVerificationService,
useValue: mock<UserVerificationService>(),
},
],
schemas: [NO_ERRORS_SCHEMA],
}).compileComponents();
});
beforeEach(() => {
fixture = TestBed.createComponent(LockComponent);
component = fixture.componentInstance;
fixture.detectChanges();
jest.clearAllMocks();
});
describe("ngOnInit", () => {
it("should call super.ngOnInit() once", async () => {
const superNgOnInitSpy = jest.spyOn(BaseLockComponent.prototype, "ngOnInit");
await component.ngOnInit();
expect(superNgOnInitSpy).toHaveBeenCalledTimes(1);
});
it('should set "autoPromptBiometric" to true if "stateService.getDisableAutoBiometricsPrompt()" resolves to false', async () => {
stateServiceMock.getDisableAutoBiometricsPrompt.mockResolvedValue(false);
await component.ngOnInit();
expect(component["autoPromptBiometric"]).toBe(true);
});
it('should set "autoPromptBiometric" to false if "stateService.getDisableAutoBiometricsPrompt()" resolves to true', async () => {
stateServiceMock.getDisableAutoBiometricsPrompt.mockResolvedValue(true);
await component.ngOnInit();
expect(component["autoPromptBiometric"]).toBe(false);
});
it('should set "biometricReady" to true if "stateService.getBiometricReady()" resolves to true', async () => {
component["canUseBiometric"] = jest.fn().mockResolvedValue(true);
await component.ngOnInit();
expect(component["biometricReady"]).toBe(true);
});
it('should set "biometricReady" to false if "stateService.getBiometricReady()" resolves to false', async () => {
component["canUseBiometric"] = jest.fn().mockResolvedValue(false);
await component.ngOnInit();
expect(component["biometricReady"]).toBe(false);
});
it("should call displayBiometricUpdateWarning", async () => {
component["displayBiometricUpdateWarning"] = jest.fn();
await component.ngOnInit();
expect(component["displayBiometricUpdateWarning"]).toHaveBeenCalledTimes(1);
});
it("should call delayedAskForBiometric", async () => {
component["delayedAskForBiometric"] = jest.fn();
await component.ngOnInit();
expect(component["delayedAskForBiometric"]).toHaveBeenCalledTimes(1);
expect(component["delayedAskForBiometric"]).toHaveBeenCalledWith(500);
});
it("should call delayedAskForBiometric when queryParams change", async () => {
activatedRouteMock.queryParams = of({ promptBiometric: true });
component["delayedAskForBiometric"] = jest.fn();
await component.ngOnInit();
expect(component["delayedAskForBiometric"]).toHaveBeenCalledTimes(1);
expect(component["delayedAskForBiometric"]).toHaveBeenCalledWith(500);
});
it("should call messagingService.send", async () => {
await component.ngOnInit();
expect(messagingServiceMock.send).toHaveBeenCalledWith("getWindowIsFocused");
});
describe("broadcasterService.subscribe", () => {
it('should call onWindowHidden() when "broadcasterService.subscribe" is called with "windowHidden"', async () => {
component["onWindowHidden"] = jest.fn();
await component.ngOnInit();
broadcasterServiceMock.subscribe.mock.calls[0][1]({ command: "windowHidden" });
expect(component["onWindowHidden"]).toHaveBeenCalledTimes(1);
});
it('should call focusInput() when "broadcasterService.subscribe" is called with "windowIsFocused" is true and deferFocus is false', async () => {
component["focusInput"] = jest.fn();
component["deferFocus"] = null;
await component.ngOnInit();
broadcasterServiceMock.subscribe.mock.calls[0][1]({
command: "windowIsFocused",
windowIsFocused: true,
} as any);
expect(component["deferFocus"]).toBe(false);
expect(component["focusInput"]).toHaveBeenCalledTimes(1);
});
it('should not call focusInput() when "broadcasterService.subscribe" is called with "windowIsFocused" is true and deferFocus is true', async () => {
component["focusInput"] = jest.fn();
component["deferFocus"] = null;
await component.ngOnInit();
broadcasterServiceMock.subscribe.mock.calls[0][1]({
command: "windowIsFocused",
windowIsFocused: false,
} as any);
expect(component["deferFocus"]).toBe(true);
expect(component["focusInput"]).toHaveBeenCalledTimes(0);
});
it('should call focusInput() when "broadcasterService.subscribe" is called with "windowIsFocused" is true and deferFocus is true', async () => {
component["focusInput"] = jest.fn();
component["deferFocus"] = true;
await component.ngOnInit();
broadcasterServiceMock.subscribe.mock.calls[0][1]({
command: "windowIsFocused",
windowIsFocused: true,
} as any);
expect(component["deferFocus"]).toBe(false);
expect(component["focusInput"]).toHaveBeenCalledTimes(1);
});
it('should not call focusInput() when "broadcasterService.subscribe" is called with "windowIsFocused" is false and deferFocus is true', async () => {
component["focusInput"] = jest.fn();
component["deferFocus"] = true;
await component.ngOnInit();
broadcasterServiceMock.subscribe.mock.calls[0][1]({
command: "windowIsFocused",
windowIsFocused: false,
} as any);
expect(component["deferFocus"]).toBe(true);
expect(component["focusInput"]).toHaveBeenCalledTimes(0);
});
});
});
describe("ngOnDestroy", () => {
it("should call super.ngOnDestroy()", () => {
const superNgOnDestroySpy = jest.spyOn(BaseLockComponent.prototype, "ngOnDestroy");
component.ngOnDestroy();
expect(superNgOnDestroySpy).toHaveBeenCalledTimes(1);
});
it("should call broadcasterService.unsubscribe()", () => {
component.ngOnDestroy();
expect(broadcasterServiceMock.unsubscribe).toHaveBeenCalledTimes(1);
});
});
describe("focusInput", () => {
it('should call "focus" on #pin input if pinEnabled is true', () => {
component["pinEnabled"] = true;
global.document.getElementById = jest.fn().mockReturnValue({ focus: jest.fn() });
component["focusInput"]();
expect(global.document.getElementById).toHaveBeenCalledWith("pin");
});
it('should call "focus" on #masterPassword input if pinEnabled is false', () => {
component["pinEnabled"] = false;
global.document.getElementById = jest.fn().mockReturnValue({ focus: jest.fn() });
component["focusInput"]();
expect(global.document.getElementById).toHaveBeenCalledWith("masterPassword");
});
});
describe("delayedAskForBiometric", () => {
beforeEach(() => {
component["supportsBiometric"] = true;
component["autoPromptBiometric"] = true;
});
it('should wait for "delay" milliseconds', fakeAsync(async () => {
const delaySpy = jest.spyOn(global, "setTimeout");
component["delayedAskForBiometric"](5000);
tick(4000);
component["biometricAsked"] = false;
tick(1000);
component["biometricAsked"] = true;
expect(delaySpy).toHaveBeenCalledWith(expect.any(Function), 5000);
}));
it('should return; if "params" is defined and "params.promptBiometric" is false', fakeAsync(async () => {
component["delayedAskForBiometric"](5000, { promptBiometric: false });
tick(5000);
expect(component["biometricAsked"]).toBe(false);
}));
it('should not return; if "params" is defined and "params.promptBiometric" is true', fakeAsync(async () => {
component["delayedAskForBiometric"](5000, { promptBiometric: true });
tick(5000);
expect(component["biometricAsked"]).toBe(true);
}));
it('should not return; if "params" is undefined', fakeAsync(async () => {
component["delayedAskForBiometric"](5000);
tick(5000);
expect(component["biometricAsked"]).toBe(true);
}));
it('should return; if "supportsBiometric" is false', fakeAsync(async () => {
component["supportsBiometric"] = false;
component["delayedAskForBiometric"](5000);
tick(5000);
expect(component["biometricAsked"]).toBe(false);
}));
it('should return; if "autoPromptBiometric" is false', fakeAsync(async () => {
component["autoPromptBiometric"] = false;
component["delayedAskForBiometric"](5000);
tick(5000);
expect(component["biometricAsked"]).toBe(false);
}));
it("should call unlockBiometric() if biometricAsked is false and window is visible", fakeAsync(async () => {
isWindowVisibleMock.mockResolvedValue(true);
component["unlockBiometric"] = jest.fn();
component["biometricAsked"] = false;
component["delayedAskForBiometric"](5000);
tick(5000);
expect(component["unlockBiometric"]).toHaveBeenCalledTimes(1);
}));
it("should not call unlockBiometric() if biometricAsked is false and window is not visible", fakeAsync(async () => {
isWindowVisibleMock.mockResolvedValue(false);
component["unlockBiometric"] = jest.fn();
component["biometricAsked"] = false;
component["delayedAskForBiometric"](5000);
tick(5000);
expect(component["unlockBiometric"]).toHaveBeenCalledTimes(0);
}));
it("should not call unlockBiometric() if biometricAsked is true", fakeAsync(async () => {
isWindowVisibleMock.mockResolvedValue(true);
component["unlockBiometric"] = jest.fn();
component["biometricAsked"] = true;
component["delayedAskForBiometric"](5000);
tick(5000);
expect(component["unlockBiometric"]).toHaveBeenCalledTimes(0);
}));
});
describe("canUseBiometric", () => {
it("should call getUserId() on stateService", async () => {
stateServiceMock.getUserId.mockResolvedValue("userId");
await component["canUseBiometric"]();
expect(ipc.platform.biometric.enabled).toHaveBeenCalledWith("userId");
});
});
it('onWindowHidden() should set "showPassword" to false', () => {
component["showPassword"] = true;
component["onWindowHidden"]();
expect(component["showPassword"]).toBe(false);
});
});

View File

@@ -1,5 +1,6 @@
import { Component, NgZone } from "@angular/core";
import { ActivatedRoute, Router } from "@angular/router";
import { switchMap } from "rxjs";
import { LockComponent as BaseLockComponent } from "@bitwarden/angular/auth/components/lock.component";
import { ApiService } from "@bitwarden/common/abstractions/api.service";
@@ -31,6 +32,8 @@ const BroadcasterSubscriptionId = "LockComponent";
export class LockComponent extends BaseLockComponent {
private deferFocus: boolean = null;
protected biometricReady = false;
private biometricAsked = false;
private autoPromptBiometric = false;
constructor(
router: Router,
@@ -78,23 +81,14 @@ export class LockComponent extends BaseLockComponent {
async ngOnInit() {
await super.ngOnInit();
const autoPromptBiometric = !(await this.stateService.getDisableAutoBiometricsPrompt());
this.autoPromptBiometric = !(await this.stateService.getDisableAutoBiometricsPrompt());
this.biometricReady = await this.canUseBiometric();
await this.displayBiometricUpdateWarning();
// eslint-disable-next-line rxjs-angular/prefer-takeuntil
this.route.queryParams.subscribe((params) => {
setTimeout(async () => {
if (!params.promptBiometric || !this.supportsBiometric || !autoPromptBiometric) {
return;
}
this.delayedAskForBiometric(500);
this.route.queryParams.pipe(switchMap((params) => this.delayedAskForBiometric(500, params)));
if (await ipc.platform.isWindowVisible()) {
this.unlockBiometric();
}
}, 1000);
});
this.broadcasterService.subscribe(BroadcasterSubscriptionId, async (message: any) => {
this.ngZone.run(() => {
switch (message.command) {
@@ -128,6 +122,23 @@ export class LockComponent extends BaseLockComponent {
this.showPassword = false;
}
private async delayedAskForBiometric(delay: number, params?: any) {
await new Promise((resolve) => setTimeout(resolve, delay));
if (params && !params.promptBiometric) {
return;
}
if (!this.supportsBiometric || !this.autoPromptBiometric || this.biometricAsked) {
return;
}
this.biometricAsked = true;
if (await ipc.platform.isWindowVisible()) {
this.unlockBiometric();
}
}
private async canUseBiometric() {
const userId = await this.stateService.getUserId();
return await ipc.platform.biometric.enabled(userId);

View File

@@ -0,0 +1,25 @@
import { WebAuthnLoginAssertionResponseRequest } from "@bitwarden/common/auth/services/webauthn-login/request/webauthn-login-assertion-response.request";
/**
* Request sent to the server to save a newly created prf key set for a credential.
*/
export class EnableCredentialEncryptionRequest {
/**
* The response received from the authenticator.
*/
deviceResponse: WebAuthnLoginAssertionResponseRequest;
/**
* An encrypted token containing information the server needs to verify the credential.
*/
token: string;
/** Used for vault encryption. See {@link RotateableKeySet.encryptedUserKey } */
encryptedUserKey?: string;
/** Used for vault encryption. See {@link RotateableKeySet.encryptedPublicKey } */
encryptedPublicKey?: string;
/** Used for vault encryption. See {@link RotateableKeySet.encryptedPrivateKey } */
encryptedPrivateKey?: string;
}

View File

@@ -3,7 +3,7 @@ import { Utils } from "@bitwarden/common/platform/misc/utils";
import { WebauthnLoginAuthenticatorResponseRequest } from "./webauthn-login-authenticator-response.request";
/**
* The response received from an authentiator after a successful attestation.
* The response received from an authenticator after a successful attestation.
* This request is used to save newly created webauthn login credentials to the server.
*/
export class WebauthnLoginAttestationResponseRequest extends WebauthnLoginAuthenticatorResponseRequest {

View File

@@ -2,8 +2,10 @@ import { Injectable } from "@angular/core";
import { ApiService } from "@bitwarden/common/abstractions/api.service";
import { SecretVerificationRequest } from "@bitwarden/common/auth/models/request/secret-verification.request";
import { CredentialAssertionOptionsResponse } from "@bitwarden/common/auth/services/webauthn-login/response/credential-assertion-options.response";
import { ListResponse } from "@bitwarden/common/models/response/list.response";
import { EnableCredentialEncryptionRequest } from "./request/enable-credential-encryption.request";
import { SaveCredentialRequest } from "./request/save-credential.request";
import { WebauthnLoginCredentialCreateOptionsResponse } from "./response/webauthn-login-credential-create-options.response";
import { WebauthnLoginCredentialResponse } from "./response/webauthn-login-credential.response";
@@ -15,10 +17,29 @@ export class WebAuthnLoginAdminApiService {
async getCredentialCreateOptions(
request: SecretVerificationRequest,
): Promise<WebauthnLoginCredentialCreateOptionsResponse> {
const response = await this.apiService.send("POST", "/webauthn/options", request, true, true);
const response = await this.apiService.send(
"POST",
"/webauthn/attestation-options",
request,
true,
true,
);
return new WebauthnLoginCredentialCreateOptionsResponse(response);
}
async getCredentialAssertionOptions(
request: SecretVerificationRequest,
): Promise<CredentialAssertionOptionsResponse> {
const response = await this.apiService.send(
"POST",
"/webauthn/assertion-options",
request,
true,
true,
);
return new CredentialAssertionOptionsResponse(response);
}
async saveCredential(request: SaveCredentialRequest): Promise<boolean> {
await this.apiService.send("POST", "/webauthn", request, true, true);
return true;
@@ -31,4 +52,8 @@ export class WebAuthnLoginAdminApiService {
async deleteCredential(credentialId: string, request: SecretVerificationRequest): Promise<void> {
await this.apiService.send("POST", `/webauthn/${credentialId}/delete`, request, true, true);
}
async updateCredential(request: EnableCredentialEncryptionRequest): Promise<void> {
await this.apiService.send("PUT", `/webauthn`, request, true, true);
}
}

View File

@@ -1,12 +1,21 @@
import { randomBytes } from "crypto";
import { mock, MockProxy } from "jest-mock-extended";
import { RotateableKeySet } from "@bitwarden/auth";
import { UserVerificationService } from "@bitwarden/common/auth/abstractions/user-verification/user-verification.service.abstraction";
import { WebAuthnLoginPrfCryptoServiceAbstraction } from "@bitwarden/common/auth/abstractions/webauthn/webauthn-login-prf-crypto.service.abstraction";
import { WebAuthnLoginCredentialAssertionView } from "@bitwarden/common/auth/models/view/webauthn-login/webauthn-login-credential-assertion.view";
import { WebAuthnLoginAssertionResponseRequest } from "@bitwarden/common/auth/services/webauthn-login/request/webauthn-login-assertion-response.request";
import { Utils } from "@bitwarden/common/platform/misc/utils";
import { EncString } from "@bitwarden/common/platform/models/domain/enc-string";
import { PrfKey } from "@bitwarden/common/platform/models/domain/symmetric-crypto-key";
import { CredentialCreateOptionsView } from "../../views/credential-create-options.view";
import { PendingWebauthnLoginCredentialView } from "../../views/pending-webauthn-login-credential.view";
import { RotateableKeySetService } from "../rotateable-key-set.service";
import { EnableCredentialEncryptionRequest } from "./request/enable-credential-encryption.request";
import { WebAuthnLoginAdminApiService } from "./webauthn-login-admin-api.service";
import { WebauthnLoginAdminService } from "./webauthn-login-admin.service";
@@ -18,10 +27,13 @@ describe("WebauthnAdminService", () => {
let credentials: MockProxy<CredentialsContainer>;
let service!: WebauthnLoginAdminService;
let originalAuthenticatorAssertionResponse!: AuthenticatorAssertionResponse | any;
beforeAll(() => {
// Polyfill missing class
window.PublicKeyCredential = class {} as any;
window.AuthenticatorAttestationResponse = class {} as any;
window.AuthenticatorAssertionResponse = class {} as any;
apiService = mock<WebAuthnLoginAdminApiService>();
userVerificationService = mock<UserVerificationService>();
rotateableKeySetService = mock<RotateableKeySetService>();
@@ -34,6 +46,20 @@ describe("WebauthnAdminService", () => {
webAuthnLoginPrfCryptoService,
credentials,
);
// Save original global class
originalAuthenticatorAssertionResponse = global.AuthenticatorAssertionResponse;
// Mock the global AuthenticatorAssertionResponse class b/c the class is only available in secure contexts
global.AuthenticatorAssertionResponse = MockAuthenticatorAssertionResponse;
});
beforeEach(() => {
jest.clearAllMocks();
});
afterAll(() => {
// Restore global after all tests are done
global.AuthenticatorAssertionResponse = originalAuthenticatorAssertionResponse;
});
describe("createCredential", () => {
@@ -70,6 +96,94 @@ describe("WebauthnAdminService", () => {
expect(result.supportsPrf).toBe(true);
});
});
describe("enableCredentialEncryption", () => {
it("should call the necessary methods to update the credential", async () => {
// Arrange
const response = new MockPublicKeyCredential();
const prfKeySet = new RotateableKeySet<PrfKey>(
new EncString("test_encryptedUserKey"),
new EncString("test_encryptedPublicKey"),
new EncString("test_encryptedPrivateKey"),
);
const assertionOptions: WebAuthnLoginCredentialAssertionView =
new WebAuthnLoginCredentialAssertionView(
"enable_credential_encryption_test_token",
new WebAuthnLoginAssertionResponseRequest(response),
{} as PrfKey,
);
const request = new EnableCredentialEncryptionRequest();
request.token = assertionOptions.token;
request.deviceResponse = assertionOptions.deviceResponse;
request.encryptedUserKey = prfKeySet.encryptedUserKey.encryptedString;
request.encryptedPublicKey = prfKeySet.encryptedPublicKey.encryptedString;
request.encryptedPrivateKey = prfKeySet.encryptedPrivateKey.encryptedString;
// Mock the necessary methods and services
const createKeySetMock = jest
.spyOn(rotateableKeySetService, "createKeySet")
.mockResolvedValue(prfKeySet);
const updateCredentialMock = jest.spyOn(apiService, "updateCredential").mockResolvedValue();
// Act
await service.enableCredentialEncryption(assertionOptions);
// Assert
expect(createKeySetMock).toHaveBeenCalledWith(assertionOptions.prfKey);
expect(updateCredentialMock).toHaveBeenCalledWith(request);
});
it("should throw error when PRF Key is undefined", async () => {
// Arrange
const response = new MockPublicKeyCredential();
const assertionOptions: WebAuthnLoginCredentialAssertionView =
new WebAuthnLoginCredentialAssertionView(
"enable_credential_encryption_test_token",
new WebAuthnLoginAssertionResponseRequest(response),
undefined,
);
// Mock the necessary methods and services
const createKeySetMock = jest
.spyOn(rotateableKeySetService, "createKeySet")
.mockResolvedValue(null);
const updateCredentialMock = jest.spyOn(apiService, "updateCredential").mockResolvedValue();
// Act
try {
await service.enableCredentialEncryption(assertionOptions);
} catch (error) {
// Assert
expect(error).toEqual(new Error("invalid credential"));
expect(createKeySetMock).not.toHaveBeenCalled();
expect(updateCredentialMock).not.toHaveBeenCalled();
}
});
it("should throw error when WehAuthnLoginCredentialAssertionView is undefined", async () => {
// Arrange
const assertionOptions: WebAuthnLoginCredentialAssertionView = undefined;
// Mock the necessary methods and services
const createKeySetMock = jest
.spyOn(rotateableKeySetService, "createKeySet")
.mockResolvedValue(null);
const updateCredentialMock = jest.spyOn(apiService, "updateCredential").mockResolvedValue();
// Act
try {
await service.enableCredentialEncryption(assertionOptions);
} catch (error) {
// Assert
expect(error).toEqual(new Error("invalid credential"));
expect(createKeySetMock).not.toHaveBeenCalled();
expect(updateCredentialMock).not.toHaveBeenCalled();
}
});
});
});
function createCredentialCreateOptions(): CredentialCreateOptionsView {
@@ -115,3 +229,58 @@ function createDeviceResponse({ prf = false }: { prf?: boolean } = {}): PublicKe
return credential;
}
/**
* Mocks for the PublicKeyCredential and AuthenticatorAssertionResponse classes copied from webauthn-login.service.spec.ts
*/
// AuthenticatorAssertionResponse && PublicKeyCredential are only available in secure contexts
// so we need to mock them and assign them to the global object to make them available
// for the tests
class MockAuthenticatorAssertionResponse implements AuthenticatorAssertionResponse {
clientDataJSON: ArrayBuffer = randomBytes(32).buffer;
authenticatorData: ArrayBuffer = randomBytes(196).buffer;
signature: ArrayBuffer = randomBytes(72).buffer;
userHandle: ArrayBuffer = randomBytes(16).buffer;
clientDataJSONB64Str = Utils.fromBufferToUrlB64(this.clientDataJSON);
authenticatorDataB64Str = Utils.fromBufferToUrlB64(this.authenticatorData);
signatureB64Str = Utils.fromBufferToUrlB64(this.signature);
userHandleB64Str = Utils.fromBufferToUrlB64(this.userHandle);
}
class MockPublicKeyCredential implements PublicKeyCredential {
authenticatorAttachment = "cross-platform";
id = "mockCredentialId";
type = "public-key";
rawId: ArrayBuffer = randomBytes(32).buffer;
rawIdB64Str = Utils.fromBufferToUrlB64(this.rawId);
response: MockAuthenticatorAssertionResponse = new MockAuthenticatorAssertionResponse();
// Use random 64 character hex string (32 bytes - matters for symmetric key creation)
// to represent the prf key binary data and convert to ArrayBuffer
// Creating the array buffer from a known hex value allows us to
// assert on the value in tests
private prfKeyArrayBuffer: ArrayBuffer = Utils.hexStringToArrayBuffer(
"1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef",
);
getClientExtensionResults(): any {
return {
prf: {
results: {
first: this.prfKeyArrayBuffer,
},
},
};
}
static isConditionalMediationAvailable(): Promise<boolean> {
return Promise.resolve(false);
}
static isUserVerifyingPlatformAuthenticatorAvailable(): Promise<boolean> {
return Promise.resolve(false);
}
}

View File

@@ -4,6 +4,8 @@ import { BehaviorSubject, filter, from, map, Observable, shareReplay, switchMap,
import { PrfKeySet } from "@bitwarden/auth";
import { UserVerificationService } from "@bitwarden/common/auth/abstractions/user-verification/user-verification.service.abstraction";
import { WebAuthnLoginPrfCryptoServiceAbstraction } from "@bitwarden/common/auth/abstractions/webauthn/webauthn-login-prf-crypto.service.abstraction";
import { WebAuthnLoginCredentialAssertionOptionsView } from "@bitwarden/common/auth/models/view/webauthn-login/webauthn-login-credential-assertion-options.view";
import { WebAuthnLoginCredentialAssertionView } from "@bitwarden/common/auth/models/view/webauthn-login/webauthn-login-credential-assertion.view";
import { Verification } from "@bitwarden/common/auth/types/verification";
import { LogService } from "@bitwarden/common/platform/abstractions/log.service";
@@ -12,6 +14,7 @@ import { PendingWebauthnLoginCredentialView } from "../../views/pending-webauthn
import { WebauthnLoginCredentialView } from "../../views/webauthn-login-credential.view";
import { RotateableKeySetService } from "../rotateable-key-set.service";
import { EnableCredentialEncryptionRequest } from "./request/enable-credential-encryption.request";
import { SaveCredentialRequest } from "./request/save-credential.request";
import { WebauthnLoginAttestationResponseRequest } from "./request/webauthn-login-attestation-response.request";
import { WebAuthnLoginAdminApiService } from "./webauthn-login-admin-api.service";
@@ -52,14 +55,31 @@ export class WebauthnLoginAdminService {
}
/**
* Get the credential attestation options needed for initiating the WebAuthnLogin credentail creation process.
* Get the credential assertion options needed for initiating the WebAuthnLogin credential update process.
* The options contains assertion options and other data for the authenticator.
* This method requires user verification.
*
* @param verification User verification data to be used for the request.
* @returns The credential assertion options and a token to be used for the credential update request.
*/
async getCredentialAssertOptions(
verification: Verification,
): Promise<WebAuthnLoginCredentialAssertionOptionsView> {
const request = await this.userVerificationService.buildRequest(verification);
const response = await this.apiService.getCredentialAssertionOptions(request);
return new WebAuthnLoginCredentialAssertionOptionsView(response.options, response.token);
}
/**
* Get the credential attestation options needed for initiating the WebAuthnLogin credential creation process.
* The options contains a challenge and other data for the authenticator.
* This method requires user verification.
*
* @param verification User verification data to be used for the request.
* @returns The credential attestation options and a token to be used for the credential creation request.
*/
async getCredentialCreateOptions(
async getCredentialAttestationOptions(
verification: Verification,
): Promise<CredentialCreateOptionsView> {
const request = await this.userVerificationService.buildRequest(verification);
@@ -169,6 +189,36 @@ export class WebauthnLoginAdminService {
this.refresh();
}
/**
* Enable encryption for a credential that has already been saved to the server.
* This will update the KeySet associated with the credential in the database.
* We short circuit the process here incase the WebAuthnLoginCredential doesn't support PRF or
* if there was a problem with the Credential Assertion.
*
* @param assertionOptions Options received from the server using `getCredentialAssertOptions`.
* @returns void
*/
async enableCredentialEncryption(
assertionOptions: WebAuthnLoginCredentialAssertionView,
): Promise<void> {
if (assertionOptions === undefined || assertionOptions?.prfKey === undefined) {
throw new Error("invalid credential");
}
const prfKeySet: PrfKeySet = await this.rotateableKeySetService.createKeySet(
assertionOptions.prfKey,
);
const request = new EnableCredentialEncryptionRequest();
request.token = assertionOptions.token;
request.deviceResponse = assertionOptions.deviceResponse;
request.encryptedUserKey = prfKeySet.encryptedUserKey.encryptedString;
request.encryptedPublicKey = prfKeySet.encryptedPublicKey.encryptedString;
request.encryptedPrivateKey = prfKeySet.encryptedPrivateKey.encryptedString;
await this.apiService.updateCredential(request);
this.refresh();
}
/**
* List of webauthn credentials saved on the server.
*

View File

@@ -94,7 +94,7 @@ export class CreateCredentialDialogComponent implements OnInit {
}
try {
this.credentialOptions = await this.webauthnService.getCredentialCreateOptions(
this.credentialOptions = await this.webauthnService.getCredentialAttestationOptions(
this.formGroup.value.userVerification.secret,
);
} catch (error) {

View File

@@ -0,0 +1,34 @@
<form [formGroup]="formGroup" [bitSubmit]="submit">
<bit-dialog dialogSize="large" [loading]="loading$ | async">
<span bitDialogTitle
>{{ "enablePasskeyEncryption" | i18n }}
<span *ngIf="credential" class="tw-text-sm tw-normal-case tw-text-muted">{{
credential.name
}}</span>
</span>
<ng-container bitDialogContent>
<ng-container *ngIf="!credential">
<i class="bwi bwi-spinner bwi-spin tw-ml-1" aria-hidden="true"></i>
</ng-container>
<ng-container *ngIf="credential">
<p bitTypography="body1">{{ "useForVaultEncryptionInfo" | i18n }}</p>
<ng-container formGroupName="userVerification">
<app-user-verification
formControlName="secret"
[(invalidSecret)]="invalidSecret"
></app-user-verification>
</ng-container>
</ng-container>
</ng-container>
<ng-container bitDialogFooter>
<button type="submit" bitButton bitFormButton buttonType="primary">
{{ "submit" | i18n }}
</button>
<button type="button" bitButton bitFormButton buttonType="secondary" bitDialogClose>
{{ "cancel" | i18n }}
</button>
</ng-container>
</bit-dialog>
</form>

View File

@@ -0,0 +1,91 @@
import { DIALOG_DATA, DialogConfig, DialogRef } from "@angular/cdk/dialog";
import { Component, Inject, OnDestroy, OnInit } from "@angular/core";
import { FormBuilder, Validators } from "@angular/forms";
import { Subject } from "rxjs";
import { takeUntil } from "rxjs/operators";
import { WebAuthnLoginServiceAbstraction } from "@bitwarden/common/auth/abstractions/webauthn/webauthn-login.service.abstraction";
import { WebAuthnLoginCredentialAssertionOptionsView } from "@bitwarden/common/auth/models/view/webauthn-login/webauthn-login-credential-assertion-options.view";
import { Verification } from "@bitwarden/common/auth/types/verification";
import { ErrorResponse } from "@bitwarden/common/models/response/error.response";
import { DialogService } from "@bitwarden/components/src/dialog/dialog.service";
import { WebauthnLoginAdminService } from "../../../core/services/webauthn-login/webauthn-login-admin.service";
import { WebauthnLoginCredentialView } from "../../../core/views/webauthn-login-credential.view";
export interface EnableEncryptionDialogParams {
credentialId: string;
}
@Component({
templateUrl: "enable-encryption-dialog.component.html",
})
export class EnableEncryptionDialogComponent implements OnInit, OnDestroy {
private destroy$ = new Subject<void>();
protected invalidSecret = false;
protected formGroup = this.formBuilder.group({
userVerification: this.formBuilder.group({
secret: [null as Verification | null, Validators.required],
}),
});
protected credential?: WebauthnLoginCredentialView;
protected credentialOptions?: WebAuthnLoginCredentialAssertionOptionsView;
protected loading$ = this.webauthnService.loading$;
constructor(
@Inject(DIALOG_DATA) private params: EnableEncryptionDialogParams,
private formBuilder: FormBuilder,
private dialogRef: DialogRef,
private webauthnService: WebauthnLoginAdminService,
private webauthnLoginService: WebAuthnLoginServiceAbstraction,
) {}
ngOnInit(): void {
this.webauthnService
.getCredential$(this.params.credentialId)
.pipe(takeUntil(this.destroy$))
.subscribe((credential: any) => (this.credential = credential));
}
submit = async () => {
if (this.credential === undefined) {
return;
}
this.dialogRef.disableClose = true;
try {
this.credentialOptions = await this.webauthnService.getCredentialAssertOptions(
this.formGroup.value.userVerification.secret,
);
await this.webauthnService.enableCredentialEncryption(
await this.webauthnLoginService.assertCredential(this.credentialOptions),
);
} catch (error) {
if (error instanceof ErrorResponse && error.statusCode === 400) {
this.invalidSecret = true;
}
throw error;
}
this.dialogRef.close();
};
ngOnDestroy(): void {
this.destroy$.next();
this.destroy$.complete();
}
}
/**
* Strongly typed helper to open a EnableEncryptionDialogComponent
* @param dialogService Instance of the dialog service that will be used to open the dialog
* @param config Configuration for the dialog
*/
export const openEnableCredentialDialogComponent = (
dialogService: DialogService,
config: DialogConfig<EnableEncryptionDialogParams>,
) => {
return dialogService.open<unknown>(EnableEncryptionDialogComponent, config);
};

View File

@@ -39,8 +39,16 @@
<span bitTypography="body1" class="tw-text-muted">{{ "usedForEncryption" | i18n }}</span>
</ng-container>
<ng-container *ngIf="credential.prfStatus === WebauthnLoginCredentialPrfStatus.Supported">
<button
type="button"
bitLink
[disabled]="loading"
[attr.aria-label]="('enablePasskeyEncryption' | i18n) + ' ' + credential.name"
(click)="enableEncryption(credential.id)"
>
<i class="bwi bwi-lock-encrypted"></i>
<span bitTypography="body1" class="tw-text-muted">{{ "encryptionNotEnabled" | i18n }}</span>
{{ "enablePasskeyEncryption" | i18n }}
</button>
</ng-container>
<span
*ngIf="credential.prfStatus === WebauthnLoginCredentialPrfStatus.Unsupported"

View File

@@ -11,6 +11,7 @@ import { WebauthnLoginCredentialView } from "../../core/views/webauthn-login-cre
import { openCreateCredentialDialog } from "./create-credential-dialog/create-credential-dialog.component";
import { openDeleteCredentialDialogComponent } from "./delete-credential-dialog/delete-credential-dialog.component";
import { openEnableCredentialDialogComponent } from "./enable-encryption-dialog/enable-encryption-dialog.component";
@Component({
selector: "app-webauthn-login-settings",
@@ -83,4 +84,8 @@ export class WebauthnLoginSettingsComponent implements OnInit, OnDestroy {
protected deleteCredential(credentialId: string) {
openDeleteCredentialDialogComponent(this.dialogService, { data: { credentialId } });
}
protected enableEncryption(credentialId: string) {
openEnableCredentialDialogComponent(this.dialogService, { data: { credentialId } });
}
}

View File

@@ -8,6 +8,7 @@ import { UserVerificationModule } from "../../shared/components/user-verificatio
import { CreateCredentialDialogComponent } from "./create-credential-dialog/create-credential-dialog.component";
import { DeleteCredentialDialogComponent } from "./delete-credential-dialog/delete-credential-dialog.component";
import { EnableEncryptionDialogComponent } from "./enable-encryption-dialog/enable-encryption-dialog.component";
import { WebauthnLoginSettingsComponent } from "./webauthn-login-settings.component";
@NgModule({
@@ -16,6 +17,7 @@ import { WebauthnLoginSettingsComponent } from "./webauthn-login-settings.compon
WebauthnLoginSettingsComponent,
CreateCredentialDialogComponent,
DeleteCredentialDialogComponent,
EnableEncryptionDialogComponent,
],
exports: [WebauthnLoginSettingsComponent],
})

View File

@@ -674,8 +674,8 @@
"encryptionNotSupported": {
"message": "Encryption not supported"
},
"encryptionNotEnabled": {
"message": "Encryption not enabled"
"enablePasskeyEncryption": {
"message": "Set up encryption"
},
"usedForEncryption": {
"message": "Used for encryption"

View File

@@ -59,7 +59,7 @@ export class FakeStorageService implements AbstractStorageService {
return Promise.resolve(this.store[key] != null);
}
save<T>(key: string, obj: T, options?: StorageOptions): Promise<void> {
this.mock.save(key, options);
this.mock.save(key, obj, options);
this.store[key] = obj;
this.updatesSubject.next({ key: key, updateType: "save" });
return Promise.resolve();

View File

@@ -69,6 +69,10 @@ export function trackEmissions<T>(observable: Observable<T>): T[] {
case "boolean":
emissions.push(value);
break;
case "symbol":
// Cheating types to make symbols work at all
emissions.push(value.toString() as T);
break;
default: {
emissions.push(clone(value));
}
@@ -85,7 +89,7 @@ function clone(value: any): any {
}
}
export async function awaitAsync(ms = 0) {
export async function awaitAsync(ms = 1) {
if (ms < 1) {
await Promise.resolve();
} else {

View File

@@ -2,7 +2,7 @@
* need to update test environment so trackEmissions works appropriately
* @jest-environment ../shared/test.environment.ts
*/
import { any, mock } from "jest-mock-extended";
import { any, anySymbol, mock } from "jest-mock-extended";
import { BehaviorSubject, firstValueFrom, of, timeout } from "rxjs";
import { Jsonify } from "type-fest";
@@ -11,7 +11,7 @@ import { FakeStorageService } from "../../../../spec/fake-storage.service";
import { AccountInfo, AccountService } from "../../../auth/abstractions/account.service";
import { AuthenticationStatus } from "../../../auth/enums/authentication-status";
import { UserId } from "../../../types/guid";
import { KeyDefinition } from "../key-definition";
import { KeyDefinition, userKeyBuilder } from "../key-definition";
import { StateDefinition } from "../state-definition";
import { DefaultActiveUserState } from "./default-active-user-state";
@@ -32,9 +32,10 @@ class TestState {
}
const testStateDefinition = new StateDefinition("fake", "disk");
const cleanupDelayMs = 10;
const testKeyDefinition = new KeyDefinition<TestState>(testStateDefinition, "fake", {
deserializer: TestState.fromJSON,
cleanupDelayMs,
});
describe("DefaultActiveUserState", () => {
@@ -56,10 +57,14 @@ describe("DefaultActiveUserState", () => {
);
});
const makeUserId = (id: string) => {
return id != null ? (`00000000-0000-1000-a000-00000000000${id}` as UserId) : undefined;
};
const changeActiveUser = async (id: string) => {
const userId = id != null ? `00000000-0000-1000-a000-00000000000${id}` : undefined;
const userId = makeUserId(id);
activeAccountSubject.next({
id: userId as UserId,
id: userId,
email: `test${id}@example.com`,
name: `Test User ${id}`,
status: AuthenticationStatus.Unlocked,
@@ -90,7 +95,7 @@ describe("DefaultActiveUserState", () => {
const emissions = trackEmissions(userState.state$);
// User signs in
changeActiveUser("1");
await changeActiveUser("1");
await awaitAsync();
// Service does an update
@@ -111,17 +116,17 @@ describe("DefaultActiveUserState", () => {
expect(diskStorageService.mock.get).toHaveBeenNthCalledWith(
1,
"user_00000000-0000-1000-a000-000000000001_fake_fake",
any(),
any(), // options
);
expect(diskStorageService.mock.get).toHaveBeenNthCalledWith(
2,
"user_00000000-0000-1000-a000-000000000001_fake_fake",
any(),
any(), // options
);
expect(diskStorageService.mock.get).toHaveBeenNthCalledWith(
3,
"user_00000000-0000-1000-a000-000000000002_fake_fake",
any(),
any(), // options
);
// Should only have saved data for the first user
@@ -129,7 +134,8 @@ describe("DefaultActiveUserState", () => {
expect(diskStorageService.mock.save).toHaveBeenNthCalledWith(
1,
"user_00000000-0000-1000-a000-000000000001_fake_fake",
any(),
updatedState,
any(), // options
);
});
@@ -183,15 +189,17 @@ describe("DefaultActiveUserState", () => {
});
it("should not emit a previous users value if that user is no longer active", async () => {
diskStorageService.internalUpdateStore({
"user_00000000-0000-1000-a000-000000000001_fake_fake": {
const user1Data: Jsonify<TestState> = {
date: "2020-09-21T13:14:17.648Z",
array: ["value"],
} as Jsonify<TestState>,
"user_00000000-0000-1000-a000-000000000002_fake_fake": {
};
const user2Data: Jsonify<TestState> = {
date: "2020-09-21T13:14:17.648Z",
array: [],
} as Jsonify<TestState>,
};
diskStorageService.internalUpdateStore({
"user_00000000-0000-1000-a000-000000000001_fake_fake": user1Data,
"user_00000000-0000-1000-a000-000000000002_fake_fake": user2Data,
});
// This starts one subscription on the observable for tracking emissions throughout
@@ -203,7 +211,7 @@ describe("DefaultActiveUserState", () => {
// This should always return a value right await
const value = await firstValueFrom(userState.state$);
expect(value).toBeTruthy();
expect(value).toEqual(user1Data);
// Make it such that there is no active user
await changeActiveUser(undefined);
@@ -222,20 +230,34 @@ describe("DefaultActiveUserState", () => {
rejectedError = err;
});
expect(resolvedValue).toBeFalsy();
expect(rejectedError).toBeTruthy();
expect(resolvedValue).toBeUndefined();
expect(rejectedError).not.toBeUndefined();
expect(rejectedError.message).toBe("Timeout has occurred");
// We need to figure out if something should be emitted
// when there becomes no active user, if we don't want that to emit
// this value is correct.
expect(emissions).toHaveLength(2);
expect(emissions).toEqual([user1Data]);
});
it("should not emit twice if there are two listeners", async () => {
await changeActiveUser("1");
const emissions = trackEmissions(userState.state$);
const emissions2 = trackEmissions(userState.state$);
await awaitAsync();
expect(emissions).toEqual([
null, // Initial value
]);
expect(emissions2).toEqual([
null, // Initial value
]);
});
describe("update", () => {
const newData = { date: new Date(), array: ["test"] };
beforeEach(async () => {
changeActiveUser("1");
await changeActiveUser("1");
});
it("should save on update", async () => {
@@ -315,6 +337,8 @@ describe("DefaultActiveUserState", () => {
return initialData;
});
await awaitAsync();
await userState.update((state, dependencies) => {
expect(state).toEqual(initialData);
return newData;
@@ -329,4 +353,303 @@ describe("DefaultActiveUserState", () => {
]);
});
});
describe("update races", () => {
const newData = { date: new Date(), array: ["test"] };
const userId = makeUserId("1");
beforeEach(async () => {
await changeActiveUser("1");
await awaitAsync();
});
test("subscriptions during an update should receive the current and latest", async () => {
const oldData = { date: new Date(2019, 1, 1), array: ["oldValue1"] };
await userState.update(() => {
return oldData;
});
const initialData = { date: new Date(2020, 1, 1), array: ["value1", "value2"] };
await userState.update(() => {
return initialData;
});
await awaitAsync();
const emissions = trackEmissions(userState.state$);
await awaitAsync();
expect(emissions).toEqual([initialData]);
let emissions2: TestState[];
const originalSave = diskStorageService.save.bind(diskStorageService);
diskStorageService.save = jest.fn().mockImplementation(async (key: string, obj: any) => {
emissions2 = trackEmissions(userState.state$);
await originalSave(key, obj);
});
const val = await userState.update(() => {
return newData;
});
await awaitAsync(10);
expect(val).toEqual(newData);
expect(emissions).toEqual([initialData, newData]);
expect(emissions2).toEqual([initialData, newData]);
});
test("subscription during an aborted update should receive the last value", async () => {
// Seed with interesting data
const initialData = { date: new Date(2020, 1, 1), array: ["value1", "value2"] };
await userState.update(() => {
return initialData;
});
await awaitAsync();
const emissions = trackEmissions(userState.state$);
await awaitAsync();
expect(emissions).toEqual([initialData]);
let emissions2: TestState[];
const val = await userState.update(
(state) => {
return newData;
},
{
shouldUpdate: () => {
emissions2 = trackEmissions(userState.state$);
return false;
},
},
);
await awaitAsync();
expect(val).toEqual(initialData);
expect(emissions).toEqual([initialData]);
expect(emissions2).toEqual([initialData]);
});
test("updates should wait until previous update is complete", async () => {
trackEmissions(userState.state$);
await awaitAsync(); // storage updates are behind a promise
const originalSave = diskStorageService.save.bind(diskStorageService);
diskStorageService.save = jest
.fn()
.mockImplementationOnce(async (key: string, obj: any) => {
let resolved = false;
await Promise.race([
userState.update(() => {
// deadlocks
resolved = true;
return newData;
}),
awaitAsync(100), // limit test to 100ms
]);
expect(resolved).toBe(false);
})
.mockImplementation((...args) => {
return originalSave(...args);
});
await userState.update(() => {
return newData;
});
});
test("updates with FAKE_DEFAULT initial value should resolve correctly", async () => {
expect(userState["stateSubject"].value).toEqual(anySymbol()); // FAKE_DEFAULT
const val = await userState.update((state) => {
return newData;
});
expect(val).toEqual(newData);
const call = diskStorageService.mock.save.mock.calls[0];
expect(call[0]).toEqual(`user_${userId}_fake_fake`);
expect(call[1]).toEqual(newData);
});
it("does not await updates if the active user changes", async () => {
const initialUserId = (await firstValueFrom(accountService.activeAccount$)).id;
expect(initialUserId).toBe(userId);
trackEmissions(userState.state$);
await awaitAsync(); // storage updates are behind a promise
const originalSave = diskStorageService.save.bind(diskStorageService);
diskStorageService.save = jest
.fn()
.mockImplementationOnce(async (key: string, obj: any) => {
let resolved = false;
await changeActiveUser("2");
await Promise.race([
userState.update(() => {
// should not deadlock because we updated the user
resolved = true;
return newData;
}),
awaitAsync(100), // limit test to 100ms
]);
expect(resolved).toBe(true);
})
.mockImplementation((...args) => {
return originalSave(...args);
});
await userState.update(() => {
return newData;
});
});
it("stores updates for users in the correct place when active user changes mid-update", async () => {
trackEmissions(userState.state$);
await awaitAsync(); // storage updates are behind a promise
const user2Data = { date: new Date(), array: ["user 2 data"] };
const originalSave = diskStorageService.save.bind(diskStorageService);
diskStorageService.save = jest
.fn()
.mockImplementationOnce(async (key: string, obj: any) => {
let resolved = false;
await changeActiveUser("2");
await Promise.race([
userState.update(() => {
// should not deadlock because we updated the user
resolved = true;
return user2Data;
}),
awaitAsync(100), // limit test to 100ms
]);
expect(resolved).toBe(true);
await originalSave(key, obj);
})
.mockImplementation((...args) => {
return originalSave(...args);
});
await userState.update(() => {
return newData;
});
await awaitAsync();
expect(diskStorageService.mock.save).toHaveBeenCalledTimes(2);
const innerCall = diskStorageService.mock.save.mock.calls[0];
expect(innerCall[0]).toEqual(`user_${makeUserId("2")}_fake_fake`);
expect(innerCall[1]).toEqual(user2Data);
const outerCall = diskStorageService.mock.save.mock.calls[1];
expect(outerCall[0]).toEqual(`user_${makeUserId("1")}_fake_fake`);
expect(outerCall[1]).toEqual(newData);
});
});
describe("cleanup", () => {
const newData = { date: new Date(), array: ["test"] };
const userId = makeUserId("1");
let userKey: string;
beforeEach(async () => {
await changeActiveUser("1");
userKey = userKeyBuilder(userId, testKeyDefinition);
});
async function assertClean() {
const emissions = trackEmissions(userState["stateSubject"]);
const initial = structuredClone(emissions);
diskStorageService.save(userKey, newData);
await awaitAsync(); // storage updates are behind a promise
expect(emissions).toEqual(initial); // no longer listening to storage updates
}
it("should cleanup after last subscriber", async () => {
const subscription = userState.state$.subscribe();
await awaitAsync(); // storage updates are behind a promise
subscription.unsubscribe();
expect(userState["subscriberCount"].getValue()).toBe(0);
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
await assertClean();
});
it("should not cleanup if there are still subscribers", async () => {
const subscription1 = userState.state$.subscribe();
const sub2Emissions: TestState[] = [];
const subscription2 = userState.state$.subscribe((v) => sub2Emissions.push(v));
await awaitAsync(); // storage updates are behind a promise
subscription1.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
expect(userState["subscriberCount"].getValue()).toBe(1);
// Still be listening to storage updates
diskStorageService.save(userKey, newData);
await awaitAsync(); // storage updates are behind a promise
expect(sub2Emissions).toEqual([null, newData]);
subscription2.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
await assertClean();
});
it("can re-initialize after cleanup", async () => {
const subscription = userState.state$.subscribe();
await awaitAsync();
subscription.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
const emissions = trackEmissions(userState.state$);
await awaitAsync();
diskStorageService.save(userKey, newData);
await awaitAsync();
expect(emissions).toEqual([null, newData]);
});
it("should not cleanup if a subscriber joins during the cleanup delay", async () => {
const subscription = userState.state$.subscribe();
await awaitAsync();
await diskStorageService.save(userKey, newData);
await awaitAsync();
subscription.unsubscribe();
expect(userState["subscriberCount"].getValue()).toBe(0);
// Do not wait long enough for cleanup
await awaitAsync(cleanupDelayMs / 2);
expect(userState["stateSubject"].value).toEqual(newData); // digging in to check that it hasn't been cleared
expect(userState["storageUpdateSubscription"]).not.toBeNull(); // still listening to storage updates
});
it("state$ observables are durable to cleanup", async () => {
const observable = userState.state$;
let subscription = observable.subscribe();
await diskStorageService.save(userKey, newData);
await awaitAsync();
subscription.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
subscription = observable.subscribe();
await diskStorageService.save(userKey, newData);
await awaitAsync();
expect(await firstValueFrom(observable)).toEqual(newData);
});
});
});

View File

@@ -4,12 +4,12 @@ import {
map,
shareReplay,
switchMap,
tap,
defer,
firstValueFrom,
combineLatestWith,
filter,
timeout,
Subscription,
tap,
} from "rxjs";
import { AccountService } from "../../../auth/abstractions/account.service";
@@ -31,13 +31,22 @@ const FAKE_DEFAULT = Symbol("fakeDefault");
export class DefaultActiveUserState<T> implements ActiveUserState<T> {
[activeMarker]: true;
private formattedKey$: Observable<string>;
private updatePromise: Promise<T> | null = null;
private storageUpdateSubscription: Subscription;
private activeAccountUpdateSubscription: Subscription;
private subscriberCount = new BehaviorSubject<number>(0);
private stateObservable: Observable<T>;
private reinitialize = false;
protected stateSubject: BehaviorSubject<T | typeof FAKE_DEFAULT> = new BehaviorSubject<
T | typeof FAKE_DEFAULT
>(FAKE_DEFAULT);
private stateSubject$ = this.stateSubject.asObservable();
state$: Observable<T>;
get state$() {
this.stateObservable = this.stateObservable ?? this.initializeObservable();
return this.stateObservable;
}
constructor(
protected keyDefinition: KeyDefinition<T>,
@@ -51,62 +60,12 @@ export class DefaultActiveUserState<T> implements ActiveUserState<T> {
? userKeyBuilder(account.id, this.keyDefinition)
: null,
),
tap(() => {
// We have a new key, so we should forget about previous update promises
this.updatePromise = null;
}),
shareReplay({ bufferSize: 1, refCount: false }),
);
const activeAccountData$ = this.formattedKey$.pipe(
switchMap(async (key) => {
if (key == null) {
return FAKE_DEFAULT;
}
return await getStoredValue(
key,
this.chosenStorageLocation,
this.keyDefinition.deserializer,
);
}),
// Share the execution
shareReplay({ refCount: false, bufferSize: 1 }),
);
const storageUpdates$ = this.chosenStorageLocation.updates$.pipe(
combineLatestWith(this.formattedKey$),
filter(([update, key]) => key !== null && update.key === key),
switchMap(async ([update, key]) => {
if (update.updateType === "remove") {
return null;
}
const data = await getStoredValue(
key,
this.chosenStorageLocation,
this.keyDefinition.deserializer,
);
return data;
}),
);
// Whomever subscribes to this data, should be notified of updated data
// if someone calls my update() method, or the active user changes.
this.state$ = defer(() => {
const accountChangeSubscription = activeAccountData$.subscribe((data) => {
this.stateSubject.next(data);
});
const storageUpdateSubscription = storageUpdates$.subscribe((data) => {
this.stateSubject.next(data);
});
return this.stateSubject$.pipe(
tap({
complete: () => {
accountChangeSubscription.unsubscribe();
storageUpdateSubscription.unsubscribe();
},
}),
);
})
// I fake the generic here because I am filtering out the other union type
// and this makes it so that typescript understands the true type
.pipe(filter<T>((value) => value != FAKE_DEFAULT));
}
async update<TCombine>(
@@ -114,8 +73,34 @@ export class DefaultActiveUserState<T> implements ActiveUserState<T> {
options: StateUpdateOptions<T, TCombine> = {},
): Promise<T> {
options = populateOptionsWithDefault(options);
try {
if (this.updatePromise != null) {
await this.updatePromise;
}
this.updatePromise = this.internalUpdate(configureState, options);
const newState = await this.updatePromise;
return newState;
} finally {
this.updatePromise = null;
}
}
// TODO: this should be removed
async getFromState(): Promise<T> {
const key = await this.createKey();
const currentState = await this.getGuaranteedState(key);
return await getStoredValue(key, this.chosenStorageLocation, this.keyDefinition.deserializer);
}
createDerived<TTo>(converter: Converter<T, TTo>): DerivedUserState<TTo> {
return new DefaultDerivedUserState<T, TTo>(converter, this.encryptService, this);
}
private async internalUpdate<TCombine>(
configureState: (state: T, dependency: TCombine) => T,
options: StateUpdateOptions<T, TCombine>,
) {
const key = await this.createKey();
const currentState = await this.getStateForUpdate(key);
const combinedDependencies =
options.combineLatestWith != null
? await firstValueFrom(options.combineLatestWith.pipe(timeout(options.msTimeout)))
@@ -130,13 +115,59 @@ export class DefaultActiveUserState<T> implements ActiveUserState<T> {
return newState;
}
async getFromState(): Promise<T> {
const key = await this.createKey();
return await getStoredValue(key, this.chosenStorageLocation, this.keyDefinition.deserializer);
private initializeObservable() {
this.storageUpdateSubscription = this.chosenStorageLocation.updates$
.pipe(
combineLatestWith(this.formattedKey$),
filter(([update, key]) => key !== null && update.key === key),
switchMap(async ([update, key]) => {
if (update.updateType === "remove") {
return null;
}
return await this.getState(key);
}),
)
.subscribe((v) => this.stateSubject.next(v));
this.activeAccountUpdateSubscription = this.formattedKey$
.pipe(
switchMap(async (key) => {
if (key == null) {
return FAKE_DEFAULT;
}
return await this.getState(key);
}),
)
.subscribe((v) => this.stateSubject.next(v));
this.subscriberCount.subscribe((count) => {
if (count === 0 && this.stateObservable != null) {
this.triggerCleanup();
}
});
return new Observable<T>((subscriber) => {
this.incrementSubscribers();
// reinitialize listeners after cleanup
if (this.reinitialize) {
this.reinitialize = false;
this.initializeObservable();
}
createDerived<TTo>(converter: Converter<T, TTo>): DerivedUserState<TTo> {
return new DefaultDerivedUserState<T, TTo>(converter, this.encryptService, this);
const prevUnsubscribe = subscriber.unsubscribe.bind(subscriber);
subscriber.unsubscribe = () => {
this.decrementSubscribers();
prevUnsubscribe();
};
return this.stateSubject
.pipe(
// Filter out fake default, which is used to indicate that state is not ready to be emitted yet.
filter((i) => i !== FAKE_DEFAULT),
)
.subscribe(subscriber);
});
}
protected async createKey(): Promise<string> {
@@ -147,22 +178,47 @@ export class DefaultActiveUserState<T> implements ActiveUserState<T> {
return formattedKey;
}
protected async getGuaranteedState(key: string) {
/** For use in update methods, does not wait for update to complete before yielding state.
* The expectation is that that await is already done
*/
protected async getStateForUpdate(key: string) {
const currentValue = this.stateSubject.getValue();
return currentValue === FAKE_DEFAULT ? await this.seedInitial(key) : currentValue;
return currentValue === FAKE_DEFAULT
? await getStoredValue(key, this.chosenStorageLocation, this.keyDefinition.deserializer)
: currentValue;
}
private async seedInitial(key: string): Promise<T> {
const value = await getStoredValue(
key,
this.chosenStorageLocation,
this.keyDefinition.deserializer,
);
this.stateSubject.next(value);
return value;
/** To be used in observables. Awaits updates to ensure they are complete */
private async getState(key: string): Promise<T> {
if (this.updatePromise != null) {
await this.updatePromise;
}
return await getStoredValue(key, this.chosenStorageLocation, this.keyDefinition.deserializer);
}
protected saveToStorage(key: string, data: T): Promise<void> {
return this.chosenStorageLocation.save(key, data);
}
private incrementSubscribers() {
this.subscriberCount.next(this.subscriberCount.value + 1);
}
private decrementSubscribers() {
this.subscriberCount.next(this.subscriberCount.value - 1);
}
private triggerCleanup() {
setTimeout(() => {
if (this.subscriberCount.value === 0) {
this.updatePromise = null;
this.storageUpdateSubscription?.unsubscribe();
this.activeAccountUpdateSubscription?.unsubscribe();
this.subscriberCount.complete();
this.subscriberCount = new BehaviorSubject<number>(0);
this.stateSubject.next(FAKE_DEFAULT);
this.reinitialize = true;
}
}, this.keyDefinition.cleanupDelayMs);
}
}

View File

@@ -3,6 +3,7 @@
* @jest-environment ../shared/test.environment.ts
*/
import { anySymbol } from "jest-mock-extended";
import { firstValueFrom, of } from "rxjs";
import { Jsonify } from "type-fest";
@@ -28,9 +29,10 @@ class TestState {
}
const testStateDefinition = new StateDefinition("fake", "disk");
const cleanupDelayMs = 10;
const testKeyDefinition = new KeyDefinition<TestState>(testStateDefinition, "fake", {
deserializer: TestState.fromJSON,
cleanupDelayMs,
});
const globalKey = globalKeyBuilder(testKeyDefinition);
@@ -79,6 +81,19 @@ describe("DefaultGlobalState", () => {
expect(diskStorageService.mock.get).toHaveBeenCalledWith("global_fake_fake", undefined);
expect(state).toBeTruthy();
});
it("should not emit twice if there are two listeners", async () => {
const emissions = trackEmissions(globalState.state$);
const emissions2 = trackEmissions(globalState.state$);
await awaitAsync();
expect(emissions).toEqual([
null, // Initial value
]);
expect(emissions2).toEqual([
null, // Initial value
]);
});
});
describe("update", () => {
@@ -133,6 +148,7 @@ describe("DefaultGlobalState", () => {
it("should not update if shouldUpdate returns false", async () => {
const emissions = trackEmissions(globalState.state$);
await awaitAsync(); // storage updates are behind a promise
const result = await globalState.update(
(state) => {
@@ -198,4 +214,212 @@ describe("DefaultGlobalState", () => {
expect(emissions).toEqual(expect.arrayContaining([initialState, newState]));
});
});
describe("update races", () => {
test("subscriptions during an update should receive the current and latest data", async () => {
const oldData = { date: new Date(2019, 1, 1) };
await globalState.update(() => {
return oldData;
});
const initialData = { date: new Date(2020, 1, 1) };
await globalState.update(() => {
return initialData;
});
await awaitAsync();
const emissions = trackEmissions(globalState.state$);
await awaitAsync();
expect(emissions).toEqual([initialData]);
let emissions2: TestState[];
const originalSave = diskStorageService.save.bind(diskStorageService);
diskStorageService.save = jest.fn().mockImplementation(async (key: string, obj: any) => {
emissions2 = trackEmissions(globalState.state$);
await originalSave(key, obj);
});
const val = await globalState.update(() => {
return newData;
});
await awaitAsync(10);
expect(val).toEqual(newData);
expect(emissions).toEqual([initialData, newData]);
expect(emissions2).toEqual([initialData, newData]);
});
test("subscription during an aborted update should receive the last value", async () => {
// Seed with interesting data
const initialData = { date: new Date(2020, 1, 1) };
await globalState.update(() => {
return initialData;
});
await awaitAsync();
const emissions = trackEmissions(globalState.state$);
await awaitAsync();
expect(emissions).toEqual([initialData]);
let emissions2: TestState[];
const val = await globalState.update(
() => {
return newData;
},
{
shouldUpdate: () => {
emissions2 = trackEmissions(globalState.state$);
return false;
},
},
);
await awaitAsync();
expect(val).toEqual(initialData);
expect(emissions).toEqual([initialData]);
expect(emissions2).toEqual([initialData]);
});
test("updates should wait until previous update is complete", async () => {
trackEmissions(globalState.state$);
await awaitAsync(); // storage updates are behind a promise
const originalSave = diskStorageService.save.bind(diskStorageService);
diskStorageService.save = jest
.fn()
.mockImplementationOnce(async () => {
let resolved = false;
await Promise.race([
globalState.update(() => {
// deadlocks
resolved = true;
return newData;
}),
awaitAsync(100), // limit test to 100ms
]);
expect(resolved).toBe(false);
})
.mockImplementation(originalSave);
await globalState.update((state) => {
return newData;
});
});
test("updates with FAKE_DEFAULT initial value should resolve correctly", async () => {
expect(globalState["stateSubject"].value).toEqual(anySymbol()); // FAKE_DEFAULT
const val = await globalState.update((state) => {
return newData;
});
expect(val).toEqual(newData);
const call = diskStorageService.mock.save.mock.calls[0];
expect(call[0]).toEqual("global_fake_fake");
expect(call[1]).toEqual(newData);
});
});
describe("cleanup", () => {
async function assertClean() {
const emissions = trackEmissions(globalState["stateSubject"]);
const initial = structuredClone(emissions);
diskStorageService.save(globalKey, newData);
await awaitAsync(); // storage updates are behind a promise
expect(emissions).toEqual(initial); // no longer listening to storage updates
}
it("should cleanup after last subscriber", async () => {
const subscription = globalState.state$.subscribe();
await awaitAsync(); // storage updates are behind a promise
subscription.unsubscribe();
expect(globalState["subscriberCount"].getValue()).toBe(0);
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
await assertClean();
});
it("should not cleanup if there are still subscribers", async () => {
const subscription1 = globalState.state$.subscribe();
const sub2Emissions: TestState[] = [];
const subscription2 = globalState.state$.subscribe((v) => sub2Emissions.push(v));
await awaitAsync(); // storage updates are behind a promise
subscription1.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
expect(globalState["subscriberCount"].getValue()).toBe(1);
// Still be listening to storage updates
diskStorageService.save(globalKey, newData);
await awaitAsync(); // storage updates are behind a promise
expect(sub2Emissions).toEqual([null, newData]);
subscription2.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
await assertClean();
});
it("can re-initialize after cleanup", async () => {
const subscription = globalState.state$.subscribe();
await awaitAsync();
subscription.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
const emissions = trackEmissions(globalState.state$);
await awaitAsync();
diskStorageService.save(globalKey, newData);
await awaitAsync();
expect(emissions).toEqual([null, newData]);
});
it("should not cleanup if a subscriber joins during the cleanup delay", async () => {
const subscription = globalState.state$.subscribe();
await awaitAsync();
await diskStorageService.save(globalKey, newData);
await awaitAsync();
subscription.unsubscribe();
expect(globalState["subscriberCount"].getValue()).toBe(0);
// Do not wait long enough for cleanup
await awaitAsync(cleanupDelayMs / 2);
expect(globalState["stateSubject"].value).toEqual(newData); // digging in to check that it hasn't been cleared
expect(globalState["storageUpdateSubscription"]).not.toBeNull(); // still listening to storage updates
});
it("state$ observables are durable to cleanup", async () => {
const observable = globalState.state$;
let subscription = observable.subscribe();
await diskStorageService.save(globalKey, newData);
await awaitAsync();
subscription.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
subscription = observable.subscribe();
await diskStorageService.save(globalKey, newData);
await awaitAsync();
expect(await firstValueFrom(observable)).toEqual(newData);
});
});
});

View File

@@ -1,12 +1,10 @@
import {
BehaviorSubject,
Observable,
defer,
Subscription,
filter,
firstValueFrom,
shareReplay,
switchMap,
tap,
timeout,
} from "rxjs";
@@ -23,54 +21,26 @@ const FAKE_DEFAULT = Symbol("fakeDefault");
export class DefaultGlobalState<T> implements GlobalState<T> {
private storageKey: string;
private updatePromise: Promise<T> | null = null;
private storageUpdateSubscription: Subscription;
private subscriberCount = new BehaviorSubject<number>(0);
private stateObservable: Observable<T>;
private reinitialize = false;
protected stateSubject: BehaviorSubject<T | typeof FAKE_DEFAULT> = new BehaviorSubject<
T | typeof FAKE_DEFAULT
>(FAKE_DEFAULT);
state$: Observable<T>;
get state$() {
this.stateObservable = this.stateObservable ?? this.initializeObservable();
return this.stateObservable;
}
constructor(
private keyDefinition: KeyDefinition<T>,
private chosenLocation: AbstractStorageService & ObservableStorageService,
) {
this.storageKey = globalKeyBuilder(this.keyDefinition);
const storageUpdates$ = this.chosenLocation.updates$.pipe(
filter((update) => update.key === this.storageKey),
switchMap(async (update) => {
if (update.updateType === "remove") {
return null;
}
return await getStoredValue(
this.storageKey,
this.chosenLocation,
this.keyDefinition.deserializer,
);
}),
shareReplay({ bufferSize: 1, refCount: false }),
);
this.state$ = defer(() => {
const storageUpdateSubscription = storageUpdates$.subscribe((value) => {
this.stateSubject.next(value);
});
this.getFromState().then((s) => {
this.stateSubject.next(s);
});
return this.stateSubject.pipe(
tap({
complete: () => {
storageUpdateSubscription.unsubscribe();
},
}),
);
}).pipe(
shareReplay({ refCount: false, bufferSize: 1 }),
filter<T>((i) => i != FAKE_DEFAULT),
);
}
async update<TCombine>(
@@ -78,7 +48,24 @@ export class DefaultGlobalState<T> implements GlobalState<T> {
options: StateUpdateOptions<T, TCombine> = {},
): Promise<T> {
options = populateOptionsWithDefault(options);
const currentState = await this.getGuaranteedState();
if (this.updatePromise != null) {
await this.updatePromise;
}
try {
this.updatePromise = this.internalUpdate(configureState, options);
const newState = await this.updatePromise;
return newState;
} finally {
this.updatePromise = null;
}
}
private async internalUpdate<TCombine>(
configureState: (state: T, dependency: TCombine) => T,
options: StateUpdateOptions<T, TCombine>,
): Promise<T> {
const currentState = await this.getStateForUpdate();
const combinedDependencies =
options.combineLatestWith != null
? await firstValueFrom(options.combineLatestWith.pipe(timeout(options.msTimeout)))
@@ -93,16 +80,94 @@ export class DefaultGlobalState<T> implements GlobalState<T> {
return newState;
}
private async getGuaranteedState() {
private initializeObservable() {
this.storageUpdateSubscription = this.chosenLocation.updates$
.pipe(
filter((update) => update.key === this.storageKey),
switchMap(async (update) => {
if (update.updateType === "remove") {
return null;
}
return await this.getFromState();
}),
)
.subscribe((v) => this.stateSubject.next(v));
this.subscriberCount.subscribe((count) => {
if (count === 0 && this.stateObservable != null) {
this.triggerCleanup();
}
});
// Intentionally un-awaited promise, we don't want to delay return of observable, but we do want to
// trigger populating it immediately.
this.getFromState().then((s) => {
this.stateSubject.next(s);
});
return new Observable<T>((subscriber) => {
this.incrementSubscribers();
// reinitialize listeners after cleanup
if (this.reinitialize) {
this.reinitialize = false;
this.initializeObservable();
}
const prevUnsubscribe = subscriber.unsubscribe.bind(subscriber);
subscriber.unsubscribe = () => {
this.decrementSubscribers();
prevUnsubscribe();
};
return this.stateSubject
.pipe(
// Filter out fake default, which is used to indicate that state is not ready to be emitted yet.
filter<T>((i) => i != FAKE_DEFAULT),
)
.subscribe(subscriber);
});
}
/** For use in update methods, does not wait for update to complete before yielding state.
* The expectation is that that await is already done
*/
private async getStateForUpdate() {
const currentValue = this.stateSubject.getValue();
return currentValue === FAKE_DEFAULT ? await this.getFromState() : currentValue;
return currentValue === FAKE_DEFAULT
? await getStoredValue(this.storageKey, this.chosenLocation, this.keyDefinition.deserializer)
: currentValue;
}
async getFromState(): Promise<T> {
if (this.updatePromise != null) {
return await this.updatePromise;
}
return await getStoredValue(
this.storageKey,
this.chosenLocation,
this.keyDefinition.deserializer,
);
}
private incrementSubscribers() {
this.subscriberCount.next(this.subscriberCount.value + 1);
}
private decrementSubscribers() {
this.subscriberCount.next(this.subscriberCount.value - 1);
}
private triggerCleanup() {
setTimeout(() => {
if (this.subscriberCount.value === 0) {
this.updatePromise = null;
this.storageUpdateSubscription.unsubscribe();
this.subscriberCount.complete();
this.subscriberCount = new BehaviorSubject<number>(0);
this.stateSubject.next(FAKE_DEFAULT);
this.reinitialize = true;
}
}, this.keyDefinition.cleanupDelayMs);
}
}

View File

@@ -3,6 +3,7 @@
* @jest-environment ../shared/test.environment.ts
*/
import { anySymbol } from "jest-mock-extended";
import { firstValueFrom, of } from "rxjs";
import { Jsonify } from "type-fest";
@@ -30,21 +31,22 @@ class TestState {
}
const testStateDefinition = new StateDefinition("fake", "disk");
const cleanupDelayMs = 10;
const testKeyDefinition = new KeyDefinition<TestState>(testStateDefinition, "fake", {
deserializer: TestState.fromJSON,
cleanupDelayMs,
});
const userId = Utils.newGuid() as UserId;
const userKey = userKeyBuilder(userId, testKeyDefinition);
describe("DefaultSingleUserState", () => {
let diskStorageService: FakeStorageService;
let globalState: DefaultSingleUserState<TestState>;
let userState: DefaultSingleUserState<TestState>;
const newData = { date: new Date() };
beforeEach(() => {
diskStorageService = new FakeStorageService();
globalState = new DefaultSingleUserState(
userState = new DefaultSingleUserState(
userId,
testKeyDefinition,
null, // Not testing anything with encrypt service
@@ -58,7 +60,7 @@ describe("DefaultSingleUserState", () => {
describe("state$", () => {
it("should emit when storage updates", async () => {
const emissions = trackEmissions(globalState.state$);
const emissions = trackEmissions(userState.state$);
await diskStorageService.save(userKey, newData);
await awaitAsync();
@@ -69,7 +71,7 @@ describe("DefaultSingleUserState", () => {
});
it("should not emit when update key does not match", async () => {
const emissions = trackEmissions(globalState.state$);
const emissions = trackEmissions(userState.state$);
await diskStorageService.save("wrong_key", newData);
expect(emissions).toHaveLength(0);
@@ -82,7 +84,7 @@ describe("DefaultSingleUserState", () => {
});
diskStorageService.internalUpdateStore(initialStorage);
const state = await firstValueFrom(globalState.state$);
const state = await firstValueFrom(userState.state$);
expect(diskStorageService.mock.get).toHaveBeenCalledTimes(1);
expect(diskStorageService.mock.get).toHaveBeenCalledWith(
`user_${userId}_fake_fake`,
@@ -94,7 +96,7 @@ describe("DefaultSingleUserState", () => {
describe("update", () => {
it("should save on update", async () => {
const result = await globalState.update((state) => {
const result = await userState.update((state) => {
return newData;
});
@@ -103,10 +105,10 @@ describe("DefaultSingleUserState", () => {
});
it("should emit once per update", async () => {
const emissions = trackEmissions(globalState.state$);
const emissions = trackEmissions(userState.state$);
await awaitAsync(); // storage updates are behind a promise
await globalState.update((state) => {
await userState.update((state) => {
return newData;
});
@@ -119,12 +121,12 @@ describe("DefaultSingleUserState", () => {
});
it("should provided combined dependencies", async () => {
const emissions = trackEmissions(globalState.state$);
const emissions = trackEmissions(userState.state$);
await awaitAsync(); // storage updates are behind a promise
const combinedDependencies = { date: new Date() };
await globalState.update(
await userState.update(
(state, dependencies) => {
expect(dependencies).toEqual(combinedDependencies);
return newData;
@@ -143,9 +145,10 @@ describe("DefaultSingleUserState", () => {
});
it("should not update if shouldUpdate returns false", async () => {
const emissions = trackEmissions(globalState.state$);
const emissions = trackEmissions(userState.state$);
await awaitAsync(); // storage updates are behind a promise
const result = await globalState.update(
const result = await userState.update(
(state) => {
return newData;
},
@@ -160,18 +163,18 @@ describe("DefaultSingleUserState", () => {
});
it("should provide the update callback with the current State", async () => {
const emissions = trackEmissions(globalState.state$);
const emissions = trackEmissions(userState.state$);
await awaitAsync(); // storage updates are behind a promise
// Seed with interesting data
const initialData = { date: new Date(2020, 1, 1) };
await globalState.update((state, dependencies) => {
await userState.update((state, dependencies) => {
return initialData;
});
await awaitAsync();
await globalState.update((state) => {
await userState.update((state) => {
expect(state).toEqual(initialData);
return newData;
});
@@ -193,14 +196,14 @@ describe("DefaultSingleUserState", () => {
initialStorage[userKey] = initialState;
diskStorageService.internalUpdateStore(initialStorage);
const emissions = trackEmissions(globalState.state$);
const emissions = trackEmissions(userState.state$);
await awaitAsync(); // storage updates are behind a promise
const newState = {
...initialState,
date: new Date(initialState.date.getFullYear(), initialState.date.getMonth() + 1),
};
const actual = await globalState.update((existingState) => newState);
const actual = await userState.update((existingState) => newState);
await awaitAsync();
@@ -209,4 +212,212 @@ describe("DefaultSingleUserState", () => {
expect(emissions).toEqual(expect.arrayContaining([initialState, newState]));
});
});
describe("update races", () => {
test("subscriptions during an update should receive the current and latest data", async () => {
const oldData = { date: new Date(2019, 1, 1) };
await userState.update(() => {
return oldData;
});
const initialData = { date: new Date(2020, 1, 1) };
await userState.update(() => {
return initialData;
});
await awaitAsync();
const emissions = trackEmissions(userState.state$);
await awaitAsync();
expect(emissions).toEqual([initialData]);
let emissions2: TestState[];
const originalSave = diskStorageService.save.bind(diskStorageService);
diskStorageService.save = jest.fn().mockImplementation(async (key: string, obj: any) => {
emissions2 = trackEmissions(userState.state$);
await originalSave(key, obj);
});
const val = await userState.update(() => {
return newData;
});
await awaitAsync(10);
expect(val).toEqual(newData);
expect(emissions).toEqual([initialData, newData]);
expect(emissions2).toEqual([initialData, newData]);
});
test("subscription during an aborted update should receive the last value", async () => {
// Seed with interesting data
const initialData = { date: new Date(2020, 1, 1) };
await userState.update(() => {
return initialData;
});
await awaitAsync();
const emissions = trackEmissions(userState.state$);
await awaitAsync();
expect(emissions).toEqual([initialData]);
let emissions2: TestState[];
const val = await userState.update(
(state) => {
return newData;
},
{
shouldUpdate: () => {
emissions2 = trackEmissions(userState.state$);
return false;
},
},
);
await awaitAsync();
expect(val).toEqual(initialData);
expect(emissions).toEqual([initialData]);
expect(emissions2).toEqual([initialData]);
});
test("updates should wait until previous update is complete", async () => {
trackEmissions(userState.state$);
await awaitAsync(); // storage updates are behind a promise
const originalSave = diskStorageService.save.bind(diskStorageService);
diskStorageService.save = jest
.fn()
.mockImplementationOnce(async () => {
let resolved = false;
await Promise.race([
userState.update(() => {
// deadlocks
resolved = true;
return newData;
}),
awaitAsync(100), // limit test to 100ms
]);
expect(resolved).toBe(false);
})
.mockImplementation(originalSave);
await userState.update((state) => {
return newData;
});
});
test("updates with FAKE_DEFAULT initial value should resolve correctly", async () => {
expect(userState["stateSubject"].value).toEqual(anySymbol()); // FAKE_DEFAULT
const val = await userState.update((state) => {
return newData;
});
expect(val).toEqual(newData);
const call = diskStorageService.mock.save.mock.calls[0];
expect(call[0]).toEqual(`user_${userId}_fake_fake`);
expect(call[1]).toEqual(newData);
});
});
describe("cleanup", () => {
async function assertClean() {
const emissions = trackEmissions(userState["stateSubject"]);
const initial = structuredClone(emissions);
diskStorageService.save(userKey, newData);
await awaitAsync(); // storage updates are behind a promise
expect(emissions).toEqual(initial); // no longer listening to storage updates
}
it("should cleanup after last subscriber", async () => {
const subscription = userState.state$.subscribe();
await awaitAsync(); // storage updates are behind a promise
subscription.unsubscribe();
expect(userState["subscriberCount"].getValue()).toBe(0);
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
await assertClean();
});
it("should not cleanup if there are still subscribers", async () => {
const subscription1 = userState.state$.subscribe();
const sub2Emissions: TestState[] = [];
const subscription2 = userState.state$.subscribe((v) => sub2Emissions.push(v));
await awaitAsync(); // storage updates are behind a promise
subscription1.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
expect(userState["subscriberCount"].getValue()).toBe(1);
// Still be listening to storage updates
diskStorageService.save(userKey, newData);
await awaitAsync(); // storage updates are behind a promise
expect(sub2Emissions).toEqual([null, newData]);
subscription2.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
await assertClean();
});
it("can re-initialize after cleanup", async () => {
const subscription = userState.state$.subscribe();
await awaitAsync();
subscription.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
const emissions = trackEmissions(userState.state$);
await awaitAsync();
diskStorageService.save(userKey, newData);
await awaitAsync();
expect(emissions).toEqual([null, newData]);
});
it("should not cleanup if a subscriber joins during the cleanup delay", async () => {
const subscription = userState.state$.subscribe();
await awaitAsync();
await diskStorageService.save(userKey, newData);
await awaitAsync();
subscription.unsubscribe();
expect(userState["subscriberCount"].getValue()).toBe(0);
// Do not wait long enough for cleanup
await awaitAsync(cleanupDelayMs / 2);
expect(userState["stateSubject"].value).toEqual(newData); // digging in to check that it hasn't been cleared
expect(userState["storageUpdateSubscription"]).not.toBeNull(); // still listening to storage updates
});
it("state$ observables are durable to cleanup", async () => {
const observable = userState.state$;
let subscription = observable.subscribe();
await diskStorageService.save(userKey, newData);
await awaitAsync();
subscription.unsubscribe();
// Wait for cleanup
await awaitAsync(cleanupDelayMs * 2);
subscription = observable.subscribe();
await diskStorageService.save(userKey, newData);
await awaitAsync();
expect(await firstValueFrom(observable)).toEqual(newData);
});
});
});

View File

@@ -1,12 +1,10 @@
import {
BehaviorSubject,
Observable,
defer,
Subscription,
filter,
firstValueFrom,
shareReplay,
switchMap,
tap,
timeout,
} from "rxjs";
@@ -23,16 +21,25 @@ import { Converter, SingleUserState } from "../user-state";
import { DefaultDerivedUserState } from "./default-derived-state";
import { getStoredValue } from "./util";
const FAKE_DEFAULT = Symbol("fakeDefault");
export class DefaultSingleUserState<T> implements SingleUserState<T> {
private storageKey: string;
private updatePromise: Promise<T> | null = null;
private storageUpdateSubscription: Subscription;
private subscriberCount = new BehaviorSubject<number>(0);
private stateObservable: Observable<T>;
private reinitialize = false;
protected stateSubject: BehaviorSubject<T | typeof FAKE_DEFAULT> = new BehaviorSubject<
T | typeof FAKE_DEFAULT
>(FAKE_DEFAULT);
state$: Observable<T>;
get state$() {
this.stateObservable = this.stateObservable ?? this.initializeObservable();
return this.stateObservable;
}
constructor(
readonly userId: UserId,
@@ -41,42 +48,6 @@ export class DefaultSingleUserState<T> implements SingleUserState<T> {
private chosenLocation: AbstractStorageService & ObservableStorageService,
) {
this.storageKey = userKeyBuilder(this.userId, this.keyDefinition);
const storageUpdates$ = this.chosenLocation.updates$.pipe(
filter((update) => update.key === this.storageKey),
switchMap(async (update) => {
if (update.updateType === "remove") {
return null;
}
return await getStoredValue(
this.storageKey,
this.chosenLocation,
this.keyDefinition.deserializer,
);
}),
shareReplay({ bufferSize: 1, refCount: false }),
);
this.state$ = defer(() => {
const storageUpdateSubscription = storageUpdates$.subscribe((value) => {
this.stateSubject.next(value);
});
this.getFromState().then((s) => {
this.stateSubject.next(s);
});
return this.stateSubject.pipe(
tap({
complete: () => {
storageUpdateSubscription.unsubscribe();
},
}),
);
}).pipe(
shareReplay({ refCount: false, bufferSize: 1 }),
filter<T>((i) => i != FAKE_DEFAULT),
);
}
async update<TCombine>(
@@ -84,7 +55,28 @@ export class DefaultSingleUserState<T> implements SingleUserState<T> {
options: StateUpdateOptions<T, TCombine> = {},
): Promise<T> {
options = populateOptionsWithDefault(options);
const currentState = await this.getGuaranteedState();
if (this.updatePromise != null) {
await this.updatePromise;
}
try {
this.updatePromise = this.internalUpdate(configureState, options);
const newState = await this.updatePromise;
return newState;
} finally {
this.updatePromise = null;
}
}
createDerived<TTo>(converter: Converter<T, TTo>): DerivedUserState<TTo> {
return new DefaultDerivedUserState<T, TTo>(converter, this.encryptService, this);
}
private async internalUpdate<TCombine>(
configureState: (state: T, dependency: TCombine) => T,
options: StateUpdateOptions<T, TCombine>,
): Promise<T> {
const currentState = await this.getStateForUpdate();
const combinedDependencies =
options.combineLatestWith != null
? await firstValueFrom(options.combineLatestWith.pipe(timeout(options.msTimeout)))
@@ -99,20 +91,94 @@ export class DefaultSingleUserState<T> implements SingleUserState<T> {
return newState;
}
createDerived<TTo>(converter: Converter<T, TTo>): DerivedUserState<TTo> {
return new DefaultDerivedUserState<T, TTo>(converter, this.encryptService, this);
private initializeObservable() {
this.storageUpdateSubscription = this.chosenLocation.updates$
.pipe(
filter((update) => update.key === this.storageKey),
switchMap(async (update) => {
if (update.updateType === "remove") {
return null;
}
return await this.getFromState();
}),
)
.subscribe((v) => this.stateSubject.next(v));
this.subscriberCount.subscribe((count) => {
if (count === 0 && this.stateObservable != null) {
this.triggerCleanup();
}
});
// Intentionally un-awaited promise, we don't want to delay return of observable, but we do want to
// trigger populating it immediately.
this.getFromState().then((s) => {
this.stateSubject.next(s);
});
return new Observable<T>((subscriber) => {
this.incrementSubscribers();
// reinitialize listeners after cleanup
if (this.reinitialize) {
this.reinitialize = false;
this.initializeObservable();
}
private async getGuaranteedState() {
const prevUnsubscribe = subscriber.unsubscribe.bind(subscriber);
subscriber.unsubscribe = () => {
this.decrementSubscribers();
prevUnsubscribe();
};
return this.stateSubject
.pipe(
// Filter out fake default, which is used to indicate that state is not ready to be emitted yet.
filter<T>((i) => i != FAKE_DEFAULT),
)
.subscribe(subscriber);
});
}
/** For use in update methods, does not wait for update to complete before yielding state.
* The expectation is that that await is already done
*/
private async getStateForUpdate() {
const currentValue = this.stateSubject.getValue();
return currentValue === FAKE_DEFAULT ? await this.getFromState() : currentValue;
return currentValue === FAKE_DEFAULT
? await getStoredValue(this.storageKey, this.chosenLocation, this.keyDefinition.deserializer)
: currentValue;
}
async getFromState(): Promise<T> {
if (this.updatePromise != null) {
return await this.updatePromise;
}
return await getStoredValue(
this.storageKey,
this.chosenLocation,
this.keyDefinition.deserializer,
);
}
private incrementSubscribers() {
this.subscriberCount.next(this.subscriberCount.value + 1);
}
private decrementSubscribers() {
this.subscriberCount.next(this.subscriberCount.value - 1);
}
private triggerCleanup() {
setTimeout(() => {
if (this.subscriberCount.value === 0) {
this.updatePromise = null;
this.storageUpdateSubscription.unsubscribe();
this.subscriberCount.complete();
this.subscriberCount = new BehaviorSubject<number>(0);
this.stateSubject.next(FAKE_DEFAULT);
this.reinitialize = true;
}
}, this.keyDefinition.cleanupDelayMs);
}
}

View File

@@ -18,6 +18,37 @@ describe("KeyDefinition", () => {
});
});
describe("cleanupDelayMs", () => {
it("defaults to 1000ms", () => {
const keyDefinition = new KeyDefinition<boolean>(fakeStateDefinition, "fake", {
deserializer: (value) => value,
});
expect(keyDefinition).toBeTruthy();
expect(keyDefinition.cleanupDelayMs).toBe(1000);
});
it("can be overridden", () => {
const keyDefinition = new KeyDefinition<boolean>(fakeStateDefinition, "fake", {
deserializer: (value) => value,
cleanupDelayMs: 500,
});
expect(keyDefinition).toBeTruthy();
expect(keyDefinition.cleanupDelayMs).toBe(500);
});
it.each([0, -1])("throws on 0 or negative (%s)", (testValue: number) => {
expect(
() =>
new KeyDefinition<boolean>(fakeStateDefinition, "fake", {
deserializer: (value) => value,
cleanupDelayMs: testValue,
}),
).toThrow();
});
});
describe("record", () => {
it("runs custom deserializer for each record value", () => {
const recordDefinition = KeyDefinition.record<boolean>(fakeStateDefinition, "fake", {

View File

@@ -19,6 +19,11 @@ type KeyDefinitionOptions<T> = {
* @returns The fully typed version of your state.
*/
readonly deserializer: (jsonValue: Jsonify<T>) => T;
/**
* The number of milliseconds to wait before cleaning up the state after the last subscriber has unsubscribed.
* Defaults to 1000ms.
*/
readonly cleanupDelayMs?: number;
};
/**
@@ -42,8 +47,12 @@ export class KeyDefinition<T> {
private readonly options: KeyDefinitionOptions<T>,
) {
if (options.deserializer == null) {
throw new Error(`'deserializer' is a required property on key ${this.errorKeyName}`);
}
if (options.cleanupDelayMs <= 0) {
throw new Error(
`'deserializer' is a required property on key ${stateDefinition.name} > ${key}`,
`'cleanupDelayMs' must be greater than 0. Value of ${options.cleanupDelayMs} passed to key ${this.errorKeyName} `,
);
}
}
@@ -55,6 +64,13 @@ export class KeyDefinition<T> {
return this.options.deserializer;
}
/**
* Gets the number of milliseconds to wait before cleaning up the state after the last subscriber has unsubscribed.
*/
get cleanupDelayMs() {
return this.options.cleanupDelayMs < 0 ? 0 : this.options.cleanupDelayMs ?? 1000;
}
/**
* Creates a {@link KeyDefinition} for state that is an array.
* @param stateDefinition The state definition to be added to the KeyDefinition
@@ -137,6 +153,10 @@ export class KeyDefinition<T> {
? `${scope}_${userId}_${this.stateDefinition.name}_${this.key}`
: `${scope}_${this.stateDefinition.name}_${this.key}`;
}
private get errorKeyName() {
return `${this.stateDefinition.name} > ${this.key}`;
}
}
export type StorageKey = Opaque<string, "StorageKey">;

View File

@@ -3,6 +3,7 @@ import { ImportResult } from "../models/import-result";
import { BaseImporter } from "./base-importer";
import { Importer } from "./importer";
/** This is the importer for the xml format from pwsafe.org */
export class PasswordSafeXmlImporter extends BaseImporter implements Importer {
parse(data: string): Promise<ImportResult> {
const result = new ImportResult();

View File

@@ -29,7 +29,7 @@ export const regularImportOptions = [
{ id: "enpassjson", name: "Enpass (json)" },
{ id: "protonpass", name: "ProtonPass (zip/json)" },
{ id: "safeincloudxml", name: "SafeInCloud (xml)" },
{ id: "pwsafexml", name: "Password Safe (xml)" },
{ id: "pwsafexml", name: "Password Safe - pwsafe.org (xml)" },
{ id: "stickypasswordxml", name: "Sticky Password (xml)" },
{ id: "msecurecsv", name: "mSecure (csv)" },
{ id: "truekeycsv", name: "True Key (csv)" },

34
package-lock.json generated
View File

@@ -33,7 +33,7 @@
"argon2-browser": "1.18.0",
"big-integer": "1.6.51",
"bootstrap": "4.6.0",
"braintree-web-drop-in": "1.40.0",
"braintree-web-drop-in": "1.41.0",
"bufferutil": "4.0.8",
"chalk": "4.1.2",
"commander": "7.2.0",
@@ -4638,9 +4638,9 @@
}
},
"node_modules/@braintree/browser-detection": {
"version": "1.14.0",
"resolved": "https://registry.npmjs.org/@braintree/browser-detection/-/browser-detection-1.14.0.tgz",
"integrity": "sha512-OsqU+28RhNvSw8Y5JEiUHUrAyn4OpYazFkjSJe8ZVZfkAaRXQc6hsV38MMEpIlkPMig+A68buk/diY+0O8/dMQ=="
"version": "1.17.1",
"resolved": "https://registry.npmjs.org/@braintree/browser-detection/-/browser-detection-1.17.1.tgz",
"integrity": "sha512-Mk7jauyp9pD14BTRS7otoy9dqIJGb3Oy0XtxKM/adGD9i9MAuCjH5uRZMyW2iVmJQTaA/PLlWdG7eSDyMWMc8Q=="
},
"node_modules/@braintree/event-emitter": {
"version": "0.4.1",
@@ -4658,9 +4658,9 @@
"integrity": "sha512-tVpr7U6u6bqeQlHreEjYMNtnHX62vLnNWziY2kQLqkWhvusPuY5DfuGEIPpWqsd+V/a1slyTQaxK6HWTlH6A/Q=="
},
"node_modules/@braintree/sanitize-url": {
"version": "6.0.2",
"resolved": "https://registry.npmjs.org/@braintree/sanitize-url/-/sanitize-url-6.0.2.tgz",
"integrity": "sha512-Tbsj02wXCbqGmzdnXNk0SOF19ChhRU70BsroIi4Pm6Ehp56in6vch94mfbdQ17DozxkL3BAVjbZ4Qc1a0HFRAg=="
"version": "6.0.4",
"resolved": "https://registry.npmjs.org/@braintree/sanitize-url/-/sanitize-url-6.0.4.tgz",
"integrity": "sha512-s3jaWicZd0pkP0jf5ysyHUI/RE7MHos6qlToFcGWXVp+ykHOy77OUMrfbgJ9it2C5bow7OIQwYYaHjk9XlBQ2A=="
},
"node_modules/@braintree/uuid": {
"version": "0.1.0",
@@ -17244,16 +17244,16 @@
}
},
"node_modules/braintree-web": {
"version": "3.96.1",
"resolved": "https://registry.npmjs.org/braintree-web/-/braintree-web-3.96.1.tgz",
"integrity": "sha512-e483TQkRmcO5zFH+pMGUokFWq5Q+Jyn9TmUIDofPpul1P+xOciwNqvekl/Ku6MsJBT1+kyH0cndBYMZVJpbrtg==",
"version": "3.97.4",
"resolved": "https://registry.npmjs.org/braintree-web/-/braintree-web-3.97.4.tgz",
"integrity": "sha512-w//M/ZI/MhjaxUwpICwZO50uTLF/L3WGLN4tFCPh/Xw20jDw8UBiM0Gzquq7gmwcQ1BgNnAAaYlR94HcSmt/Cg==",
"dependencies": {
"@braintree/asset-loader": "0.4.4",
"@braintree/browser-detection": "1.14.0",
"@braintree/browser-detection": "1.17.1",
"@braintree/event-emitter": "0.4.1",
"@braintree/extended-promise": "0.4.1",
"@braintree/iframer": "1.1.0",
"@braintree/sanitize-url": "6.0.2",
"@braintree/sanitize-url": "6.0.4",
"@braintree/uuid": "0.1.0",
"@braintree/wrap-promise": "2.1.0",
"card-validator": "8.1.1",
@@ -17265,16 +17265,16 @@
}
},
"node_modules/braintree-web-drop-in": {
"version": "1.40.0",
"resolved": "https://registry.npmjs.org/braintree-web-drop-in/-/braintree-web-drop-in-1.40.0.tgz",
"integrity": "sha512-YrR+KitYu0X6+qkP2hvkjAolMJw3jKyd+Yppk+LZUFk1Bbfj7YayLjTQqLTyFT5MQmzpQKLYwbI75BBClPGc+Q==",
"version": "1.41.0",
"resolved": "https://registry.npmjs.org/braintree-web-drop-in/-/braintree-web-drop-in-1.41.0.tgz",
"integrity": "sha512-cpFY13iyoPNCTIOU7dipHmOvoblUtYFuA7ADAm0DUPk6oqxFz4EIr94R0Yg2rCabvjeauINDf01Y2d7/E1IaXg==",
"dependencies": {
"@braintree/asset-loader": "0.4.4",
"@braintree/browser-detection": "1.14.0",
"@braintree/browser-detection": "1.17.1",
"@braintree/event-emitter": "0.4.1",
"@braintree/uuid": "0.1.0",
"@braintree/wrap-promise": "2.1.0",
"braintree-web": "3.96.1"
"braintree-web": "3.97.4"
}
},
"node_modules/braintree-web/node_modules/promise-polyfill": {

View File

@@ -165,7 +165,7 @@
"argon2-browser": "1.18.0",
"big-integer": "1.6.51",
"bootstrap": "4.6.0",
"braintree-web-drop-in": "1.40.0",
"braintree-web-drop-in": "1.41.0",
"bufferutil": "4.0.8",
"chalk": "4.1.2",
"commander": "7.2.0",