diff --git a/apps/browser/src/auth/popup/account-switching/account-switcher.component.ts b/apps/browser/src/auth/popup/account-switching/account-switcher.component.ts index 8d4777c30d8..ec035d2a3c1 100644 --- a/apps/browser/src/auth/popup/account-switching/account-switcher.component.ts +++ b/apps/browser/src/auth/popup/account-switching/account-switcher.component.ts @@ -1,13 +1,18 @@ import { Component } from "@angular/core"; import { Router } from "@angular/router"; +import { BrowserRouterService } from "../../../platform/popup/services/browser-router.service"; import { AccountSwitcherService } from "../services/account-switcher.service"; @Component({ templateUrl: "account-switcher.component.html", }) export class AccountSwitcherComponent { - constructor(private accountSwitcherService: AccountSwitcherService, private router: Router) {} + constructor( + private accountSwitcherService: AccountSwitcherService, + private router: Router, + private routerService: BrowserRouterService + ) {} get accountOptions$() { return this.accountSwitcherService.accountOptions$; @@ -15,6 +20,6 @@ export class AccountSwitcherComponent { async selectAccount(id: string) { await this.accountSwitcherService.selectAccount(id); - this.router.navigate(["/home"]); + this.router.navigate([this.routerService.getPreviousUrl() ?? "/home"]); } } diff --git a/apps/browser/src/auth/popup/account-switching/current-account.component.html b/apps/browser/src/auth/popup/account-switching/current-account.component.html index bb482347e72..189ea4c736f 100644 --- a/apps/browser/src/auth/popup/account-switching/current-account.component.html +++ b/apps/browser/src/auth/popup/account-switching/current-account.component.html @@ -1,5 +1,5 @@
- +
diff --git a/apps/browser/src/auth/popup/account-switching/current-account.component.ts b/apps/browser/src/auth/popup/account-switching/current-account.component.ts index cf50ab2798d..902c80c79e6 100644 --- a/apps/browser/src/auth/popup/account-switching/current-account.component.ts +++ b/apps/browser/src/auth/popup/account-switching/current-account.component.ts @@ -1,7 +1,9 @@ import { Component } from "@angular/core"; import { Router } from "@angular/router"; +import { map } from "rxjs"; import { AccountService } from "@bitwarden/common/auth/abstractions/account.service"; +import { Utils } from "@bitwarden/common/platform/misc/utils"; @Component({ selector: "app-current-account", @@ -14,7 +16,15 @@ export class CurrentAccountComponent { return this.accountService.activeAccount$; } - currentAccountClicked() { - this.router.navigate(["/account-switcher"]); + get currentAccountName$() { + return this.currentAccount$.pipe( + map((a) => { + return Utils.isNullOrWhitespace(a.name) ? a.email : a.name; + }) + ); + } + + async currentAccountClicked() { + await this.router.navigate(["/account-switcher"]); } } diff --git a/apps/browser/src/auth/popup/services/account-switcher.service.ts b/apps/browser/src/auth/popup/services/account-switcher.service.ts index 6614ccec1d5..f5a3cbed7a4 100644 --- a/apps/browser/src/auth/popup/services/account-switcher.service.ts +++ b/apps/browser/src/auth/popup/services/account-switcher.service.ts @@ -54,7 +54,7 @@ export class AccountSwitcherService { return; } - this.accountService.switchAccount(id as UserId); + await this.accountService.switchAccount(id as UserId); this.messagingService.send("switchAccount", { userId: id }); } } diff --git a/apps/browser/src/background/main.background.ts b/apps/browser/src/background/main.background.ts index bb61bb64658..faf399aa34f 100644 --- a/apps/browser/src/background/main.background.ts +++ b/apps/browser/src/background/main.background.ts @@ -24,6 +24,8 @@ import { TokenService as TokenServiceAbstraction } from "@bitwarden/common/auth/ import { TwoFactorService as TwoFactorServiceAbstraction } from "@bitwarden/common/auth/abstractions/two-factor.service"; import { UserVerificationApiServiceAbstraction } from "@bitwarden/common/auth/abstractions/user-verification/user-verification-api.service.abstraction"; import { UserVerificationService as UserVerificationServiceAbstraction } from "@bitwarden/common/auth/abstractions/user-verification/user-verification.service.abstraction"; +import { AuthenticationStatus } from "@bitwarden/common/auth/enums/authentication-status"; +import { ForceSetPasswordReason } from "@bitwarden/common/auth/models/domain/force-set-password-reason"; import { AccountServiceImplementation } from "@bitwarden/common/auth/services/account.service"; import { AuthRequestCryptoServiceImplementation } from "@bitwarden/common/auth/services/auth-request-crypto.service.implementation"; import { AuthService } from "@bitwarden/common/auth/services/auth.service"; @@ -87,6 +89,7 @@ import { import { SendApiService } from "@bitwarden/common/tools/send/services/send-api.service"; import { SendApiService as SendApiServiceAbstraction } from "@bitwarden/common/tools/send/services/send-api.service.abstraction"; import { InternalSendService as InternalSendServiceAbstraction } from "@bitwarden/common/tools/send/services/send.service.abstraction"; +import { UserId } from "@bitwarden/common/types/guid"; import { CipherService as CipherServiceAbstraction } from "@bitwarden/common/vault/abstractions/cipher.service"; import { CollectionService as CollectionServiceAbstraction } from "@bitwarden/common/vault/abstractions/collection.service"; import { Fido2AuthenticatorService as Fido2AuthenticatorServiceAbstraction } from "@bitwarden/common/vault/abstractions/fido2/fido2-authenticator.service.abstraction"; @@ -829,6 +832,32 @@ export default class MainBackground { } } + async switchAccount(userId: UserId) { + if (userId != null) { + await this.stateService.setActiveUser(userId); + } + + const status = await this.authService.getAuthStatus(userId); + const forcePasswordReset = + (await this.stateService.getForceSetPasswordReason({ userId: userId })) != + ForceSetPasswordReason.None; + + await this.systemService.clearPendingClipboard(); + await this.notificationsService.updateConnection(false); + + if (status === AuthenticationStatus.Locked) { + this.messagingService.send("locked", { userId: userId }); + } else if (forcePasswordReset) { + this.messagingService.send("update-temp-password", { userId: userId }); + } else { + this.messagingService.send("unlocked", { userId: userId }); + await this.refreshBadge(); + await this.refreshMenu(); + await this.syncService.fullSync(false); + this.messagingService.send("switchAccountFinish", { userId: userId }); + } + } + async logout(expired: boolean, userId?: string) { await this.eventUploadService.uploadEvents(userId); @@ -849,7 +878,14 @@ export default class MainBackground { //Needs to be checked before state is cleaned const needStorageReseed = await this.needsStorageReseed(); - await this.stateService.clean({ userId: userId }); + const newActiveUser = await this.stateService.clean({ userId: userId }); + + if (newActiveUser != null) { + // we have a new active user, do not continue tearing down application + this.switchAccount(newActiveUser as UserId); + this.messagingService.send("switchAccountFinish"); + return; + } if (userId == null || userId === (await this.stateService.getUserId())) { this.searchService.clearIndex(); diff --git a/apps/browser/src/background/runtime.background.ts b/apps/browser/src/background/runtime.background.ts index a5c101fa969..cc932a4d923 100644 --- a/apps/browser/src/background/runtime.background.ts +++ b/apps/browser/src/background/runtime.background.ts @@ -294,6 +294,9 @@ export default class RuntimeBackground { } } ); + case "switchAccount": { + await this.main.switchAccount(msg.userId); + } } } diff --git a/apps/browser/src/platform/storage/background-memory-storage.service.ts b/apps/browser/src/platform/storage/background-memory-storage.service.ts index 14dadf225ee..9fb8cb71627 100644 --- a/apps/browser/src/platform/storage/background-memory-storage.service.ts +++ b/apps/browser/src/platform/storage/background-memory-storage.service.ts @@ -53,7 +53,7 @@ export class BackgroundMemoryStorageService extends MemoryStorageService { break; } case "save": - await this.save(message.key, JSON.parse(message.data as string) as unknown); + await this.save(message.key, JSON.parse((message.data as string) ?? null) as unknown); break; case "remove": await this.remove(message.key); diff --git a/apps/browser/src/platform/storage/foreground-memory-storage.service.ts b/apps/browser/src/platform/storage/foreground-memory-storage.service.ts index ea36c322082..25c59798877 100644 --- a/apps/browser/src/platform/storage/foreground-memory-storage.service.ts +++ b/apps/browser/src/platform/storage/foreground-memory-storage.service.ts @@ -78,7 +78,7 @@ export class ForegroundMemoryStorageService extends AbstractMemoryStorageService const response = firstValueFrom( this._backgroundResponses$.pipe( filter((message) => message.id === id), - map((message) => JSON.parse(message.data as string) as T) + map((message) => JSON.parse((message.data as string) ?? null) as T) ) ); diff --git a/apps/browser/src/platform/storage/memory-storage-service-interactions.spec.ts b/apps/browser/src/platform/storage/memory-storage-service-interactions.spec.ts index a09d733c6d3..f9d59b8962a 100644 --- a/apps/browser/src/platform/storage/memory-storage-service-interactions.spec.ts +++ b/apps/browser/src/platform/storage/memory-storage-service-interactions.spec.ts @@ -1,3 +1,8 @@ +/** + * need to update test environment so structuredClone works appropriately + * @jest-environment ../../libs/shared/test.environment.ts + */ + import { trackEmissions } from "@bitwarden/common/../spec/utils"; import { BackgroundMemoryStorageService } from "./background-memory-storage.service"; diff --git a/apps/browser/src/popup/app-routing.module.ts b/apps/browser/src/popup/app-routing.module.ts index 33ac4558eca..68a7a6316da 100644 --- a/apps/browser/src/popup/app-routing.module.ts +++ b/apps/browser/src/popup/app-routing.module.ts @@ -366,7 +366,7 @@ const routes: Routes = [ { path: "account-switcher", component: AccountSwitcherComponent, - data: { state: "account-switcher" }, + data: { state: "account-switcher", doNotSaveUrl: true }, }, ]; diff --git a/apps/browser/src/popup/app.component.ts b/apps/browser/src/popup/app.component.ts index 9d91bef0ba3..0a402f7c370 100644 --- a/apps/browser/src/popup/app.component.ts +++ b/apps/browser/src/popup/app.component.ts @@ -100,7 +100,10 @@ export class AppComponent implements OnInit, OnDestroy { this.changeDetectorRef.detectChanges(); } else if (msg.command === "authBlocked") { this.router.navigate(["home"]); - } else if (msg.command === "locked" && msg.userId == null) { + } else if ( + msg.command === "locked" && + (msg.userId === null || msg.userId == this.activeUserId) + ) { this.router.navigate(["lock"]); } else if (msg.command === "showDialog") { this.showDialog(msg); @@ -123,6 +126,11 @@ export class AppComponent implements OnInit, OnDestroy { this.router.navigate(["/"]); } else if (msg.command === "convertAccountToKeyConnector") { this.router.navigate(["/remove-password"]); + } else if (msg.command === "switchAccountFinish") { + // TODO: unset loading? + // this.loading = false; + } else if (msg.command == "update-temp-password") { + this.router.navigate(["/update-temp-password"]); } else { msg.webExtSender = sender; this.broadcasterService.send(msg); diff --git a/apps/desktop/src/app/layout/account-switcher.component.ts b/apps/desktop/src/app/layout/account-switcher.component.ts index b603fde14e7..ec81b86e90c 100644 --- a/apps/desktop/src/app/layout/account-switcher.component.ts +++ b/apps/desktop/src/app/layout/account-switcher.component.ts @@ -136,7 +136,7 @@ export class AccountSwitcherComponent implements OnInit, OnDestroy { this.close(); await this.stateService.setActiveUser(null); await this.stateService.setRememberedEmail(null); - this.router.navigate(["/login"]); + await this.router.navigate(["/login"]); } private async createInactiveAccounts(baseAccounts: { diff --git a/libs/common/spec/fake-state-provider.ts b/libs/common/spec/fake-state-provider.ts new file mode 100644 index 00000000000..821210d1917 --- /dev/null +++ b/libs/common/spec/fake-state-provider.ts @@ -0,0 +1,49 @@ +import { + GlobalState, + GlobalStateProvider, + KeyDefinition, + UserState, + UserStateProvider, +} from "../src/platform/state"; + +import { FakeGlobalState, FakeUserState } from "./fake-state"; + +export class FakeGlobalStateProvider implements GlobalStateProvider { + states: Map, GlobalState> = new Map(); + get(keyDefinition: KeyDefinition): GlobalState { + let result = this.states.get(keyDefinition) as GlobalState; + + if (result == null) { + result = new FakeGlobalState(); + this.states.set(keyDefinition, result); + } + return result; + } + + getFake(keyDefinition: KeyDefinition): FakeGlobalState { + const key = Array.from(this.states.keys()).find( + (k) => k.stateDefinition === keyDefinition.stateDefinition && k.key === keyDefinition.key + ); + return this.get(key) as FakeGlobalState; + } +} + +export class FakeUserStateProvider implements UserStateProvider { + states: Map, UserState> = new Map(); + get(keyDefinition: KeyDefinition): UserState { + let result = this.states.get(keyDefinition) as UserState; + + if (result == null) { + result = new FakeUserState(); + this.states.set(keyDefinition, result); + } + return result; + } + + getFake(keyDefinition: KeyDefinition): FakeUserState { + const key = Array.from(this.states.keys()).find( + (k) => k.stateDefinition === keyDefinition.stateDefinition && k.key === keyDefinition.key + ); + return this.get(key) as FakeUserState; + } +} diff --git a/libs/common/spec/fake-state.ts b/libs/common/spec/fake-state.ts new file mode 100644 index 00000000000..24a16721444 --- /dev/null +++ b/libs/common/spec/fake-state.ts @@ -0,0 +1,99 @@ +import { ReplaySubject, firstValueFrom, timeout } from "rxjs"; + +import { DerivedUserState, GlobalState, UserState } from "../src/platform/state"; +// eslint-disable-next-line import/no-restricted-paths -- using unexposed options for clean typing in test class +import { StateUpdateOptions } from "../src/platform/state/state-update-options"; +import { UserId } from "../src/types/guid"; + +const DEFAULT_TEST_OPTIONS: StateUpdateOptions = { + shouldUpdate: () => true, + combineLatestWith: null, + msTimeout: 10, +}; + +function populateOptionsWithDefault( + options: StateUpdateOptions +): StateUpdateOptions { + return { + ...DEFAULT_TEST_OPTIONS, + ...options, + }; +} + +export class FakeGlobalState implements GlobalState { + // eslint-disable-next-line rxjs/no-exposed-subjects -- exposed for testing setup + stateSubject = new ReplaySubject(1); + + update: ( + configureState: (state: T, dependency: TCombine) => T, + options?: StateUpdateOptions + ) => Promise = jest.fn(async (configureState, options) => { + options = populateOptionsWithDefault(options); + if (this.stateSubject["_buffer"].length == 0) { + // throw a more helpful not initialized error + throw new Error( + "You must initialize the state with a value before calling update. Try calling `stateSubject.next(initialState)` before calling update" + ); + } + const current = await firstValueFrom(this.state$.pipe(timeout(100))); + const combinedDependencies = + options.combineLatestWith != null + ? await firstValueFrom(options.combineLatestWith.pipe(timeout(options.msTimeout))) + : null; + if (!options.shouldUpdate(current, combinedDependencies)) { + return; + } + const newState = configureState(current, combinedDependencies); + this.stateSubject.next(newState); + return newState; + }); + + updateMock = this.update as jest.MockedFunction; + + get state$() { + return this.stateSubject.asObservable(); + } +} + +export class FakeUserState implements UserState { + // eslint-disable-next-line rxjs/no-exposed-subjects -- exposed for testing setup + stateSubject = new ReplaySubject(1); + + update: ( + configureState: (state: T, dependency: TCombine) => T, + options?: StateUpdateOptions + ) => Promise = jest.fn(async (configureState, options) => { + options = populateOptionsWithDefault(options); + const current = await firstValueFrom(this.state$.pipe(timeout(options.msTimeout))); + const combinedDependencies = + options.combineLatestWith != null + ? await firstValueFrom(options.combineLatestWith.pipe(timeout(options.msTimeout))) + : null; + if (!options.shouldUpdate(current, combinedDependencies)) { + return; + } + const newState = configureState(current, combinedDependencies); + this.stateSubject.next(newState); + return newState; + }); + + updateMock = this.update as jest.MockedFunction; + + updateFor: ( + userId: UserId, + configureState: (state: T, dependency: TCombine) => T, + options?: StateUpdateOptions + ) => Promise = jest.fn(); + + createDerived: ( + converter: (data: T, context: any) => Promise + ) => DerivedUserState = jest.fn(); + + getFromState: () => Promise = jest.fn(async () => { + return await firstValueFrom(this.state$.pipe(timeout(10))); + }); + + get state$() { + return this.stateSubject.asObservable(); + } +} diff --git a/libs/common/src/auth/abstractions/account.service.ts b/libs/common/src/auth/abstractions/account.service.ts index 2fdbfb7830f..ca9e82335f0 100644 --- a/libs/common/src/auth/abstractions/account.service.ts +++ b/libs/common/src/auth/abstractions/account.service.ts @@ -14,7 +14,7 @@ export type AccountInfo = { }; export function accountInfoEqual(a: AccountInfo, b: AccountInfo) { - return a.status == b.status && a.email == b.email && a.name == b.name; + return a?.status === b?.status && a?.email === b?.email && a?.name === b?.name; } export abstract class AccountService { @@ -27,31 +27,31 @@ export abstract class AccountService { * @param userId * @param accountData */ - abstract addAccount(userId: UserId, accountData: AccountInfo): void; + abstract addAccount(userId: UserId, accountData: AccountInfo): Promise; /** * updates the `accounts$` observable with the new preferred name for the account. * @param userId * @param name */ - abstract setAccountName(userId: UserId, name: string): void; + abstract setAccountName(userId: UserId, name: string): Promise; /** * updates the `accounts$` observable with the new email for the account. * @param userId * @param email */ - abstract setAccountEmail(userId: UserId, email: string): void; + abstract setAccountEmail(userId: UserId, email: string): Promise; /** * Updates the `accounts$` observable with the new account status. * Also emits the `accountLock$` or `accountLogout$` observable if the status is `Locked` or `LoggedOut` respectively. * @param userId * @param status */ - abstract setAccountStatus(userId: UserId, status: AuthenticationStatus): void; + abstract setAccountStatus(userId: UserId, status: AuthenticationStatus): Promise; /** * Updates the `activeAccount$` observable with the new active account. * @param userId */ - abstract switchAccount(userId: UserId): void; + abstract switchAccount(userId: UserId): Promise; } export abstract class InternalAccountService extends AccountService { diff --git a/libs/common/src/auth/services/account.service.spec.ts b/libs/common/src/auth/services/account.service.spec.ts index d6aabef4eed..0f611f83457 100644 --- a/libs/common/src/auth/services/account.service.spec.ts +++ b/libs/common/src/auth/services/account.service.spec.ts @@ -1,30 +1,28 @@ import { MockProxy, mock } from "jest-mock-extended"; -import { BehaviorSubject, firstValueFrom } from "rxjs"; +import { firstValueFrom } from "rxjs"; +import { FakeGlobalState } from "../../../spec/fake-state"; +import { FakeGlobalStateProvider } from "../../../spec/fake-state-provider"; import { trackEmissions } from "../../../spec/utils"; import { LogService } from "../../platform/abstractions/log.service"; import { MessagingService } from "../../platform/abstractions/messaging.service"; -import { - ACCOUNT_ACCOUNTS, - ACCOUNT_ACTIVE_ACCOUNT_ID, - GlobalState, - GlobalStateProvider, -} from "../../platform/state"; import { UserId } from "../../types/guid"; import { AccountInfo } from "../abstractions/account.service"; import { AuthenticationStatus } from "../enums/authentication-status"; -import { AccountServiceImplementation } from "./account.service"; +import { + ACCOUNT_ACCOUNTS, + ACCOUNT_ACTIVE_ACCOUNT_ID, + AccountServiceImplementation, +} from "./account.service"; describe("accountService", () => { let messagingService: MockProxy; let logService: MockProxy; - let globalStateProvider: MockProxy; - let accountsState: MockProxy>>; - let accountsSubject: BehaviorSubject>; - let activeAccountIdState: MockProxy>; - let activeAccountIdSubject: BehaviorSubject; + let globalStateProvider: FakeGlobalStateProvider; let sut: AccountServiceImplementation; + let accountsState: FakeGlobalState>; + let activeAccountIdState: FakeGlobalState; const userId = "userId" as UserId; function userInfo(status: AuthenticationStatus): AccountInfo { return { status, email: "email", name: "name" }; @@ -33,27 +31,14 @@ describe("accountService", () => { beforeEach(() => { messagingService = mock(); logService = mock(); - globalStateProvider = mock(); - accountsState = mock(); - activeAccountIdState = mock(); - - accountsSubject = new BehaviorSubject>(null); - accountsState.state$ = accountsSubject.asObservable(); - activeAccountIdSubject = new BehaviorSubject(null); - activeAccountIdState.state$ = activeAccountIdSubject.asObservable(); - - globalStateProvider.get.mockImplementation((keyDefinition) => { - switch (keyDefinition) { - case ACCOUNT_ACCOUNTS: - return accountsState; - case ACCOUNT_ACTIVE_ACCOUNT_ID: - return activeAccountIdState; - default: - throw new Error("Unknown key definition"); - } - }); + globalStateProvider = new FakeGlobalStateProvider(); sut = new AccountServiceImplementation(messagingService, logService, globalStateProvider); + + accountsState = globalStateProvider.getFake(ACCOUNT_ACCOUNTS); + // initialize to empty + accountsState.stateSubject.next({}); + activeAccountIdState = globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID); }); afterEach(() => { @@ -69,20 +54,17 @@ describe("accountService", () => { it("should emit the active account and status", async () => { const emissions = trackEmissions(sut.activeAccount$); - accountsSubject.next({ [userId]: userInfo(AuthenticationStatus.Unlocked) }); - activeAccountIdSubject.next(userId); + accountsState.stateSubject.next({ [userId]: userInfo(AuthenticationStatus.Unlocked) }); + activeAccountIdState.stateSubject.next(userId); - expect(emissions).toEqual([ - undefined, // initial value - { id: userId, ...userInfo(AuthenticationStatus.Unlocked) }, - ]); + expect(emissions).toEqual([{ id: userId, ...userInfo(AuthenticationStatus.Unlocked) }]); }); it("should update the status if the account status changes", async () => { - accountsSubject.next({ [userId]: userInfo(AuthenticationStatus.Unlocked) }); - activeAccountIdSubject.next(userId); + accountsState.stateSubject.next({ [userId]: userInfo(AuthenticationStatus.Unlocked) }); + activeAccountIdState.stateSubject.next(userId); const emissions = trackEmissions(sut.activeAccount$); - accountsSubject.next({ [userId]: userInfo(AuthenticationStatus.Locked) }); + accountsState.stateSubject.next({ [userId]: userInfo(AuthenticationStatus.Locked) }); expect(emissions).toEqual([ { id: userId, ...userInfo(AuthenticationStatus.Unlocked) }, @@ -91,8 +73,8 @@ describe("accountService", () => { }); it("should remember the last emitted value", async () => { - accountsSubject.next({ [userId]: userInfo(AuthenticationStatus.Unlocked) }); - activeAccountIdSubject.next(userId); + accountsState.stateSubject.next({ [userId]: userInfo(AuthenticationStatus.Unlocked) }); + activeAccountIdState.stateSubject.next(userId); expect(await firstValueFrom(sut.activeAccount$)).toEqual({ id: userId, @@ -103,83 +85,80 @@ describe("accountService", () => { describe("accounts$", () => { it("should maintain an accounts cache", async () => { - expect(await firstValueFrom(sut.accounts$)).toEqual({}); - accountsSubject.next({ [userId]: userInfo(AuthenticationStatus.Unlocked) }); + accountsState.stateSubject.next({ [userId]: userInfo(AuthenticationStatus.Unlocked) }); + accountsState.stateSubject.next({ [userId]: userInfo(AuthenticationStatus.Locked) }); expect(await firstValueFrom(sut.accounts$)).toEqual({ - [userId]: userInfo(AuthenticationStatus.Unlocked), + [userId]: userInfo(AuthenticationStatus.Locked), }); }); }); describe("addAccount", () => { - it("should emit the new account", () => { - sut.addAccount(userId, userInfo(AuthenticationStatus.Unlocked)); + it("should emit the new account", async () => { + await sut.addAccount(userId, userInfo(AuthenticationStatus.Unlocked)); + const currentValue = await firstValueFrom(sut.accounts$); - expect(accountsState.update).toHaveBeenCalledTimes(1); - const callback = accountsState.update.mock.calls[0][0]; - expect(callback({}, null)).toEqual({ [userId]: userInfo(AuthenticationStatus.Unlocked) }); + expect(currentValue).toEqual({ [userId]: userInfo(AuthenticationStatus.Unlocked) }); }); }); describe("setAccountName", () => { + const initialState = { [userId]: userInfo(AuthenticationStatus.Unlocked) }; beforeEach(() => { - accountsSubject.next({ [userId]: userInfo(AuthenticationStatus.Unlocked) }); + accountsState.stateSubject.next(initialState); }); it("should update the account", async () => { - sut.setAccountName(userId, "new name"); + await sut.setAccountName(userId, "new name"); + const currentState = await firstValueFrom(accountsState.state$); - const callback = accountsState.update.mock.calls[0][0]; - - expect(callback(accountsSubject.value, null)).toEqual({ + expect(currentState).toEqual({ [userId]: { ...userInfo(AuthenticationStatus.Unlocked), name: "new name" }, }); }); it("should not update if the name is the same", async () => { - sut.setAccountName(userId, "name"); + await sut.setAccountName(userId, "name"); + const currentState = await firstValueFrom(accountsState.state$); - const callback = accountsState.update.mock.calls[0][1].shouldUpdate; - - expect(callback(accountsSubject.value, null)).toBe(false); + expect(currentState).toEqual(initialState); }); }); describe("setAccountEmail", () => { + const initialState = { [userId]: userInfo(AuthenticationStatus.Unlocked) }; beforeEach(() => { - accountsSubject.next({ [userId]: userInfo(AuthenticationStatus.Unlocked) }); + accountsState.stateSubject.next(initialState); }); - it("should update the account", () => { - sut.setAccountEmail(userId, "new email"); + it("should update the account", async () => { + await sut.setAccountEmail(userId, "new email"); + const currentState = await firstValueFrom(accountsState.state$); - const callback = accountsState.update.mock.calls[0][0]; - - expect(callback(accountsSubject.value, null)).toEqual({ + expect(currentState).toEqual({ [userId]: { ...userInfo(AuthenticationStatus.Unlocked), email: "new email" }, }); }); - it("should not update if the email is the same", () => { - sut.setAccountEmail(userId, "email"); + it("should not update if the email is the same", async () => { + await sut.setAccountEmail(userId, "email"); + const currentState = await firstValueFrom(accountsState.state$); - const callback = accountsState.update.mock.calls[0][1].shouldUpdate; - - expect(callback(accountsSubject.value, null)).toBe(false); + expect(currentState).toEqual(initialState); }); }); describe("setAccountStatus", () => { + const initialState = { [userId]: userInfo(AuthenticationStatus.Unlocked) }; beforeEach(() => { - accountsSubject.next({ [userId]: userInfo(AuthenticationStatus.Unlocked) }); + accountsState.stateSubject.next(initialState); }); - it("should update the account", () => { - sut.setAccountStatus(userId, AuthenticationStatus.Locked); + it("should update the account", async () => { + await sut.setAccountStatus(userId, AuthenticationStatus.Locked); + const currentState = await firstValueFrom(accountsState.state$); - const callback = accountsState.update.mock.calls[0][0]; - - expect(callback(accountsSubject.value, null)).toEqual({ + expect(currentState).toEqual({ [userId]: { ...userInfo(AuthenticationStatus.Unlocked), status: AuthenticationStatus.Locked, @@ -187,24 +166,23 @@ describe("accountService", () => { }); }); - it("should not update if the status is the same", () => { - sut.setAccountStatus(userId, AuthenticationStatus.Unlocked); + it("should not update if the status is the same", async () => { + await sut.setAccountStatus(userId, AuthenticationStatus.Unlocked); + const currentState = await firstValueFrom(accountsState.state$); - const callback = accountsState.update.mock.calls[0][1].shouldUpdate; - - expect(callback(accountsSubject.value, null)).toBe(false); + expect(currentState).toEqual(initialState); }); - it("should emit logout if the status is logged out", () => { + it("should emit logout if the status is logged out", async () => { const emissions = trackEmissions(sut.accountLogout$); - sut.setAccountStatus(userId, AuthenticationStatus.LoggedOut); + await sut.setAccountStatus(userId, AuthenticationStatus.LoggedOut); expect(emissions).toEqual([userId]); }); - it("should emit lock if the status is locked", () => { + it("should emit lock if the status is locked", async () => { const emissions = trackEmissions(sut.accountLock$); - sut.setAccountStatus(userId, AuthenticationStatus.Locked); + await sut.setAccountStatus(userId, AuthenticationStatus.Locked); expect(emissions).toEqual([userId]); }); @@ -212,19 +190,18 @@ describe("accountService", () => { describe("switchAccount", () => { beforeEach(() => { - accountsSubject.next({ [userId]: userInfo(AuthenticationStatus.Unlocked) }); + accountsState.stateSubject.next({ [userId]: userInfo(AuthenticationStatus.Unlocked) }); + activeAccountIdState.stateSubject.next(userId); }); - it("should emit undefined if no account is provided", () => { - sut.switchAccount(null); - const callback = activeAccountIdState.update.mock.calls[0][0]; - expect(callback(userId, accountsSubject.value)).toBeUndefined(); + it("should emit undefined if no account is provided", async () => { + await sut.switchAccount(null); + const currentState = await firstValueFrom(sut.activeAccount$); + expect(currentState).toBeUndefined(); }); it("should throw if the account does not exist", () => { - sut.switchAccount("unknown" as UserId); - const callback = activeAccountIdState.update.mock.calls[0][0]; - expect(() => callback(userId, accountsSubject.value)).toThrowError("Account does not exist"); + expect(sut.switchAccount("unknown" as UserId)).rejects.toThrowError("Account does not exist"); }); }); }); diff --git a/libs/common/src/auth/services/account.service.ts b/libs/common/src/auth/services/account.service.ts index 5164d9cf22a..1ea9ca7bd4a 100644 --- a/libs/common/src/auth/services/account.service.ts +++ b/libs/common/src/auth/services/account.service.ts @@ -1,5 +1,4 @@ import { Subject, combineLatestWith, map, distinctUntilChanged, shareReplay } from "rxjs"; -import { Jsonify } from "type-fest"; import { AccountInfo, @@ -9,23 +8,25 @@ import { import { LogService } from "../../platform/abstractions/log.service"; import { MessagingService } from "../../platform/abstractions/messaging.service"; import { - ACCOUNT_ACCOUNTS, - ACCOUNT_ACTIVE_ACCOUNT_ID, + ACCOUNT_MEMORY, GlobalState, GlobalStateProvider, + KeyDefinition, } from "../../platform/state"; import { UserId } from "../../types/guid"; import { AuthenticationStatus } from "../enums/authentication-status"; -export function AccountsDeserializer( - accounts: Jsonify | null> -): Record { - if (accounts == null) { - return {}; +export const ACCOUNT_ACCOUNTS = KeyDefinition.record( + ACCOUNT_MEMORY, + "accounts", + { + deserializer: (accountInfo) => accountInfo, } +); - return accounts; -} +export const ACCOUNT_ACTIVE_ACCOUNT_ID = new KeyDefinition(ACCOUNT_MEMORY, "activeAccountId", { + deserializer: (id: UserId) => id, +}); export class AccountServiceImplementation implements InternalAccountService { private lock = new Subject(); @@ -52,29 +53,29 @@ export class AccountServiceImplementation implements InternalAccountService { this.activeAccount$ = this.activeAccountIdState.state$.pipe( combineLatestWith(this.accounts$), map(([id, accounts]) => (id ? { id, ...accounts[id] } : undefined)), - distinctUntilChanged(), + distinctUntilChanged((a, b) => a?.id === b?.id && accountInfoEqual(a, b)), shareReplay({ bufferSize: 1, refCount: false }) ); } - addAccount(userId: UserId, accountData: AccountInfo): void { - this.accountsState.update((accounts) => { + async addAccount(userId: UserId, accountData: AccountInfo): Promise { + await this.accountsState.update((accounts) => { accounts ||= {}; accounts[userId] = accountData; return accounts; }); } - setAccountName(userId: UserId, name: string): void { - this.setAccountInfo(userId, { name }); + async setAccountName(userId: UserId, name: string): Promise { + await this.setAccountInfo(userId, { name }); } - setAccountEmail(userId: UserId, email: string): void { - this.setAccountInfo(userId, { email }); + async setAccountEmail(userId: UserId, email: string): Promise { + await this.setAccountInfo(userId, { email }); } - setAccountStatus(userId: UserId, status: AuthenticationStatus): void { - this.setAccountInfo(userId, { status }); + async setAccountStatus(userId: UserId, status: AuthenticationStatus): Promise { + await this.setAccountInfo(userId, { status }); if (status === AuthenticationStatus.LoggedOut) { this.logout.next(userId); @@ -83,12 +84,12 @@ export class AccountServiceImplementation implements InternalAccountService { } } - switchAccount(userId: UserId) { - this.activeAccountIdState.update( + async switchAccount(userId: UserId): Promise { + await this.activeAccountIdState.update( (_, accounts) => { if (userId == null) { // indicates no account is active - return undefined; + return null; } if (accounts?.[userId] == null) { @@ -98,6 +99,10 @@ export class AccountServiceImplementation implements InternalAccountService { }, { combineLatestWith: this.accounts$, + shouldUpdate: (id) => { + // update only if userId changes + return id !== userId; + }, } ); } @@ -112,11 +117,11 @@ export class AccountServiceImplementation implements InternalAccountService { } } - private setAccountInfo(userId: UserId, update: Partial) { + private async setAccountInfo(userId: UserId, update: Partial): Promise { function newAccountInfo(oldAccountInfo: AccountInfo): AccountInfo { return { ...oldAccountInfo, ...update }; } - this.accountsState.update( + await this.accountsState.update( (accounts) => { accounts[userId] = newAccountInfo(accounts[userId]); return accounts; diff --git a/libs/common/src/platform/abstractions/state.service.ts b/libs/common/src/platform/abstractions/state.service.ts index 61e361ad6de..872daea7d83 100644 --- a/libs/common/src/platform/abstractions/state.service.ts +++ b/libs/common/src/platform/abstractions/state.service.ts @@ -17,6 +17,7 @@ import { GeneratedPasswordHistory, PasswordGeneratorOptions } from "../../tools/ import { UsernameGeneratorOptions } from "../../tools/generator/username"; import { SendData } from "../../tools/send/models/data/send.data"; import { SendView } from "../../tools/send/models/view/send.view"; +import { UserId } from "../../types/guid"; import { UriMatchType } from "../../vault/enums"; import { CipherData } from "../../vault/models/data/cipher.data"; import { CollectionData } from "../../vault/models/data/collection.data"; @@ -48,7 +49,7 @@ export abstract class StateService { addAccount: (account: T) => Promise; setActiveUser: (userId: string) => Promise; - clean: (options?: StorageOptions) => Promise; + clean: (options?: StorageOptions) => Promise; init: () => Promise; getAccessToken: (options?: StorageOptions) => Promise; diff --git a/libs/common/src/platform/services/memory-storage.service.ts b/libs/common/src/platform/services/memory-storage.service.ts index 233cb6e7cb3..7a6b9971eca 100644 --- a/libs/common/src/platform/services/memory-storage.service.ts +++ b/libs/common/src/platform/services/memory-storage.service.ts @@ -29,6 +29,9 @@ export class MemoryStorageService extends AbstractMemoryStorageService { if (obj == null) { return this.remove(key); } + // TODO: Remove once foreground/background contexts are separated in browser + // Needed to ensure ownership of all memory by the context running the storage service + obj = structuredClone(obj); this.store.set(key, obj); this.updatesSubject.next({ key, updateType: "save" }); return Promise.resolve(); diff --git a/libs/common/src/platform/services/state.service.ts b/libs/common/src/platform/services/state.service.ts index 7baaabec5f4..caea67f84e0 100644 --- a/libs/common/src/platform/services/state.service.ts +++ b/libs/common/src/platform/services/state.service.ts @@ -173,13 +173,13 @@ export class StateService< // if it's not in the accounts list. if (state.activeUserId != null && this.accountsSubject.value[state.activeUserId] == null) { const activeDiskAccount = await this.getAccountFromDisk({ userId: state.activeUserId }); - this.accountService.addAccount(state.activeUserId as UserId, { + await this.accountService.addAccount(state.activeUserId as UserId, { name: activeDiskAccount.profile.name, email: activeDiskAccount.profile.email, status: AuthenticationStatus.LoggedOut, }); } - this.accountService.switchAccount(state.activeUserId as UserId); + await this.accountService.switchAccount(state.activeUserId as UserId); // End TODO return state; @@ -198,7 +198,7 @@ export class StateService< const diskAccount = await this.getAccountFromDisk({ userId: userId }); state.accounts[userId].profile = diskAccount.profile; // TODO: Temporary update to avoid routing all account status changes through account service for now. - this.accountService.addAccount(userId as UserId, { + await this.accountService.addAccount(userId as UserId, { status: AuthenticationStatus.Locked, name: diskAccount.profile.name, email: diskAccount.profile.email, @@ -218,7 +218,7 @@ export class StateService< await this.scaffoldNewAccountStorage(account); await this.setLastActive(new Date().getTime(), { userId: account.profile.userId }); // TODO: Temporary update to avoid routing all account status changes through account service for now. - this.accountService.addAccount(account.profile.userId as UserId, { + await this.accountService.addAccount(account.profile.userId as UserId, { status: AuthenticationStatus.Locked, name: account.profile.name, email: account.profile.email, @@ -228,13 +228,13 @@ export class StateService< } async setActiveUser(userId: string): Promise { - this.clearDecryptedDataForActiveUser(); + await this.clearDecryptedDataForActiveUser(); await this.updateState(async (state) => { state.activeUserId = userId; await this.storageService.save(keys.activeUserId, userId); this.activeAccountSubject.next(state.activeUserId); // TODO: temporary update to avoid routing all account status changes through account service for now. - this.accountService.switchAccount(userId as UserId); + await this.accountService.switchAccount(userId as UserId); return state; }); @@ -242,16 +242,18 @@ export class StateService< await this.pushAccounts(); } - async clean(options?: StorageOptions): Promise { + async clean(options?: StorageOptions): Promise { options = this.reconcileOptions(options, await this.defaultInMemoryOptions()); await this.deAuthenticateAccount(options.userId); - if (options.userId === (await this.state())?.activeUserId) { - await this.dynamicallySetActiveUser(); + let currentUser = (await this.state())?.activeUserId; + if (options.userId === currentUser) { + currentUser = await this.dynamicallySetActiveUser(); } await this.removeAccountFromDisk(options?.userId); - this.removeAccountFromMemory(options?.userId); + await this.removeAccountFromMemory(options?.userId); await this.pushAccounts(); + return currentUser as UserId; } async getAccessToken(options?: StorageOptions): Promise { @@ -577,7 +579,7 @@ export class StateService< ); const nextStatus = value != null ? AuthenticationStatus.Unlocked : AuthenticationStatus.Locked; - this.accountService.setAccountStatus(options.userId as UserId, nextStatus); + await this.accountService.setAccountStatus(options.userId as UserId, nextStatus); if (options.userId == this.activeAccountSubject.getValue()) { const nextValue = value != null; @@ -613,7 +615,7 @@ export class StateService< ); const nextStatus = value != null ? AuthenticationStatus.Unlocked : AuthenticationStatus.Locked; - this.accountService.setAccountStatus(options.userId as UserId, nextStatus); + await this.accountService.setAccountStatus(options.userId as UserId, nextStatus); if (options?.userId == this.activeAccountSubject.getValue()) { const nextValue = value != null; @@ -3137,7 +3139,6 @@ export class StateService< } protected async pushAccounts(): Promise { - await this.pruneInMemoryAccounts(); await this.state().then((state) => { if (state.accounts == null || Object.keys(state.accounts).length < 1) { this.accountsSubject.next({}); @@ -3253,16 +3254,7 @@ export class StateService< return state; }); // TODO: Invert this logic, we should remove accounts based on logged out emit - this.accountService.setAccountStatus(userId as UserId, AuthenticationStatus.LoggedOut); - } - - protected async pruneInMemoryAccounts() { - // We preserve settings for logged out accounts, but we don't want to consider them when thinking about active account state - for (const userId in (await this.state())?.accounts) { - if (!(await this.getIsAuthenticated({ userId: userId }))) { - await this.removeAccountFromMemory(userId); - } - } + await this.accountService.setAccountStatus(userId as UserId, AuthenticationStatus.LoggedOut); } // settings persist even on reset, and are not affected by this method @@ -3333,18 +3325,22 @@ export class StateService< const accounts = (await this.state())?.accounts; if (accounts == null || Object.keys(accounts).length < 1) { await this.setActiveUser(null); - return; + return null; } + + let newActiveUser; for (const userId in accounts) { if (userId == null) { continue; } if (await this.getIsAuthenticated({ userId: userId })) { - await this.setActiveUser(userId); + newActiveUser = userId; break; } - await this.setActiveUser(null); + newActiveUser = null; } + await this.setActiveUser(newActiveUser); + return newActiveUser; } private async getTimeoutBasedStorageOptions(options?: StorageOptions): Promise { diff --git a/libs/common/src/platform/state/index.ts b/libs/common/src/platform/state/index.ts index bab0bec90f3..178c21e0b6b 100644 --- a/libs/common/src/platform/state/index.ts +++ b/libs/common/src/platform/state/index.ts @@ -3,5 +3,6 @@ export { GlobalState } from "./global-state"; export { GlobalStateProvider } from "./global-state.provider"; export { UserState } from "./user-state"; export { UserStateProvider } from "./user-state.provider"; +export { KeyDefinition } from "./key-definition"; -export * from "./key-definitions"; +export * from "./state-definitions"; diff --git a/libs/common/src/platform/state/key-definitions.ts b/libs/common/src/platform/state/key-definitions.ts deleted file mode 100644 index 50112137e58..00000000000 --- a/libs/common/src/platform/state/key-definitions.ts +++ /dev/null @@ -1,18 +0,0 @@ -import { AccountInfo } from "../../auth/abstractions/account.service"; -import { AccountsDeserializer } from "../../auth/services/account.service"; -import { UserId } from "../../types/guid"; - -import { KeyDefinition } from "./key-definition"; -import { StateDefinition } from "./state-definition"; - -const ACCOUNT_MEMORY = new StateDefinition("account", "memory"); -export const ACCOUNT_ACCOUNTS = new KeyDefinition>( - ACCOUNT_MEMORY, - "accounts", - { - deserializer: (obj) => AccountsDeserializer(obj), - } -); -export const ACCOUNT_ACTIVE_ACCOUNT_ID = new KeyDefinition(ACCOUNT_MEMORY, "activeAccountId", { - deserializer: (id: UserId) => id, -}); diff --git a/libs/common/src/platform/state/state-definitions.ts b/libs/common/src/platform/state/state-definitions.ts new file mode 100644 index 00000000000..4ec1a6a87fd --- /dev/null +++ b/libs/common/src/platform/state/state-definitions.ts @@ -0,0 +1,3 @@ +import { StateDefinition } from "./state-definition"; + +export const ACCOUNT_MEMORY = new StateDefinition("account", "memory");