diff --git a/.github/workflows/version-bump.yml b/.github/workflows/version-bump.yml index 8e126e1da60..680441a8eca 100644 --- a/.github/workflows/version-bump.yml +++ b/.github/workflows/version-bump.yml @@ -71,7 +71,7 @@ jobs: uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 # v4.1.0 with: repository: bitwarden/clients - ref: master + ref: main token: ${{ steps.retrieve-secrets.outputs.github-pat-bitwarden-devops-bot-repo-scope }} - name: Import GPG key diff --git a/apps/browser/src/autofill/background/overlay.background.ts b/apps/browser/src/autofill/background/overlay.background.ts index 3d6f00ec10f..f760ee22e62 100644 --- a/apps/browser/src/autofill/background/overlay.background.ts +++ b/apps/browser/src/autofill/background/overlay.background.ts @@ -673,7 +673,7 @@ class OverlayBackground implements OverlayBackgroundInterface { */ private setupExtensionMessageListeners() { BrowserApi.messageListener("overlay.background", this.handleExtensionMessage); - chrome.runtime.onConnect.addListener(this.handlePortOnConnect); + BrowserApi.addListener(chrome.runtime.onConnect, this.handlePortOnConnect); } /** diff --git a/apps/browser/src/autofill/background/service_factories/autofill-service.factory.ts b/apps/browser/src/autofill/background/service_factories/autofill-service.factory.ts index acd9be2a8ea..972a2421cbb 100644 --- a/apps/browser/src/autofill/background/service_factories/autofill-service.factory.ts +++ b/apps/browser/src/autofill/background/service_factories/autofill-service.factory.ts @@ -10,6 +10,10 @@ import { settingsServiceFactory, SettingsServiceInitOptions, } from "../../../background/service-factories/settings-service.factory"; +import { + configServiceFactory, + ConfigServiceInitOptions, +} from "../../../platform/background/service-factories/config-service.factory"; import { CachedServices, factory, @@ -43,7 +47,8 @@ export type AutoFillServiceInitOptions = AutoFillServiceOptions & EventCollectionServiceInitOptions & LogServiceInitOptions & SettingsServiceInitOptions & - UserVerificationServiceInitOptions; + UserVerificationServiceInitOptions & + ConfigServiceInitOptions; export function autofillServiceFactory( cache: { autofillService?: AbstractAutoFillService } & CachedServices, @@ -62,6 +67,7 @@ export function autofillServiceFactory( await logServiceFactory(cache, opts), await settingsServiceFactory(cache, opts), await userVerificationServiceFactory(cache, opts), + await configServiceFactory(cache, opts), ), ); } diff --git a/apps/browser/src/autofill/background/tabs.background.spec.ts b/apps/browser/src/autofill/background/tabs.background.spec.ts index b3de1e96ce3..304d43bd143 100644 --- a/apps/browser/src/autofill/background/tabs.background.spec.ts +++ b/apps/browser/src/autofill/background/tabs.background.spec.ts @@ -15,7 +15,7 @@ import OverlayBackground from "./overlay.background"; import TabsBackground from "./tabs.background"; describe("TabsBackground", () => { - let tabsBackgorund: TabsBackground; + let tabsBackground: TabsBackground; const mainBackground = mock({ messagingService: { send: jest.fn(), @@ -25,7 +25,7 @@ describe("TabsBackground", () => { const overlayBackground = mock(); beforeEach(() => { - tabsBackgorund = new TabsBackground(mainBackground, notificationBackground, overlayBackground); + tabsBackground = new TabsBackground(mainBackground, notificationBackground, overlayBackground); }); afterEach(() => { @@ -35,11 +35,11 @@ describe("TabsBackground", () => { describe("init", () => { it("sets up a window on focusChanged listener", () => { const handleWindowOnFocusChangedSpy = jest.spyOn( - tabsBackgorund as any, + tabsBackground as any, "handleWindowOnFocusChanged", ); - tabsBackgorund.init(); + tabsBackground.init(); expect(chrome.windows.onFocusChanged.addListener).toHaveBeenCalledWith( handleWindowOnFocusChangedSpy, @@ -49,7 +49,7 @@ describe("TabsBackground", () => { describe("tab event listeners", () => { beforeEach(() => { - tabsBackgorund.init(); + tabsBackground["setupTabEventListeners"](); }); describe("window onFocusChanged event", () => { @@ -64,7 +64,7 @@ describe("TabsBackground", () => { triggerWindowOnFocusedChangedEvent(10); await flushPromises(); - expect(tabsBackgorund["focusedWindowId"]).toBe(10); + expect(tabsBackground["focusedWindowId"]).toBe(10); }); it("updates the current tab data", async () => { @@ -144,7 +144,7 @@ describe("TabsBackground", () => { beforeEach(() => { mainBackground.onUpdatedRan = false; - tabsBackgorund["focusedWindowId"] = focusedWindowId; + tabsBackground["focusedWindowId"] = focusedWindowId; tab = mock({ windowId: focusedWindowId, active: true, diff --git a/apps/browser/src/autofill/background/tabs.background.ts b/apps/browser/src/autofill/background/tabs.background.ts index b095f99ce20..8b4cb356a7e 100644 --- a/apps/browser/src/autofill/background/tabs.background.ts +++ b/apps/browser/src/autofill/background/tabs.background.ts @@ -20,6 +20,14 @@ export default class TabsBackground { return; } + this.updateCurrentTabData(); + this.setupTabEventListeners(); + } + + /** + * Sets up the tab and window event listeners. + */ + private setupTabEventListeners() { chrome.windows.onFocusChanged.addListener(this.handleWindowOnFocusChanged); chrome.tabs.onActivated.addListener(this.handleTabOnActivated); chrome.tabs.onReplaced.addListener(this.handleTabOnReplaced); @@ -33,7 +41,7 @@ export default class TabsBackground { * @param windowId - The ID of the window that was focused. */ private handleWindowOnFocusChanged = async (windowId: number) => { - if (!windowId) { + if (windowId == null || windowId < 0) { return; } @@ -116,8 +124,10 @@ export default class TabsBackground { * for the current tab. Also updates the overlay ciphers. */ private updateCurrentTabData = async () => { - await this.main.refreshBadge(); - await this.main.refreshMenu(); - await this.overlayBackground.updateOverlayCiphers(); + await Promise.all([ + this.main.refreshBadge(), + this.main.refreshMenu(), + this.overlayBackground.updateOverlayCiphers(), + ]); }; } diff --git a/apps/browser/src/autofill/content/abstractions/autofill-init.ts b/apps/browser/src/autofill/content/abstractions/autofill-init.ts index 139099a4d58..91866ffa0bb 100644 --- a/apps/browser/src/autofill/content/abstractions/autofill-init.ts +++ b/apps/browser/src/autofill/content/abstractions/autofill-init.ts @@ -17,6 +17,7 @@ type AutofillExtensionMessage = { direction?: "previous" | "next"; isOpeningFullOverlay?: boolean; forceCloseOverlay?: boolean; + autofillOverlayVisibility?: number; }; }; @@ -34,10 +35,12 @@ type AutofillExtensionMessageHandlers = { updateIsOverlayCiphersPopulated: ({ message }: AutofillExtensionMessageParam) => void; bgUnlockPopoutOpened: () => void; bgVaultItemRepromptPopoutOpened: () => void; + updateAutofillOverlayVisibility: ({ message }: AutofillExtensionMessageParam) => void; }; interface AutofillInit { init(): void; + destroy(): void; } export { AutofillExtensionMessage, AutofillExtensionMessageHandlers, AutofillInit }; diff --git a/apps/browser/src/autofill/content/autofill-init.spec.ts b/apps/browser/src/autofill/content/autofill-init.spec.ts index 1524fdce100..ecf67740183 100644 --- a/apps/browser/src/autofill/content/autofill-init.spec.ts +++ b/apps/browser/src/autofill/content/autofill-init.spec.ts @@ -6,7 +6,7 @@ import { flushPromises, sendExtensionRuntimeMessage } from "../jest/testing-util import AutofillPageDetails from "../models/autofill-page-details"; import AutofillScript from "../models/autofill-script"; import AutofillOverlayContentService from "../services/autofill-overlay-content.service"; -import { RedirectFocusDirection } from "../utils/autofill-overlay.enum"; +import { AutofillOverlayVisibility, RedirectFocusDirection } from "../utils/autofill-overlay.enum"; import { AutofillExtensionMessage } from "./abstractions/autofill-init"; import AutofillInit from "./autofill-init"; @@ -16,6 +16,11 @@ describe("AutofillInit", () => { const autofillOverlayContentService = mock(); beforeEach(() => { + chrome.runtime.connect = jest.fn().mockReturnValue({ + onDisconnect: { + addListener: jest.fn(), + }, + }); autofillInit = new AutofillInit(autofillOverlayContentService); }); @@ -477,6 +482,57 @@ describe("AutofillInit", () => { expect(autofillInit["removeAutofillOverlay"]).toHaveBeenCalled(); }); }); + + describe("updateAutofillOverlayVisibility", () => { + beforeEach(() => { + autofillInit["autofillOverlayContentService"].autofillOverlayVisibility = + AutofillOverlayVisibility.OnButtonClick; + }); + + it("skips attempting to update the overlay visibility if the autofillOverlayVisibility data value is not present", () => { + sendExtensionRuntimeMessage({ + command: "updateAutofillOverlayVisibility", + data: {}, + }); + + expect(autofillInit["autofillOverlayContentService"].autofillOverlayVisibility).toEqual( + AutofillOverlayVisibility.OnButtonClick, + ); + }); + + it("updates the overlay visibility value", () => { + const message = { + command: "updateAutofillOverlayVisibility", + data: { + autofillOverlayVisibility: AutofillOverlayVisibility.Off, + }, + }; + + sendExtensionRuntimeMessage(message); + + expect(autofillInit["autofillOverlayContentService"].autofillOverlayVisibility).toEqual( + message.data.autofillOverlayVisibility, + ); + }); + }); + }); + }); + + describe("destroy", () => { + it("removes the extension message listeners", () => { + autofillInit.destroy(); + + expect(chrome.runtime.onMessage.removeListener).toHaveBeenCalledWith( + autofillInit["handleExtensionMessage"], + ); + }); + + it("destroys the collectAutofillContentService", () => { + jest.spyOn(autofillInit["collectAutofillContentService"], "destroy"); + + autofillInit.destroy(); + + expect(autofillInit["collectAutofillContentService"].destroy).toHaveBeenCalled(); }); }); }); diff --git a/apps/browser/src/autofill/content/autofill-init.ts b/apps/browser/src/autofill/content/autofill-init.ts index 9b23305377c..5a2ec3dd397 100644 --- a/apps/browser/src/autofill/content/autofill-init.ts +++ b/apps/browser/src/autofill/content/autofill-init.ts @@ -26,6 +26,7 @@ class AutofillInit implements AutofillInitInterface { updateIsOverlayCiphersPopulated: ({ message }) => this.updateIsOverlayCiphersPopulated(message), bgUnlockPopoutOpened: () => this.blurAndRemoveOverlay(), bgVaultItemRepromptPopoutOpened: () => this.blurAndRemoveOverlay(), + updateAutofillOverlayVisibility: ({ message }) => this.updateAutofillOverlayVisibility(message), }; /** @@ -214,6 +215,19 @@ class AutofillInit implements AutofillInitInterface { ); } + /** + * Updates the autofill overlay visibility. + * + * @param data - Contains the autoFillOverlayVisibility value + */ + private updateAutofillOverlayVisibility({ data }: AutofillExtensionMessage) { + if (!this.autofillOverlayContentService || isNaN(data?.autofillOverlayVisibility)) { + return; + } + + this.autofillOverlayContentService.autofillOverlayVisibility = data?.autofillOverlayVisibility; + } + /** * Sets up the extension message listeners for the content script. */ @@ -247,6 +261,16 @@ class AutofillInit implements AutofillInitInterface { Promise.resolve(messageResponse).then((response) => sendResponse(response)); return true; }; + + /** + * Handles destroying the autofill init content script. Removes all + * listeners, timeouts, and object instances to prevent memory leaks. + */ + destroy() { + chrome.runtime.onMessage.removeListener(this.handleExtensionMessage); + this.collectAutofillContentService.destroy(); + this.autofillOverlayContentService?.destroy(); + } } export default AutofillInit; diff --git a/apps/browser/src/autofill/content/autofiller.ts b/apps/browser/src/autofill/content/autofiller.ts index 7f58e72c7d3..c3a2f7f5793 100644 --- a/apps/browser/src/autofill/content/autofiller.ts +++ b/apps/browser/src/autofill/content/autofiller.ts @@ -1,3 +1,5 @@ +import { getFromLocalStorage, setupExtensionDisconnectAction } from "../utils"; + if (document.readyState === "loading") { document.addEventListener("DOMContentLoaded", loadAutofiller); } else { @@ -8,27 +10,30 @@ function loadAutofiller() { let pageHref: string = null; let filledThisHref = false; let delayFillTimeout: number; - - const activeUserIdKey = "activeUserId"; - let activeUserId: string; - - chrome.storage.local.get(activeUserIdKey, (obj: any) => { - if (obj == null || obj[activeUserIdKey] == null) { - return; - } - activeUserId = obj[activeUserIdKey]; - }); - - chrome.storage.local.get(activeUserId, (obj: any) => { - if (obj?.[activeUserId]?.settings?.enableAutoFillOnPageLoad === true) { - setInterval(() => doFillIfNeeded(), 500); - } - }); - chrome.runtime.onMessage.addListener((msg, sender, sendResponse) => { - if (msg.command === "fillForm" && pageHref === msg.url) { + let doFillInterval: NodeJS.Timeout; + const handleExtensionDisconnect = () => { + clearDoFillInterval(); + clearDelayFillTimeout(); + }; + const handleExtensionMessage = (message: any) => { + if (message.command === "fillForm" && pageHref === message.url) { filledThisHref = true; } - }); + }; + + setupExtensionEventListeners(); + triggerUserFillOnLoad(); + + async function triggerUserFillOnLoad() { + const activeUserIdKey = "activeUserId"; + const userKeyStorage = await getFromLocalStorage(activeUserIdKey); + const activeUserId = userKeyStorage[activeUserIdKey]; + const activeUserStorage = await getFromLocalStorage(activeUserId); + if (activeUserStorage?.[activeUserId]?.settings?.enableAutoFillOnPageLoad === true) { + clearDoFillInterval(); + doFillInterval = setInterval(() => doFillIfNeeded(), 500); + } + } function doFillIfNeeded(force = false) { if (force || pageHref !== window.location.href) { @@ -36,9 +41,7 @@ function loadAutofiller() { // Some websites are slow and rendering all page content. Try to fill again later // if we haven't already. filledThisHref = false; - if (delayFillTimeout != null) { - window.clearTimeout(delayFillTimeout); - } + clearDelayFillTimeout(); delayFillTimeout = window.setTimeout(() => { if (!filledThisHref) { doFillIfNeeded(true); @@ -55,4 +58,21 @@ function loadAutofiller() { chrome.runtime.sendMessage(msg); } } + + function clearDoFillInterval() { + if (doFillInterval) { + window.clearInterval(doFillInterval); + } + } + + function clearDelayFillTimeout() { + if (delayFillTimeout) { + window.clearTimeout(delayFillTimeout); + } + } + + function setupExtensionEventListeners() { + setupExtensionDisconnectAction(handleExtensionDisconnect); + chrome.runtime.onMessage.addListener(handleExtensionMessage); + } } diff --git a/apps/browser/src/autofill/content/bootstrap-autofill-overlay.ts b/apps/browser/src/autofill/content/bootstrap-autofill-overlay.ts index 5bc9fb1718f..ab21e367c29 100644 --- a/apps/browser/src/autofill/content/bootstrap-autofill-overlay.ts +++ b/apps/browser/src/autofill/content/bootstrap-autofill-overlay.ts @@ -1,4 +1,5 @@ import AutofillOverlayContentService from "../services/autofill-overlay-content.service"; +import { setupAutofillInitDisconnectAction } from "../utils"; import AutofillInit from "./autofill-init"; @@ -6,6 +7,8 @@ import AutofillInit from "./autofill-init"; if (!windowContext.bitwardenAutofillInit) { const autofillOverlayContentService = new AutofillOverlayContentService(); windowContext.bitwardenAutofillInit = new AutofillInit(autofillOverlayContentService); + setupAutofillInitDisconnectAction(windowContext); + windowContext.bitwardenAutofillInit.init(); } })(window); diff --git a/apps/browser/src/autofill/content/bootstrap-autofill.ts b/apps/browser/src/autofill/content/bootstrap-autofill.ts index 3264c77ea0e..f98d4bc1d72 100644 --- a/apps/browser/src/autofill/content/bootstrap-autofill.ts +++ b/apps/browser/src/autofill/content/bootstrap-autofill.ts @@ -1,8 +1,12 @@ +import { setupAutofillInitDisconnectAction } from "../utils"; + import AutofillInit from "./autofill-init"; (function (windowContext) { if (!windowContext.bitwardenAutofillInit) { windowContext.bitwardenAutofillInit = new AutofillInit(); + setupAutofillInitDisconnectAction(windowContext); + windowContext.bitwardenAutofillInit.init(); } })(window); diff --git a/apps/browser/src/autofill/content/message_handler.ts b/apps/browser/src/autofill/content/message_handler.ts index 9bf48e3b17d..7b52aeb3556 100644 --- a/apps/browser/src/autofill/content/message_handler.ts +++ b/apps/browser/src/autofill/content/message_handler.ts @@ -1,31 +1,4 @@ -window.addEventListener( - "message", - (event) => { - if (event.source !== window) { - return; - } - - if (event.data.command && event.data.command === "authResult") { - chrome.runtime.sendMessage({ - command: event.data.command, - code: event.data.code, - state: event.data.state, - lastpass: event.data.lastpass, - referrer: event.source.location.hostname, - }); - } - - if (event.data.command && event.data.command === "webAuthnResult") { - chrome.runtime.sendMessage({ - command: event.data.command, - data: event.data.data, - remember: event.data.remember, - referrer: event.source.location.hostname, - }); - } - }, - false, -); +import { setupExtensionDisconnectAction } from "../utils"; const forwardCommands = [ "bgUnlockPopoutOpened", @@ -34,8 +7,59 @@ const forwardCommands = [ "addedCipher", ]; -chrome.runtime.onMessage.addListener((event) => { - if (forwardCommands.includes(event.command)) { - chrome.runtime.sendMessage(event); +/** + * Handles sending extension messages to the background + * script based on window messages from the page. + * + * @param event - Window message event + */ +const handleWindowMessage = (event: MessageEvent) => { + if (event.source !== window) { + return; } -}); + + if (event.data.command && event.data.command === "authResult") { + chrome.runtime.sendMessage({ + command: event.data.command, + code: event.data.code, + state: event.data.state, + lastpass: event.data.lastpass, + referrer: event.source.location.hostname, + }); + } + + if (event.data.command && event.data.command === "webAuthnResult") { + chrome.runtime.sendMessage({ + command: event.data.command, + data: event.data.data, + remember: event.data.remember, + referrer: event.source.location.hostname, + }); + } +}; + +/** + * Handles forwarding any commands that need to trigger + * an action from one service of the extension background + * to another. + * + * @param message - Message from the extension + */ +const handleExtensionMessage = (message: any) => { + if (forwardCommands.includes(message.command)) { + chrome.runtime.sendMessage(message); + } +}; + +/** + * Handles cleaning up any event listeners that were + * added to the window or extension. + */ +const handleExtensionDisconnect = () => { + window.removeEventListener("message", handleWindowMessage); + chrome.runtime.onMessage.removeListener(handleExtensionMessage); +}; + +window.addEventListener("message", handleWindowMessage, false); +chrome.runtime.onMessage.addListener(handleExtensionMessage); +setupExtensionDisconnectAction(handleExtensionDisconnect); diff --git a/apps/browser/src/autofill/content/notification-bar.ts b/apps/browser/src/autofill/content/notification-bar.ts index 92e8f599385..6c3f3561e5b 100644 --- a/apps/browser/src/autofill/content/notification-bar.ts +++ b/apps/browser/src/autofill/content/notification-bar.ts @@ -4,6 +4,7 @@ import AddLoginRuntimeMessage from "../notification/models/add-login-runtime-mes import ChangePasswordRuntimeMessage from "../notification/models/change-password-runtime-message"; import { FormData } from "../services/abstractions/autofill.service"; import { GlobalSettings, UserSettings } from "../types"; +import { getFromLocalStorage, setupExtensionDisconnectAction } from "../utils"; interface HTMLElementWithFormOpId extends HTMLElement { formOpId: string; @@ -122,6 +123,8 @@ async function loadNotificationBar() { } } + setupExtensionDisconnectAction(handleExtensionDisconnection); + if (!showNotificationBar) { return; } @@ -999,11 +1002,23 @@ async function loadNotificationBar() { return theEl === document; } + function handleExtensionDisconnection(port: chrome.runtime.Port) { + closeBar(false); + clearTimeout(domObservationCollectTimeoutId); + clearTimeout(collectPageDetailsTimeoutId); + clearTimeout(handlePageChangeTimeoutId); + observer?.disconnect(); + observer = null; + watchedForms.forEach((wf: WatchedForm) => { + const form = wf.formEl; + form.removeEventListener("submit", formSubmitted, false); + const submitButton = getSubmitButton( + form, + unionSets(logInButtonNames, changePasswordButtonNames), + ); + submitButton?.removeEventListener("click", formSubmitted, false); + }); + } + // End Helper Functions } - -async function getFromLocalStorage(keys: string | string[]): Promise> { - return new Promise((resolve) => { - chrome.storage.local.get(keys, (storage: Record) => resolve(storage)); - }); -} diff --git a/apps/browser/src/autofill/enums/autofill-port.enums.ts b/apps/browser/src/autofill/enums/autofill-port.enums.ts new file mode 100644 index 00000000000..e5b8f17aad5 --- /dev/null +++ b/apps/browser/src/autofill/enums/autofill-port.enums.ts @@ -0,0 +1,5 @@ +const AutofillPort = { + InjectedScript: "autofill-injected-script-port", +} as const; + +export { AutofillPort }; diff --git a/apps/browser/src/autofill/overlay/iframe-content/autofill-overlay-iframe.service.ts b/apps/browser/src/autofill/overlay/iframe-content/autofill-overlay-iframe.service.ts index 20f5aa830fc..c878f961f1c 100644 --- a/apps/browser/src/autofill/overlay/iframe-content/autofill-overlay-iframe.service.ts +++ b/apps/browser/src/autofill/overlay/iframe-content/autofill-overlay-iframe.service.ts @@ -1,5 +1,5 @@ import { EVENTS } from "../../constants"; -import { setElementStyles } from "../../utils/utils"; +import { setElementStyles } from "../../utils"; import { BackgroundPortMessageHandlers, AutofillOverlayIframeService as AutofillOverlayIframeServiceInterface, @@ -166,9 +166,10 @@ class AutofillOverlayIframeService implements AutofillOverlayIframeServiceInterf this.updateElementStyles(this.iframe, { opacity: "0", height: "0px", display: "block" }); globalThis.removeEventListener("message", this.handleWindowMessage); - this.port.onMessage.removeListener(this.handlePortMessage); - this.port.onDisconnect.removeListener(this.handlePortDisconnect); - this.port.disconnect(); + this.unobserveIframe(); + this.port?.onMessage.removeListener(this.handlePortMessage); + this.port?.onDisconnect.removeListener(this.handlePortDisconnect); + this.port?.disconnect(); this.port = null; }; @@ -369,7 +370,7 @@ class AutofillOverlayIframeService implements AutofillOverlayIframeServiceInterf * Unobserves the iframe element for mutations to its style attribute. */ private unobserveIframe() { - this.iframeMutationObserver.disconnect(); + this.iframeMutationObserver?.disconnect(); } /** diff --git a/apps/browser/src/autofill/overlay/pages/button/autofill-overlay-button.ts b/apps/browser/src/autofill/overlay/pages/button/autofill-overlay-button.ts index 94c0772fd2b..bfb57087452 100644 --- a/apps/browser/src/autofill/overlay/pages/button/autofill-overlay-button.ts +++ b/apps/browser/src/autofill/overlay/pages/button/autofill-overlay-button.ts @@ -3,8 +3,8 @@ import "lit/polyfill-support.js"; import { AuthenticationStatus } from "@bitwarden/common/auth/enums/authentication-status"; import { EVENTS } from "../../../constants"; +import { buildSvgDomElement } from "../../../utils"; import { logoIcon, logoLockedIcon } from "../../../utils/svg-icons"; -import { buildSvgDomElement } from "../../../utils/utils"; import { InitAutofillOverlayButtonMessage, OverlayButtonWindowMessageHandlers, diff --git a/apps/browser/src/autofill/overlay/pages/list/autofill-overlay-list.ts b/apps/browser/src/autofill/overlay/pages/list/autofill-overlay-list.ts index 053fddb9c13..3f13061a0ed 100644 --- a/apps/browser/src/autofill/overlay/pages/list/autofill-overlay-list.ts +++ b/apps/browser/src/autofill/overlay/pages/list/autofill-overlay-list.ts @@ -4,8 +4,8 @@ import { AuthenticationStatus } from "@bitwarden/common/auth/enums/authenticatio import { OverlayCipherData } from "../../../background/abstractions/overlay.background"; import { EVENTS } from "../../../constants"; +import { buildSvgDomElement } from "../../../utils"; import { globeIcon, lockIcon, plusIcon, viewCipherIcon } from "../../../utils/svg-icons"; -import { buildSvgDomElement } from "../../../utils/utils"; import { InitAutofillOverlayListMessage, OverlayListWindowMessageHandlers, diff --git a/apps/browser/src/autofill/popup/settings/autofill.component.ts b/apps/browser/src/autofill/popup/settings/autofill.component.ts index d9038b0eb27..728b74bc90a 100644 --- a/apps/browser/src/autofill/popup/settings/autofill.component.ts +++ b/apps/browser/src/autofill/popup/settings/autofill.component.ts @@ -10,6 +10,7 @@ import { UriMatchType } from "@bitwarden/common/vault/enums"; import { BrowserApi } from "../../../platform/browser/browser-api"; import { flagEnabled } from "../../../platform/flags"; +import { AutofillService } from "../../services/abstractions/autofill.service"; import { AutofillOverlayVisibility } from "../../utils/autofill-overlay.enum"; @Component({ @@ -35,6 +36,7 @@ export class AutofillComponent implements OnInit { private platformUtilsService: PlatformUtilsService, private configService: ConfigServiceAbstraction, private settingsService: SettingsService, + private autofillService: AutofillService, ) { this.autoFillOverlayVisibilityOptions = [ { @@ -86,7 +88,10 @@ export class AutofillComponent implements OnInit { } async updateAutoFillOverlayVisibility() { + const previousAutoFillOverlayVisibility = + await this.settingsService.getAutoFillOverlayVisibility(); await this.settingsService.setAutoFillOverlayVisibility(this.autoFillOverlayVisibility); + await this.handleUpdatingAutofillOverlayContentScripts(previousAutoFillOverlayVisibility); } async updateAutoFillOnPageLoad() { @@ -144,4 +149,25 @@ export class AutofillComponent implements OnInit { event.preventDefault(); BrowserApi.createNewTab(this.disablePasswordManagerLink); } + + private async handleUpdatingAutofillOverlayContentScripts( + previousAutoFillOverlayVisibility: number, + ) { + const autofillOverlayPreviouslyDisabled = + previousAutoFillOverlayVisibility === AutofillOverlayVisibility.Off; + const autofillOverlayCurrentlyDisabled = + this.autoFillOverlayVisibility === AutofillOverlayVisibility.Off; + + if (!autofillOverlayPreviouslyDisabled && !autofillOverlayCurrentlyDisabled) { + const tabs = await BrowserApi.tabsQuery({}); + tabs.forEach((tab) => + BrowserApi.tabSendMessageData(tab, "updateAutofillOverlayVisibility", { + autofillOverlayVisibility: this.autoFillOverlayVisibility, + }), + ); + return; + } + + await this.autofillService.reloadAutofillScripts(); + } } diff --git a/apps/browser/src/autofill/services/abstractions/autofill-overlay-content.service.ts b/apps/browser/src/autofill/services/abstractions/autofill-overlay-content.service.ts index ac7d55a54d4..ec594ac829f 100644 --- a/apps/browser/src/autofill/services/abstractions/autofill-overlay-content.service.ts +++ b/apps/browser/src/autofill/services/abstractions/autofill-overlay-content.service.ts @@ -14,6 +14,7 @@ interface AutofillOverlayContentService { isCurrentlyFilling: boolean; isOverlayCiphersPopulated: boolean; pageDetailsUpdateRequired: boolean; + autofillOverlayVisibility: number; init(): void; setupAutofillOverlayListenerOnField( autofillFieldElement: ElementWithOpId, @@ -27,6 +28,7 @@ interface AutofillOverlayContentService { redirectOverlayFocusOut(direction: "previous" | "next"): void; focusMostRecentOverlayField(): void; blurMostRecentOverlayField(): void; + destroy(): void; } export { OpenAutofillOverlayOptions, AutofillOverlayContentService }; diff --git a/apps/browser/src/autofill/services/abstractions/autofill.service.ts b/apps/browser/src/autofill/services/abstractions/autofill.service.ts index a0959db72cb..c44e3adf7c9 100644 --- a/apps/browser/src/autofill/services/abstractions/autofill.service.ts +++ b/apps/browser/src/autofill/services/abstractions/autofill.service.ts @@ -44,10 +44,12 @@ export interface GenerateFillScriptOptions { } export abstract class AutofillService { + loadAutofillScriptsOnInstall: () => Promise; + reloadAutofillScripts: () => Promise; injectAutofillScripts: ( - sender: chrome.runtime.MessageSender, - autofillV2?: boolean, - autofillOverlay?: boolean, + tab: chrome.tabs.Tab, + frameId?: number, + triggeringOnPageLoad?: boolean, ) => Promise; getFormsWithPasswordFields: (pageDetails: AutofillPageDetails) => FormData[]; doAutoFill: (options: AutoFillOptions) => Promise; diff --git a/apps/browser/src/autofill/services/abstractions/collect-autofill-content.service.ts b/apps/browser/src/autofill/services/abstractions/collect-autofill-content.service.ts index 78befa7bc61..46ad615059a 100644 --- a/apps/browser/src/autofill/services/abstractions/collect-autofill-content.service.ts +++ b/apps/browser/src/autofill/services/abstractions/collect-autofill-content.service.ts @@ -22,6 +22,7 @@ interface CollectAutofillContentService { filterCallback: CallableFunction, isObservingShadowRoot?: boolean, ): Node[]; + destroy(): void; } export { diff --git a/apps/browser/src/autofill/services/autofill-overlay-content.service.spec.ts b/apps/browser/src/autofill/services/autofill-overlay-content.service.spec.ts index 7753a4b2672..f3aa77258fe 100644 --- a/apps/browser/src/autofill/services/autofill-overlay-content.service.spec.ts +++ b/apps/browser/src/autofill/services/autofill-overlay-content.service.spec.ts @@ -1609,4 +1609,100 @@ describe("AutofillOverlayContentService", () => { expect(autofillOverlayContentService["removeAutofillOverlay"]).toHaveBeenCalled(); }); }); + + describe("destroy", () => { + let autofillFieldElement: ElementWithOpId; + let autofillFieldData: AutofillField; + + beforeEach(() => { + document.body.innerHTML = ` +
+ + +
+ `; + + autofillFieldElement = document.getElementById( + "username-field", + ) as ElementWithOpId; + autofillFieldElement.opid = "op-1"; + autofillFieldData = createAutofillFieldMock({ + opid: "username-field", + form: "validFormId", + placeholder: "username", + elementNumber: 1, + }); + autofillOverlayContentService.setupAutofillOverlayListenerOnField( + autofillFieldElement, + autofillFieldData, + ); + autofillOverlayContentService["mostRecentlyFocusedField"] = autofillFieldElement; + }); + + it("disconnects all mutation observers", () => { + autofillOverlayContentService["setupMutationObserver"](); + jest.spyOn(autofillOverlayContentService["bodyElementMutationObserver"], "disconnect"); + jest.spyOn(autofillOverlayContentService["documentElementMutationObserver"], "disconnect"); + + autofillOverlayContentService.destroy(); + + expect( + autofillOverlayContentService["documentElementMutationObserver"].disconnect, + ).toHaveBeenCalled(); + expect( + autofillOverlayContentService["bodyElementMutationObserver"].disconnect, + ).toHaveBeenCalled(); + }); + + it("clears the user interaction event timeout", () => { + jest.spyOn(autofillOverlayContentService as any, "clearUserInteractionEventTimeout"); + + autofillOverlayContentService.destroy(); + + expect(autofillOverlayContentService["clearUserInteractionEventTimeout"]).toHaveBeenCalled(); + }); + + it("de-registers all global event listeners", () => { + jest.spyOn(globalThis.document, "removeEventListener"); + jest.spyOn(globalThis, "removeEventListener"); + jest.spyOn(autofillOverlayContentService as any, "removeOverlayRepositionEventListeners"); + + autofillOverlayContentService.destroy(); + + expect(globalThis.document.removeEventListener).toHaveBeenCalledWith( + EVENTS.VISIBILITYCHANGE, + autofillOverlayContentService["handleVisibilityChangeEvent"], + ); + expect(globalThis.removeEventListener).toHaveBeenCalledWith( + EVENTS.FOCUSOUT, + autofillOverlayContentService["handleFormFieldBlurEvent"], + ); + expect( + autofillOverlayContentService["removeOverlayRepositionEventListeners"], + ).toHaveBeenCalled(); + }); + + it("de-registers any event listeners that are attached to the form field elements", () => { + jest.spyOn(autofillOverlayContentService as any, "removeCachedFormFieldEventListeners"); + jest.spyOn(autofillFieldElement, "removeEventListener"); + jest.spyOn(autofillOverlayContentService["formFieldElements"], "delete"); + + autofillOverlayContentService.destroy(); + + expect( + autofillOverlayContentService["removeCachedFormFieldEventListeners"], + ).toHaveBeenCalledWith(autofillFieldElement); + expect(autofillFieldElement.removeEventListener).toHaveBeenCalledWith( + EVENTS.BLUR, + autofillOverlayContentService["handleFormFieldBlurEvent"], + ); + expect(autofillFieldElement.removeEventListener).toHaveBeenCalledWith( + EVENTS.KEYUP, + autofillOverlayContentService["handleFormFieldKeyupEvent"], + ); + expect(autofillOverlayContentService["formFieldElements"].delete).toHaveBeenCalledWith( + autofillFieldElement, + ); + }); + }); }); diff --git a/apps/browser/src/autofill/services/autofill-overlay-content.service.ts b/apps/browser/src/autofill/services/autofill-overlay-content.service.ts index 9e5acae887c..c713c6ea411 100644 --- a/apps/browser/src/autofill/services/autofill-overlay-content.service.ts +++ b/apps/browser/src/autofill/services/autofill-overlay-content.service.ts @@ -10,16 +10,12 @@ import AutofillField from "../models/autofill-field"; import AutofillOverlayButtonIframe from "../overlay/iframe-content/autofill-overlay-button-iframe"; import AutofillOverlayListIframe from "../overlay/iframe-content/autofill-overlay-list-iframe"; import { ElementWithOpId, FillableFormFieldElement, FormFieldElement } from "../types"; +import { generateRandomCustomElementName, sendExtensionMessage, setElementStyles } from "../utils"; import { AutofillOverlayElement, RedirectFocusDirection, AutofillOverlayVisibility, } from "../utils/autofill-overlay.enum"; -import { - generateRandomCustomElementName, - sendExtensionMessage, - setElementStyles, -} from "../utils/utils"; import { AutofillOverlayContentService as AutofillOverlayContentServiceInterface, @@ -32,9 +28,10 @@ class AutofillOverlayContentService implements AutofillOverlayContentServiceInte isCurrentlyFilling = false; isOverlayCiphersPopulated = false; pageDetailsUpdateRequired = false; + autofillOverlayVisibility: number; private readonly findTabs = tabbable; private readonly sendExtensionMessage = sendExtensionMessage; - private autofillOverlayVisibility: number; + private formFieldElements: Set> = new Set([]); private userFilledFields: Record = {}; private authStatus: AuthenticationStatus; private focusableElements: FocusableElement[] = []; @@ -47,6 +44,7 @@ class AutofillOverlayContentService implements AutofillOverlayContentServiceInte private userInteractionEventTimeout: NodeJS.Timeout; private overlayElementsMutationObserver: MutationObserver; private bodyElementMutationObserver: MutationObserver; + private documentElementMutationObserver: MutationObserver; private mutationObserverIterations = 0; private mutationObserverIterationsResetTimeout: NodeJS.Timeout; private autofillFieldKeywordsMap: WeakMap = new WeakMap(); @@ -86,6 +84,8 @@ class AutofillOverlayContentService implements AutofillOverlayContentServiceInte return; } + this.formFieldElements.add(formFieldElement); + if (!this.autofillOverlayVisibility) { await this.getAutofillOverlayVisibility(); } @@ -901,10 +901,10 @@ class AutofillOverlayContentService implements AutofillOverlayContentServiceInte this.handleBodyElementMutationObserverUpdate, ); - const documentElementMutationObserver = new MutationObserver( + this.documentElementMutationObserver = new MutationObserver( this.handleDocumentElementMutationObserverUpdate, ); - documentElementMutationObserver.observe(globalThis.document.documentElement, { + this.documentElementMutationObserver.observe(globalThis.document.documentElement, { childList: true, }); }; @@ -1117,6 +1117,28 @@ class AutofillOverlayContentService implements AutofillOverlayContentServiceInte const documentRoot = element.getRootNode() as ShadowRoot | Document; return documentRoot?.activeElement; } + + /** + * Destroys the autofill overlay content service. This method will + * disconnect the mutation observers and remove all event listeners. + */ + destroy() { + this.documentElementMutationObserver?.disconnect(); + this.clearUserInteractionEventTimeout(); + this.formFieldElements.forEach((formFieldElement) => { + this.removeCachedFormFieldEventListeners(formFieldElement); + formFieldElement.removeEventListener(EVENTS.BLUR, this.handleFormFieldBlurEvent); + formFieldElement.removeEventListener(EVENTS.KEYUP, this.handleFormFieldKeyupEvent); + this.formFieldElements.delete(formFieldElement); + }); + globalThis.document.removeEventListener( + EVENTS.VISIBILITYCHANGE, + this.handleVisibilityChangeEvent, + ); + globalThis.removeEventListener(EVENTS.FOCUSOUT, this.handleFormFieldBlurEvent); + this.removeAutofillOverlay(); + this.removeOverlayRepositionEventListeners(); + } } export default AutofillOverlayContentService; diff --git a/apps/browser/src/autofill/services/autofill.service.spec.ts b/apps/browser/src/autofill/services/autofill.service.spec.ts index aa9232c791f..5f9c1db68c3 100644 --- a/apps/browser/src/autofill/services/autofill.service.spec.ts +++ b/apps/browser/src/autofill/services/autofill.service.spec.ts @@ -2,7 +2,9 @@ import { mock, mockReset } from "jest-mock-extended"; import { UserVerificationService } from "@bitwarden/common/auth/services/user-verification/user-verification.service"; import { EventType } from "@bitwarden/common/enums"; +import { FeatureFlag } from "@bitwarden/common/enums/feature-flag.enum"; import { LogService } from "@bitwarden/common/platform/abstractions/log.service"; +import { ConfigService } from "@bitwarden/common/platform/services/config/config.service"; import { EventCollectionService } from "@bitwarden/common/services/event/event-collection.service"; import { SettingsService } from "@bitwarden/common/services/settings.service"; import { @@ -24,6 +26,7 @@ import { TotpService } from "@bitwarden/common/vault/services/totp.service"; import { BrowserApi } from "../../platform/browser/browser-api"; import { BrowserStateService } from "../../platform/services/browser-state.service"; +import { AutofillPort } from "../enums/autofill-port.enums"; import { createAutofillFieldMock, createAutofillPageDetailsMock, @@ -54,6 +57,7 @@ describe("AutofillService", () => { const logService = mock(); const settingsService = mock(); const userVerificationService = mock(); + const configService = mock(); beforeEach(() => { autofillService = new AutofillService( @@ -64,6 +68,7 @@ describe("AutofillService", () => { logService, settingsService, userVerificationService, + configService, ); }); @@ -72,6 +77,72 @@ describe("AutofillService", () => { mockReset(cipherService); }); + describe("loadAutofillScriptsOnInstall", () => { + let tab1: chrome.tabs.Tab; + let tab2: chrome.tabs.Tab; + let tab3: chrome.tabs.Tab; + + beforeEach(() => { + tab1 = createChromeTabMock({ id: 1, url: "https://some-url.com" }); + tab2 = createChromeTabMock({ id: 2, url: "http://some-url.com" }); + tab3 = createChromeTabMock({ id: 3, url: "chrome-extension://some-extension-route" }); + jest.spyOn(BrowserApi, "tabsQuery").mockResolvedValueOnce([tab1, tab2]); + }); + + it("queries all browser tabs and injects the autofill scripts into them", async () => { + jest.spyOn(autofillService, "injectAutofillScripts"); + + await autofillService.loadAutofillScriptsOnInstall(); + + expect(BrowserApi.tabsQuery).toHaveBeenCalledWith({}); + expect(autofillService.injectAutofillScripts).toHaveBeenCalledWith(tab1, 0, false); + expect(autofillService.injectAutofillScripts).toHaveBeenCalledWith(tab2, 0, false); + }); + + it("skips injecting scripts into tabs that do not have an http(s) protocol", async () => { + jest.spyOn(autofillService, "injectAutofillScripts"); + + await autofillService.loadAutofillScriptsOnInstall(); + + expect(BrowserApi.tabsQuery).toHaveBeenCalledWith({}); + expect(autofillService.injectAutofillScripts).not.toHaveBeenCalledWith(tab3); + }); + + it("sets up an extension runtime onConnect listener", async () => { + await autofillService.loadAutofillScriptsOnInstall(); + + // eslint-disable-next-line no-restricted-syntax + expect(chrome.runtime.onConnect.addListener).toHaveBeenCalledWith(expect.any(Function)); + }); + }); + + describe("reloadAutofillScripts", () => { + it("disconnects and removes all autofill script ports", () => { + const port1 = mock({ + disconnect: jest.fn(), + }); + const port2 = mock({ + disconnect: jest.fn(), + }); + autofillService["autofillScriptPortsSet"] = new Set([port1, port2]); + + autofillService.reloadAutofillScripts(); + + expect(port1.disconnect).toHaveBeenCalled(); + expect(port2.disconnect).toHaveBeenCalled(); + expect(autofillService["autofillScriptPortsSet"].size).toBe(0); + }); + + it("re-injects the autofill scripts in all tabs", () => { + autofillService["autofillScriptPortsSet"] = new Set([mock()]); + jest.spyOn(autofillService as any, "injectAutofillScriptsInAllTabs"); + + autofillService.reloadAutofillScripts(); + + expect(autofillService["injectAutofillScriptsInAllTabs"]).toHaveBeenCalled(); + }); + }); + describe("injectAutofillScripts", () => { const autofillV1Script = "autofill.js"; const autofillV2BootstrapScript = "bootstrap-autofill.js"; @@ -83,12 +154,12 @@ describe("AutofillService", () => { beforeEach(() => { tabMock = createChromeTabMock(); - sender = { tab: tabMock }; + sender = { tab: tabMock, frameId: 1 }; jest.spyOn(BrowserApi, "executeScriptInTab").mockImplementation(); }); it("accepts an extension message sender and injects the autofill scripts into the tab of the sender", async () => { - await autofillService.injectAutofillScripts(sender); + await autofillService.injectAutofillScripts(sender.tab, sender.frameId, true); [autofillV1Script, ...defaultAutofillScripts].forEach((scriptName) => { expect(BrowserApi.executeScriptInTab).toHaveBeenCalledWith(tabMock.id, { @@ -105,7 +176,11 @@ describe("AutofillService", () => { }); it("will inject the bootstrap-autofill script if the enableAutofillV2 flag is set", async () => { - await autofillService.injectAutofillScripts(sender, true); + jest + .spyOn(configService, "getFeatureFlag") + .mockImplementation((flag) => Promise.resolve(flag === FeatureFlag.AutofillV2)); + + await autofillService.injectAutofillScripts(sender.tab, sender.frameId); expect(BrowserApi.executeScriptInTab).toHaveBeenCalledWith(tabMock.id, { file: `content/${autofillV2BootstrapScript}`, @@ -120,11 +195,16 @@ describe("AutofillService", () => { }); it("will inject the bootstrap-autofill-overlay script if the enableAutofillOverlay flag is set and the user has the autofill overlay enabled", async () => { + jest + .spyOn(configService, "getFeatureFlag") + .mockImplementation((flag) => + Promise.resolve(flag === FeatureFlag.AutofillOverlay || flag === FeatureFlag.AutofillV2), + ); jest .spyOn(autofillService["settingsService"], "getAutoFillOverlayVisibility") .mockResolvedValue(AutofillOverlayVisibility.OnFieldFocus); - await autofillService.injectAutofillScripts(sender, true, true); + await autofillService.injectAutofillScripts(sender.tab, sender.frameId); expect(BrowserApi.executeScriptInTab).toHaveBeenCalledWith(tabMock.id, { file: `content/${autofillOverlayBootstrapScript}`, @@ -144,18 +224,25 @@ describe("AutofillService", () => { }); it("will inject the bootstrap-autofill script if the enableAutofillOverlay flag is set but the user does not have the autofill overlay enabled", async () => { + jest + .spyOn(configService, "getFeatureFlag") + .mockImplementation((flag) => + Promise.resolve(flag === FeatureFlag.AutofillOverlay || flag === FeatureFlag.AutofillV2), + ); jest .spyOn(autofillService["settingsService"], "getAutoFillOverlayVisibility") .mockResolvedValue(AutofillOverlayVisibility.Off); - await autofillService.injectAutofillScripts(sender, true, true); + await autofillService.injectAutofillScripts(sender.tab, sender.frameId); expect(BrowserApi.executeScriptInTab).toHaveBeenCalledWith(tabMock.id, { file: `content/${autofillV2BootstrapScript}`, + frameId: sender.frameId, ...defaultExecuteScriptOptions, }); expect(BrowserApi.executeScriptInTab).not.toHaveBeenCalledWith(tabMock.id, { file: `content/${autofillV1Script}`, + frameId: sender.frameId, ...defaultExecuteScriptOptions, }); }); @@ -4436,4 +4523,58 @@ describe("AutofillService", () => { expect(autofillService["currentlyOpeningPasswordRepromptPopout"]).toBe(false); }); }); + + describe("handleInjectedScriptPortConnection", () => { + it("ignores port connections that do not have the correct port name", () => { + const port = mock({ + name: "some-invalid-port-name", + onDisconnect: { addListener: jest.fn() }, + }) as any; + + autofillService["handleInjectedScriptPortConnection"](port); + + expect(port.onDisconnect.addListener).not.toHaveBeenCalled(); + expect(autofillService["autofillScriptPortsSet"].size).toBe(0); + }); + + it("adds the connect port to the set of injected script ports and sets up an onDisconnect listener", () => { + const port = mock({ + name: AutofillPort.InjectedScript, + onDisconnect: { addListener: jest.fn() }, + }) as any; + jest.spyOn(autofillService as any, "handleInjectScriptPortOnDisconnect"); + + autofillService["handleInjectedScriptPortConnection"](port); + + expect(port.onDisconnect.addListener).toHaveBeenCalledWith( + autofillService["handleInjectScriptPortOnDisconnect"], + ); + expect(autofillService["autofillScriptPortsSet"].size).toBe(1); + }); + }); + + describe("handleInjectScriptPortOnDisconnect", () => { + it("ignores port disconnections that do not have the correct port name", () => { + autofillService["autofillScriptPortsSet"].add(mock()); + + autofillService["handleInjectScriptPortOnDisconnect"]( + mock({ + name: "some-invalid-port-name", + }), + ); + + expect(autofillService["autofillScriptPortsSet"].size).toBe(1); + }); + + it("removes the port from the set of injected script ports", () => { + const port = mock({ + name: AutofillPort.InjectedScript, + }) as any; + autofillService["autofillScriptPortsSet"].add(port); + + autofillService["handleInjectScriptPortOnDisconnect"](port); + + expect(autofillService["autofillScriptPortsSet"].size).toBe(0); + }); + }); }); diff --git a/apps/browser/src/autofill/services/autofill.service.ts b/apps/browser/src/autofill/services/autofill.service.ts index a1ef5a47a1c..eddc8f93a9e 100644 --- a/apps/browser/src/autofill/services/autofill.service.ts +++ b/apps/browser/src/autofill/services/autofill.service.ts @@ -2,6 +2,8 @@ import { EventCollectionService } from "@bitwarden/common/abstractions/event/eve import { SettingsService } from "@bitwarden/common/abstractions/settings.service"; import { UserVerificationService } from "@bitwarden/common/auth/abstractions/user-verification/user-verification.service.abstraction"; import { EventType } from "@bitwarden/common/enums"; +import { FeatureFlag } from "@bitwarden/common/enums/feature-flag.enum"; +import { ConfigServiceAbstraction } from "@bitwarden/common/platform/abstractions/config/config.service.abstraction"; import { LogService } from "@bitwarden/common/platform/abstractions/log.service"; import { CipherService } from "@bitwarden/common/vault/abstractions/cipher.service"; import { TotpService } from "@bitwarden/common/vault/abstractions/totp.service"; @@ -13,6 +15,7 @@ import { FieldView } from "@bitwarden/common/vault/models/view/field.view"; import { BrowserApi } from "../../platform/browser/browser-api"; import { BrowserStateService } from "../../platform/services/abstractions/browser-state.service"; import { openVaultItemPasswordRepromptPopout } from "../../vault/popup/utils/vault-popout-window"; +import { AutofillPort } from "../enums/autofill-port.enums"; import AutofillField from "../models/autofill-field"; import AutofillPageDetails from "../models/autofill-page-details"; import AutofillScript from "../models/autofill-script"; @@ -35,6 +38,7 @@ export default class AutofillService implements AutofillServiceInterface { private openVaultItemPasswordRepromptPopout = openVaultItemPasswordRepromptPopout; private openPasswordRepromptPopoutDebounce: NodeJS.Timeout; private currentlyOpeningPasswordRepromptPopout = false; + private autofillScriptPortsSet = new Set(); constructor( private cipherService: CipherService, @@ -44,23 +48,54 @@ export default class AutofillService implements AutofillServiceInterface { private logService: LogService, private settingsService: SettingsService, private userVerificationService: UserVerificationService, + private configService: ConfigServiceAbstraction, ) {} + /** + * Triggers on installation of the extension Handles injecting + * content scripts into all tabs that are currently open, and + * sets up a listener to ensure content scripts can identify + * if the extension context has been disconnected. + */ + async loadAutofillScriptsOnInstall() { + BrowserApi.addListener(chrome.runtime.onConnect, this.handleInjectedScriptPortConnection); + + this.injectAutofillScriptsInAllTabs(); + } + + /** + * Triggers a complete reload of all autofill scripts on tabs open within + * the user's browsing session. This is done by first disconnecting all + * existing autofill content script ports, which cleans up existing object + * instances, and then re-injecting the autofill scripts into all tabs. + */ + async reloadAutofillScripts() { + this.autofillScriptPortsSet.forEach((port) => { + port.disconnect(); + this.autofillScriptPortsSet.delete(port); + }); + + this.injectAutofillScriptsInAllTabs(); + } + /** * Injects the autofill scripts into the current tab and all frames * found within the tab. Temporarily, will conditionally inject * the refactor of the core autofill script if the feature flag * is enabled. - * @param {chrome.runtime.MessageSender} sender - * @param {boolean} autofillV2 - * @param {boolean} autofillOverlay - * @returns {Promise} + * @param {chrome.tabs.Tab} tab + * @param {number} frameId + * @param {boolean} triggeringOnPageLoad */ async injectAutofillScripts( - sender: chrome.runtime.MessageSender, - autofillV2 = false, - autofillOverlay = false, - ) { + tab: chrome.tabs.Tab, + frameId = 0, + triggeringOnPageLoad = true, + ): Promise { + const autofillV2 = await this.configService.getFeatureFlag(FeatureFlag.AutofillV2); + const autofillOverlay = await this.configService.getFeatureFlag( + FeatureFlag.AutofillOverlay, + ); let mainAutofillScript = "autofill.js"; const isUsingAutofillOverlay = @@ -73,20 +108,24 @@ export default class AutofillService implements AutofillServiceInterface { : "bootstrap-autofill.js"; } - const injectedScripts = [ - mainAutofillScript, - "autofiller.js", - "notificationBar.js", - "contextMenuHandler.js", - ]; + const injectedScripts = [mainAutofillScript]; + if (triggeringOnPageLoad) { + injectedScripts.push("autofiller.js"); + } + injectedScripts.push("notificationBar.js", "contextMenuHandler.js"); for (const injectedScript of injectedScripts) { - await BrowserApi.executeScriptInTab(sender.tab.id, { + await BrowserApi.executeScriptInTab(tab.id, { file: `content/${injectedScript}`, - frameId: sender.frameId, + frameId, runAt: "document_start", }); } + + await BrowserApi.executeScriptInTab(tab.id, { + file: "content/message_handler.js", + runAt: "document_start", + }); } /** @@ -1877,4 +1916,47 @@ export default class AutofillService implements AutofillServiceInterface { return false; } + + /** + * Handles incoming long-lived connections from injected autofill scripts. + * Stores the port in a set to facilitate disconnecting ports if the extension + * needs to re-inject the autofill scripts. + * + * @param port - The port that was connected + */ + private handleInjectedScriptPortConnection = (port: chrome.runtime.Port) => { + if (port.name !== AutofillPort.InjectedScript) { + return; + } + + this.autofillScriptPortsSet.add(port); + port.onDisconnect.addListener(this.handleInjectScriptPortOnDisconnect); + }; + + /** + * Handles disconnecting ports that relate to injected autofill scripts. + + * @param port - The port that was disconnected + */ + private handleInjectScriptPortOnDisconnect = (port: chrome.runtime.Port) => { + if (port.name !== AutofillPort.InjectedScript) { + return; + } + + this.autofillScriptPortsSet.delete(port); + }; + + /** + * Queries all open tabs in the user's browsing session + * and injects the autofill scripts into the page. + */ + private async injectAutofillScriptsInAllTabs() { + const tabs = await BrowserApi.tabsQuery({}); + for (let index = 0; index < tabs.length; index++) { + const tab = tabs[index]; + if (tab.url?.startsWith("http")) { + this.injectAutofillScripts(tab, 0, false); + } + } + } } diff --git a/apps/browser/src/autofill/services/collect-autofill-content.service.ts b/apps/browser/src/autofill/services/collect-autofill-content.service.ts index d675a37921d..ebddc201417 100644 --- a/apps/browser/src/autofill/services/collect-autofill-content.service.ts +++ b/apps/browser/src/autofill/services/collect-autofill-content.service.ts @@ -1249,6 +1249,17 @@ class CollectAutofillContentService implements CollectAutofillContentServiceInte return attributeValue; } + + /** + * Destroys the CollectAutofillContentService. Clears all + * timeouts and disconnects the mutation observer. + */ + destroy() { + if (this.updateAutofillElementsAfterMutationTimeout) { + clearTimeout(this.updateAutofillElementsAfterMutationTimeout); + } + this.mutationObserver?.disconnect(); + } } export default CollectAutofillContentService; diff --git a/apps/browser/src/autofill/utils/utils.spec.ts b/apps/browser/src/autofill/utils/index.spec.ts similarity index 51% rename from apps/browser/src/autofill/utils/utils.spec.ts rename to apps/browser/src/autofill/utils/index.spec.ts index 1da83fef242..4024d5839a8 100644 --- a/apps/browser/src/autofill/utils/utils.spec.ts +++ b/apps/browser/src/autofill/utils/index.spec.ts @@ -1,10 +1,17 @@ +import { AutofillPort } from "../enums/autofill-port.enums"; +import { triggerPortOnDisconnectEvent } from "../jest/testing-utils"; + import { logoIcon, logoLockedIcon } from "./svg-icons"; + import { buildSvgDomElement, generateRandomCustomElementName, sendExtensionMessage, setElementStyles, -} from "./utils"; + getFromLocalStorage, + setupExtensionDisconnectAction, + setupAutofillInitDisconnectAction, +} from "./index"; describe("buildSvgDomElement", () => { it("returns an SVG DOM element", () => { @@ -116,3 +123,107 @@ describe("setElementStyles", () => { expect(testDiv.style.cssText).toEqual(expectedCSSRuleString); }); }); + +describe("getFromLocalStorage", () => { + it("returns a promise with the storage object pulled from the extension storage api", async () => { + const localStorage: Record = { + testValue: "test", + another: "another", + }; + jest.spyOn(chrome.storage.local, "get").mockImplementation((keys, callback) => { + const localStorageObject: Record = {}; + + if (typeof keys === "string") { + localStorageObject[keys] = localStorage[keys]; + } else if (Array.isArray(keys)) { + for (const key of keys) { + localStorageObject[key] = localStorage[key]; + } + } + + callback(localStorageObject); + }); + + const returnValue = await getFromLocalStorage("testValue"); + + expect(chrome.storage.local.get).toHaveBeenCalled(); + expect(returnValue).toEqual({ testValue: "test" }); + }); +}); + +describe("setupExtensionDisconnectAction", () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + it("connects a port to the extension background and sets up an onDisconnect listener", () => { + const onDisconnectCallback = jest.fn(); + let port: chrome.runtime.Port; + jest.spyOn(chrome.runtime, "connect").mockImplementation(() => { + port = { + onDisconnect: { + addListener: onDisconnectCallback, + removeListener: jest.fn(), + }, + } as unknown as chrome.runtime.Port; + + return port; + }); + + setupExtensionDisconnectAction(onDisconnectCallback); + + expect(chrome.runtime.connect).toHaveBeenCalledWith({ + name: AutofillPort.InjectedScript, + }); + expect(port.onDisconnect.addListener).toHaveBeenCalledWith(expect.any(Function)); + }); +}); + +describe("setupAutofillInitDisconnectAction", () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + it("skips setting up the extension disconnect action if the bitwardenAutofillInit object is not populated", () => { + const onDisconnectCallback = jest.fn(); + window.bitwardenAutofillInit = undefined; + const portConnectSpy = jest.spyOn(chrome.runtime, "connect").mockImplementation(() => { + return { + onDisconnect: { + addListener: onDisconnectCallback, + removeListener: jest.fn(), + }, + } as unknown as chrome.runtime.Port; + }); + + setupAutofillInitDisconnectAction(window); + + expect(portConnectSpy).not.toHaveBeenCalled(); + }); + + it("destroys the autofill init instance when the port is disconnected", () => { + let port: chrome.runtime.Port; + const autofillInitDestroy: CallableFunction = jest.fn(); + window.bitwardenAutofillInit = { + destroy: autofillInitDestroy, + } as any; + jest.spyOn(chrome.runtime, "connect").mockImplementation(() => { + port = { + onDisconnect: { + addListener: jest.fn(), + removeListener: jest.fn(), + }, + } as unknown as chrome.runtime.Port; + + return port; + }); + + setupAutofillInitDisconnectAction(window); + triggerPortOnDisconnectEvent(port as chrome.runtime.Port); + + expect(chrome.runtime.connect).toHaveBeenCalled(); + expect(port.onDisconnect.addListener).toHaveBeenCalled(); + expect(autofillInitDestroy).toHaveBeenCalled(); + expect(window.bitwardenAutofillInit).toBeUndefined(); + }); +}); diff --git a/apps/browser/src/autofill/utils/utils.ts b/apps/browser/src/autofill/utils/index.ts similarity index 65% rename from apps/browser/src/autofill/utils/utils.ts rename to apps/browser/src/autofill/utils/index.ts index 73e133da32c..a2ce51c8cc1 100644 --- a/apps/browser/src/autofill/utils/utils.ts +++ b/apps/browser/src/autofill/utils/index.ts @@ -1,3 +1,5 @@ +import { AutofillPort } from "../enums/autofill-port.enums"; + /** * Generates a random string of characters that formatted as a custom element name. */ @@ -103,9 +105,57 @@ function setElementStyles( } } +/** + * Get data from local storage based on the keys provided. + * + * @param keys - String or array of strings of keys to get from local storage + */ +async function getFromLocalStorage(keys: string | string[]): Promise> { + return new Promise((resolve) => { + chrome.storage.local.get(keys, (storage: Record) => resolve(storage)); + }); +} + +/** + * Sets up a long-lived connection with the extension background + * and triggers an onDisconnect event if the extension context + * is invalidated. + * + * @param callback - Callback function to run when the extension disconnects + */ +function setupExtensionDisconnectAction(callback: (port: chrome.runtime.Port) => void) { + const port = chrome.runtime.connect({ name: AutofillPort.InjectedScript }); + const onDisconnectCallback = (disconnectedPort: chrome.runtime.Port) => { + callback(disconnectedPort); + port.onDisconnect.removeListener(onDisconnectCallback); + }; + port.onDisconnect.addListener(onDisconnectCallback); +} + +/** + * Handles setup of the extension disconnect action for the autofill init class + * in both instances where the overlay might or might not be initialized. + * + * @param windowContext - The global window context + */ +function setupAutofillInitDisconnectAction(windowContext: Window) { + if (!windowContext.bitwardenAutofillInit) { + return; + } + + const onDisconnectCallback = () => { + windowContext.bitwardenAutofillInit.destroy(); + delete windowContext.bitwardenAutofillInit; + }; + setupExtensionDisconnectAction(onDisconnectCallback); +} + export { generateRandomCustomElementName, buildSvgDomElement, sendExtensionMessage, setElementStyles, + getFromLocalStorage, + setupExtensionDisconnectAction, + setupAutofillInitDisconnectAction, }; diff --git a/apps/browser/src/background/idle.background.ts b/apps/browser/src/background/idle.background.ts index ccc76883b11..7ef6f7090bd 100644 --- a/apps/browser/src/background/idle.background.ts +++ b/apps/browser/src/background/idle.background.ts @@ -1,5 +1,8 @@ +import { firstValueFrom } from "rxjs"; + import { NotificationsService } from "@bitwarden/common/abstractions/notifications.service"; import { VaultTimeoutService } from "@bitwarden/common/abstractions/vault-timeout/vault-timeout.service"; +import { AccountService } from "@bitwarden/common/auth/abstractions/account.service"; import { VaultTimeoutAction } from "@bitwarden/common/enums/vault-timeout-action.enum"; import { BrowserStateService } from "../platform/services/abstractions/browser-state.service"; @@ -7,7 +10,7 @@ import { BrowserStateService } from "../platform/services/abstractions/browser-s const IdleInterval = 60 * 5; // 5 minutes export default class IdleBackground { - private idle: any; + private idle: typeof chrome.idle | typeof browser.idle | null; private idleTimer: number = null; private idleState = "active"; @@ -15,6 +18,7 @@ export default class IdleBackground { private vaultTimeoutService: VaultTimeoutService, private stateService: BrowserStateService, private notificationsService: NotificationsService, + private accountService: AccountService, ) { this.idle = chrome.idle || (browser != null ? browser.idle : null); } @@ -39,21 +43,27 @@ export default class IdleBackground { } if (this.idle.onStateChanged) { - this.idle.onStateChanged.addListener(async (newState: string) => { - if (newState === "locked") { - // If the screen is locked or the screensaver activates - const timeout = await this.stateService.getVaultTimeout(); - if (timeout === -2) { - // On System Lock vault timeout option - const action = await this.stateService.getVaultTimeoutAction(); - if (action === VaultTimeoutAction.LogOut) { - await this.vaultTimeoutService.logOut(); - } else { - await this.vaultTimeoutService.lock(); + this.idle.onStateChanged.addListener( + async (newState: chrome.idle.IdleState | browser.idle.IdleState) => { + if (newState === "locked") { + // Need to check if any of the current users have their timeout set to `onLocked` + const allUsers = await firstValueFrom(this.accountService.accounts$); + for (const userId in allUsers) { + // If the screen is locked or the screensaver activates + const timeout = await this.stateService.getVaultTimeout({ userId: userId }); + if (timeout === -2) { + // On System Lock vault timeout option + const action = await this.stateService.getVaultTimeoutAction({ userId: userId }); + if (action === VaultTimeoutAction.LogOut) { + await this.vaultTimeoutService.logOut(userId); + } else { + await this.vaultTimeoutService.lock(userId); + } + } } } - } - }); + }, + ); } } diff --git a/apps/browser/src/background/main.background.ts b/apps/browser/src/background/main.background.ts index 2c1573d4e9b..fa77c498953 100644 --- a/apps/browser/src/background/main.background.ts +++ b/apps/browser/src/background/main.background.ts @@ -566,6 +566,7 @@ export default class MainBackground { this.logService, this.settingsService, this.userVerificationService, + this.configService, ); this.auditService = new AuditService(this.cryptoFunctionService, this.apiService); @@ -729,6 +730,7 @@ export default class MainBackground { this.vaultTimeoutService, this.stateService, this.notificationsService, + this.accountService, ); this.webRequestBackground = new WebRequestBackground( this.platformUtilsService, diff --git a/apps/browser/src/background/runtime.background.ts b/apps/browser/src/background/runtime.background.ts index fcaefc7c5e2..de43e585603 100644 --- a/apps/browser/src/background/runtime.background.ts +++ b/apps/browser/src/background/runtime.background.ts @@ -1,5 +1,4 @@ import { NotificationsService } from "@bitwarden/common/abstractions/notifications.service"; -import { FeatureFlag } from "@bitwarden/common/enums/feature-flag.enum"; import { ConfigServiceAbstraction } from "@bitwarden/common/platform/abstractions/config/config.service.abstraction"; import { I18nService } from "@bitwarden/common/platform/abstractions/i18n.service"; import { LogService } from "@bitwarden/common/platform/abstractions/log.service"; @@ -97,9 +96,9 @@ export default class RuntimeBackground { await closeUnlockPopout(); } + await this.notificationsService.updateConnection(msg.command === "loggedIn"); await this.main.refreshBadge(); await this.main.refreshMenu(false); - this.notificationsService.updateConnection(msg.command === "unlocked"); this.systemService.cancelProcessReload(); if (item) { @@ -133,11 +132,7 @@ export default class RuntimeBackground { await this.main.openPopup(); break; case "triggerAutofillScriptInjection": - await this.autofillService.injectAutofillScripts( - sender, - await this.configService.getFeatureFlag(FeatureFlag.AutofillV2), - await this.configService.getFeatureFlag(FeatureFlag.AutofillOverlay), - ); + await this.autofillService.injectAutofillScripts(sender.tab, sender.frameId); break; case "bgCollectPageDetails": await this.main.collectPageDetailsForContentScript(sender.tab, msg.sender, sender.frameId); @@ -325,6 +320,8 @@ export default class RuntimeBackground { private async checkOnInstalled() { setTimeout(async () => { + this.autofillService.loadAutofillScriptsOnInstall(); + if (this.onInstalledReason != null) { if (this.onInstalledReason === "install") { BrowserApi.createNewTab("https://bitwarden.com/browser-start/"); diff --git a/apps/browser/src/manifest.json b/apps/browser/src/manifest.json index aa71d648e8e..09f12c56f18 100644 --- a/apps/browser/src/manifest.json +++ b/apps/browser/src/manifest.json @@ -24,12 +24,6 @@ "matches": ["http://*/*", "https://*/*", "file:///*"], "run_at": "document_start" }, - { - "all_frames": false, - "js": ["content/message_handler.js"], - "matches": ["http://*/*", "https://*/*", "file:///*"], - "run_at": "document_start" - }, { "all_frames": true, "css": ["content/autofill.css"], diff --git a/apps/browser/test.setup.ts b/apps/browser/test.setup.ts index 2da36f0a5a5..647e4cfdfb3 100644 --- a/apps/browser/test.setup.ts +++ b/apps/browser/test.setup.ts @@ -23,11 +23,12 @@ const runtime = { removeListener: jest.fn(), }, sendMessage: jest.fn(), - getManifest: jest.fn(), + getManifest: jest.fn(() => ({ version: 2 })), getURL: jest.fn((path) => `chrome-extension://id/${path}`), connect: jest.fn(), onConnect: { addListener: jest.fn(), + removeListener: jest.fn(), }, }; diff --git a/apps/desktop/src/auth/lock.component.spec.ts b/apps/desktop/src/auth/lock.component.spec.ts new file mode 100644 index 00000000000..fa947319571 --- /dev/null +++ b/apps/desktop/src/auth/lock.component.spec.ts @@ -0,0 +1,394 @@ +import { NO_ERRORS_SCHEMA } from "@angular/core"; +import { ComponentFixture, TestBed, fakeAsync, tick } from "@angular/core/testing"; +import { ActivatedRoute } from "@angular/router"; +import { MockProxy, mock } from "jest-mock-extended"; +import { of } from "rxjs"; + +import { LockComponent as BaseLockComponent } from "@bitwarden/angular/auth/components/lock.component"; +import { I18nPipe } from "@bitwarden/angular/platform/pipes/i18n.pipe"; +import { ApiService } from "@bitwarden/common/abstractions/api.service"; +import { VaultTimeoutSettingsService } from "@bitwarden/common/abstractions/vault-timeout/vault-timeout-settings.service"; +import { VaultTimeoutService } from "@bitwarden/common/abstractions/vault-timeout/vault-timeout.service"; +import { PolicyApiServiceAbstraction } from "@bitwarden/common/admin-console/abstractions/policy/policy-api.service.abstraction"; +import { InternalPolicyService } from "@bitwarden/common/admin-console/abstractions/policy/policy.service.abstraction"; +import { DeviceTrustCryptoServiceAbstraction } from "@bitwarden/common/auth/abstractions/device-trust-crypto.service.abstraction"; +import { UserVerificationService } from "@bitwarden/common/auth/abstractions/user-verification/user-verification.service.abstraction"; +import { BroadcasterService } from "@bitwarden/common/platform/abstractions/broadcaster.service"; +import { CryptoService } from "@bitwarden/common/platform/abstractions/crypto.service"; +import { EnvironmentService } from "@bitwarden/common/platform/abstractions/environment.service"; +import { I18nService } from "@bitwarden/common/platform/abstractions/i18n.service"; +import { LogService } from "@bitwarden/common/platform/abstractions/log.service"; +import { MessagingService } from "@bitwarden/common/platform/abstractions/messaging.service"; +import { PlatformUtilsService } from "@bitwarden/common/platform/abstractions/platform-utils.service"; +import { PasswordStrengthServiceAbstraction } from "@bitwarden/common/tools/password-strength"; +import { DialogService } from "@bitwarden/components"; + +import { ElectronStateService } from "../platform/services/electron-state.service.abstraction"; + +import { LockComponent } from "./lock.component"; + +// ipc mock global +const isWindowVisibleMock = jest.fn(); +(global as any).ipc = { + platform: { + biometric: { + enabled: jest.fn(), + }, + isWindowVisible: isWindowVisibleMock, + }, +}; + +describe("LockComponent", () => { + let component: LockComponent; + let fixture: ComponentFixture; + let stateServiceMock: MockProxy; + let messagingServiceMock: MockProxy; + let broadcasterServiceMock: MockProxy; + let platformUtilsServiceMock: MockProxy; + let activatedRouteMock: MockProxy; + + beforeEach(() => { + stateServiceMock = mock(); + stateServiceMock.activeAccount$ = of(null); + + messagingServiceMock = mock(); + broadcasterServiceMock = mock(); + platformUtilsServiceMock = mock(); + + activatedRouteMock = mock(); + activatedRouteMock.queryParams = mock(); + + TestBed.configureTestingModule({ + declarations: [LockComponent, I18nPipe], + providers: [ + { + provide: I18nService, + useValue: mock(), + }, + { + provide: PlatformUtilsService, + useValue: platformUtilsServiceMock, + }, + { + provide: MessagingService, + useValue: messagingServiceMock, + }, + { + provide: CryptoService, + useValue: mock(), + }, + { + provide: VaultTimeoutService, + useValue: mock(), + }, + { + provide: VaultTimeoutSettingsService, + useValue: mock(), + }, + { + provide: EnvironmentService, + useValue: mock(), + }, + { + provide: ElectronStateService, + useValue: stateServiceMock, + }, + { + provide: ApiService, + useValue: mock(), + }, + { + provide: ActivatedRoute, + useValue: activatedRouteMock, + }, + { + provide: BroadcasterService, + useValue: broadcasterServiceMock, + }, + { + provide: PolicyApiServiceAbstraction, + useValue: mock(), + }, + { + provide: InternalPolicyService, + useValue: mock(), + }, + { + provide: PasswordStrengthServiceAbstraction, + useValue: mock(), + }, + { + provide: LogService, + useValue: mock(), + }, + { + provide: DialogService, + useValue: mock(), + }, + { + provide: DeviceTrustCryptoServiceAbstraction, + useValue: mock(), + }, + { + provide: UserVerificationService, + useValue: mock(), + }, + ], + schemas: [NO_ERRORS_SCHEMA], + }).compileComponents(); + }); + + beforeEach(() => { + fixture = TestBed.createComponent(LockComponent); + component = fixture.componentInstance; + fixture.detectChanges(); + jest.clearAllMocks(); + }); + + describe("ngOnInit", () => { + it("should call super.ngOnInit() once", async () => { + const superNgOnInitSpy = jest.spyOn(BaseLockComponent.prototype, "ngOnInit"); + await component.ngOnInit(); + expect(superNgOnInitSpy).toHaveBeenCalledTimes(1); + }); + + it('should set "autoPromptBiometric" to true if "stateService.getDisableAutoBiometricsPrompt()" resolves to false', async () => { + stateServiceMock.getDisableAutoBiometricsPrompt.mockResolvedValue(false); + + await component.ngOnInit(); + expect(component["autoPromptBiometric"]).toBe(true); + }); + + it('should set "autoPromptBiometric" to false if "stateService.getDisableAutoBiometricsPrompt()" resolves to true', async () => { + stateServiceMock.getDisableAutoBiometricsPrompt.mockResolvedValue(true); + + await component.ngOnInit(); + expect(component["autoPromptBiometric"]).toBe(false); + }); + + it('should set "biometricReady" to true if "stateService.getBiometricReady()" resolves to true', async () => { + component["canUseBiometric"] = jest.fn().mockResolvedValue(true); + + await component.ngOnInit(); + expect(component["biometricReady"]).toBe(true); + }); + + it('should set "biometricReady" to false if "stateService.getBiometricReady()" resolves to false', async () => { + component["canUseBiometric"] = jest.fn().mockResolvedValue(false); + + await component.ngOnInit(); + expect(component["biometricReady"]).toBe(false); + }); + + it("should call displayBiometricUpdateWarning", async () => { + component["displayBiometricUpdateWarning"] = jest.fn(); + await component.ngOnInit(); + expect(component["displayBiometricUpdateWarning"]).toHaveBeenCalledTimes(1); + }); + + it("should call delayedAskForBiometric", async () => { + component["delayedAskForBiometric"] = jest.fn(); + await component.ngOnInit(); + expect(component["delayedAskForBiometric"]).toHaveBeenCalledTimes(1); + expect(component["delayedAskForBiometric"]).toHaveBeenCalledWith(500); + }); + + it("should call delayedAskForBiometric when queryParams change", async () => { + activatedRouteMock.queryParams = of({ promptBiometric: true }); + component["delayedAskForBiometric"] = jest.fn(); + await component.ngOnInit(); + + expect(component["delayedAskForBiometric"]).toHaveBeenCalledTimes(1); + expect(component["delayedAskForBiometric"]).toHaveBeenCalledWith(500); + }); + + it("should call messagingService.send", async () => { + await component.ngOnInit(); + expect(messagingServiceMock.send).toHaveBeenCalledWith("getWindowIsFocused"); + }); + + describe("broadcasterService.subscribe", () => { + it('should call onWindowHidden() when "broadcasterService.subscribe" is called with "windowHidden"', async () => { + component["onWindowHidden"] = jest.fn(); + await component.ngOnInit(); + broadcasterServiceMock.subscribe.mock.calls[0][1]({ command: "windowHidden" }); + expect(component["onWindowHidden"]).toHaveBeenCalledTimes(1); + }); + + it('should call focusInput() when "broadcasterService.subscribe" is called with "windowIsFocused" is true and deferFocus is false', async () => { + component["focusInput"] = jest.fn(); + component["deferFocus"] = null; + await component.ngOnInit(); + broadcasterServiceMock.subscribe.mock.calls[0][1]({ + command: "windowIsFocused", + windowIsFocused: true, + } as any); + expect(component["deferFocus"]).toBe(false); + expect(component["focusInput"]).toHaveBeenCalledTimes(1); + }); + + it('should not call focusInput() when "broadcasterService.subscribe" is called with "windowIsFocused" is true and deferFocus is true', async () => { + component["focusInput"] = jest.fn(); + component["deferFocus"] = null; + await component.ngOnInit(); + broadcasterServiceMock.subscribe.mock.calls[0][1]({ + command: "windowIsFocused", + windowIsFocused: false, + } as any); + expect(component["deferFocus"]).toBe(true); + expect(component["focusInput"]).toHaveBeenCalledTimes(0); + }); + + it('should call focusInput() when "broadcasterService.subscribe" is called with "windowIsFocused" is true and deferFocus is true', async () => { + component["focusInput"] = jest.fn(); + component["deferFocus"] = true; + await component.ngOnInit(); + broadcasterServiceMock.subscribe.mock.calls[0][1]({ + command: "windowIsFocused", + windowIsFocused: true, + } as any); + expect(component["deferFocus"]).toBe(false); + expect(component["focusInput"]).toHaveBeenCalledTimes(1); + }); + + it('should not call focusInput() when "broadcasterService.subscribe" is called with "windowIsFocused" is false and deferFocus is true', async () => { + component["focusInput"] = jest.fn(); + component["deferFocus"] = true; + await component.ngOnInit(); + broadcasterServiceMock.subscribe.mock.calls[0][1]({ + command: "windowIsFocused", + windowIsFocused: false, + } as any); + expect(component["deferFocus"]).toBe(true); + expect(component["focusInput"]).toHaveBeenCalledTimes(0); + }); + }); + }); + + describe("ngOnDestroy", () => { + it("should call super.ngOnDestroy()", () => { + const superNgOnDestroySpy = jest.spyOn(BaseLockComponent.prototype, "ngOnDestroy"); + component.ngOnDestroy(); + expect(superNgOnDestroySpy).toHaveBeenCalledTimes(1); + }); + + it("should call broadcasterService.unsubscribe()", () => { + component.ngOnDestroy(); + expect(broadcasterServiceMock.unsubscribe).toHaveBeenCalledTimes(1); + }); + }); + + describe("focusInput", () => { + it('should call "focus" on #pin input if pinEnabled is true', () => { + component["pinEnabled"] = true; + global.document.getElementById = jest.fn().mockReturnValue({ focus: jest.fn() }); + component["focusInput"](); + expect(global.document.getElementById).toHaveBeenCalledWith("pin"); + }); + + it('should call "focus" on #masterPassword input if pinEnabled is false', () => { + component["pinEnabled"] = false; + global.document.getElementById = jest.fn().mockReturnValue({ focus: jest.fn() }); + component["focusInput"](); + expect(global.document.getElementById).toHaveBeenCalledWith("masterPassword"); + }); + }); + + describe("delayedAskForBiometric", () => { + beforeEach(() => { + component["supportsBiometric"] = true; + component["autoPromptBiometric"] = true; + }); + + it('should wait for "delay" milliseconds', fakeAsync(async () => { + const delaySpy = jest.spyOn(global, "setTimeout"); + component["delayedAskForBiometric"](5000); + + tick(4000); + component["biometricAsked"] = false; + + tick(1000); + component["biometricAsked"] = true; + + expect(delaySpy).toHaveBeenCalledWith(expect.any(Function), 5000); + })); + + it('should return; if "params" is defined and "params.promptBiometric" is false', fakeAsync(async () => { + component["delayedAskForBiometric"](5000, { promptBiometric: false }); + tick(5000); + expect(component["biometricAsked"]).toBe(false); + })); + + it('should not return; if "params" is defined and "params.promptBiometric" is true', fakeAsync(async () => { + component["delayedAskForBiometric"](5000, { promptBiometric: true }); + tick(5000); + expect(component["biometricAsked"]).toBe(true); + })); + + it('should not return; if "params" is undefined', fakeAsync(async () => { + component["delayedAskForBiometric"](5000); + tick(5000); + expect(component["biometricAsked"]).toBe(true); + })); + + it('should return; if "supportsBiometric" is false', fakeAsync(async () => { + component["supportsBiometric"] = false; + component["delayedAskForBiometric"](5000); + tick(5000); + expect(component["biometricAsked"]).toBe(false); + })); + + it('should return; if "autoPromptBiometric" is false', fakeAsync(async () => { + component["autoPromptBiometric"] = false; + component["delayedAskForBiometric"](5000); + tick(5000); + expect(component["biometricAsked"]).toBe(false); + })); + + it("should call unlockBiometric() if biometricAsked is false and window is visible", fakeAsync(async () => { + isWindowVisibleMock.mockResolvedValue(true); + component["unlockBiometric"] = jest.fn(); + component["biometricAsked"] = false; + component["delayedAskForBiometric"](5000); + tick(5000); + + expect(component["unlockBiometric"]).toHaveBeenCalledTimes(1); + })); + + it("should not call unlockBiometric() if biometricAsked is false and window is not visible", fakeAsync(async () => { + isWindowVisibleMock.mockResolvedValue(false); + component["unlockBiometric"] = jest.fn(); + component["biometricAsked"] = false; + component["delayedAskForBiometric"](5000); + tick(5000); + + expect(component["unlockBiometric"]).toHaveBeenCalledTimes(0); + })); + + it("should not call unlockBiometric() if biometricAsked is true", fakeAsync(async () => { + isWindowVisibleMock.mockResolvedValue(true); + component["unlockBiometric"] = jest.fn(); + component["biometricAsked"] = true; + + component["delayedAskForBiometric"](5000); + tick(5000); + + expect(component["unlockBiometric"]).toHaveBeenCalledTimes(0); + })); + }); + + describe("canUseBiometric", () => { + it("should call getUserId() on stateService", async () => { + stateServiceMock.getUserId.mockResolvedValue("userId"); + await component["canUseBiometric"](); + + expect(ipc.platform.biometric.enabled).toHaveBeenCalledWith("userId"); + }); + }); + + it('onWindowHidden() should set "showPassword" to false', () => { + component["showPassword"] = true; + component["onWindowHidden"](); + expect(component["showPassword"]).toBe(false); + }); +}); diff --git a/apps/desktop/src/auth/lock.component.ts b/apps/desktop/src/auth/lock.component.ts index 91a60557fb9..3f62df7dd1c 100644 --- a/apps/desktop/src/auth/lock.component.ts +++ b/apps/desktop/src/auth/lock.component.ts @@ -1,5 +1,6 @@ import { Component, NgZone } from "@angular/core"; import { ActivatedRoute, Router } from "@angular/router"; +import { switchMap } from "rxjs"; import { LockComponent as BaseLockComponent } from "@bitwarden/angular/auth/components/lock.component"; import { ApiService } from "@bitwarden/common/abstractions/api.service"; @@ -31,6 +32,8 @@ const BroadcasterSubscriptionId = "LockComponent"; export class LockComponent extends BaseLockComponent { private deferFocus: boolean = null; protected biometricReady = false; + private biometricAsked = false; + private autoPromptBiometric = false; constructor( router: Router, @@ -78,23 +81,14 @@ export class LockComponent extends BaseLockComponent { async ngOnInit() { await super.ngOnInit(); - const autoPromptBiometric = !(await this.stateService.getDisableAutoBiometricsPrompt()); + this.autoPromptBiometric = !(await this.stateService.getDisableAutoBiometricsPrompt()); this.biometricReady = await this.canUseBiometric(); await this.displayBiometricUpdateWarning(); - // eslint-disable-next-line rxjs-angular/prefer-takeuntil - this.route.queryParams.subscribe((params) => { - setTimeout(async () => { - if (!params.promptBiometric || !this.supportsBiometric || !autoPromptBiometric) { - return; - } + this.delayedAskForBiometric(500); + this.route.queryParams.pipe(switchMap((params) => this.delayedAskForBiometric(500, params))); - if (await ipc.platform.isWindowVisible()) { - this.unlockBiometric(); - } - }, 1000); - }); this.broadcasterService.subscribe(BroadcasterSubscriptionId, async (message: any) => { this.ngZone.run(() => { switch (message.command) { @@ -128,6 +122,23 @@ export class LockComponent extends BaseLockComponent { this.showPassword = false; } + private async delayedAskForBiometric(delay: number, params?: any) { + await new Promise((resolve) => setTimeout(resolve, delay)); + + if (params && !params.promptBiometric) { + return; + } + + if (!this.supportsBiometric || !this.autoPromptBiometric || this.biometricAsked) { + return; + } + + this.biometricAsked = true; + if (await ipc.platform.isWindowVisible()) { + this.unlockBiometric(); + } + } + private async canUseBiometric() { const userId = await this.stateService.getUserId(); return await ipc.platform.biometric.enabled(userId); diff --git a/apps/web/src/app/auth/core/services/webauthn-login/request/enable-credential-encryption.request.ts b/apps/web/src/app/auth/core/services/webauthn-login/request/enable-credential-encryption.request.ts new file mode 100644 index 00000000000..6dc08728ad1 --- /dev/null +++ b/apps/web/src/app/auth/core/services/webauthn-login/request/enable-credential-encryption.request.ts @@ -0,0 +1,25 @@ +import { WebAuthnLoginAssertionResponseRequest } from "@bitwarden/common/auth/services/webauthn-login/request/webauthn-login-assertion-response.request"; + +/** + * Request sent to the server to save a newly created prf key set for a credential. + */ +export class EnableCredentialEncryptionRequest { + /** + * The response received from the authenticator. + */ + deviceResponse: WebAuthnLoginAssertionResponseRequest; + + /** + * An encrypted token containing information the server needs to verify the credential. + */ + token: string; + + /** Used for vault encryption. See {@link RotateableKeySet.encryptedUserKey } */ + encryptedUserKey?: string; + + /** Used for vault encryption. See {@link RotateableKeySet.encryptedPublicKey } */ + encryptedPublicKey?: string; + + /** Used for vault encryption. See {@link RotateableKeySet.encryptedPrivateKey } */ + encryptedPrivateKey?: string; +} diff --git a/apps/web/src/app/auth/core/services/webauthn-login/request/webauthn-login-attestation-response.request.ts b/apps/web/src/app/auth/core/services/webauthn-login/request/webauthn-login-attestation-response.request.ts index 249b2ebffa9..ef3d657f2f9 100644 --- a/apps/web/src/app/auth/core/services/webauthn-login/request/webauthn-login-attestation-response.request.ts +++ b/apps/web/src/app/auth/core/services/webauthn-login/request/webauthn-login-attestation-response.request.ts @@ -3,7 +3,7 @@ import { Utils } from "@bitwarden/common/platform/misc/utils"; import { WebauthnLoginAuthenticatorResponseRequest } from "./webauthn-login-authenticator-response.request"; /** - * The response received from an authentiator after a successful attestation. + * The response received from an authenticator after a successful attestation. * This request is used to save newly created webauthn login credentials to the server. */ export class WebauthnLoginAttestationResponseRequest extends WebauthnLoginAuthenticatorResponseRequest { diff --git a/apps/web/src/app/auth/core/services/webauthn-login/webauthn-login-admin-api.service.ts b/apps/web/src/app/auth/core/services/webauthn-login/webauthn-login-admin-api.service.ts index 8b99396b0fb..efa32d0c6fe 100644 --- a/apps/web/src/app/auth/core/services/webauthn-login/webauthn-login-admin-api.service.ts +++ b/apps/web/src/app/auth/core/services/webauthn-login/webauthn-login-admin-api.service.ts @@ -2,8 +2,10 @@ import { Injectable } from "@angular/core"; import { ApiService } from "@bitwarden/common/abstractions/api.service"; import { SecretVerificationRequest } from "@bitwarden/common/auth/models/request/secret-verification.request"; +import { CredentialAssertionOptionsResponse } from "@bitwarden/common/auth/services/webauthn-login/response/credential-assertion-options.response"; import { ListResponse } from "@bitwarden/common/models/response/list.response"; +import { EnableCredentialEncryptionRequest } from "./request/enable-credential-encryption.request"; import { SaveCredentialRequest } from "./request/save-credential.request"; import { WebauthnLoginCredentialCreateOptionsResponse } from "./response/webauthn-login-credential-create-options.response"; import { WebauthnLoginCredentialResponse } from "./response/webauthn-login-credential.response"; @@ -15,10 +17,29 @@ export class WebAuthnLoginAdminApiService { async getCredentialCreateOptions( request: SecretVerificationRequest, ): Promise { - const response = await this.apiService.send("POST", "/webauthn/options", request, true, true); + const response = await this.apiService.send( + "POST", + "/webauthn/attestation-options", + request, + true, + true, + ); return new WebauthnLoginCredentialCreateOptionsResponse(response); } + async getCredentialAssertionOptions( + request: SecretVerificationRequest, + ): Promise { + const response = await this.apiService.send( + "POST", + "/webauthn/assertion-options", + request, + true, + true, + ); + return new CredentialAssertionOptionsResponse(response); + } + async saveCredential(request: SaveCredentialRequest): Promise { await this.apiService.send("POST", "/webauthn", request, true, true); return true; @@ -31,4 +52,8 @@ export class WebAuthnLoginAdminApiService { async deleteCredential(credentialId: string, request: SecretVerificationRequest): Promise { await this.apiService.send("POST", `/webauthn/${credentialId}/delete`, request, true, true); } + + async updateCredential(request: EnableCredentialEncryptionRequest): Promise { + await this.apiService.send("PUT", `/webauthn`, request, true, true); + } } diff --git a/apps/web/src/app/auth/core/services/webauthn-login/webauthn-login-admin.service.spec.ts b/apps/web/src/app/auth/core/services/webauthn-login/webauthn-login-admin.service.spec.ts index 49c1f89052b..bc92114e878 100644 --- a/apps/web/src/app/auth/core/services/webauthn-login/webauthn-login-admin.service.spec.ts +++ b/apps/web/src/app/auth/core/services/webauthn-login/webauthn-login-admin.service.spec.ts @@ -1,12 +1,21 @@ +import { randomBytes } from "crypto"; + import { mock, MockProxy } from "jest-mock-extended"; +import { RotateableKeySet } from "@bitwarden/auth"; import { UserVerificationService } from "@bitwarden/common/auth/abstractions/user-verification/user-verification.service.abstraction"; import { WebAuthnLoginPrfCryptoServiceAbstraction } from "@bitwarden/common/auth/abstractions/webauthn/webauthn-login-prf-crypto.service.abstraction"; +import { WebAuthnLoginCredentialAssertionView } from "@bitwarden/common/auth/models/view/webauthn-login/webauthn-login-credential-assertion.view"; +import { WebAuthnLoginAssertionResponseRequest } from "@bitwarden/common/auth/services/webauthn-login/request/webauthn-login-assertion-response.request"; +import { Utils } from "@bitwarden/common/platform/misc/utils"; +import { EncString } from "@bitwarden/common/platform/models/domain/enc-string"; +import { PrfKey } from "@bitwarden/common/platform/models/domain/symmetric-crypto-key"; import { CredentialCreateOptionsView } from "../../views/credential-create-options.view"; import { PendingWebauthnLoginCredentialView } from "../../views/pending-webauthn-login-credential.view"; import { RotateableKeySetService } from "../rotateable-key-set.service"; +import { EnableCredentialEncryptionRequest } from "./request/enable-credential-encryption.request"; import { WebAuthnLoginAdminApiService } from "./webauthn-login-admin-api.service"; import { WebauthnLoginAdminService } from "./webauthn-login-admin.service"; @@ -18,10 +27,13 @@ describe("WebauthnAdminService", () => { let credentials: MockProxy; let service!: WebauthnLoginAdminService; + let originalAuthenticatorAssertionResponse!: AuthenticatorAssertionResponse | any; + beforeAll(() => { // Polyfill missing class window.PublicKeyCredential = class {} as any; window.AuthenticatorAttestationResponse = class {} as any; + window.AuthenticatorAssertionResponse = class {} as any; apiService = mock(); userVerificationService = mock(); rotateableKeySetService = mock(); @@ -34,6 +46,20 @@ describe("WebauthnAdminService", () => { webAuthnLoginPrfCryptoService, credentials, ); + + // Save original global class + originalAuthenticatorAssertionResponse = global.AuthenticatorAssertionResponse; + // Mock the global AuthenticatorAssertionResponse class b/c the class is only available in secure contexts + global.AuthenticatorAssertionResponse = MockAuthenticatorAssertionResponse; + }); + + beforeEach(() => { + jest.clearAllMocks(); + }); + + afterAll(() => { + // Restore global after all tests are done + global.AuthenticatorAssertionResponse = originalAuthenticatorAssertionResponse; }); describe("createCredential", () => { @@ -70,6 +96,94 @@ describe("WebauthnAdminService", () => { expect(result.supportsPrf).toBe(true); }); }); + + describe("enableCredentialEncryption", () => { + it("should call the necessary methods to update the credential", async () => { + // Arrange + const response = new MockPublicKeyCredential(); + const prfKeySet = new RotateableKeySet( + new EncString("test_encryptedUserKey"), + new EncString("test_encryptedPublicKey"), + new EncString("test_encryptedPrivateKey"), + ); + + const assertionOptions: WebAuthnLoginCredentialAssertionView = + new WebAuthnLoginCredentialAssertionView( + "enable_credential_encryption_test_token", + new WebAuthnLoginAssertionResponseRequest(response), + {} as PrfKey, + ); + + const request = new EnableCredentialEncryptionRequest(); + request.token = assertionOptions.token; + request.deviceResponse = assertionOptions.deviceResponse; + request.encryptedUserKey = prfKeySet.encryptedUserKey.encryptedString; + request.encryptedPublicKey = prfKeySet.encryptedPublicKey.encryptedString; + request.encryptedPrivateKey = prfKeySet.encryptedPrivateKey.encryptedString; + + // Mock the necessary methods and services + const createKeySetMock = jest + .spyOn(rotateableKeySetService, "createKeySet") + .mockResolvedValue(prfKeySet); + const updateCredentialMock = jest.spyOn(apiService, "updateCredential").mockResolvedValue(); + + // Act + await service.enableCredentialEncryption(assertionOptions); + + // Assert + expect(createKeySetMock).toHaveBeenCalledWith(assertionOptions.prfKey); + expect(updateCredentialMock).toHaveBeenCalledWith(request); + }); + + it("should throw error when PRF Key is undefined", async () => { + // Arrange + const response = new MockPublicKeyCredential(); + + const assertionOptions: WebAuthnLoginCredentialAssertionView = + new WebAuthnLoginCredentialAssertionView( + "enable_credential_encryption_test_token", + new WebAuthnLoginAssertionResponseRequest(response), + undefined, + ); + + // Mock the necessary methods and services + const createKeySetMock = jest + .spyOn(rotateableKeySetService, "createKeySet") + .mockResolvedValue(null); + const updateCredentialMock = jest.spyOn(apiService, "updateCredential").mockResolvedValue(); + + // Act + try { + await service.enableCredentialEncryption(assertionOptions); + } catch (error) { + // Assert + expect(error).toEqual(new Error("invalid credential")); + expect(createKeySetMock).not.toHaveBeenCalled(); + expect(updateCredentialMock).not.toHaveBeenCalled(); + } + }); + + it("should throw error when WehAuthnLoginCredentialAssertionView is undefined", async () => { + // Arrange + const assertionOptions: WebAuthnLoginCredentialAssertionView = undefined; + + // Mock the necessary methods and services + const createKeySetMock = jest + .spyOn(rotateableKeySetService, "createKeySet") + .mockResolvedValue(null); + const updateCredentialMock = jest.spyOn(apiService, "updateCredential").mockResolvedValue(); + + // Act + try { + await service.enableCredentialEncryption(assertionOptions); + } catch (error) { + // Assert + expect(error).toEqual(new Error("invalid credential")); + expect(createKeySetMock).not.toHaveBeenCalled(); + expect(updateCredentialMock).not.toHaveBeenCalled(); + } + }); + }); }); function createCredentialCreateOptions(): CredentialCreateOptionsView { @@ -115,3 +229,58 @@ function createDeviceResponse({ prf = false }: { prf?: boolean } = {}): PublicKe return credential; } + +/** + * Mocks for the PublicKeyCredential and AuthenticatorAssertionResponse classes copied from webauthn-login.service.spec.ts + */ + +// AuthenticatorAssertionResponse && PublicKeyCredential are only available in secure contexts +// so we need to mock them and assign them to the global object to make them available +// for the tests +class MockAuthenticatorAssertionResponse implements AuthenticatorAssertionResponse { + clientDataJSON: ArrayBuffer = randomBytes(32).buffer; + authenticatorData: ArrayBuffer = randomBytes(196).buffer; + signature: ArrayBuffer = randomBytes(72).buffer; + userHandle: ArrayBuffer = randomBytes(16).buffer; + + clientDataJSONB64Str = Utils.fromBufferToUrlB64(this.clientDataJSON); + authenticatorDataB64Str = Utils.fromBufferToUrlB64(this.authenticatorData); + signatureB64Str = Utils.fromBufferToUrlB64(this.signature); + userHandleB64Str = Utils.fromBufferToUrlB64(this.userHandle); +} + +class MockPublicKeyCredential implements PublicKeyCredential { + authenticatorAttachment = "cross-platform"; + id = "mockCredentialId"; + type = "public-key"; + rawId: ArrayBuffer = randomBytes(32).buffer; + rawIdB64Str = Utils.fromBufferToUrlB64(this.rawId); + + response: MockAuthenticatorAssertionResponse = new MockAuthenticatorAssertionResponse(); + + // Use random 64 character hex string (32 bytes - matters for symmetric key creation) + // to represent the prf key binary data and convert to ArrayBuffer + // Creating the array buffer from a known hex value allows us to + // assert on the value in tests + private prfKeyArrayBuffer: ArrayBuffer = Utils.hexStringToArrayBuffer( + "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", + ); + + getClientExtensionResults(): any { + return { + prf: { + results: { + first: this.prfKeyArrayBuffer, + }, + }, + }; + } + + static isConditionalMediationAvailable(): Promise { + return Promise.resolve(false); + } + + static isUserVerifyingPlatformAuthenticatorAvailable(): Promise { + return Promise.resolve(false); + } +} diff --git a/apps/web/src/app/auth/core/services/webauthn-login/webauthn-login-admin.service.ts b/apps/web/src/app/auth/core/services/webauthn-login/webauthn-login-admin.service.ts index 1fee82c74be..a59b2395e11 100644 --- a/apps/web/src/app/auth/core/services/webauthn-login/webauthn-login-admin.service.ts +++ b/apps/web/src/app/auth/core/services/webauthn-login/webauthn-login-admin.service.ts @@ -4,6 +4,8 @@ import { BehaviorSubject, filter, from, map, Observable, shareReplay, switchMap, import { PrfKeySet } from "@bitwarden/auth"; import { UserVerificationService } from "@bitwarden/common/auth/abstractions/user-verification/user-verification.service.abstraction"; import { WebAuthnLoginPrfCryptoServiceAbstraction } from "@bitwarden/common/auth/abstractions/webauthn/webauthn-login-prf-crypto.service.abstraction"; +import { WebAuthnLoginCredentialAssertionOptionsView } from "@bitwarden/common/auth/models/view/webauthn-login/webauthn-login-credential-assertion-options.view"; +import { WebAuthnLoginCredentialAssertionView } from "@bitwarden/common/auth/models/view/webauthn-login/webauthn-login-credential-assertion.view"; import { Verification } from "@bitwarden/common/auth/types/verification"; import { LogService } from "@bitwarden/common/platform/abstractions/log.service"; @@ -12,6 +14,7 @@ import { PendingWebauthnLoginCredentialView } from "../../views/pending-webauthn import { WebauthnLoginCredentialView } from "../../views/webauthn-login-credential.view"; import { RotateableKeySetService } from "../rotateable-key-set.service"; +import { EnableCredentialEncryptionRequest } from "./request/enable-credential-encryption.request"; import { SaveCredentialRequest } from "./request/save-credential.request"; import { WebauthnLoginAttestationResponseRequest } from "./request/webauthn-login-attestation-response.request"; import { WebAuthnLoginAdminApiService } from "./webauthn-login-admin-api.service"; @@ -52,14 +55,31 @@ export class WebauthnLoginAdminService { } /** - * Get the credential attestation options needed for initiating the WebAuthnLogin credentail creation process. + * Get the credential assertion options needed for initiating the WebAuthnLogin credential update process. + * The options contains assertion options and other data for the authenticator. + * This method requires user verification. + * + * @param verification User verification data to be used for the request. + * @returns The credential assertion options and a token to be used for the credential update request. + */ + async getCredentialAssertOptions( + verification: Verification, + ): Promise { + const request = await this.userVerificationService.buildRequest(verification); + const response = await this.apiService.getCredentialAssertionOptions(request); + return new WebAuthnLoginCredentialAssertionOptionsView(response.options, response.token); + } + + /** + * Get the credential attestation options needed for initiating the WebAuthnLogin credential creation process. * The options contains a challenge and other data for the authenticator. * This method requires user verification. * * @param verification User verification data to be used for the request. * @returns The credential attestation options and a token to be used for the credential creation request. */ - async getCredentialCreateOptions( + + async getCredentialAttestationOptions( verification: Verification, ): Promise { const request = await this.userVerificationService.buildRequest(verification); @@ -169,6 +189,36 @@ export class WebauthnLoginAdminService { this.refresh(); } + /** + * Enable encryption for a credential that has already been saved to the server. + * This will update the KeySet associated with the credential in the database. + * We short circuit the process here incase the WebAuthnLoginCredential doesn't support PRF or + * if there was a problem with the Credential Assertion. + * + * @param assertionOptions Options received from the server using `getCredentialAssertOptions`. + * @returns void + */ + async enableCredentialEncryption( + assertionOptions: WebAuthnLoginCredentialAssertionView, + ): Promise { + if (assertionOptions === undefined || assertionOptions?.prfKey === undefined) { + throw new Error("invalid credential"); + } + + const prfKeySet: PrfKeySet = await this.rotateableKeySetService.createKeySet( + assertionOptions.prfKey, + ); + + const request = new EnableCredentialEncryptionRequest(); + request.token = assertionOptions.token; + request.deviceResponse = assertionOptions.deviceResponse; + request.encryptedUserKey = prfKeySet.encryptedUserKey.encryptedString; + request.encryptedPublicKey = prfKeySet.encryptedPublicKey.encryptedString; + request.encryptedPrivateKey = prfKeySet.encryptedPrivateKey.encryptedString; + await this.apiService.updateCredential(request); + this.refresh(); + } + /** * List of webauthn credentials saved on the server. * diff --git a/apps/web/src/app/auth/settings/webauthn-login-settings/create-credential-dialog/create-credential-dialog.component.ts b/apps/web/src/app/auth/settings/webauthn-login-settings/create-credential-dialog/create-credential-dialog.component.ts index fcb0e995f4a..4c5198ea132 100644 --- a/apps/web/src/app/auth/settings/webauthn-login-settings/create-credential-dialog/create-credential-dialog.component.ts +++ b/apps/web/src/app/auth/settings/webauthn-login-settings/create-credential-dialog/create-credential-dialog.component.ts @@ -94,7 +94,7 @@ export class CreateCredentialDialogComponent implements OnInit { } try { - this.credentialOptions = await this.webauthnService.getCredentialCreateOptions( + this.credentialOptions = await this.webauthnService.getCredentialAttestationOptions( this.formGroup.value.userVerification.secret, ); } catch (error) { diff --git a/apps/web/src/app/auth/settings/webauthn-login-settings/enable-encryption-dialog/enable-encryption-dialog.component.html b/apps/web/src/app/auth/settings/webauthn-login-settings/enable-encryption-dialog/enable-encryption-dialog.component.html new file mode 100644 index 00000000000..3fe6f43a052 --- /dev/null +++ b/apps/web/src/app/auth/settings/webauthn-login-settings/enable-encryption-dialog/enable-encryption-dialog.component.html @@ -0,0 +1,34 @@ +
+ + {{ "enablePasskeyEncryption" | i18n }} + {{ + credential.name + }} + + + + + + + +

