diff --git a/apps/browser/src/background/main.background.ts b/apps/browser/src/background/main.background.ts index 75481dde8cf..35704fc289f 100644 --- a/apps/browser/src/background/main.background.ts +++ b/apps/browser/src/background/main.background.ts @@ -737,6 +737,7 @@ export default class MainBackground { this.logService, (logoutReason: LogoutReason, userId?: UserId) => this.logout(logoutReason, userId), this.vaultTimeoutSettingsService, + this.accountService, { createRequest: (url, request) => new Request(url, request) }, ); @@ -843,7 +844,7 @@ export default class MainBackground { this.tokenService, ); - this.configApiService = new ConfigApiService(this.apiService, this.tokenService); + this.configApiService = new ConfigApiService(this.apiService); this.configService = new DefaultConfigService( this.configApiService, diff --git a/apps/cli/src/platform/services/node-api.service.ts b/apps/cli/src/platform/services/node-api.service.ts index d695272364b..e6527ed3abd 100644 --- a/apps/cli/src/platform/services/node-api.service.ts +++ b/apps/cli/src/platform/services/node-api.service.ts @@ -4,6 +4,7 @@ import * as FormData from "form-data"; import { HttpsProxyAgent } from "https-proxy-agent"; import * as fe from "node-fetch"; +import { AccountService } from "@bitwarden/common/auth/abstractions/account.service"; import { TokenService } from "@bitwarden/common/auth/abstractions/token.service"; import { VaultTimeoutSettingsService } from "@bitwarden/common/key-management/vault-timeout"; import { AppIdService } from "@bitwarden/common/platform/abstractions/app-id.service"; @@ -28,6 +29,7 @@ export class NodeApiService extends ApiService { logService: LogService, logoutCallback: () => Promise, vaultTimeoutSettingsService: VaultTimeoutSettingsService, + accountService: AccountService, customUserAgent: string = null, ) { super( @@ -39,6 +41,7 @@ export class NodeApiService extends ApiService { logService, logoutCallback, vaultTimeoutSettingsService, + accountService, { createRequest: (url, request) => new Request(url, request) }, customUserAgent, ); diff --git a/apps/cli/src/service-container/service-container.ts b/apps/cli/src/service-container/service-container.ts index 508ade4650e..27fde5863de 100644 --- a/apps/cli/src/service-container/service-container.ts +++ b/apps/cli/src/service-container/service-container.ts @@ -504,12 +504,13 @@ export class ServiceContainer { this.logService, logoutCallback, this.vaultTimeoutSettingsService, + this.accountService, customUserAgent, ); this.containerService = new ContainerService(this.keyService, this.encryptService); - this.configApiService = new ConfigApiService(this.apiService, this.tokenService); + this.configApiService = new ConfigApiService(this.apiService); this.authService = new AuthService( this.accountService, diff --git a/libs/angular/src/services/jslib-services.module.ts b/libs/angular/src/services/jslib-services.module.ts index c6f8d4a3ae9..72bdd9f8b2f 100644 --- a/libs/angular/src/services/jslib-services.module.ts +++ b/libs/angular/src/services/jslib-services.module.ts @@ -752,6 +752,7 @@ const safeProviders: SafeProvider[] = [ LogService, LOGOUT_CALLBACK, VaultTimeoutSettingsService, + AccountService, HTTP_OPERATIONS, ], }), @@ -1158,7 +1159,7 @@ const safeProviders: SafeProvider[] = [ safeProvider({ provide: ConfigApiServiceAbstraction, useClass: ConfigApiService, - deps: [ApiServiceAbstraction, TokenServiceAbstraction], + deps: [ApiServiceAbstraction], }), safeProvider({ provide: AnonymousHubServiceAbstraction, diff --git a/libs/common/src/abstractions/api.service.ts b/libs/common/src/abstractions/api.service.ts index 726b04534ad..ab217c56fc4 100644 --- a/libs/common/src/abstractions/api.service.ts +++ b/libs/common/src/abstractions/api.service.ts @@ -127,11 +127,34 @@ import { OptionalCipherResponse } from "../vault/models/response/optional-cipher * of this decision please read https://contributing.bitwarden.com/architecture/adr/refactor-api-service. */ export abstract class ApiService { + /** @deprecated Use the overload accepting the user you want the request authenticated for. */ abstract send( method: "GET" | "POST" | "PUT" | "DELETE" | "PATCH", path: string, body: any, - authed: boolean, + authed: true, + hasResponse: boolean, + apiUrl?: string | null, + alterHeaders?: (header: Headers) => void, + ): Promise; + + /** Sends an unauthenticated API request. */ + abstract send( + method: "GET" | "POST" | "PUT" | "DELETE" | "PATCH", + path: string, + body: any, + authed: false, + hasResponse: boolean, + apiUrl?: string | null, + alterHeaders?: (header: Headers) => void, + ): Promise; + + /** Sends an API request authenticated with the given users ID. */ + abstract send( + method: "GET" | "POST" | "PUT" | "DELETE" | "PATCH", + path: string, + body: any, + userId: UserId, hasResponse: boolean, apiUrl?: string | null, alterHeaders?: (headers: Headers) => void, @@ -499,7 +522,7 @@ export abstract class ApiService { abstract postBitPayInvoice(request: BitPayInvoiceRequest): Promise; abstract postSetupPayment(): Promise; - abstract getActiveBearerToken(): Promise; + abstract getActiveBearerToken(userId: UserId): Promise; abstract fetch(request: Request): Promise; abstract nativeFetch(request: Request): Promise; diff --git a/libs/common/src/auth/abstractions/token.service.ts b/libs/common/src/auth/abstractions/token.service.ts index 2139f32fca2..673bc7bdf0a 100644 --- a/libs/common/src/auth/abstractions/token.service.ts +++ b/libs/common/src/auth/abstractions/token.service.ts @@ -72,14 +72,14 @@ export abstract class TokenService { * @param userId - The optional user id to get the access token for; if not provided, the active user is used. * @returns A promise that resolves with the access token or null. */ - abstract getAccessToken(userId?: UserId): Promise; + abstract getAccessToken(userId: UserId): Promise; /** * Gets the refresh token. * @param userId - The optional user id to get the refresh token for; if not provided, the active user is used. * @returns A promise that resolves with the refresh token or null. */ - abstract getRefreshToken(userId?: UserId): Promise; + abstract getRefreshToken(userId: UserId): Promise; /** * Sets the API Key Client ID for the active user id in memory or disk based on the given vaultTimeoutAction and vaultTimeout. @@ -96,10 +96,10 @@ export abstract class TokenService { ): Promise; /** - * Gets the API Key Client ID for the active user. + * Gets the API Key Client ID for the given user. * @returns A promise that resolves with the API Key Client ID or undefined */ - abstract getClientId(userId?: UserId): Promise; + abstract getClientId(userId: UserId): Promise; /** * Sets the API Key Client Secret for the active user id in memory or disk based on the given vaultTimeoutAction and vaultTimeout. @@ -116,10 +116,10 @@ export abstract class TokenService { ): Promise; /** - * Gets the API Key Client Secret for the active user. + * Gets the API Key Client Secret for the given user. * @returns A promise that resolves with the API Key Client Secret or undefined */ - abstract getClientSecret(userId?: UserId): Promise; + abstract getClientSecret(userId: UserId): Promise; /** * Sets the two factor token for the given email in global state. @@ -157,7 +157,7 @@ export abstract class TokenService { * Gets the expiration date for the access token. Returns if token can't be decoded or has no expiration * @returns A promise that resolves with the expiration date for the access token. */ - abstract getTokenExpirationDate(): Promise; + abstract getTokenExpirationDate(userId: UserId): Promise; /** * Calculates the adjusted time in seconds until the access token expires, considering an optional offset. @@ -168,14 +168,14 @@ export abstract class TokenService { * based on the actual expiration. * @returns {Promise} Promise resolving to the adjusted seconds remaining. */ - abstract tokenSecondsRemaining(offsetSeconds?: number): Promise; + abstract tokenSecondsRemaining(userId: UserId, offsetSeconds?: number): Promise; /** * Checks if the access token needs to be refreshed. * @param {number} [minutes=5] - Optional number of minutes before the access token expires to consider refreshing it. * @returns A promise that resolves with a boolean indicating if the access token needs to be refreshed. */ - abstract tokenNeedsRefresh(minutes?: number): Promise; + abstract tokenNeedsRefresh(userId: UserId, minutes?: number): Promise; /** * Gets the user id for the active user from the access token. diff --git a/libs/common/src/auth/services/token.service.spec.ts b/libs/common/src/auth/services/token.service.spec.ts index 7274954c950..f4e4ec5e204 100644 --- a/libs/common/src/auth/services/token.service.spec.ts +++ b/libs/common/src/auth/services/token.service.spec.ts @@ -409,28 +409,8 @@ describe("TokenService", () => { }); describe("getAccessToken", () => { - it("returns null when no user id is provided and there is no active user in global state", async () => { - // Act - const result = await tokenService.getAccessToken(); - // Assert - expect(result).toBeNull(); - }); - - it("returns null when no access token is found in memory, disk, or secure storage", async () => { - // Arrange - globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); - - // Act - const result = await tokenService.getAccessToken(); - // Assert - expect(result).toBeNull(); - }); - describe("Memory storage tests", () => { - test.each([ - ["gets the access token from memory when a user id is provided ", userIdFromAccessToken], - ["gets the access token from memory when no user id is provided", undefined], - ])("%s", async (_, userId) => { + it("gets the access token from memory when a user id is provided ", async () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY) @@ -442,12 +422,10 @@ describe("TokenService", () => { .nextState(undefined); // Need to have global active id set to the user id - if (!userId) { - globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); - } + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act - const result = await tokenService.getAccessToken(userId); + const result = await tokenService.getAccessToken(userIdFromAccessToken); // Assert expect(result).toEqual(accessTokenJwt); @@ -455,10 +433,7 @@ describe("TokenService", () => { }); describe("Disk storage tests (secure storage not supported on platform)", () => { - test.each([ - ["gets the access token from disk when the user id is specified", userIdFromAccessToken], - ["gets the access token from disk when no user id is specified", undefined], - ])("%s", async (_, userId) => { + it("gets the access token from disk when the user id is specified", async () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY) @@ -469,12 +444,10 @@ describe("TokenService", () => { .nextState(accessTokenJwt); // Need to have global active id set to the user id - if (!userId) { - globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); - } + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act - const result = await tokenService.getAccessToken(userId); + const result = await tokenService.getAccessToken(userIdFromAccessToken); // Assert expect(result).toEqual(accessTokenJwt); }); @@ -486,16 +459,7 @@ describe("TokenService", () => { tokenService = createTokenService(supportsSecureStorage); }); - test.each([ - [ - "gets the encrypted access token from disk, decrypts it, and returns it when a user id is provided", - userIdFromAccessToken, - ], - [ - "gets the encrypted access token from disk, decrypts it, and returns it when no user id is provided", - undefined, - ], - ])("%s", async (_, userId) => { + it("gets the encrypted access token from disk, decrypts it, and returns it when a user id is provided", async () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY) @@ -509,27 +473,17 @@ describe("TokenService", () => { encryptService.decryptString.mockResolvedValue("decryptedAccessToken"); // Need to have global active id set to the user id - if (!userId) { - globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); - } + + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act - const result = await tokenService.getAccessToken(userId); + const result = await tokenService.getAccessToken(userIdFromAccessToken); // Assert expect(result).toEqual("decryptedAccessToken"); }); - test.each([ - [ - "falls back and gets the unencrypted access token from disk when there isn't an access token key in secure storage and a user id is provided", - userIdFromAccessToken, - ], - [ - "falls back and gets the unencrypted access token from disk when there isn't an access token key in secure storage and no user id is provided", - undefined, - ], - ])("%s", async (_, userId) => { + it("falls back and gets the unencrypted access token from disk when there isn't an access token key in secure storage and a user id is provided", async () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY) @@ -540,14 +494,12 @@ describe("TokenService", () => { .nextState(accessTokenJwt); // Need to have global active id set to the user id - if (!userId) { - globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); - } + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // No access token key set // Act - const result = await tokenService.getAccessToken(userId); + const result = await tokenService.getAccessToken(userIdFromAccessToken); // Assert expect(result).toEqual(accessTokenJwt); @@ -738,7 +690,7 @@ describe("TokenService", () => { // Act // note: don't await here because we want to test the error - const result = tokenService.getTokenExpirationDate(); + const result = tokenService.getTokenExpirationDate(userIdFromAccessToken); // Assert await expect(result).rejects.toThrow("Failed to decode access token: Mock error"); }); @@ -748,7 +700,7 @@ describe("TokenService", () => { tokenService.decodeAccessToken = jest.fn().mockResolvedValue(null); // Act - const result = await tokenService.getTokenExpirationDate(); + const result = await tokenService.getTokenExpirationDate(userIdFromAccessToken); // Assert expect(result).toBeNull(); @@ -763,7 +715,7 @@ describe("TokenService", () => { .mockResolvedValue(accessTokenDecodedWithoutExp); // Act - const result = await tokenService.getTokenExpirationDate(); + const result = await tokenService.getTokenExpirationDate(userIdFromAccessToken); // Assert expect(result).toBeNull(); @@ -777,7 +729,7 @@ describe("TokenService", () => { .mockResolvedValue(accessTokenDecodedWithNonNumericExp); // Act - const result = await tokenService.getTokenExpirationDate(); + const result = await tokenService.getTokenExpirationDate(userIdFromAccessToken); // Assert expect(result).toBeNull(); @@ -788,7 +740,7 @@ describe("TokenService", () => { tokenService.decodeAccessToken = jest.fn().mockResolvedValue(accessTokenDecoded); // Act - const result = await tokenService.getTokenExpirationDate(); + const result = await tokenService.getTokenExpirationDate(userIdFromAccessToken); // Assert expect(result).toEqual(new Date(accessTokenDecoded.exp * 1000)); @@ -801,7 +753,7 @@ describe("TokenService", () => { tokenService.getTokenExpirationDate = jest.fn().mockResolvedValue(null); // Act - const result = await tokenService.tokenSecondsRemaining(); + const result = await tokenService.tokenSecondsRemaining(userIdFromAccessToken); // Assert expect(result).toEqual(0); @@ -823,7 +775,7 @@ describe("TokenService", () => { tokenService.getTokenExpirationDate = jest.fn().mockResolvedValue(expirationDate); // Act - const result = await tokenService.tokenSecondsRemaining(); + const result = await tokenService.tokenSecondsRemaining(userIdFromAccessToken); // Assert expect(result).toEqual(expectedSecondsRemaining); @@ -849,7 +801,10 @@ describe("TokenService", () => { tokenService.getTokenExpirationDate = jest.fn().mockResolvedValue(expirationDate); // Act - const result = await tokenService.tokenSecondsRemaining(offsetSeconds); + const result = await tokenService.tokenSecondsRemaining( + userIdFromAccessToken, + offsetSeconds, + ); // Assert expect(result).toEqual(expectedSecondsRemaining); @@ -866,7 +821,7 @@ describe("TokenService", () => { tokenService.tokenSecondsRemaining = jest.fn().mockResolvedValue(tokenSecondsRemaining); // Act - const result = await tokenService.tokenNeedsRefresh(); + const result = await tokenService.tokenNeedsRefresh(userIdFromAccessToken); // Assert expect(result).toEqual(true); @@ -878,7 +833,7 @@ describe("TokenService", () => { tokenService.tokenSecondsRemaining = jest.fn().mockResolvedValue(tokenSecondsRemaining); // Act - const result = await tokenService.tokenNeedsRefresh(); + const result = await tokenService.tokenNeedsRefresh(userIdFromAccessToken); // Assert expect(result).toEqual(false); @@ -890,7 +845,7 @@ describe("TokenService", () => { tokenService.tokenSecondsRemaining = jest.fn().mockResolvedValue(tokenSecondsRemaining); // Act - const result = await tokenService.tokenNeedsRefresh(2); + const result = await tokenService.tokenNeedsRefresh(userIdFromAccessToken, 2); // Assert expect(result).toEqual(true); @@ -902,7 +857,7 @@ describe("TokenService", () => { tokenService.tokenSecondsRemaining = jest.fn().mockResolvedValue(tokenSecondsRemaining); // Act - const result = await tokenService.tokenNeedsRefresh(5); + const result = await tokenService.tokenNeedsRefresh(userIdFromAccessToken, 5); // Assert expect(result).toEqual(false); @@ -1565,26 +1520,6 @@ describe("TokenService", () => { }); describe("Memory storage tests", () => { - it("gets the refresh token from memory when no user id is specified (uses global active user)", async () => { - // Arrange - singleUserStateProvider - .getFake(userIdFromAccessToken, REFRESH_TOKEN_MEMORY) - .nextState(refreshToken); - - singleUserStateProvider - .getFake(userIdFromAccessToken, REFRESH_TOKEN_DISK) - .nextState(undefined); - - // Need to have global active id set to the user id - globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); - - // Act - const result = await tokenService.getRefreshToken(); - - // Assert - expect(result).toEqual(refreshToken); - }); - it("gets the refresh token from memory when a user id is specified", async () => { // Arrange singleUserStateProvider @@ -1603,25 +1538,6 @@ describe("TokenService", () => { }); describe("Disk storage tests (secure storage not supported on platform)", () => { - it("gets the refresh token from disk when no user id is specified", async () => { - // Arrange - singleUserStateProvider - .getFake(userIdFromAccessToken, REFRESH_TOKEN_MEMORY) - .nextState(undefined); - - singleUserStateProvider - .getFake(userIdFromAccessToken, REFRESH_TOKEN_DISK) - .nextState(refreshToken); - - // Need to have global active id set to the user id - globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); - - // Act - const result = await tokenService.getRefreshToken(); - // Assert - expect(result).toEqual(refreshToken); - }); - it("gets the refresh token from disk when a user id is specified", async () => { // Arrange singleUserStateProvider @@ -1645,27 +1561,6 @@ describe("TokenService", () => { tokenService = createTokenService(supportsSecureStorage); }); - it("gets the refresh token from secure storage when no user id is specified", async () => { - // Arrange - singleUserStateProvider - .getFake(userIdFromAccessToken, REFRESH_TOKEN_MEMORY) - .nextState(undefined); - - singleUserStateProvider - .getFake(userIdFromAccessToken, REFRESH_TOKEN_DISK) - .nextState(undefined); - - secureStorageService.get.mockResolvedValue(refreshToken); - - // Need to have global active id set to the user id - globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); - - // Act - const result = await tokenService.getRefreshToken(); - // Assert - expect(result).toEqual(refreshToken); - }); - it("gets the refresh token from secure storage when a user id is specified", async () => { // Arrange @@ -1705,29 +1600,6 @@ describe("TokenService", () => { expect(secureStorageService.get).not.toHaveBeenCalled(); }); - it("falls back and gets the refresh token from disk when no user id is specified even if the platform supports secure storage", async () => { - // Arrange - singleUserStateProvider - .getFake(userIdFromAccessToken, REFRESH_TOKEN_MEMORY) - .nextState(undefined); - - singleUserStateProvider - .getFake(userIdFromAccessToken, REFRESH_TOKEN_DISK) - .nextState(refreshToken); - - // Need to have global active id set to the user id - globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); - - // Act - const result = await tokenService.getRefreshToken(); - - // Assert - expect(result).toEqual(refreshToken); - - // assert that secure storage was not called - expect(secureStorageService.get).not.toHaveBeenCalled(); - }); - it("returns null when the refresh token is not found in memory, on disk, or in secure storage", async () => { // Arrange secureStorageService.get.mockResolvedValue(null); @@ -1944,45 +1816,7 @@ describe("TokenService", () => { }); describe("getClientId", () => { - it("returns undefined when no user id is provided and there is no active user in global state", async () => { - // Act - const result = await tokenService.getClientId(); - // Assert - expect(result).toBeUndefined(); - }); - - it("returns null when no client id is found in memory or disk", async () => { - // Arrange - globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); - - // Act - const result = await tokenService.getClientId(); - // Assert - expect(result).toBeNull(); - }); - describe("Memory storage tests", () => { - it("gets the client id from memory when no user id is specified (uses global active user)", async () => { - // Arrange - singleUserStateProvider - .getFake(userIdFromAccessToken, API_KEY_CLIENT_ID_MEMORY) - .nextState(clientId); - - // set disk to undefined - singleUserStateProvider - .getFake(userIdFromAccessToken, API_KEY_CLIENT_ID_DISK) - .nextState(undefined); - - // Need to have global active id set to the user id - globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); - - // Act - const result = await tokenService.getClientId(); - - // Assert - expect(result).toEqual(clientId); - }); - it("gets the client id from memory when given a user id", async () => { // Arrange singleUserStateProvider @@ -2002,25 +1836,6 @@ describe("TokenService", () => { }); describe("Disk storage tests", () => { - it("gets the client id from disk when no user id is specified", async () => { - // Arrange - singleUserStateProvider - .getFake(userIdFromAccessToken, API_KEY_CLIENT_ID_MEMORY) - .nextState(undefined); - - singleUserStateProvider - .getFake(userIdFromAccessToken, API_KEY_CLIENT_ID_DISK) - .nextState(clientId); - - // Need to have global active id set to the user id - globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); - - // Act - const result = await tokenService.getClientId(); - // Assert - expect(result).toEqual(clientId); - }); - it("gets the client id from disk when a user id is specified", async () => { // Arrange singleUserStateProvider @@ -2215,45 +2030,17 @@ describe("TokenService", () => { }); describe("getClientSecret", () => { - it("returns undefined when no user id is provided and there is no active user in global state", async () => { - // Act - const result = await tokenService.getClientSecret(); - // Assert - expect(result).toBeUndefined(); - }); - it("returns null when no client secret is found in memory or disk", async () => { // Arrange globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act - const result = await tokenService.getClientSecret(); + const result = await tokenService.getClientSecret(userIdFromAccessToken); // Assert expect(result).toBeNull(); }); describe("Memory storage tests", () => { - it("gets the client secret from memory when no user id is specified (uses global active user)", async () => { - // Arrange - singleUserStateProvider - .getFake(userIdFromAccessToken, API_KEY_CLIENT_SECRET_MEMORY) - .nextState(clientSecret); - - // set disk to undefined - singleUserStateProvider - .getFake(userIdFromAccessToken, API_KEY_CLIENT_SECRET_DISK) - .nextState(undefined); - - // Need to have global active id set to the user id - globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); - - // Act - const result = await tokenService.getClientSecret(); - - // Assert - expect(result).toEqual(clientSecret); - }); - it("gets the client secret from memory when a user id is specified", async () => { // Arrange singleUserStateProvider @@ -2273,25 +2060,6 @@ describe("TokenService", () => { }); describe("Disk storage tests", () => { - it("gets the client secret from disk when no user id specified", async () => { - // Arrange - singleUserStateProvider - .getFake(userIdFromAccessToken, API_KEY_CLIENT_SECRET_MEMORY) - .nextState(undefined); - - singleUserStateProvider - .getFake(userIdFromAccessToken, API_KEY_CLIENT_SECRET_DISK) - .nextState(clientSecret); - - // Need to have global active id set to the user id - globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); - - // Act - const result = await tokenService.getClientSecret(); - // Assert - expect(result).toEqual(clientSecret); - }); - it("gets the client secret from disk when a user id is specified", async () => { // Arrange singleUserStateProvider diff --git a/libs/common/src/auth/services/token.service.ts b/libs/common/src/auth/services/token.service.ts index 21ccd672056..0721927bd13 100644 --- a/libs/common/src/auth/services/token.service.ts +++ b/libs/common/src/auth/services/token.service.ts @@ -452,9 +452,7 @@ export class TokenService implements TokenServiceAbstraction { await this.singleUserStateProvider.get(userId, ACCESS_TOKEN_MEMORY).update((_) => null); } - async getAccessToken(userId?: UserId): Promise { - userId ??= await firstValueFrom(this.activeUserIdGlobalState.state$); - + async getAccessToken(userId: UserId): Promise { if (!userId) { return null; } @@ -631,9 +629,7 @@ export class TokenService implements TokenServiceAbstraction { } } - async getRefreshToken(userId?: UserId): Promise { - userId ??= await firstValueFrom(this.activeUserIdGlobalState.state$); - + async getRefreshToken(userId: UserId): Promise { if (!userId) { return null; } @@ -746,9 +742,7 @@ export class TokenService implements TokenServiceAbstraction { } } - async getClientId(userId?: UserId): Promise { - userId ??= await firstValueFrom(this.activeUserIdGlobalState.state$); - + async getClientId(userId: UserId): Promise { if (!userId) { return undefined; } @@ -822,9 +816,7 @@ export class TokenService implements TokenServiceAbstraction { } } - async getClientSecret(userId?: UserId): Promise { - userId ??= await firstValueFrom(this.activeUserIdGlobalState.state$); - + async getClientSecret(userId: UserId): Promise { if (!userId) { return undefined; } @@ -915,7 +907,9 @@ export class TokenService implements TokenServiceAbstraction { if (Utils.isGuid(tokenOrUserId)) { token = await this.getAccessToken(tokenOrUserId as UserId); } else { - token ??= await this.getAccessToken(); + token ??= await this.getAccessToken( + await firstValueFrom(this.activeUserIdGlobalState.state$), + ); } if (token == null) { @@ -928,10 +922,10 @@ export class TokenService implements TokenServiceAbstraction { // TODO: PM-6678- tech debt - consider consolidating the return types of all these access // token data retrieval methods to return null if something goes wrong instead of throwing an error. - async getTokenExpirationDate(): Promise { + async getTokenExpirationDate(userId: UserId): Promise { let decoded: DecodedAccessToken; try { - decoded = await this.decodeAccessToken(); + decoded = await this.decodeAccessToken(userId); } catch (error) { throw new Error("Failed to decode access token: " + error.message); } @@ -947,8 +941,8 @@ export class TokenService implements TokenServiceAbstraction { return expirationDate; } - async tokenSecondsRemaining(offsetSeconds = 0): Promise { - const date = await this.getTokenExpirationDate(); + async tokenSecondsRemaining(userId: UserId, offsetSeconds = 0): Promise { + const date = await this.getTokenExpirationDate(userId); if (date == null) { return 0; } @@ -957,8 +951,8 @@ export class TokenService implements TokenServiceAbstraction { return Math.round(msRemaining / 1000); } - async tokenNeedsRefresh(minutes = 5): Promise { - const sRemaining = await this.tokenSecondsRemaining(); + async tokenNeedsRefresh(userId: UserId, minutes = 5): Promise { + const sRemaining = await this.tokenSecondsRemaining(userId); return sRemaining < 60 * minutes; } diff --git a/libs/common/src/key-management/vault-timeout/services/vault-timeout-settings.service.ts b/libs/common/src/key-management/vault-timeout/services/vault-timeout-settings.service.ts index 7e43ee394f6..e40b896dc8c 100644 --- a/libs/common/src/key-management/vault-timeout/services/vault-timeout-settings.service.ts +++ b/libs/common/src/key-management/vault-timeout/services/vault-timeout-settings.service.ts @@ -70,17 +70,17 @@ export class VaultTimeoutSettingsService implements VaultTimeoutSettingsServiceA // We swap these tokens from being on disk for lock actions, and in memory for logout actions // Get them here to set them to their new location after changing the timeout action and clearing if needed - const accessToken = await this.tokenService.getAccessToken(); - const refreshToken = await this.tokenService.getRefreshToken(); - const clientId = await this.tokenService.getClientId(); - const clientSecret = await this.tokenService.getClientSecret(); + const accessToken = await this.tokenService.getAccessToken(userId); + const refreshToken = await this.tokenService.getRefreshToken(userId); + const clientId = await this.tokenService.getClientId(userId); + const clientSecret = await this.tokenService.getClientSecret(userId); await this.setVaultTimeout(userId, timeout); if (timeout != VaultTimeoutStringType.Never && action === VaultTimeoutAction.LogOut) { // if we have a vault timeout and the action is log out, reset tokens // as the tokens were stored on disk and now should be stored in memory - await this.tokenService.clearTokens(); + await this.tokenService.clearTokens(userId); } await this.setVaultTimeoutAction(userId, action); diff --git a/libs/common/src/platform/server-notifications/internal/signalr-connection.service.ts b/libs/common/src/platform/server-notifications/internal/signalr-connection.service.ts index 58d6311c668..5998668f138 100644 --- a/libs/common/src/platform/server-notifications/internal/signalr-connection.service.ts +++ b/libs/common/src/platform/server-notifications/internal/signalr-connection.service.ts @@ -78,7 +78,7 @@ export class SignalRConnectionService { return new Observable((subsciber) => { const connection = this.hubConnectionBuilderFactory() .withUrl(notificationsUrl + "/hub", { - accessTokenFactory: () => this.apiService.getActiveBearerToken(), + accessTokenFactory: () => this.apiService.getActiveBearerToken(userId), skipNegotiation: true, transport: HttpTransportType.WebSockets, }) diff --git a/libs/common/src/platform/server-notifications/internal/web-push-notifications-api.service.ts b/libs/common/src/platform/server-notifications/internal/web-push-notifications-api.service.ts index 891dab2c069..861835c086d 100644 --- a/libs/common/src/platform/server-notifications/internal/web-push-notifications-api.service.ts +++ b/libs/common/src/platform/server-notifications/internal/web-push-notifications-api.service.ts @@ -1,3 +1,5 @@ +import { UserId } from "@bitwarden/user-core"; + import { ApiService } from "../../../abstractions/api.service"; import { AppIdService } from "../../abstractions/app-id.service"; @@ -12,13 +14,13 @@ export class WebPushNotificationsApiService { /** * Posts a device-user association to the server and ensures it's installed for push server notifications */ - async putSubscription(pushSubscription: PushSubscriptionJSON): Promise { + async putSubscription(pushSubscription: PushSubscriptionJSON, userId: UserId): Promise { const request = WebPushRequest.from(pushSubscription); await this.apiService.send( "POST", `/devices/identifier/${await this.appIdService.getAppId()}/web-push-auth`, request, - true, + userId, false, ); } diff --git a/libs/common/src/platform/server-notifications/internal/worker-webpush-connection.service.ts b/libs/common/src/platform/server-notifications/internal/worker-webpush-connection.service.ts index d8a2c33568e..8b38ebd5b17 100644 --- a/libs/common/src/platform/server-notifications/internal/worker-webpush-connection.service.ts +++ b/libs/common/src/platform/server-notifications/internal/worker-webpush-connection.service.ts @@ -143,7 +143,7 @@ class MyWebPushConnector implements WebPushConnector { await subscriptionUsersState.update(() => subscriptionUsers); // Inform the server about the new subscription-user association - await this.webPushApiService.putSubscription(subscription.toJSON()); + await this.webPushApiService.putSubscription(subscription.toJSON(), this.userId); }), switchMap(() => this.pushEvent$), map((e) => { diff --git a/libs/common/src/platform/services/config/config-api.service.ts b/libs/common/src/platform/services/config/config-api.service.ts index b7ecb9c8712..752a0075346 100644 --- a/libs/common/src/platform/services/config/config-api.service.ts +++ b/libs/common/src/platform/services/config/config-api.service.ts @@ -1,22 +1,21 @@ import { ApiService } from "../../../abstractions/api.service"; -import { TokenService } from "../../../auth/abstractions/token.service"; import { UserId } from "../../../types/guid"; import { ConfigApiServiceAbstraction } from "../../abstractions/config/config-api.service.abstraction"; import { ServerConfigResponse } from "../../models/response/server-config.response"; export class ConfigApiService implements ConfigApiServiceAbstraction { - constructor( - private apiService: ApiService, - private tokenService: TokenService, - ) {} + constructor(private apiService: ApiService) {} async get(userId: UserId | null): Promise { // Authentication adds extra context to config responses, if the user has an access token, we want to use it // We don't particularly care about ensuring the token is valid and not expired, just that it exists - const authed: boolean = - userId == null ? false : (await this.tokenService.getAccessToken(userId)) != null; + let r: any; + if (userId == null) { + r = await this.apiService.send("GET", "/config", null, false, true); + } else { + r = await this.apiService.send("GET", "/config", null, userId, true); + } - const r = await this.apiService.send("GET", "/config", null, authed, true); return new ServerConfigResponse(r); } } diff --git a/libs/common/src/services/api.service.spec.ts b/libs/common/src/services/api.service.spec.ts index fffe0478254..144b0cc02c4 100644 --- a/libs/common/src/services/api.service.spec.ts +++ b/libs/common/src/services/api.service.spec.ts @@ -1,13 +1,19 @@ import { mock, MockProxy } from "jest-mock-extended"; -import { of } from "rxjs"; +import { ObservedValueOf, of } from "rxjs"; // This import has been flagged as unallowed for this class. It may be involved in a circular dependency loop. // eslint-disable-next-line no-restricted-imports import { LogoutReason } from "@bitwarden/auth/common"; +import { UserId } from "@bitwarden/user-core"; +import { AccountService } from "../auth/abstractions/account.service"; import { TokenService } from "../auth/abstractions/token.service"; import { DeviceType } from "../enums"; -import { VaultTimeoutSettingsService } from "../key-management/vault-timeout"; +import { + VaultTimeoutAction, + VaultTimeoutSettingsService, + VaultTimeoutStringType, +} from "../key-management/vault-timeout"; import { ErrorResponse } from "../models/response/error.response"; import { AppIdService } from "../platform/abstractions/app-id.service"; import { Environment, EnvironmentService } from "../platform/abstractions/environment.service"; @@ -25,10 +31,14 @@ describe("ApiService", () => { let logService: MockProxy; let logoutCallback: jest.Mock, [reason: LogoutReason]>; let vaultTimeoutSettingsService: MockProxy; + let accountService: MockProxy; let httpOperations: MockProxy; let sut: ApiService; + const testActiveUser = "activeUser" as UserId; + const testInactiveUser = "inactiveUser" as UserId; + beforeEach(() => { tokenService = mock(); platformUtilsService = mock(); @@ -40,6 +50,15 @@ describe("ApiService", () => { logService = mock(); logoutCallback = jest.fn(); vaultTimeoutSettingsService = mock(); + accountService = mock(); + + accountService.activeAccount$ = of({ + id: testActiveUser, + email: "user1@example.com", + emailVerified: true, + name: "Test Name", + } satisfies ObservedValueOf); + httpOperations = mock(); sut = new ApiService( @@ -51,6 +70,7 @@ describe("ApiService", () => { logService, logoutCallback, vaultTimeoutSettingsService, + accountService, httpOperations, "custom-user-agent", ); @@ -62,6 +82,12 @@ describe("ApiService", () => { getApiUrl: () => "https://example.com", } satisfies Partial as Environment); + environmentService.getEnvironment$.mockReturnValue( + of({ + getApiUrl: () => "https://authed.example.com", + } satisfies Partial as Environment), + ); + httpOperations.createRequest.mockImplementation((url, request) => { return { url: url, @@ -96,6 +122,7 @@ describe("ApiService", () => { expect(nativeFetch).toHaveBeenCalledTimes(1); const request = nativeFetch.mock.calls[0][0]; + expect(request.url).toBe("https://authed.example.com/something"); // This should get set for users of send expect(request.cache).toBe("no-store"); // TODO: Could expect on the credentials parameter @@ -109,6 +136,185 @@ describe("ApiService", () => { // The response body expect(response).toEqual({ hello: "world" }); }); + + it("authenticates with non-active user when user is passed in", async () => { + environmentService.environment$ = of({ + getApiUrl: () => "https://example.com", + } satisfies Partial as Environment); + + environmentService.getEnvironment$.calledWith(testInactiveUser).mockReturnValueOnce( + of({ + getApiUrl: () => "https://inactive.example.com", + } satisfies Partial as Environment), + ); + + httpOperations.createRequest.mockImplementation((url, request) => { + return { + url: url, + cache: request.cache, + credentials: request.credentials, + method: request.method, + mode: request.mode, + signal: request.signal, + headers: new Headers(request.headers), + } satisfies Partial as unknown as Request; + }); + + tokenService.getAccessToken + .calledWith(testInactiveUser) + .mockResolvedValue("inactive_access_token"); + + tokenService.tokenNeedsRefresh.calledWith(testInactiveUser).mockResolvedValue(false); + + const nativeFetch = jest.fn, [request: Request]>(); + + nativeFetch.mockImplementation((request) => { + return Promise.resolve({ + ok: true, + status: 200, + json: () => Promise.resolve({ hello: "world" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + }); + + sut.nativeFetch = nativeFetch; + + const response = await sut.send( + "GET", + "/something", + null, + testInactiveUser, + true, + null, + null, + ); + + expect(nativeFetch).toHaveBeenCalledTimes(1); + const request = nativeFetch.mock.calls[0][0]; + expect(request.url).toBe("https://inactive.example.com/something"); + // This should get set for users of send + expect(request.cache).toBe("no-store"); + // TODO: Could expect on the credentials parameter + expect(request.headers.get("Device-Type")).toBe("2"); // Chrome Extension + // Custom user agent should get set + expect(request.headers.get("User-Agent")).toBe("custom-user-agent"); + // This should be set when the caller has indicated there is a response + expect(request.headers.get("Accept")).toBe("application/json"); + // If they have indicated that it's authed, then the authorization header should get set. + expect(request.headers.get("Authorization")).toBe("Bearer inactive_access_token"); + // The response body + expect(response).toEqual({ hello: "world" }); + }); + + const cases: { + name: string; + authedOrUserId: boolean | UserId; + expectedEffectiveUser: UserId; + }[] = [ + { + name: "refreshes active user when true passed in for auth", + authedOrUserId: true, + expectedEffectiveUser: testActiveUser, + }, + { + name: "refreshes acess token when the user passed in happens to be the active one", + authedOrUserId: testActiveUser, + expectedEffectiveUser: testActiveUser, + }, + { + name: "refreshes access token when the user passed in happens to be inactive", + authedOrUserId: testInactiveUser, + expectedEffectiveUser: testInactiveUser, + }, + ]; + + it.each(cases)("$name does", async ({ authedOrUserId, expectedEffectiveUser }) => { + environmentService.getEnvironment$.calledWith(expectedEffectiveUser).mockReturnValue( + of({ + getApiUrl: () => `https://${expectedEffectiveUser}.example.com`, + getIdentityUrl: () => `https://${expectedEffectiveUser}.identity.example.com`, + } satisfies Partial as Environment), + ); + + tokenService.getAccessToken + .calledWith(expectedEffectiveUser) + .mockResolvedValue(`${expectedEffectiveUser}_access_token`); + + tokenService.tokenNeedsRefresh.calledWith(expectedEffectiveUser).mockResolvedValue(true); + + tokenService.getRefreshToken + .calledWith(expectedEffectiveUser) + .mockResolvedValue(`${expectedEffectiveUser}_refresh_token`); + + tokenService.decodeAccessToken + .calledWith(expectedEffectiveUser) + .mockResolvedValue({ client_id: "web" }); + + tokenService.decodeAccessToken + .calledWith(`${expectedEffectiveUser}_new_access_token`) + .mockResolvedValue({ sub: expectedEffectiveUser }); + + vaultTimeoutSettingsService.getVaultTimeoutActionByUserId$ + .calledWith(expectedEffectiveUser) + .mockReturnValue(of(VaultTimeoutAction.Lock)); + + vaultTimeoutSettingsService.getVaultTimeoutByUserId$ + .calledWith(expectedEffectiveUser) + .mockReturnValue(of(VaultTimeoutStringType.Never)); + + tokenService.setTokens + .calledWith( + `${expectedEffectiveUser}_new_access_token`, + VaultTimeoutAction.Lock, + VaultTimeoutStringType.Never, + `${expectedEffectiveUser}_new_refresh_token`, + ) + .mockResolvedValue({ accessToken: `${expectedEffectiveUser}_refreshed_access_token` }); + + httpOperations.createRequest.mockImplementation((url, request) => { + return { + url: url, + cache: request.cache, + credentials: request.credentials, + method: request.method, + mode: request.mode, + signal: request.signal, + headers: new Headers(request.headers), + } satisfies Partial as unknown as Request; + }); + + const nativeFetch = jest.fn, [request: Request]>(); + + nativeFetch.mockImplementation((request) => { + if (request.url.includes("identity")) { + return Promise.resolve({ + ok: true, + status: 200, + json: () => + Promise.resolve({ + access_token: `${expectedEffectiveUser}_new_access_token`, + refresh_token: `${expectedEffectiveUser}_new_refresh_token`, + }), + } satisfies Partial as Response); + } + return Promise.resolve({ + ok: true, + status: 200, + json: () => Promise.resolve({ hello: "world" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + }); + + sut.nativeFetch = nativeFetch; + + await sut.send("GET", "/something", null, authedOrUserId, true, null, null); + + expect(nativeFetch).toHaveBeenCalledTimes(2); + }); }); const errorData: { @@ -169,9 +375,11 @@ describe("ApiService", () => { it.each(errorData)( "throws error-like response when not ok response with $name", async ({ input, error }) => { - environmentService.environment$ = of({ - getApiUrl: () => "https://example.com", - } satisfies Partial as Environment); + environmentService.getEnvironment$.calledWith(testActiveUser).mockReturnValue( + of({ + getApiUrl: () => "https://example.com", + } satisfies Partial as Environment), + ); httpOperations.createRequest.mockImplementation((url, request) => { return { diff --git a/libs/common/src/services/api.service.ts b/libs/common/src/services/api.service.ts index 6a670368b1f..bbf990122df 100644 --- a/libs/common/src/services/api.service.ts +++ b/libs/common/src/services/api.service.ts @@ -1,6 +1,6 @@ // FIXME: Update this file to be type safe and remove this and next line // @ts-strict-ignore -import { firstValueFrom } from "rxjs"; +import { firstValueFrom, map } from "rxjs"; // This import has been flagged as unallowed for this class. It may be involved in a circular dependency loop. // eslint-disable-next-line no-restricted-imports @@ -47,6 +47,7 @@ import { ProviderUserUserDetailsResponse, } from "../admin-console/models/response/provider/provider-user.response"; import { SelectionReadOnlyResponse } from "../admin-console/models/response/selection-read-only.response"; +import { AccountService } from "../auth/abstractions/account.service"; import { TokenService } from "../auth/abstractions/token.service"; import { DeviceVerificationRequest } from "../auth/models/request/device-verification.request"; import { DisableTwoFactorAuthenticatorRequest } from "../auth/models/request/disable-two-factor-authenticator.request"; @@ -121,7 +122,7 @@ import { ListResponse } from "../models/response/list.response"; import { ProfileResponse } from "../models/response/profile.response"; import { UserKeyResponse } from "../models/response/user-key.response"; import { AppIdService } from "../platform/abstractions/app-id.service"; -import { EnvironmentService } from "../platform/abstractions/environment.service"; +import { Environment, EnvironmentService } from "../platform/abstractions/environment.service"; import { LogService } from "../platform/abstractions/log.service"; import { PlatformUtilsService } from "../platform/abstractions/platform-utils.service"; import { flagEnabled } from "../platform/misc/flags"; @@ -155,7 +156,7 @@ export type HttpOperations = { export class ApiService implements ApiServiceAbstraction { private device: DeviceType; private deviceType: string; - private refreshTokenPromise: Promise | undefined; + private refreshTokenPromise: Record> = {}; /** * The message (responseJson.ErrorModel.Message) that comes back from the server when a new device verification is required. @@ -172,6 +173,7 @@ export class ApiService implements ApiServiceAbstraction { private logService: LogService, private logoutCallback: (logoutReason: LogoutReason) => Promise, private vaultTimeoutSettingsService: VaultTimeoutSettingsService, + private readonly accountService: AccountService, private readonly httpOperations: HttpOperations, private customUserAgent: string = null, ) { @@ -209,7 +211,7 @@ export class ApiService implements ApiServiceAbstraction { const response = await this.fetch( this.httpOperations.createRequest(env.getIdentityUrl() + "/connect/token", { body: this.qsStringify(identityToken), - credentials: await this.getCredentials(), + credentials: await this.getCredentials(env), cache: "no-store", headers: headers, method: "POST", @@ -241,9 +243,13 @@ export class ApiService implements ApiServiceAbstraction { return Promise.reject(new ErrorResponse(responseJson, response.status, true)); } - async refreshIdentityToken(): Promise { + async refreshIdentityToken(userId: UserId | null = null): Promise { + const normalizedUser = (userId ??= await this.getActiveUser()); + if (normalizedUser == null) { + throw new Error("No user provided and no active user, cannot refresh the identity token."); + } try { - await this.refreshToken(); + await this.refreshToken(normalizedUser); } catch (e) { this.logService.error("Error refreshing access token: ", e); throw e; @@ -1398,11 +1404,16 @@ export class ApiService implements ApiServiceAbstraction { if (this.customUserAgent != null) { headers.set("User-Agent", this.customUserAgent); } - const env = await firstValueFrom(this.environmentService.environment$); + + const env = await firstValueFrom( + userId == null + ? this.environmentService.environment$ + : this.environmentService.getEnvironment$(userId), + ); const response = await this.fetch( this.httpOperations.createRequest(env.getEventsUrl() + "/collect", { cache: "no-store", - credentials: await this.getCredentials(), + credentials: await this.getCredentials(env), method: "POST", body: JSON.stringify(request), headers: headers, @@ -1444,7 +1455,11 @@ export class ApiService implements ApiServiceAbstraction { async getMasterKeyFromKeyConnector( keyConnectorUrl: string, ): Promise { - const authHeader = await this.getActiveBearerToken(); + const activeUser = await this.getActiveUser(); + if (activeUser == null) { + throw new Error("No active user, cannot get master key from key connector."); + } + const authHeader = await this.getActiveBearerToken(activeUser); const response = await this.fetch( this.httpOperations.createRequest(keyConnectorUrl + "/user-keys", { @@ -1469,7 +1484,11 @@ export class ApiService implements ApiServiceAbstraction { keyConnectorUrl: string, request: KeyConnectorUserKeyRequest, ): Promise { - const authHeader = await this.getActiveBearerToken(); + const activeUser = await this.getActiveUser(); + if (activeUser == null) { + throw new Error("No active user, cannot post key to key connector."); + } + const authHeader = await this.getActiveBearerToken(activeUser); const response = await this.fetch( this.httpOperations.createRequest(keyConnectorUrl + "/user-keys", { @@ -1521,10 +1540,10 @@ export class ApiService implements ApiServiceAbstraction { // Helpers - async getActiveBearerToken(): Promise { - let accessToken = await this.tokenService.getAccessToken(); - if (await this.tokenService.tokenNeedsRefresh()) { - accessToken = await this.refreshToken(); + async getActiveBearerToken(userId: UserId): Promise { + let accessToken = await this.tokenService.getAccessToken(userId); + if (await this.tokenService.tokenNeedsRefresh(userId)) { + accessToken = await this.refreshToken(userId); } return accessToken; } @@ -1563,7 +1582,7 @@ export class ApiService implements ApiServiceAbstraction { const response = await this.fetch( this.httpOperations.createRequest(env.getIdentityUrl() + path, { cache: "no-store", - credentials: await this.getCredentials(), + credentials: await this.getCredentials(env), headers: headers, method: "GET", }), @@ -1646,26 +1665,27 @@ export class ApiService implements ApiServiceAbstraction { } // Keep the running refreshTokenPromise to prevent parallel calls. - protected refreshToken(): Promise { - if (this.refreshTokenPromise === undefined) { - this.refreshTokenPromise = this.internalRefreshToken(); - void this.refreshTokenPromise.finally(() => { - this.refreshTokenPromise = undefined; + protected refreshToken(userId: UserId): Promise { + if (this.refreshTokenPromise[userId] === undefined) { + // TODO: Have different promise for each user + this.refreshTokenPromise[userId] = this.internalRefreshToken(userId); + void this.refreshTokenPromise[userId].finally(() => { + delete this.refreshTokenPromise[userId]; }); } - return this.refreshTokenPromise; + return this.refreshTokenPromise[userId]; } - private async internalRefreshToken(): Promise { - const refreshToken = await this.tokenService.getRefreshToken(); + private async internalRefreshToken(userId: UserId): Promise { + const refreshToken = await this.tokenService.getRefreshToken(userId); if (refreshToken != null && refreshToken !== "") { - return this.refreshAccessToken(); + return await this.refreshAccessToken(userId); } - const clientId = await this.tokenService.getClientId(); - const clientSecret = await this.tokenService.getClientSecret(); + const clientId = await this.tokenService.getClientId(userId); + const clientSecret = await this.tokenService.getClientSecret(userId); if (!Utils.isNullOrWhitespace(clientId) && !Utils.isNullOrWhitespace(clientSecret)) { - return this.refreshApiToken(); + return await this.refreshApiToken(userId); } this.refreshAccessTokenErrorCallback(); @@ -1673,8 +1693,8 @@ export class ApiService implements ApiServiceAbstraction { throw new Error("Cannot refresh access token, no refresh token or api keys are stored."); } - protected async refreshAccessToken(): Promise { - const refreshToken = await this.tokenService.getRefreshToken(); + private async refreshAccessToken(userId: UserId): Promise { + const refreshToken = await this.tokenService.getRefreshToken(userId); if (refreshToken == null || refreshToken === "") { throw new Error(); } @@ -1687,8 +1707,8 @@ export class ApiService implements ApiServiceAbstraction { headers.set("User-Agent", this.customUserAgent); } - const env = await firstValueFrom(this.environmentService.environment$); - const decodedToken = await this.tokenService.decodeAccessToken(); + const env = await firstValueFrom(this.environmentService.getEnvironment$(userId)); + const decodedToken = await this.tokenService.decodeAccessToken(userId); const response = await this.fetch( this.httpOperations.createRequest(env.getIdentityUrl() + "/connect/token", { body: this.qsStringify({ @@ -1697,7 +1717,7 @@ export class ApiService implements ApiServiceAbstraction { refresh_token: refreshToken, }), cache: "no-store", - credentials: await this.getCredentials(), + credentials: await this.getCredentials(env), headers: headers, method: "POST", }), @@ -1732,9 +1752,9 @@ export class ApiService implements ApiServiceAbstraction { } } - protected async refreshApiToken(): Promise { - const clientId = await this.tokenService.getClientId(); - const clientSecret = await this.tokenService.getClientSecret(); + protected async refreshApiToken(userId: UserId): Promise { + const clientId = await this.tokenService.getClientId(userId); + const clientSecret = await this.tokenService.getClientSecret(userId); const appId = await this.appIdService.getAppId(); const deviceRequest = new DeviceRequest(appId, this.platformUtilsService); @@ -1751,7 +1771,12 @@ export class ApiService implements ApiServiceAbstraction { } const newDecodedAccessToken = await this.tokenService.decodeAccessToken(response.accessToken); - const userId = newDecodedAccessToken.sub; + + if (newDecodedAccessToken.sub !== userId) { + throw new Error( + `Token was supposed to be refreshed for ${userId} but the token we got back was for ${newDecodedAccessToken.sub}`, + ); + } const vaultTimeoutAction = await firstValueFrom( this.vaultTimeoutSettingsService.getVaultTimeoutActionByUserId$(userId), @@ -1772,12 +1797,28 @@ export class ApiService implements ApiServiceAbstraction { method: "GET" | "POST" | "PUT" | "DELETE" | "PATCH", path: string, body: any, - authed: boolean, + authedOrUserId: UserId | boolean, hasResponse: boolean, apiUrl?: string | null, alterHeaders?: (headers: Headers) => void, ): Promise { - const env = await firstValueFrom(this.environmentService.environment$); + if (authedOrUserId == null) { + throw new Error("A user id was given but it was null, cannot complete API request."); + } + + let userId: UserId | null = null; + if (typeof authedOrUserId === "boolean" && authedOrUserId) { + // Backwards compatible for authenticating the active user when `true` is passed in + userId = await this.getActiveUser(); + } else if (typeof authedOrUserId === "string") { + userId = authedOrUserId; + } + + const env = await firstValueFrom( + userId == null + ? this.environmentService.environment$ + : this.environmentService.getEnvironment$(userId), + ); apiUrl = Utils.isNullOrWhitespace(apiUrl) ? env.getApiUrl() : apiUrl; // Prevent directory traversal from malicious paths @@ -1786,7 +1827,7 @@ export class ApiService implements ApiServiceAbstraction { apiUrl + Utils.normalizePath(pathParts[0]) + (pathParts.length > 1 ? `?${pathParts[1]}` : ""); const [requestHeaders, requestBody] = await this.buildHeadersAndBody( - authed, + userId, hasResponse, body, alterHeaders, @@ -1794,7 +1835,7 @@ export class ApiService implements ApiServiceAbstraction { const requestInit: RequestInit = { cache: "no-store", - credentials: await this.getCredentials(), + credentials: await this.getCredentials(env), method: method, }; requestInit.headers = requestHeaders; @@ -1810,13 +1851,13 @@ export class ApiService implements ApiServiceAbstraction { } else if (hasResponse && response.status === 200 && responseIsCsv) { return await response.text(); } else if (response.status !== 200 && response.status !== 204) { - const error = await this.handleError(response, false, authed); + const error = await this.handleError(response, false, userId != null); return Promise.reject(error); } } private async buildHeadersAndBody( - authed: boolean, + userToAuthenticate: UserId | null, hasResponse: boolean, body: any, alterHeaders: (headers: Headers) => void, @@ -1838,8 +1879,8 @@ export class ApiService implements ApiServiceAbstraction { if (alterHeaders != null) { alterHeaders(headers); } - if (authed) { - const authHeader = await this.getActiveBearerToken(); + if (userToAuthenticate != null) { + const authHeader = await this.getActiveBearerToken(userToAuthenticate); headers.set("Authorization", "Bearer " + authHeader); } else { // For unauthenticated requests, we need to tell the server what the device is for flag targeting, @@ -1901,8 +1942,11 @@ export class ApiService implements ApiServiceAbstraction { .join("&"); } - private async getCredentials(): Promise { - const env = await firstValueFrom(this.environmentService.environment$); + private async getActiveUser(): Promise { + return await firstValueFrom(this.accountService.activeAccount$.pipe(map((a) => a?.id))); + } + + private async getCredentials(env: Environment): Promise { if (this.platformUtilsService.getClientType() !== ClientType.Web || env.hasBaseUrl()) { return "include"; }