{{ "useForVaultEncryptionInfo" | i18n }}

+ + + + +
+
+ + + + +
+
diff --git a/apps/web/src/app/auth/settings/webauthn-login-settings/enable-encryption-dialog/enable-encryption-dialog.component.ts b/apps/web/src/app/auth/settings/webauthn-login-settings/enable-encryption-dialog/enable-encryption-dialog.component.ts new file mode 100644 index 00000000000..741b71abcf2 --- /dev/null +++ b/apps/web/src/app/auth/settings/webauthn-login-settings/enable-encryption-dialog/enable-encryption-dialog.component.ts @@ -0,0 +1,91 @@ +import { DIALOG_DATA, DialogConfig, DialogRef } from "@angular/cdk/dialog"; +import { Component, Inject, OnDestroy, OnInit } from "@angular/core"; +import { FormBuilder, Validators } from "@angular/forms"; +import { Subject } from "rxjs"; +import { takeUntil } from "rxjs/operators"; + +import { WebAuthnLoginServiceAbstraction } from "@bitwarden/common/auth/abstractions/webauthn/webauthn-login.service.abstraction"; +import { WebAuthnLoginCredentialAssertionOptionsView } from "@bitwarden/common/auth/models/view/webauthn-login/webauthn-login-credential-assertion-options.view"; +import { Verification } from "@bitwarden/common/auth/types/verification"; +import { ErrorResponse } from "@bitwarden/common/models/response/error.response"; +import { DialogService } from "@bitwarden/components/src/dialog/dialog.service"; + +import { WebauthnLoginAdminService } from "../../../core/services/webauthn-login/webauthn-login-admin.service"; +import { WebauthnLoginCredentialView } from "../../../core/views/webauthn-login-credential.view"; + +export interface EnableEncryptionDialogParams { + credentialId: string; +} + +@Component({ + templateUrl: "enable-encryption-dialog.component.html", +}) +export class EnableEncryptionDialogComponent implements OnInit, OnDestroy { + private destroy$ = new Subject(); + + protected invalidSecret = false; + protected formGroup = this.formBuilder.group({ + userVerification: this.formBuilder.group({ + secret: [null as Verification | null, Validators.required], + }), + }); + + protected credential?: WebauthnLoginCredentialView; + protected credentialOptions?: WebAuthnLoginCredentialAssertionOptionsView; + protected loading$ = this.webauthnService.loading$; + + constructor( + @Inject(DIALOG_DATA) private params: EnableEncryptionDialogParams, + private formBuilder: FormBuilder, + private dialogRef: DialogRef, + private webauthnService: WebauthnLoginAdminService, + private webauthnLoginService: WebAuthnLoginServiceAbstraction, + ) {} + + ngOnInit(): void { + this.webauthnService + .getCredential$(this.params.credentialId) + .pipe(takeUntil(this.destroy$)) + .subscribe((credential: any) => (this.credential = credential)); + } + + submit = async () => { + if (this.credential === undefined) { + return; + } + + this.dialogRef.disableClose = true; + try { + this.credentialOptions = await this.webauthnService.getCredentialAssertOptions( + this.formGroup.value.userVerification.secret, + ); + await this.webauthnService.enableCredentialEncryption( + await this.webauthnLoginService.assertCredential(this.credentialOptions), + ); + } catch (error) { + if (error instanceof ErrorResponse && error.statusCode === 400) { + this.invalidSecret = true; + } + throw error; + } + + this.dialogRef.close(); + }; + + ngOnDestroy(): void { + this.destroy$.next(); + this.destroy$.complete(); + } +} + +/** + * Strongly typed helper to open a EnableEncryptionDialogComponent + * @param dialogService Instance of the dialog service that will be used to open the dialog + * @param config Configuration for the dialog + */ +export const openEnableCredentialDialogComponent = ( + dialogService: DialogService, + config: DialogConfig, +) => { + return dialogService.open(EnableEncryptionDialogComponent, config); +}; diff --git a/apps/web/src/app/auth/settings/webauthn-login-settings/webauthn-login-settings.component.html b/apps/web/src/app/auth/settings/webauthn-login-settings/webauthn-login-settings.component.html index dc55be99f1b..968b8565a6f 100644 --- a/apps/web/src/app/auth/settings/webauthn-login-settings/webauthn-login-settings.component.html +++ b/apps/web/src/app/auth/settings/webauthn-login-settings/webauthn-login-settings.component.html @@ -39,8 +39,16 @@ {{ "usedForEncryption" | i18n }} - - {{ "encryptionNotEnabled" | i18n }} + (key: string, obj: T, options?: StorageOptions): Promise { - this.mock.save(key, options); + this.mock.save(key, obj, options); this.store[key] = obj; this.updatesSubject.next({ key: key, updateType: "save" }); return Promise.resolve(); diff --git a/libs/common/spec/utils.ts b/libs/common/spec/utils.ts index 5053a71c874..ad5907f61d3 100644 --- a/libs/common/spec/utils.ts +++ b/libs/common/spec/utils.ts @@ -69,6 +69,10 @@ export function trackEmissions(observable: Observable): T[] { case "boolean": emissions.push(value); break; + case "symbol": + // Cheating types to make symbols work at all + emissions.push(value.toString() as T); + break; default: { emissions.push(clone(value)); } @@ -85,7 +89,7 @@ function clone(value: any): any { } } -export async function awaitAsync(ms = 0) { +export async function awaitAsync(ms = 1) { if (ms < 1) { await Promise.resolve(); } else { diff --git a/libs/common/src/platform/state/implementations/default-active-user-state.spec.ts b/libs/common/src/platform/state/implementations/default-active-user-state.spec.ts index 065f7a8e959..64c1d1b233f 100644 --- a/libs/common/src/platform/state/implementations/default-active-user-state.spec.ts +++ b/libs/common/src/platform/state/implementations/default-active-user-state.spec.ts @@ -2,7 +2,7 @@ * need to update test environment so trackEmissions works appropriately * @jest-environment ../shared/test.environment.ts */ -import { any, mock } from "jest-mock-extended"; +import { any, anySymbol, mock } from "jest-mock-extended"; import { BehaviorSubject, firstValueFrom, of, timeout } from "rxjs"; import { Jsonify } from "type-fest"; @@ -11,7 +11,7 @@ import { FakeStorageService } from "../../../../spec/fake-storage.service"; import { AccountInfo, AccountService } from "../../../auth/abstractions/account.service"; import { AuthenticationStatus } from "../../../auth/enums/authentication-status"; import { UserId } from "../../../types/guid"; -import { KeyDefinition } from "../key-definition"; +import { KeyDefinition, userKeyBuilder } from "../key-definition"; import { StateDefinition } from "../state-definition"; import { DefaultActiveUserState } from "./default-active-user-state"; @@ -32,9 +32,10 @@ class TestState { } const testStateDefinition = new StateDefinition("fake", "disk"); - +const cleanupDelayMs = 10; const testKeyDefinition = new KeyDefinition(testStateDefinition, "fake", { deserializer: TestState.fromJSON, + cleanupDelayMs, }); describe("DefaultActiveUserState", () => { @@ -56,10 +57,14 @@ describe("DefaultActiveUserState", () => { ); }); + const makeUserId = (id: string) => { + return id != null ? (`00000000-0000-1000-a000-00000000000${id}` as UserId) : undefined; + }; + const changeActiveUser = async (id: string) => { - const userId = id != null ? `00000000-0000-1000-a000-00000000000${id}` : undefined; + const userId = makeUserId(id); activeAccountSubject.next({ - id: userId as UserId, + id: userId, email: `test${id}@example.com`, name: `Test User ${id}`, status: AuthenticationStatus.Unlocked, @@ -90,7 +95,7 @@ describe("DefaultActiveUserState", () => { const emissions = trackEmissions(userState.state$); // User signs in - changeActiveUser("1"); + await changeActiveUser("1"); await awaitAsync(); // Service does an update @@ -111,17 +116,17 @@ describe("DefaultActiveUserState", () => { expect(diskStorageService.mock.get).toHaveBeenNthCalledWith( 1, "user_00000000-0000-1000-a000-000000000001_fake_fake", - any(), + any(), // options ); expect(diskStorageService.mock.get).toHaveBeenNthCalledWith( 2, "user_00000000-0000-1000-a000-000000000001_fake_fake", - any(), + any(), // options ); expect(diskStorageService.mock.get).toHaveBeenNthCalledWith( 3, "user_00000000-0000-1000-a000-000000000002_fake_fake", - any(), + any(), // options ); // Should only have saved data for the first user @@ -129,7 +134,8 @@ describe("DefaultActiveUserState", () => { expect(diskStorageService.mock.save).toHaveBeenNthCalledWith( 1, "user_00000000-0000-1000-a000-000000000001_fake_fake", - any(), + updatedState, + any(), // options ); }); @@ -183,15 +189,17 @@ describe("DefaultActiveUserState", () => { }); it("should not emit a previous users value if that user is no longer active", async () => { + const user1Data: Jsonify = { + date: "2020-09-21T13:14:17.648Z", + array: ["value"], + }; + const user2Data: Jsonify = { + date: "2020-09-21T13:14:17.648Z", + array: [], + }; diskStorageService.internalUpdateStore({ - "user_00000000-0000-1000-a000-000000000001_fake_fake": { - date: "2020-09-21T13:14:17.648Z", - array: ["value"], - } as Jsonify, - "user_00000000-0000-1000-a000-000000000002_fake_fake": { - date: "2020-09-21T13:14:17.648Z", - array: [], - } as Jsonify, + "user_00000000-0000-1000-a000-000000000001_fake_fake": user1Data, + "user_00000000-0000-1000-a000-000000000002_fake_fake": user2Data, }); // This starts one subscription on the observable for tracking emissions throughout @@ -203,7 +211,7 @@ describe("DefaultActiveUserState", () => { // This should always return a value right await const value = await firstValueFrom(userState.state$); - expect(value).toBeTruthy(); + expect(value).toEqual(user1Data); // Make it such that there is no active user await changeActiveUser(undefined); @@ -222,20 +230,34 @@ describe("DefaultActiveUserState", () => { rejectedError = err; }); - expect(resolvedValue).toBeFalsy(); - expect(rejectedError).toBeTruthy(); + expect(resolvedValue).toBeUndefined(); + expect(rejectedError).not.toBeUndefined(); expect(rejectedError.message).toBe("Timeout has occurred"); // We need to figure out if something should be emitted // when there becomes no active user, if we don't want that to emit // this value is correct. - expect(emissions).toHaveLength(2); + expect(emissions).toEqual([user1Data]); + }); + + it("should not emit twice if there are two listeners", async () => { + await changeActiveUser("1"); + const emissions = trackEmissions(userState.state$); + const emissions2 = trackEmissions(userState.state$); + await awaitAsync(); + + expect(emissions).toEqual([ + null, // Initial value + ]); + expect(emissions2).toEqual([ + null, // Initial value + ]); }); describe("update", () => { const newData = { date: new Date(), array: ["test"] }; beforeEach(async () => { - changeActiveUser("1"); + await changeActiveUser("1"); }); it("should save on update", async () => { @@ -315,6 +337,8 @@ describe("DefaultActiveUserState", () => { return initialData; }); + await awaitAsync(); + await userState.update((state, dependencies) => { expect(state).toEqual(initialData); return newData; @@ -329,4 +353,303 @@ describe("DefaultActiveUserState", () => { ]); }); }); + + describe("update races", () => { + const newData = { date: new Date(), array: ["test"] }; + const userId = makeUserId("1"); + + beforeEach(async () => { + await changeActiveUser("1"); + await awaitAsync(); + }); + + test("subscriptions during an update should receive the current and latest", async () => { + const oldData = { date: new Date(2019, 1, 1), array: ["oldValue1"] }; + await userState.update(() => { + return oldData; + }); + const initialData = { date: new Date(2020, 1, 1), array: ["value1", "value2"] }; + await userState.update(() => { + return initialData; + }); + + await awaitAsync(); + + const emissions = trackEmissions(userState.state$); + await awaitAsync(); + expect(emissions).toEqual([initialData]); + + let emissions2: TestState[]; + const originalSave = diskStorageService.save.bind(diskStorageService); + diskStorageService.save = jest.fn().mockImplementation(async (key: string, obj: any) => { + emissions2 = trackEmissions(userState.state$); + await originalSave(key, obj); + }); + + const val = await userState.update(() => { + return newData; + }); + + await awaitAsync(10); + + expect(val).toEqual(newData); + expect(emissions).toEqual([initialData, newData]); + expect(emissions2).toEqual([initialData, newData]); + }); + + test("subscription during an aborted update should receive the last value", async () => { + // Seed with interesting data + const initialData = { date: new Date(2020, 1, 1), array: ["value1", "value2"] }; + await userState.update(() => { + return initialData; + }); + + await awaitAsync(); + + const emissions = trackEmissions(userState.state$); + await awaitAsync(); + expect(emissions).toEqual([initialData]); + + let emissions2: TestState[]; + const val = await userState.update( + (state) => { + return newData; + }, + { + shouldUpdate: () => { + emissions2 = trackEmissions(userState.state$); + return false; + }, + }, + ); + + await awaitAsync(); + + expect(val).toEqual(initialData); + expect(emissions).toEqual([initialData]); + + expect(emissions2).toEqual([initialData]); + }); + + test("updates should wait until previous update is complete", async () => { + trackEmissions(userState.state$); + await awaitAsync(); // storage updates are behind a promise + + const originalSave = diskStorageService.save.bind(diskStorageService); + diskStorageService.save = jest + .fn() + .mockImplementationOnce(async (key: string, obj: any) => { + let resolved = false; + await Promise.race([ + userState.update(() => { + // deadlocks + resolved = true; + return newData; + }), + awaitAsync(100), // limit test to 100ms + ]); + expect(resolved).toBe(false); + }) + .mockImplementation((...args) => { + return originalSave(...args); + }); + + await userState.update(() => { + return newData; + }); + }); + + test("updates with FAKE_DEFAULT initial value should resolve correctly", async () => { + expect(userState["stateSubject"].value).toEqual(anySymbol()); // FAKE_DEFAULT + const val = await userState.update((state) => { + return newData; + }); + + expect(val).toEqual(newData); + const call = diskStorageService.mock.save.mock.calls[0]; + expect(call[0]).toEqual(`user_${userId}_fake_fake`); + expect(call[1]).toEqual(newData); + }); + + it("does not await updates if the active user changes", async () => { + const initialUserId = (await firstValueFrom(accountService.activeAccount$)).id; + expect(initialUserId).toBe(userId); + trackEmissions(userState.state$); + await awaitAsync(); // storage updates are behind a promise + + const originalSave = diskStorageService.save.bind(diskStorageService); + diskStorageService.save = jest + .fn() + .mockImplementationOnce(async (key: string, obj: any) => { + let resolved = false; + await changeActiveUser("2"); + await Promise.race([ + userState.update(() => { + // should not deadlock because we updated the user + resolved = true; + return newData; + }), + awaitAsync(100), // limit test to 100ms + ]); + expect(resolved).toBe(true); + }) + .mockImplementation((...args) => { + return originalSave(...args); + }); + + await userState.update(() => { + return newData; + }); + }); + + it("stores updates for users in the correct place when active user changes mid-update", async () => { + trackEmissions(userState.state$); + await awaitAsync(); // storage updates are behind a promise + + const user2Data = { date: new Date(), array: ["user 2 data"] }; + + const originalSave = diskStorageService.save.bind(diskStorageService); + diskStorageService.save = jest + .fn() + .mockImplementationOnce(async (key: string, obj: any) => { + let resolved = false; + await changeActiveUser("2"); + await Promise.race([ + userState.update(() => { + // should not deadlock because we updated the user + resolved = true; + return user2Data; + }), + awaitAsync(100), // limit test to 100ms + ]); + expect(resolved).toBe(true); + await originalSave(key, obj); + }) + .mockImplementation((...args) => { + return originalSave(...args); + }); + + await userState.update(() => { + return newData; + }); + await awaitAsync(); + + expect(diskStorageService.mock.save).toHaveBeenCalledTimes(2); + const innerCall = diskStorageService.mock.save.mock.calls[0]; + expect(innerCall[0]).toEqual(`user_${makeUserId("2")}_fake_fake`); + expect(innerCall[1]).toEqual(user2Data); + const outerCall = diskStorageService.mock.save.mock.calls[1]; + expect(outerCall[0]).toEqual(`user_${makeUserId("1")}_fake_fake`); + expect(outerCall[1]).toEqual(newData); + }); + }); + + describe("cleanup", () => { + const newData = { date: new Date(), array: ["test"] }; + const userId = makeUserId("1"); + let userKey: string; + + beforeEach(async () => { + await changeActiveUser("1"); + userKey = userKeyBuilder(userId, testKeyDefinition); + }); + + async function assertClean() { + const emissions = trackEmissions(userState["stateSubject"]); + const initial = structuredClone(emissions); + + diskStorageService.save(userKey, newData); + await awaitAsync(); // storage updates are behind a promise + + expect(emissions).toEqual(initial); // no longer listening to storage updates + } + + it("should cleanup after last subscriber", async () => { + const subscription = userState.state$.subscribe(); + await awaitAsync(); // storage updates are behind a promise + + subscription.unsubscribe(); + expect(userState["subscriberCount"].getValue()).toBe(0); + // Wait for cleanup + await awaitAsync(cleanupDelayMs * 2); + + await assertClean(); + }); + + it("should not cleanup if there are still subscribers", async () => { + const subscription1 = userState.state$.subscribe(); + const sub2Emissions: TestState[] = []; + const subscription2 = userState.state$.subscribe((v) => sub2Emissions.push(v)); + await awaitAsync(); // storage updates are behind a promise + + subscription1.unsubscribe(); + + // Wait for cleanup + await awaitAsync(cleanupDelayMs * 2); + + expect(userState["subscriberCount"].getValue()).toBe(1); + + // Still be listening to storage updates + diskStorageService.save(userKey, newData); + await awaitAsync(); // storage updates are behind a promise + expect(sub2Emissions).toEqual([null, newData]); + + subscription2.unsubscribe(); + // Wait for cleanup + await awaitAsync(cleanupDelayMs * 2); + + await assertClean(); + }); + + it("can re-initialize after cleanup", async () => { + const subscription = userState.state$.subscribe(); + await awaitAsync(); + + subscription.unsubscribe(); + // Wait for cleanup + await awaitAsync(cleanupDelayMs * 2); + + const emissions = trackEmissions(userState.state$); + await awaitAsync(); + + diskStorageService.save(userKey, newData); + await awaitAsync(); + + expect(emissions).toEqual([null, newData]); + }); + + it("should not cleanup if a subscriber joins during the cleanup delay", async () => { + const subscription = userState.state$.subscribe(); + await awaitAsync(); + + await diskStorageService.save(userKey, newData); + await awaitAsync(); + + subscription.unsubscribe(); + expect(userState["subscriberCount"].getValue()).toBe(0); + // Do not wait long enough for cleanup + await awaitAsync(cleanupDelayMs / 2); + + expect(userState["stateSubject"].value).toEqual(newData); // digging in to check that it hasn't been cleared + expect(userState["storageUpdateSubscription"]).not.toBeNull(); // still listening to storage updates + }); + + it("state$ observables are durable to cleanup", async () => { + const observable = userState.state$; + let subscription = observable.subscribe(); + + await diskStorageService.save(userKey, newData); + await awaitAsync(); + + subscription.unsubscribe(); + // Wait for cleanup + await awaitAsync(cleanupDelayMs * 2); + + subscription = observable.subscribe(); + await diskStorageService.save(userKey, newData); + await awaitAsync(); + + expect(await firstValueFrom(observable)).toEqual(newData); + }); + }); }); diff --git a/libs/common/src/platform/state/implementations/default-active-user-state.ts b/libs/common/src/platform/state/implementations/default-active-user-state.ts index 3d36af1d61c..02cd53cfb85 100644 --- a/libs/common/src/platform/state/implementations/default-active-user-state.ts +++ b/libs/common/src/platform/state/implementations/default-active-user-state.ts @@ -4,12 +4,12 @@ import { map, shareReplay, switchMap, - tap, - defer, firstValueFrom, combineLatestWith, filter, timeout, + Subscription, + tap, } from "rxjs"; import { AccountService } from "../../../auth/abstractions/account.service"; @@ -31,13 +31,22 @@ const FAKE_DEFAULT = Symbol("fakeDefault"); export class DefaultActiveUserState implements ActiveUserState { [activeMarker]: true; private formattedKey$: Observable; + private updatePromise: Promise | null = null; + private storageUpdateSubscription: Subscription; + private activeAccountUpdateSubscription: Subscription; + private subscriberCount = new BehaviorSubject(0); + private stateObservable: Observable; + private reinitialize = false; protected stateSubject: BehaviorSubject = new BehaviorSubject< T | typeof FAKE_DEFAULT >(FAKE_DEFAULT); private stateSubject$ = this.stateSubject.asObservable(); - state$: Observable; + get state$() { + this.stateObservable = this.stateObservable ?? this.initializeObservable(); + return this.stateObservable; + } constructor( protected keyDefinition: KeyDefinition, @@ -51,62 +60,12 @@ export class DefaultActiveUserState implements ActiveUserState { ? userKeyBuilder(account.id, this.keyDefinition) : null, ), + tap(() => { + // We have a new key, so we should forget about previous update promises + this.updatePromise = null; + }), shareReplay({ bufferSize: 1, refCount: false }), ); - - const activeAccountData$ = this.formattedKey$.pipe( - switchMap(async (key) => { - if (key == null) { - return FAKE_DEFAULT; - } - return await getStoredValue( - key, - this.chosenStorageLocation, - this.keyDefinition.deserializer, - ); - }), - // Share the execution - shareReplay({ refCount: false, bufferSize: 1 }), - ); - - const storageUpdates$ = this.chosenStorageLocation.updates$.pipe( - combineLatestWith(this.formattedKey$), - filter(([update, key]) => key !== null && update.key === key), - switchMap(async ([update, key]) => { - if (update.updateType === "remove") { - return null; - } - const data = await getStoredValue( - key, - this.chosenStorageLocation, - this.keyDefinition.deserializer, - ); - return data; - }), - ); - - // Whomever subscribes to this data, should be notified of updated data - // if someone calls my update() method, or the active user changes. - this.state$ = defer(() => { - const accountChangeSubscription = activeAccountData$.subscribe((data) => { - this.stateSubject.next(data); - }); - const storageUpdateSubscription = storageUpdates$.subscribe((data) => { - this.stateSubject.next(data); - }); - - return this.stateSubject$.pipe( - tap({ - complete: () => { - accountChangeSubscription.unsubscribe(); - storageUpdateSubscription.unsubscribe(); - }, - }), - ); - }) - // I fake the generic here because I am filtering out the other union type - // and this makes it so that typescript understands the true type - .pipe(filter((value) => value != FAKE_DEFAULT)); } async update( @@ -114,8 +73,34 @@ export class DefaultActiveUserState implements ActiveUserState { options: StateUpdateOptions = {}, ): Promise { options = populateOptionsWithDefault(options); + try { + if (this.updatePromise != null) { + await this.updatePromise; + } + this.updatePromise = this.internalUpdate(configureState, options); + const newState = await this.updatePromise; + return newState; + } finally { + this.updatePromise = null; + } + } + + // TODO: this should be removed + async getFromState(): Promise { const key = await this.createKey(); - const currentState = await this.getGuaranteedState(key); + return await getStoredValue(key, this.chosenStorageLocation, this.keyDefinition.deserializer); + } + + createDerived(converter: Converter): DerivedUserState { + return new DefaultDerivedUserState(converter, this.encryptService, this); + } + + private async internalUpdate( + configureState: (state: T, dependency: TCombine) => T, + options: StateUpdateOptions, + ) { + const key = await this.createKey(); + const currentState = await this.getStateForUpdate(key); const combinedDependencies = options.combineLatestWith != null ? await firstValueFrom(options.combineLatestWith.pipe(timeout(options.msTimeout))) @@ -130,13 +115,59 @@ export class DefaultActiveUserState implements ActiveUserState { return newState; } - async getFromState(): Promise { - const key = await this.createKey(); - return await getStoredValue(key, this.chosenStorageLocation, this.keyDefinition.deserializer); - } + private initializeObservable() { + this.storageUpdateSubscription = this.chosenStorageLocation.updates$ + .pipe( + combineLatestWith(this.formattedKey$), + filter(([update, key]) => key !== null && update.key === key), + switchMap(async ([update, key]) => { + if (update.updateType === "remove") { + return null; + } + return await this.getState(key); + }), + ) + .subscribe((v) => this.stateSubject.next(v)); - createDerived(converter: Converter): DerivedUserState { - return new DefaultDerivedUserState(converter, this.encryptService, this); + this.activeAccountUpdateSubscription = this.formattedKey$ + .pipe( + switchMap(async (key) => { + if (key == null) { + return FAKE_DEFAULT; + } + return await this.getState(key); + }), + ) + .subscribe((v) => this.stateSubject.next(v)); + + this.subscriberCount.subscribe((count) => { + if (count === 0 && this.stateObservable != null) { + this.triggerCleanup(); + } + }); + + return new Observable((subscriber) => { + this.incrementSubscribers(); + + // reinitialize listeners after cleanup + if (this.reinitialize) { + this.reinitialize = false; + this.initializeObservable(); + } + + const prevUnsubscribe = subscriber.unsubscribe.bind(subscriber); + subscriber.unsubscribe = () => { + this.decrementSubscribers(); + prevUnsubscribe(); + }; + + return this.stateSubject + .pipe( + // Filter out fake default, which is used to indicate that state is not ready to be emitted yet. + filter((i) => i !== FAKE_DEFAULT), + ) + .subscribe(subscriber); + }); } protected async createKey(): Promise { @@ -147,22 +178,47 @@ export class DefaultActiveUserState implements ActiveUserState { return formattedKey; } - protected async getGuaranteedState(key: string) { + /** For use in update methods, does not wait for update to complete before yielding state. + * The expectation is that that await is already done + */ + protected async getStateForUpdate(key: string) { const currentValue = this.stateSubject.getValue(); - return currentValue === FAKE_DEFAULT ? await this.seedInitial(key) : currentValue; + return currentValue === FAKE_DEFAULT + ? await getStoredValue(key, this.chosenStorageLocation, this.keyDefinition.deserializer) + : currentValue; } - private async seedInitial(key: string): Promise { - const value = await getStoredValue( - key, - this.chosenStorageLocation, - this.keyDefinition.deserializer, - ); - this.stateSubject.next(value); - return value; + /** To be used in observables. Awaits updates to ensure they are complete */ + private async getState(key: string): Promise { + if (this.updatePromise != null) { + await this.updatePromise; + } + return await getStoredValue(key, this.chosenStorageLocation, this.keyDefinition.deserializer); } protected saveToStorage(key: string, data: T): Promise { return this.chosenStorageLocation.save(key, data); } + + private incrementSubscribers() { + this.subscriberCount.next(this.subscriberCount.value + 1); + } + + private decrementSubscribers() { + this.subscriberCount.next(this.subscriberCount.value - 1); + } + + private triggerCleanup() { + setTimeout(() => { + if (this.subscriberCount.value === 0) { + this.updatePromise = null; + this.storageUpdateSubscription?.unsubscribe(); + this.activeAccountUpdateSubscription?.unsubscribe(); + this.subscriberCount.complete(); + this.subscriberCount = new BehaviorSubject(0); + this.stateSubject.next(FAKE_DEFAULT); + this.reinitialize = true; + } + }, this.keyDefinition.cleanupDelayMs); + } } diff --git a/libs/common/src/platform/state/implementations/default-global-state.spec.ts b/libs/common/src/platform/state/implementations/default-global-state.spec.ts index ae6cd1adbfd..35ce0fa0983 100644 --- a/libs/common/src/platform/state/implementations/default-global-state.spec.ts +++ b/libs/common/src/platform/state/implementations/default-global-state.spec.ts @@ -3,6 +3,7 @@ * @jest-environment ../shared/test.environment.ts */ +import { anySymbol } from "jest-mock-extended"; import { firstValueFrom, of } from "rxjs"; import { Jsonify } from "type-fest"; @@ -28,9 +29,10 @@ class TestState { } const testStateDefinition = new StateDefinition("fake", "disk"); - +const cleanupDelayMs = 10; const testKeyDefinition = new KeyDefinition(testStateDefinition, "fake", { deserializer: TestState.fromJSON, + cleanupDelayMs, }); const globalKey = globalKeyBuilder(testKeyDefinition); @@ -79,6 +81,19 @@ describe("DefaultGlobalState", () => { expect(diskStorageService.mock.get).toHaveBeenCalledWith("global_fake_fake", undefined); expect(state).toBeTruthy(); }); + + it("should not emit twice if there are two listeners", async () => { + const emissions = trackEmissions(globalState.state$); + const emissions2 = trackEmissions(globalState.state$); + await awaitAsync(); + + expect(emissions).toEqual([ + null, // Initial value + ]); + expect(emissions2).toEqual([ + null, // Initial value + ]); + }); }); describe("update", () => { @@ -133,6 +148,7 @@ describe("DefaultGlobalState", () => { it("should not update if shouldUpdate returns false", async () => { const emissions = trackEmissions(globalState.state$); + await awaitAsync(); // storage updates are behind a promise const result = await globalState.update( (state) => { @@ -198,4 +214,212 @@ describe("DefaultGlobalState", () => { expect(emissions).toEqual(expect.arrayContaining([initialState, newState])); }); }); + + describe("update races", () => { + test("subscriptions during an update should receive the current and latest data", async () => { + const oldData = { date: new Date(2019, 1, 1) }; + await globalState.update(() => { + return oldData; + }); + const initialData = { date: new Date(2020, 1, 1) }; + await globalState.update(() => { + return initialData; + }); + + await awaitAsync(); + + const emissions = trackEmissions(globalState.state$); + await awaitAsync(); + expect(emissions).toEqual([initialData]); + + let emissions2: TestState[]; + const originalSave = diskStorageService.save.bind(diskStorageService); + diskStorageService.save = jest.fn().mockImplementation(async (key: string, obj: any) => { + emissions2 = trackEmissions(globalState.state$); + await originalSave(key, obj); + }); + + const val = await globalState.update(() => { + return newData; + }); + + await awaitAsync(10); + + expect(val).toEqual(newData); + expect(emissions).toEqual([initialData, newData]); + expect(emissions2).toEqual([initialData, newData]); + }); + + test("subscription during an aborted update should receive the last value", async () => { + // Seed with interesting data + const initialData = { date: new Date(2020, 1, 1) }; + await globalState.update(() => { + return initialData; + }); + + await awaitAsync(); + + const emissions = trackEmissions(globalState.state$); + await awaitAsync(); + expect(emissions).toEqual([initialData]); + + let emissions2: TestState[]; + const val = await globalState.update( + () => { + return newData; + }, + { + shouldUpdate: () => { + emissions2 = trackEmissions(globalState.state$); + return false; + }, + }, + ); + + await awaitAsync(); + + expect(val).toEqual(initialData); + expect(emissions).toEqual([initialData]); + + expect(emissions2).toEqual([initialData]); + }); + + test("updates should wait until previous update is complete", async () => { + trackEmissions(globalState.state$); + await awaitAsync(); // storage updates are behind a promise + + const originalSave = diskStorageService.save.bind(diskStorageService); + diskStorageService.save = jest + .fn() + .mockImplementationOnce(async () => { + let resolved = false; + await Promise.race([ + globalState.update(() => { + // deadlocks + resolved = true; + return newData; + }), + awaitAsync(100), // limit test to 100ms + ]); + expect(resolved).toBe(false); + }) + .mockImplementation(originalSave); + + await globalState.update((state) => { + return newData; + }); + }); + + test("updates with FAKE_DEFAULT initial value should resolve correctly", async () => { + expect(globalState["stateSubject"].value).toEqual(anySymbol()); // FAKE_DEFAULT + const val = await globalState.update((state) => { + return newData; + }); + + expect(val).toEqual(newData); + const call = diskStorageService.mock.save.mock.calls[0]; + expect(call[0]).toEqual("global_fake_fake"); + expect(call[1]).toEqual(newData); + }); + }); + + describe("cleanup", () => { + async function assertClean() { + const emissions = trackEmissions(globalState["stateSubject"]); + const initial = structuredClone(emissions); + + diskStorageService.save(globalKey, newData); + await awaitAsync(); // storage updates are behind a promise + + expect(emissions).toEqual(initial); // no longer listening to storage updates + } + + it("should cleanup after last subscriber", async () => { + const subscription = globalState.state$.subscribe(); + await awaitAsync(); // storage updates are behind a promise + + subscription.unsubscribe(); + expect(globalState["subscriberCount"].getValue()).toBe(0); + // Wait for cleanup + await awaitAsync(cleanupDelayMs * 2); + + await assertClean(); + }); + + it("should not cleanup if there are still subscribers", async () => { + const subscription1 = globalState.state$.subscribe(); + const sub2Emissions: TestState[] = []; + const subscription2 = globalState.state$.subscribe((v) => sub2Emissions.push(v)); + await awaitAsync(); // storage updates are behind a promise + + subscription1.unsubscribe(); + + // Wait for cleanup + await awaitAsync(cleanupDelayMs * 2); + + expect(globalState["subscriberCount"].getValue()).toBe(1); + + // Still be listening to storage updates + diskStorageService.save(globalKey, newData); + await awaitAsync(); // storage updates are behind a promise + expect(sub2Emissions).toEqual([null, newData]); + + subscription2.unsubscribe(); + // Wait for cleanup + await awaitAsync(cleanupDelayMs * 2); + + await assertClean(); + }); + + it("can re-initialize after cleanup", async () => { + const subscription = globalState.state$.subscribe(); + await awaitAsync(); + + subscription.unsubscribe(); + // Wait for cleanup + await awaitAsync(cleanupDelayMs * 2); + + const emissions = trackEmissions(globalState.state$); + await awaitAsync(); + + diskStorageService.save(globalKey, newData); + await awaitAsync(); + + expect(emissions).toEqual([null, newData]); + }); + + it("should not cleanup if a subscriber joins during the cleanup delay", async () => { + const subscription = globalState.state$.subscribe(); + await awaitAsync(); + + await diskStorageService.save(globalKey, newData); + await awaitAsync(); + + subscription.unsubscribe(); + expect(globalState["subscriberCount"].getValue()).toBe(0); + // Do not wait long enough for cleanup + await awaitAsync(cleanupDelayMs / 2); + + expect(globalState["stateSubject"].value).toEqual(newData); // digging in to check that it hasn't been cleared + expect(globalState["storageUpdateSubscription"]).not.toBeNull(); // still listening to storage updates + }); + + it("state$ observables are durable to cleanup", async () => { + const observable = globalState.state$; + let subscription = observable.subscribe(); + + await diskStorageService.save(globalKey, newData); + await awaitAsync(); + + subscription.unsubscribe(); + // Wait for cleanup + await awaitAsync(cleanupDelayMs * 2); + + subscription = observable.subscribe(); + await diskStorageService.save(globalKey, newData); + await awaitAsync(); + + expect(await firstValueFrom(observable)).toEqual(newData); + }); + }); }); diff --git a/libs/common/src/platform/state/implementations/default-global-state.ts b/libs/common/src/platform/state/implementations/default-global-state.ts index 8e08717f721..73a3fe5d046 100644 --- a/libs/common/src/platform/state/implementations/default-global-state.ts +++ b/libs/common/src/platform/state/implementations/default-global-state.ts @@ -1,12 +1,10 @@ import { BehaviorSubject, Observable, - defer, + Subscription, filter, firstValueFrom, - shareReplay, switchMap, - tap, timeout, } from "rxjs"; @@ -23,54 +21,26 @@ const FAKE_DEFAULT = Symbol("fakeDefault"); export class DefaultGlobalState implements GlobalState { private storageKey: string; + private updatePromise: Promise | null = null; + private storageUpdateSubscription: Subscription; + private subscriberCount = new BehaviorSubject(0); + private stateObservable: Observable; + private reinitialize = false; protected stateSubject: BehaviorSubject = new BehaviorSubject< T | typeof FAKE_DEFAULT >(FAKE_DEFAULT); - state$: Observable; + get state$() { + this.stateObservable = this.stateObservable ?? this.initializeObservable(); + return this.stateObservable; + } constructor( private keyDefinition: KeyDefinition, private chosenLocation: AbstractStorageService & ObservableStorageService, ) { this.storageKey = globalKeyBuilder(this.keyDefinition); - - const storageUpdates$ = this.chosenLocation.updates$.pipe( - filter((update) => update.key === this.storageKey), - switchMap(async (update) => { - if (update.updateType === "remove") { - return null; - } - return await getStoredValue( - this.storageKey, - this.chosenLocation, - this.keyDefinition.deserializer, - ); - }), - shareReplay({ bufferSize: 1, refCount: false }), - ); - - this.state$ = defer(() => { - const storageUpdateSubscription = storageUpdates$.subscribe((value) => { - this.stateSubject.next(value); - }); - - this.getFromState().then((s) => { - this.stateSubject.next(s); - }); - - return this.stateSubject.pipe( - tap({ - complete: () => { - storageUpdateSubscription.unsubscribe(); - }, - }), - ); - }).pipe( - shareReplay({ refCount: false, bufferSize: 1 }), - filter((i) => i != FAKE_DEFAULT), - ); } async update( @@ -78,7 +48,24 @@ export class DefaultGlobalState implements GlobalState { options: StateUpdateOptions = {}, ): Promise { options = populateOptionsWithDefault(options); - const currentState = await this.getGuaranteedState(); + if (this.updatePromise != null) { + await this.updatePromise; + } + + try { + this.updatePromise = this.internalUpdate(configureState, options); + const newState = await this.updatePromise; + return newState; + } finally { + this.updatePromise = null; + } + } + + private async internalUpdate( + configureState: (state: T, dependency: TCombine) => T, + options: StateUpdateOptions, + ): Promise { + const currentState = await this.getStateForUpdate(); const combinedDependencies = options.combineLatestWith != null ? await firstValueFrom(options.combineLatestWith.pipe(timeout(options.msTimeout))) @@ -93,16 +80,94 @@ export class DefaultGlobalState implements GlobalState { return newState; } - private async getGuaranteedState() { + private initializeObservable() { + this.storageUpdateSubscription = this.chosenLocation.updates$ + .pipe( + filter((update) => update.key === this.storageKey), + switchMap(async (update) => { + if (update.updateType === "remove") { + return null; + } + return await this.getFromState(); + }), + ) + .subscribe((v) => this.stateSubject.next(v)); + + this.subscriberCount.subscribe((count) => { + if (count === 0 && this.stateObservable != null) { + this.triggerCleanup(); + } + }); + + // Intentionally un-awaited promise, we don't want to delay return of observable, but we do want to + // trigger populating it immediately. + this.getFromState().then((s) => { + this.stateSubject.next(s); + }); + + return new Observable((subscriber) => { + this.incrementSubscribers(); + + // reinitialize listeners after cleanup + if (this.reinitialize) { + this.reinitialize = false; + this.initializeObservable(); + } + + const prevUnsubscribe = subscriber.unsubscribe.bind(subscriber); + subscriber.unsubscribe = () => { + this.decrementSubscribers(); + prevUnsubscribe(); + }; + + return this.stateSubject + .pipe( + // Filter out fake default, which is used to indicate that state is not ready to be emitted yet. + filter((i) => i != FAKE_DEFAULT), + ) + .subscribe(subscriber); + }); + } + + /** For use in update methods, does not wait for update to complete before yielding state. + * The expectation is that that await is already done + */ + private async getStateForUpdate() { const currentValue = this.stateSubject.getValue(); - return currentValue === FAKE_DEFAULT ? await this.getFromState() : currentValue; + return currentValue === FAKE_DEFAULT + ? await getStoredValue(this.storageKey, this.chosenLocation, this.keyDefinition.deserializer) + : currentValue; } async getFromState(): Promise { + if (this.updatePromise != null) { + return await this.updatePromise; + } return await getStoredValue( this.storageKey, this.chosenLocation, this.keyDefinition.deserializer, ); } + + private incrementSubscribers() { + this.subscriberCount.next(this.subscriberCount.value + 1); + } + + private decrementSubscribers() { + this.subscriberCount.next(this.subscriberCount.value - 1); + } + + private triggerCleanup() { + setTimeout(() => { + if (this.subscriberCount.value === 0) { + this.updatePromise = null; + this.storageUpdateSubscription.unsubscribe(); + this.subscriberCount.complete(); + this.subscriberCount = new BehaviorSubject(0); + this.stateSubject.next(FAKE_DEFAULT); + this.reinitialize = true; + } + }, this.keyDefinition.cleanupDelayMs); + } } diff --git a/libs/common/src/platform/state/implementations/default-single-user-state.spec.ts b/libs/common/src/platform/state/implementations/default-single-user-state.spec.ts index a25ee863e6b..715b770b2ae 100644 --- a/libs/common/src/platform/state/implementations/default-single-user-state.spec.ts +++ b/libs/common/src/platform/state/implementations/default-single-user-state.spec.ts @@ -3,6 +3,7 @@ * @jest-environment ../shared/test.environment.ts */ +import { anySymbol } from "jest-mock-extended"; import { firstValueFrom, of } from "rxjs"; import { Jsonify } from "type-fest"; @@ -30,21 +31,22 @@ class TestState { } const testStateDefinition = new StateDefinition("fake", "disk"); - +const cleanupDelayMs = 10; const testKeyDefinition = new KeyDefinition(testStateDefinition, "fake", { deserializer: TestState.fromJSON, + cleanupDelayMs, }); const userId = Utils.newGuid() as UserId; const userKey = userKeyBuilder(userId, testKeyDefinition); describe("DefaultSingleUserState", () => { let diskStorageService: FakeStorageService; - let globalState: DefaultSingleUserState; + let userState: DefaultSingleUserState; const newData = { date: new Date() }; beforeEach(() => { diskStorageService = new FakeStorageService(); - globalState = new DefaultSingleUserState( + userState = new DefaultSingleUserState( userId, testKeyDefinition, null, // Not testing anything with encrypt service @@ -58,7 +60,7 @@ describe("DefaultSingleUserState", () => { describe("state$", () => { it("should emit when storage updates", async () => { - const emissions = trackEmissions(globalState.state$); + const emissions = trackEmissions(userState.state$); await diskStorageService.save(userKey, newData); await awaitAsync(); @@ -69,7 +71,7 @@ describe("DefaultSingleUserState", () => { }); it("should not emit when update key does not match", async () => { - const emissions = trackEmissions(globalState.state$); + const emissions = trackEmissions(userState.state$); await diskStorageService.save("wrong_key", newData); expect(emissions).toHaveLength(0); @@ -82,7 +84,7 @@ describe("DefaultSingleUserState", () => { }); diskStorageService.internalUpdateStore(initialStorage); - const state = await firstValueFrom(globalState.state$); + const state = await firstValueFrom(userState.state$); expect(diskStorageService.mock.get).toHaveBeenCalledTimes(1); expect(diskStorageService.mock.get).toHaveBeenCalledWith( `user_${userId}_fake_fake`, @@ -94,7 +96,7 @@ describe("DefaultSingleUserState", () => { describe("update", () => { it("should save on update", async () => { - const result = await globalState.update((state) => { + const result = await userState.update((state) => { return newData; }); @@ -103,10 +105,10 @@ describe("DefaultSingleUserState", () => { }); it("should emit once per update", async () => { - const emissions = trackEmissions(globalState.state$); + const emissions = trackEmissions(userState.state$); await awaitAsync(); // storage updates are behind a promise - await globalState.update((state) => { + await userState.update((state) => { return newData; }); @@ -119,12 +121,12 @@ describe("DefaultSingleUserState", () => { }); it("should provided combined dependencies", async () => { - const emissions = trackEmissions(globalState.state$); + const emissions = trackEmissions(userState.state$); await awaitAsync(); // storage updates are behind a promise const combinedDependencies = { date: new Date() }; - await globalState.update( + await userState.update( (state, dependencies) => { expect(dependencies).toEqual(combinedDependencies); return newData; @@ -143,9 +145,10 @@ describe("DefaultSingleUserState", () => { }); it("should not update if shouldUpdate returns false", async () => { - const emissions = trackEmissions(globalState.state$); + const emissions = trackEmissions(userState.state$); + await awaitAsync(); // storage updates are behind a promise - const result = await globalState.update( + const result = await userState.update( (state) => { return newData; }, @@ -160,18 +163,18 @@ describe("DefaultSingleUserState", () => { }); it("should provide the update callback with the current State", async () => { - const emissions = trackEmissions(globalState.state$); + const emissions = trackEmissions(userState.state$); await awaitAsync(); // storage updates are behind a promise // Seed with interesting data const initialData = { date: new Date(2020, 1, 1) }; - await globalState.update((state, dependencies) => { + await userState.update((state, dependencies) => { return initialData; }); await awaitAsync(); - await globalState.update((state) => { + await userState.update((state) => { expect(state).toEqual(initialData); return newData; }); @@ -193,14 +196,14 @@ describe("DefaultSingleUserState", () => { initialStorage[userKey] = initialState; diskStorageService.internalUpdateStore(initialStorage); - const emissions = trackEmissions(globalState.state$); + const emissions = trackEmissions(userState.state$); await awaitAsync(); // storage updates are behind a promise const newState = { ...initialState, date: new Date(initialState.date.getFullYear(), initialState.date.getMonth() + 1), }; - const actual = await globalState.update((existingState) => newState); + const actual = await userState.update((existingState) => newState); await awaitAsync(); @@ -209,4 +212,212 @@ describe("DefaultSingleUserState", () => { expect(emissions).toEqual(expect.arrayContaining([initialState, newState])); }); }); + + describe("update races", () => { + test("subscriptions during an update should receive the current and latest data", async () => { + const oldData = { date: new Date(2019, 1, 1) }; + await userState.update(() => { + return oldData; + }); + const initialData = { date: new Date(2020, 1, 1) }; + await userState.update(() => { + return initialData; + }); + + await awaitAsync(); + + const emissions = trackEmissions(userState.state$); + await awaitAsync(); + expect(emissions).toEqual([initialData]); + + let emissions2: TestState[]; + const originalSave = diskStorageService.save.bind(diskStorageService); + diskStorageService.save = jest.fn().mockImplementation(async (key: string, obj: any) => { + emissions2 = trackEmissions(userState.state$); + await originalSave(key, obj); + }); + + const val = await userState.update(() => { + return newData; + }); + + await awaitAsync(10); + + expect(val).toEqual(newData); + expect(emissions).toEqual([initialData, newData]); + expect(emissions2).toEqual([initialData, newData]); + }); + + test("subscription during an aborted update should receive the last value", async () => { + // Seed with interesting data + const initialData = { date: new Date(2020, 1, 1) }; + await userState.update(() => { + return initialData; + }); + + await awaitAsync(); + + const emissions = trackEmissions(userState.state$); + await awaitAsync(); + expect(emissions).toEqual([initialData]); + + let emissions2: TestState[]; + const val = await userState.update( + (state) => { + return newData; + }, + { + shouldUpdate: () => { + emissions2 = trackEmissions(userState.state$); + return false; + }, + }, + ); + + await awaitAsync(); + + expect(val).toEqual(initialData); + expect(emissions).toEqual([initialData]); + + expect(emissions2).toEqual([initialData]); + }); + + test("updates should wait until previous update is complete", async () => { + trackEmissions(userState.state$); + await awaitAsync(); // storage updates are behind a promise + + const originalSave = diskStorageService.save.bind(diskStorageService); + diskStorageService.save = jest + .fn() + .mockImplementationOnce(async () => { + let resolved = false; + await Promise.race([ + userState.update(() => { + // deadlocks + resolved = true; + return newData; + }), + awaitAsync(100), // limit test to 100ms + ]); + expect(resolved).toBe(false); + }) + .mockImplementation(originalSave); + + await userState.update((state) => { + return newData; + }); + }); + + test("updates with FAKE_DEFAULT initial value should resolve correctly", async () => { + expect(userState["stateSubject"].value).toEqual(anySymbol()); // FAKE_DEFAULT + const val = await userState.update((state) => { + return newData; + }); + + expect(val).toEqual(newData); + const call = diskStorageService.mock.save.mock.calls[0]; + expect(call[0]).toEqual(`user_${userId}_fake_fake`); + expect(call[1]).toEqual(newData); + }); + }); + + describe("cleanup", () => { + async function assertClean() { + const emissions = trackEmissions(userState["stateSubject"]); + const initial = structuredClone(emissions); + + diskStorageService.save(userKey, newData); + await awaitAsync(); // storage updates are behind a promise + + expect(emissions).toEqual(initial); // no longer listening to storage updates + } + + it("should cleanup after last subscriber", async () => { + const subscription = userState.state$.subscribe(); + await awaitAsync(); // storage updates are behind a promise + + subscription.unsubscribe(); + expect(userState["subscriberCount"].getValue()).toBe(0); + // Wait for cleanup + await awaitAsync(cleanupDelayMs * 2); + + await assertClean(); + }); + + it("should not cleanup if there are still subscribers", async () => { + const subscription1 = userState.state$.subscribe(); + const sub2Emissions: TestState[] = []; + const subscription2 = userState.state$.subscribe((v) => sub2Emissions.push(v)); + await awaitAsync(); // storage updates are behind a promise + + subscription1.unsubscribe(); + + // Wait for cleanup + await awaitAsync(cleanupDelayMs * 2); + + expect(userState["subscriberCount"].getValue()).toBe(1); + + // Still be listening to storage updates + diskStorageService.save(userKey, newData); + await awaitAsync(); // storage updates are behind a promise + expect(sub2Emissions).toEqual([null, newData]); + + subscription2.unsubscribe(); + // Wait for cleanup + await awaitAsync(cleanupDelayMs * 2); + + await assertClean(); + }); + + it("can re-initialize after cleanup", async () => { + const subscription = userState.state$.subscribe(); + await awaitAsync(); + + subscription.unsubscribe(); + // Wait for cleanup + await awaitAsync(cleanupDelayMs * 2); + + const emissions = trackEmissions(userState.state$); + await awaitAsync(); + + diskStorageService.save(userKey, newData); + await awaitAsync(); + + expect(emissions).toEqual([null, newData]); + }); + + it("should not cleanup if a subscriber joins during the cleanup delay", async () => { + const subscription = userState.state$.subscribe(); + await awaitAsync(); + + await diskStorageService.save(userKey, newData); + await awaitAsync(); + + subscription.unsubscribe(); + expect(userState["subscriberCount"].getValue()).toBe(0); + // Do not wait long enough for cleanup + await awaitAsync(cleanupDelayMs / 2); + + expect(userState["stateSubject"].value).toEqual(newData); // digging in to check that it hasn't been cleared + expect(userState["storageUpdateSubscription"]).not.toBeNull(); // still listening to storage updates + }); + + it("state$ observables are durable to cleanup", async () => { + const observable = userState.state$; + let subscription = observable.subscribe(); + + await diskStorageService.save(userKey, newData); + await awaitAsync(); + + subscription.unsubscribe(); + // Wait for cleanup + await awaitAsync(cleanupDelayMs * 2); + + subscription = observable.subscribe(); + await diskStorageService.save(userKey, newData); + await awaitAsync(); + + expect(await firstValueFrom(observable)).toEqual(newData); + }); + }); }); diff --git a/libs/common/src/platform/state/implementations/default-single-user-state.ts b/libs/common/src/platform/state/implementations/default-single-user-state.ts index 46fa00ffb35..4c7c70d4267 100644 --- a/libs/common/src/platform/state/implementations/default-single-user-state.ts +++ b/libs/common/src/platform/state/implementations/default-single-user-state.ts @@ -1,12 +1,10 @@ import { BehaviorSubject, Observable, - defer, + Subscription, filter, firstValueFrom, - shareReplay, switchMap, - tap, timeout, } from "rxjs"; @@ -23,16 +21,25 @@ import { Converter, SingleUserState } from "../user-state"; import { DefaultDerivedUserState } from "./default-derived-state"; import { getStoredValue } from "./util"; + const FAKE_DEFAULT = Symbol("fakeDefault"); export class DefaultSingleUserState implements SingleUserState { private storageKey: string; + private updatePromise: Promise | null = null; + private storageUpdateSubscription: Subscription; + private subscriberCount = new BehaviorSubject(0); + private stateObservable: Observable; + private reinitialize = false; protected stateSubject: BehaviorSubject = new BehaviorSubject< T | typeof FAKE_DEFAULT >(FAKE_DEFAULT); - state$: Observable; + get state$() { + this.stateObservable = this.stateObservable ?? this.initializeObservable(); + return this.stateObservable; + } constructor( readonly userId: UserId, @@ -41,42 +48,6 @@ export class DefaultSingleUserState implements SingleUserState { private chosenLocation: AbstractStorageService & ObservableStorageService, ) { this.storageKey = userKeyBuilder(this.userId, this.keyDefinition); - - const storageUpdates$ = this.chosenLocation.updates$.pipe( - filter((update) => update.key === this.storageKey), - switchMap(async (update) => { - if (update.updateType === "remove") { - return null; - } - return await getStoredValue( - this.storageKey, - this.chosenLocation, - this.keyDefinition.deserializer, - ); - }), - shareReplay({ bufferSize: 1, refCount: false }), - ); - - this.state$ = defer(() => { - const storageUpdateSubscription = storageUpdates$.subscribe((value) => { - this.stateSubject.next(value); - }); - - this.getFromState().then((s) => { - this.stateSubject.next(s); - }); - - return this.stateSubject.pipe( - tap({ - complete: () => { - storageUpdateSubscription.unsubscribe(); - }, - }), - ); - }).pipe( - shareReplay({ refCount: false, bufferSize: 1 }), - filter((i) => i != FAKE_DEFAULT), - ); } async update( @@ -84,7 +55,28 @@ export class DefaultSingleUserState implements SingleUserState { options: StateUpdateOptions = {}, ): Promise { options = populateOptionsWithDefault(options); - const currentState = await this.getGuaranteedState(); + if (this.updatePromise != null) { + await this.updatePromise; + } + + try { + this.updatePromise = this.internalUpdate(configureState, options); + const newState = await this.updatePromise; + return newState; + } finally { + this.updatePromise = null; + } + } + + createDerived(converter: Converter): DerivedUserState { + return new DefaultDerivedUserState(converter, this.encryptService, this); + } + + private async internalUpdate( + configureState: (state: T, dependency: TCombine) => T, + options: StateUpdateOptions, + ): Promise { + const currentState = await this.getStateForUpdate(); const combinedDependencies = options.combineLatestWith != null ? await firstValueFrom(options.combineLatestWith.pipe(timeout(options.msTimeout))) @@ -99,20 +91,94 @@ export class DefaultSingleUserState implements SingleUserState { return newState; } - createDerived(converter: Converter): DerivedUserState { - return new DefaultDerivedUserState(converter, this.encryptService, this); + private initializeObservable() { + this.storageUpdateSubscription = this.chosenLocation.updates$ + .pipe( + filter((update) => update.key === this.storageKey), + switchMap(async (update) => { + if (update.updateType === "remove") { + return null; + } + return await this.getFromState(); + }), + ) + .subscribe((v) => this.stateSubject.next(v)); + + this.subscriberCount.subscribe((count) => { + if (count === 0 && this.stateObservable != null) { + this.triggerCleanup(); + } + }); + + // Intentionally un-awaited promise, we don't want to delay return of observable, but we do want to + // trigger populating it immediately. + this.getFromState().then((s) => { + this.stateSubject.next(s); + }); + + return new Observable((subscriber) => { + this.incrementSubscribers(); + + // reinitialize listeners after cleanup + if (this.reinitialize) { + this.reinitialize = false; + this.initializeObservable(); + } + + const prevUnsubscribe = subscriber.unsubscribe.bind(subscriber); + subscriber.unsubscribe = () => { + this.decrementSubscribers(); + prevUnsubscribe(); + }; + + return this.stateSubject + .pipe( + // Filter out fake default, which is used to indicate that state is not ready to be emitted yet. + filter((i) => i != FAKE_DEFAULT), + ) + .subscribe(subscriber); + }); } - private async getGuaranteedState() { + /** For use in update methods, does not wait for update to complete before yielding state. + * The expectation is that that await is already done + */ + private async getStateForUpdate() { const currentValue = this.stateSubject.getValue(); - return currentValue === FAKE_DEFAULT ? await this.getFromState() : currentValue; + return currentValue === FAKE_DEFAULT + ? await getStoredValue(this.storageKey, this.chosenLocation, this.keyDefinition.deserializer) + : currentValue; } async getFromState(): Promise { + if (this.updatePromise != null) { + return await this.updatePromise; + } return await getStoredValue( this.storageKey, this.chosenLocation, this.keyDefinition.deserializer, ); } + + private incrementSubscribers() { + this.subscriberCount.next(this.subscriberCount.value + 1); + } + + private decrementSubscribers() { + this.subscriberCount.next(this.subscriberCount.value - 1); + } + + private triggerCleanup() { + setTimeout(() => { + if (this.subscriberCount.value === 0) { + this.updatePromise = null; + this.storageUpdateSubscription.unsubscribe(); + this.subscriberCount.complete(); + this.subscriberCount = new BehaviorSubject(0); + this.stateSubject.next(FAKE_DEFAULT); + this.reinitialize = true; + } + }, this.keyDefinition.cleanupDelayMs); + } } diff --git a/libs/common/src/platform/state/key-definition.spec.ts b/libs/common/src/platform/state/key-definition.spec.ts index cbb1e49a9a1..ee926bccd8e 100644 --- a/libs/common/src/platform/state/key-definition.spec.ts +++ b/libs/common/src/platform/state/key-definition.spec.ts @@ -18,6 +18,37 @@ describe("KeyDefinition", () => { }); }); + describe("cleanupDelayMs", () => { + it("defaults to 1000ms", () => { + const keyDefinition = new KeyDefinition(fakeStateDefinition, "fake", { + deserializer: (value) => value, + }); + + expect(keyDefinition).toBeTruthy(); + expect(keyDefinition.cleanupDelayMs).toBe(1000); + }); + + it("can be overridden", () => { + const keyDefinition = new KeyDefinition(fakeStateDefinition, "fake", { + deserializer: (value) => value, + cleanupDelayMs: 500, + }); + + expect(keyDefinition).toBeTruthy(); + expect(keyDefinition.cleanupDelayMs).toBe(500); + }); + + it.each([0, -1])("throws on 0 or negative (%s)", (testValue: number) => { + expect( + () => + new KeyDefinition(fakeStateDefinition, "fake", { + deserializer: (value) => value, + cleanupDelayMs: testValue, + }), + ).toThrow(); + }); + }); + describe("record", () => { it("runs custom deserializer for each record value", () => { const recordDefinition = KeyDefinition.record(fakeStateDefinition, "fake", { diff --git a/libs/common/src/platform/state/key-definition.ts b/libs/common/src/platform/state/key-definition.ts index db65740388e..9989bf37a24 100644 --- a/libs/common/src/platform/state/key-definition.ts +++ b/libs/common/src/platform/state/key-definition.ts @@ -19,6 +19,11 @@ type KeyDefinitionOptions = { * @returns The fully typed version of your state. */ readonly deserializer: (jsonValue: Jsonify) => T; + /** + * The number of milliseconds to wait before cleaning up the state after the last subscriber has unsubscribed. + * Defaults to 1000ms. + */ + readonly cleanupDelayMs?: number; }; /** @@ -42,8 +47,12 @@ export class KeyDefinition { private readonly options: KeyDefinitionOptions, ) { if (options.deserializer == null) { + throw new Error(`'deserializer' is a required property on key ${this.errorKeyName}`); + } + + if (options.cleanupDelayMs <= 0) { throw new Error( - `'deserializer' is a required property on key ${stateDefinition.name} > ${key}`, + `'cleanupDelayMs' must be greater than 0. Value of ${options.cleanupDelayMs} passed to key ${this.errorKeyName} `, ); } } @@ -55,6 +64,13 @@ export class KeyDefinition { return this.options.deserializer; } + /** + * Gets the number of milliseconds to wait before cleaning up the state after the last subscriber has unsubscribed. + */ + get cleanupDelayMs() { + return this.options.cleanupDelayMs < 0 ? 0 : this.options.cleanupDelayMs ?? 1000; + } + /** * Creates a {@link KeyDefinition} for state that is an array. * @param stateDefinition The state definition to be added to the KeyDefinition @@ -137,6 +153,10 @@ export class KeyDefinition { ? `${scope}_${userId}_${this.stateDefinition.name}_${this.key}` : `${scope}_${this.stateDefinition.name}_${this.key}`; } + + private get errorKeyName() { + return `${this.stateDefinition.name} > ${this.key}`; + } } export type StorageKey = Opaque; diff --git a/libs/importer/src/importers/passwordsafe-xml-importer.ts b/libs/importer/src/importers/passwordsafe-xml-importer.ts index feefafec15d..8a6e3b629c7 100644 --- a/libs/importer/src/importers/passwordsafe-xml-importer.ts +++ b/libs/importer/src/importers/passwordsafe-xml-importer.ts @@ -3,6 +3,7 @@ import { ImportResult } from "../models/import-result"; import { BaseImporter } from "./base-importer"; import { Importer } from "./importer"; +/** This is the importer for the xml format from pwsafe.org */ export class PasswordSafeXmlImporter extends BaseImporter implements Importer { parse(data: string): Promise { const result = new ImportResult(); diff --git a/libs/importer/src/models/import-options.ts b/libs/importer/src/models/import-options.ts index fc0b4ce2bb7..64546cc57b3 100644 --- a/libs/importer/src/models/import-options.ts +++ b/libs/importer/src/models/import-options.ts @@ -29,7 +29,7 @@ export const regularImportOptions = [ { id: "enpassjson", name: "Enpass (json)" }, { id: "protonpass", name: "ProtonPass (zip/json)" }, { id: "safeincloudxml", name: "SafeInCloud (xml)" }, - { id: "pwsafexml", name: "Password Safe (xml)" }, + { id: "pwsafexml", name: "Password Safe - pwsafe.org (xml)" }, { id: "stickypasswordxml", name: "Sticky Password (xml)" }, { id: "msecurecsv", name: "mSecure (csv)" }, { id: "truekeycsv", name: "True Key (csv)" }, diff --git a/package-lock.json b/package-lock.json index 1fdd961d5d0..ea8955f2adc 100644 --- a/package-lock.json +++ b/package-lock.json @@ -33,7 +33,7 @@ "argon2-browser": "1.18.0", "big-integer": "1.6.51", "bootstrap": "4.6.0", - "braintree-web-drop-in": "1.40.0", + "braintree-web-drop-in": "1.41.0", "bufferutil": "4.0.8", "chalk": "4.1.2", "commander": "7.2.0", @@ -4638,9 +4638,9 @@ } }, "node_modules/@braintree/browser-detection": { - "version": "1.14.0", - "resolved": "https://registry.npmjs.org/@braintree/browser-detection/-/browser-detection-1.14.0.tgz", - "integrity": "sha512-OsqU+28RhNvSw8Y5JEiUHUrAyn4OpYazFkjSJe8ZVZfkAaRXQc6hsV38MMEpIlkPMig+A68buk/diY+0O8/dMQ==" + "version": "1.17.1", + "resolved": "https://registry.npmjs.org/@braintree/browser-detection/-/browser-detection-1.17.1.tgz", + "integrity": "sha512-Mk7jauyp9pD14BTRS7otoy9dqIJGb3Oy0XtxKM/adGD9i9MAuCjH5uRZMyW2iVmJQTaA/PLlWdG7eSDyMWMc8Q==" }, "node_modules/@braintree/event-emitter": { "version": "0.4.1", @@ -4658,9 +4658,9 @@ "integrity": "sha512-tVpr7U6u6bqeQlHreEjYMNtnHX62vLnNWziY2kQLqkWhvusPuY5DfuGEIPpWqsd+V/a1slyTQaxK6HWTlH6A/Q==" }, "node_modules/@braintree/sanitize-url": { - "version": "6.0.2", - "resolved": "https://registry.npmjs.org/@braintree/sanitize-url/-/sanitize-url-6.0.2.tgz", - "integrity": "sha512-Tbsj02wXCbqGmzdnXNk0SOF19ChhRU70BsroIi4Pm6Ehp56in6vch94mfbdQ17DozxkL3BAVjbZ4Qc1a0HFRAg==" + "version": "6.0.4", + "resolved": "https://registry.npmjs.org/@braintree/sanitize-url/-/sanitize-url-6.0.4.tgz", + "integrity": "sha512-s3jaWicZd0pkP0jf5ysyHUI/RE7MHos6qlToFcGWXVp+ykHOy77OUMrfbgJ9it2C5bow7OIQwYYaHjk9XlBQ2A==" }, "node_modules/@braintree/uuid": { "version": "0.1.0", @@ -17244,16 +17244,16 @@ } }, "node_modules/braintree-web": { - "version": "3.96.1", - "resolved": "https://registry.npmjs.org/braintree-web/-/braintree-web-3.96.1.tgz", - "integrity": "sha512-e483TQkRmcO5zFH+pMGUokFWq5Q+Jyn9TmUIDofPpul1P+xOciwNqvekl/Ku6MsJBT1+kyH0cndBYMZVJpbrtg==", + "version": "3.97.4", + "resolved": "https://registry.npmjs.org/braintree-web/-/braintree-web-3.97.4.tgz", + "integrity": "sha512-w//M/ZI/MhjaxUwpICwZO50uTLF/L3WGLN4tFCPh/Xw20jDw8UBiM0Gzquq7gmwcQ1BgNnAAaYlR94HcSmt/Cg==", "dependencies": { "@braintree/asset-loader": "0.4.4", - "@braintree/browser-detection": "1.14.0", + "@braintree/browser-detection": "1.17.1", "@braintree/event-emitter": "0.4.1", "@braintree/extended-promise": "0.4.1", "@braintree/iframer": "1.1.0", - "@braintree/sanitize-url": "6.0.2", + "@braintree/sanitize-url": "6.0.4", "@braintree/uuid": "0.1.0", "@braintree/wrap-promise": "2.1.0", "card-validator": "8.1.1", @@ -17265,16 +17265,16 @@ } }, "node_modules/braintree-web-drop-in": { - "version": "1.40.0", - "resolved": "https://registry.npmjs.org/braintree-web-drop-in/-/braintree-web-drop-in-1.40.0.tgz", - "integrity": "sha512-YrR+KitYu0X6+qkP2hvkjAolMJw3jKyd+Yppk+LZUFk1Bbfj7YayLjTQqLTyFT5MQmzpQKLYwbI75BBClPGc+Q==", + "version": "1.41.0", + "resolved": "https://registry.npmjs.org/braintree-web-drop-in/-/braintree-web-drop-in-1.41.0.tgz", + "integrity": "sha512-cpFY13iyoPNCTIOU7dipHmOvoblUtYFuA7ADAm0DUPk6oqxFz4EIr94R0Yg2rCabvjeauINDf01Y2d7/E1IaXg==", "dependencies": { "@braintree/asset-loader": "0.4.4", - "@braintree/browser-detection": "1.14.0", + "@braintree/browser-detection": "1.17.1", "@braintree/event-emitter": "0.4.1", "@braintree/uuid": "0.1.0", "@braintree/wrap-promise": "2.1.0", - "braintree-web": "3.96.1" + "braintree-web": "3.97.4" } }, "node_modules/braintree-web/node_modules/promise-polyfill": { diff --git a/package.json b/package.json index 586612bcb9f..ecba301c142 100644 --- a/package.json +++ b/package.json @@ -165,7 +165,7 @@ "argon2-browser": "1.18.0", "big-integer": "1.6.51", "bootstrap": "4.6.0", - "braintree-web-drop-in": "1.40.0", + "braintree-web-drop-in": "1.41.0", "bufferutil": "4.0.8", "chalk": "4.1.2", "commander": "7.2.0",