From 3c1e39b0fb29e9bd0312fec52f156b62e6c5dcdd Mon Sep 17 00:00:00 2001 From: Todd Martin <106564991+trmartin4@users.noreply.github.com> Date: Tue, 6 Jan 2026 15:24:03 -0500 Subject: [PATCH] feat(tokens): [BEEEP] Refresh access token on 401 API response * Update to handle 401 to refresh token. * Updated to revert changes to extract token comparison. * Fixed tests * Adjusted tests. * Removed debug logging * Test updates * Added race condition test. * Added clarified logout reason * Fixed typo Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> * Fixed tests * Fixed extra space * Removed extra logout reasons to be introduced later. * Added warning on 401 and retry --------- Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> --- .../src/common/types/logout-reason.type.ts | 14 +- libs/common/src/abstractions/api.service.ts | 7 + libs/common/src/services/api.service.spec.ts | 794 ++++++++++++++++++ libs/common/src/services/api.service.ts | 220 +++-- 4 files changed, 961 insertions(+), 74 deletions(-) diff --git a/libs/auth/src/common/types/logout-reason.type.ts b/libs/auth/src/common/types/logout-reason.type.ts index 71fff51064a..dab19ca9418 100644 --- a/libs/auth/src/common/types/logout-reason.type.ts +++ b/libs/auth/src/common/types/logout-reason.type.ts @@ -1,10 +1,10 @@ export type LogoutReason = - | "invalidGrantError" - | "vaultTimeout" - | "invalidSecurityStamp" - | "logoutNotification" - | "keyConnectorError" - | "sessionExpired" | "accessTokenUnableToBeDecrypted" + | "accountDeleted" + | "invalidAccessToken" + | "invalidSecurityStamp" + | "keyConnectorError" + | "logoutNotification" | "refreshTokenSecureStorageRetrievalFailure" - | "accountDeleted"; + | "sessionExpired" + | "vaultTimeout"; diff --git a/libs/common/src/abstractions/api.service.ts b/libs/common/src/abstractions/api.service.ts index 72a17f0fa87..7e4ff031ef2 100644 --- a/libs/common/src/abstractions/api.service.ts +++ b/libs/common/src/abstractions/api.service.ts @@ -446,6 +446,13 @@ export abstract class ApiService { abstract postBitPayInvoice(request: BitPayInvoiceRequest): Promise; abstract postSetupPayment(): Promise; + /** + * Retrieves the bearer access token for the user. + * If the access token is expired or within 5 minutes of expiration, attempts to refresh the token + * and persists the refresh token to state before returning it. + * @param userId The user for whom we're retrieving the access token + * @returns The access token, or an Error if no access token exists. + */ abstract getActiveBearerToken(userId: UserId): Promise; abstract fetch(request: Request): Promise; abstract nativeFetch(request: Request): Promise; diff --git a/libs/common/src/services/api.service.spec.ts b/libs/common/src/services/api.service.spec.ts index 9ab84ecb16b..faed9ff77a7 100644 --- a/libs/common/src/services/api.service.spec.ts +++ b/libs/common/src/services/api.service.spec.ts @@ -449,4 +449,798 @@ describe("ApiService", () => { ).rejects.toThrow(InsecureUrlNotAllowedError); expect(nativeFetch).not.toHaveBeenCalled(); }); + + describe("When a 401 Unauthorized status is received", () => { + it("retries request with refreshed token when initial request with access token returns 401", async () => { + // This test verifies the 401 retry flow: + // 1. Initial request with valid token returns 401 (token expired server-side) + // 2. After 401, buildRequest is called again, which checks tokenNeedsRefresh + // 3. tokenNeedsRefresh returns true, triggering refreshToken via getActiveBearerToken + // 4. refreshToken makes an HTTP call to /connect/token to get new tokens + // 5. setTokens is called to store the new tokens, returning the refreshed access token + // 6. Request is retried with the refreshed token and succeeds + environmentService.getEnvironment$.calledWith(testActiveUser).mockReturnValue( + of({ + getApiUrl: () => "https://example.com", + getIdentityUrl: () => "https://identity.example.com", + } satisfies Partial as Environment), + ); + + httpOperations.createRequest.mockImplementation((url, request) => { + return { + url: url, + cache: request.cache, + credentials: request.credentials, + method: request.method, + mode: request.mode, + signal: request.signal, + headers: new Headers(request.headers), + } satisfies Partial as unknown as Request; + }); + tokenService.getAccessToken.calledWith(testActiveUser).mockResolvedValue("access_token"); + // First call (initial request): token doesn't need refresh yet + // Subsequent calls (after 401): token needs refresh, triggering the refresh flow + tokenService.tokenNeedsRefresh + .calledWith(testActiveUser) + .mockResolvedValueOnce(false) + .mockResolvedValue(true); + + tokenService.getRefreshToken.calledWith(testActiveUser).mockResolvedValue("refresh_token"); + + tokenService.decodeAccessToken + .calledWith(testActiveUser) + .mockResolvedValue({ client_id: "web" }); + + tokenService.decodeAccessToken + .calledWith("new_access_token") + .mockResolvedValue({ sub: testActiveUser }); + + vaultTimeoutSettingsService.getVaultTimeoutActionByUserId$ + .calledWith(testActiveUser) + .mockReturnValue(of(VaultTimeoutAction.Lock)); + + vaultTimeoutSettingsService.getVaultTimeoutByUserId$ + .calledWith(testActiveUser) + .mockReturnValue(of(VaultTimeoutStringType.Never)); + + tokenService.setTokens + .calledWith( + "new_access_token", + VaultTimeoutAction.Lock, + VaultTimeoutStringType.Never, + "new_refresh_token", + ) + .mockResolvedValue({ accessToken: "new_access_token" }); + + const nativeFetch = jest.fn, [request: Request]>(); + let callCount = 0; + + nativeFetch.mockImplementation((request) => { + callCount++; + + // First call: initial request with expired token returns 401 + if (callCount === 1) { + return Promise.resolve({ + ok: false, + status: 401, + json: () => Promise.resolve({ message: "Unauthorized" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + } + + // Second call: token refresh request + if (callCount === 2 && request.url.includes("identity")) { + return Promise.resolve({ + ok: true, + status: 200, + json: () => + Promise.resolve({ + access_token: "new_access_token", + token_type: "Bearer", + refresh_token: "new_refresh_token", + }), + } satisfies Partial as Response); + } + + // Third call: retry with refreshed token succeeds + if (callCount === 3) { + expect(request.headers.get("Authorization")).toBe("Bearer new_access_token"); + return Promise.resolve({ + ok: true, + status: 200, + json: () => Promise.resolve({ data: "success" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + } + + throw new Error(`Unexpected call #${callCount}: ${request.method} ${request.url}`); + }); + + sut.nativeFetch = nativeFetch; + + const response = await sut.send("GET", "/something", null, true, true, null, null); + + expect(nativeFetch).toHaveBeenCalledTimes(3); + expect(response).toEqual({ data: "success" }); + }); + + it("does not retry when request has no access token and returns 401", async () => { + environmentService.environment$ = of({ + getApiUrl: () => "https://example.com", + } satisfies Partial as Environment); + + httpOperations.createRequest.mockImplementation((url, request) => { + return { + url: url, + cache: request.cache, + credentials: request.credentials, + method: request.method, + mode: request.mode, + signal: request.signal, + headers: new Headers(request.headers), + } satisfies Partial as unknown as Request; + }); + + const nativeFetch = jest.fn, [request: Request]>(); + + nativeFetch.mockImplementation((request) => { + return Promise.resolve({ + ok: false, + status: 401, + json: () => Promise.resolve({ message: "Unauthorized" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + }); + + sut.nativeFetch = nativeFetch; + + await expect( + async () => await sut.send("GET", "/something", null, false, true, null, null), + ).rejects.toMatchObject({ message: "Unauthorized" }); + + // Should only be called once (no retry) + expect(nativeFetch).toHaveBeenCalledTimes(1); + }); + + it("does not retry when request returns non-401 error", async () => { + environmentService.getEnvironment$.calledWith(testActiveUser).mockReturnValue( + of({ + getApiUrl: () => "https://example.com", + getIdentityUrl: () => "https://identity.example.com", + } satisfies Partial as Environment), + ); + + httpOperations.createRequest.mockImplementation((url, request) => { + return { + url: url, + cache: request.cache, + credentials: request.credentials, + method: request.method, + mode: request.mode, + signal: request.signal, + headers: new Headers(request.headers), + } satisfies Partial as unknown as Request; + }); + + tokenService.getAccessToken.calledWith(testActiveUser).mockResolvedValue("valid_token"); + tokenService.tokenNeedsRefresh.calledWith(testActiveUser).mockResolvedValue(false); + + const nativeFetch = jest.fn, [request: Request]>(); + + nativeFetch.mockImplementation((request) => { + return Promise.resolve({ + ok: false, + status: 400, + json: () => Promise.resolve({ message: "Bad Request" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + }); + + sut.nativeFetch = nativeFetch; + + await expect( + async () => await sut.send("GET", "/something", null, true, true, null, null), + ).rejects.toMatchObject({ message: "Bad Request" }); + + // Should only be called once (no retry for non-401 errors) + expect(nativeFetch).toHaveBeenCalledTimes(1); + }); + + it("does not attempt to log out unauthenticated user", async () => { + environmentService.environment$ = of({ + getApiUrl: () => "https://example.com", + } satisfies Partial as Environment); + + httpOperations.createRequest.mockImplementation((url, request) => { + return { + url: url, + cache: request.cache, + credentials: request.credentials, + method: request.method, + mode: request.mode, + signal: request.signal, + headers: new Headers(request.headers), + } satisfies Partial as unknown as Request; + }); + + const nativeFetch = jest.fn, [request: Request]>(); + + nativeFetch.mockImplementation((request) => { + return Promise.resolve({ + ok: false, + status: 401, + json: () => Promise.resolve({ message: "Unauthorized" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + }); + + sut.nativeFetch = nativeFetch; + + await expect( + async () => await sut.send("GET", "/something", null, false, true, null, null), + ).rejects.toMatchObject({ message: "Unauthorized" }); + + expect(logoutCallback).not.toHaveBeenCalled(); + }); + + it("does not retry when hasResponse is false", async () => { + environmentService.environment$ = of({ + getApiUrl: () => "https://example.com", + } satisfies Partial as Environment); + + environmentService.getEnvironment$.calledWith(testActiveUser).mockReturnValue( + of({ + getApiUrl: () => "https://example.com", + getIdentityUrl: () => "https://identity.example.com", + } satisfies Partial as Environment), + ); + + httpOperations.createRequest.mockImplementation((url, request) => { + return { + url: url, + cache: request.cache, + credentials: request.credentials, + method: request.method, + mode: request.mode, + signal: request.signal, + headers: new Headers(request.headers), + } satisfies Partial as unknown as Request; + }); + + tokenService.getAccessToken.calledWith(testActiveUser).mockResolvedValue("expired_token"); + tokenService.tokenNeedsRefresh.calledWith(testActiveUser).mockResolvedValue(false); + + const nativeFetch = jest.fn, [request: Request]>(); + + nativeFetch.mockImplementation((request) => { + return Promise.resolve({ + ok: false, + status: 401, + json: () => Promise.resolve({ message: "Unauthorized" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + }); + + sut.nativeFetch = nativeFetch; + + // When hasResponse is false, the method should throw even though no retry happens + await expect( + async () => await sut.send("POST", "/something", null, true, false, null, null), + ).rejects.toMatchObject({ message: "Unauthorized" }); + + // Should only be called once (no retry when hasResponse is false) + expect(nativeFetch).toHaveBeenCalledTimes(1); + }); + + it("uses original user token for retry even if active user changes between requests", async () => { + // Setup: Initial request is for testActiveUser, but during the retry, the active user switches + // to testInactiveUser. The retry should still use testActiveUser's refreshed token. + + let activeUserId = testActiveUser; + + // Mock accountService to return different active users based on when it's called + accountService.activeAccount$ = of({ + id: activeUserId, + email: "user1@example.com", + emailVerified: true, + name: "Test Name", + } satisfies ObservedValueOf); + + environmentService.getEnvironment$.calledWith(testActiveUser).mockReturnValue( + of({ + getApiUrl: () => "https://example.com", + getIdentityUrl: () => "https://identity.example.com", + } satisfies Partial as Environment), + ); + + environmentService.getEnvironment$.calledWith(testInactiveUser).mockReturnValue( + of({ + getApiUrl: () => "https://inactive.example.com", + getIdentityUrl: () => "https://identity.inactive.example.com", + } satisfies Partial as Environment), + ); + + httpOperations.createRequest.mockImplementation((url, request) => { + return { + url: url, + cache: request.cache, + credentials: request.credentials, + method: request.method, + mode: request.mode, + signal: request.signal, + headers: new Headers(request.headers), + } satisfies Partial as unknown as Request; + }); + + tokenService.getAccessToken + .calledWith(testActiveUser) + .mockResolvedValue("active_access_token"); + tokenService.tokenNeedsRefresh + .calledWith(testActiveUser) + .mockResolvedValueOnce(false) + .mockResolvedValue(true); + + tokenService.getRefreshToken + .calledWith(testActiveUser) + .mockResolvedValue("active_refresh_token"); + + tokenService.decodeAccessToken + .calledWith(testActiveUser) + .mockResolvedValue({ client_id: "web" }); + + tokenService.decodeAccessToken + .calledWith("active_new_access_token") + .mockResolvedValue({ sub: testActiveUser }); + + vaultTimeoutSettingsService.getVaultTimeoutActionByUserId$ + .calledWith(testActiveUser) + .mockReturnValue(of(VaultTimeoutAction.Lock)); + + vaultTimeoutSettingsService.getVaultTimeoutByUserId$ + .calledWith(testActiveUser) + .mockReturnValue(of(VaultTimeoutStringType.Never)); + + tokenService.setTokens + .calledWith( + "active_new_access_token", + VaultTimeoutAction.Lock, + VaultTimeoutStringType.Never, + "active_new_refresh_token", + ) + .mockResolvedValue({ accessToken: "active_new_access_token" }); + + // Mock tokens for inactive user (should NOT be used) + tokenService.getAccessToken + .calledWith(testInactiveUser) + .mockResolvedValue("inactive_access_token"); + + const nativeFetch = jest.fn, [request: Request]>(); + let callCount = 0; + + nativeFetch.mockImplementation((request) => { + callCount++; + + // First call: initial request with active user's token returns 401 + if (callCount === 1) { + expect(request.url).toBe("https://example.com/something"); + expect(request.headers.get("Authorization")).toBe("Bearer active_access_token"); + + // After the 401, simulate active user changing + activeUserId = testInactiveUser; + accountService.activeAccount$ = of({ + id: testInactiveUser, + email: "user2@example.com", + emailVerified: true, + name: "Inactive User", + } satisfies ObservedValueOf); + + return Promise.resolve({ + ok: false, + status: 401, + json: () => Promise.resolve({ message: "Unauthorized" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + } + + // Second call: token refresh request for ORIGINAL user (testActiveUser) + if (callCount === 2 && request.url.includes("identity")) { + expect(request.url).toContain("identity.example.com"); + return Promise.resolve({ + ok: true, + status: 200, + json: () => + Promise.resolve({ + access_token: "active_new_access_token", + token_type: "Bearer", + refresh_token: "active_new_refresh_token", + }), + } satisfies Partial as Response); + } + + // Third call: retry with ORIGINAL user's refreshed token, NOT the new active user's token + if (callCount === 3) { + expect(request.url).toBe("https://example.com/something"); + expect(request.headers.get("Authorization")).toBe("Bearer active_new_access_token"); + // Verify we're NOT using the inactive user's endpoint + expect(request.url).not.toContain("inactive"); + return Promise.resolve({ + ok: true, + status: 200, + json: () => Promise.resolve({ data: "success with original user" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + } + + throw new Error(`Unexpected call #${callCount}: ${request.method} ${request.url}`); + }); + + sut.nativeFetch = nativeFetch; + + // Explicitly pass testActiveUser to ensure the request is for that specific user + const response = await sut.send("GET", "/something", null, testActiveUser, true, null, null); + + expect(nativeFetch).toHaveBeenCalledTimes(3); + expect(response).toEqual({ data: "success with original user" }); + + // Verify that inactive user's token was never requested + expect(tokenService.getAccessToken.calledWith(testInactiveUser)).not.toHaveBeenCalled(); + }); + + it("throws error when retry also returns 401", async () => { + environmentService.getEnvironment$.calledWith(testActiveUser).mockReturnValue( + of({ + getApiUrl: () => "https://example.com", + getIdentityUrl: () => "https://identity.example.com", + } satisfies Partial as Environment), + ); + + httpOperations.createRequest.mockImplementation((url, request) => { + return { + url: url, + cache: request.cache, + credentials: request.credentials, + method: request.method, + mode: request.mode, + signal: request.signal, + headers: new Headers(request.headers), + } satisfies Partial as unknown as Request; + }); + + tokenService.getAccessToken.calledWith(testActiveUser).mockResolvedValue("access_token"); + // First call (initial request): token doesn't need refresh yet + // Subsequent calls (after 401): token needs refresh, triggering the refresh flow + tokenService.tokenNeedsRefresh + .calledWith(testActiveUser) + .mockResolvedValueOnce(false) + .mockResolvedValue(true); + + tokenService.getRefreshToken.calledWith(testActiveUser).mockResolvedValue("refresh_token"); + + tokenService.decodeAccessToken + .calledWith(testActiveUser) + .mockResolvedValue({ client_id: "web" }); + + tokenService.decodeAccessToken + .calledWith("new_access_token") + .mockResolvedValue({ sub: testActiveUser }); + + vaultTimeoutSettingsService.getVaultTimeoutActionByUserId$ + .calledWith(testActiveUser) + .mockReturnValue(of(VaultTimeoutAction.Lock)); + + vaultTimeoutSettingsService.getVaultTimeoutByUserId$ + .calledWith(testActiveUser) + .mockReturnValue(of(VaultTimeoutStringType.Never)); + + tokenService.setTokens + .calledWith( + "new_access_token", + VaultTimeoutAction.Lock, + VaultTimeoutStringType.Never, + "new_refresh_token", + ) + .mockResolvedValue({ accessToken: "new_access_token" }); + + const nativeFetch = jest.fn, [request: Request]>(); + let callCount = 0; + + nativeFetch.mockImplementation((request) => { + callCount++; + + // First call: initial request with expired token returns 401 + if (callCount === 1) { + return Promise.resolve({ + ok: false, + status: 401, + json: () => Promise.resolve({ message: "Unauthorized" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + } + + // Second call: token refresh request + if (callCount === 2 && request.url.includes("identity")) { + return Promise.resolve({ + ok: true, + status: 200, + json: () => + Promise.resolve({ + access_token: "new_access_token", + token_type: "Bearer", + refresh_token: "new_refresh_token", + }), + } satisfies Partial as Response); + } + + // Third call: retry with refreshed token still returns 401 (user no longer has permission) + if (callCount === 3) { + return Promise.resolve({ + ok: false, + status: 401, + json: () => Promise.resolve({ message: "Still Unauthorized" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + } + + throw new Error("Unexpected call"); + }); + + sut.nativeFetch = nativeFetch; + + await expect( + async () => await sut.send("GET", "/something", null, true, true, null, null), + ).rejects.toMatchObject({ message: "Still Unauthorized" }); + + expect(nativeFetch).toHaveBeenCalledTimes(3); + expect(logoutCallback).toHaveBeenCalledWith("invalidAccessToken"); + }); + + it("handles concurrent requests that both receive 401 and share token refresh", async () => { + // This test verifies the race condition scenario: + // 1. Request A starts with valid token + // 2. Request B starts with valid token + // 3. Request A gets 401, triggers refresh + // 4. Request B gets 401 while A is refreshing + // 5. Request B should wait for A's refresh to complete (via refreshTokenPromise cache) + // 6. Both requests retry with the new token + + environmentService.getEnvironment$.calledWith(testActiveUser).mockReturnValue( + of({ + getApiUrl: () => "https://example.com", + getIdentityUrl: () => "https://identity.example.com", + } satisfies Partial as Environment), + ); + + httpOperations.createRequest.mockImplementation((url, request) => { + return { + url: url, + cache: request.cache, + credentials: request.credentials, + method: request.method, + mode: request.mode, + signal: request.signal, + headers: new Headers(request.headers), + } satisfies Partial as unknown as Request; + }); + + tokenService.getAccessToken.calledWith(testActiveUser).mockResolvedValue("expired_token"); + + // First two calls: token doesn't need refresh yet + // Subsequent calls: token needs refresh + tokenService.tokenNeedsRefresh + .calledWith(testActiveUser) + .mockResolvedValueOnce(false) // Request A initial + .mockResolvedValueOnce(false) // Request B initial + .mockResolvedValue(true); // Both retries after 401 + + tokenService.getRefreshToken.calledWith(testActiveUser).mockResolvedValue("refresh_token"); + + tokenService.decodeAccessToken + .calledWith(testActiveUser) + .mockResolvedValue({ client_id: "web" }); + + tokenService.decodeAccessToken + .calledWith("new_access_token") + .mockResolvedValue({ sub: testActiveUser }); + + vaultTimeoutSettingsService.getVaultTimeoutActionByUserId$ + .calledWith(testActiveUser) + .mockReturnValue(of(VaultTimeoutAction.Lock)); + + vaultTimeoutSettingsService.getVaultTimeoutByUserId$ + .calledWith(testActiveUser) + .mockReturnValue(of(VaultTimeoutStringType.Never)); + + tokenService.setTokens + .calledWith( + "new_access_token", + VaultTimeoutAction.Lock, + VaultTimeoutStringType.Never, + "new_refresh_token", + ) + .mockResolvedValue({ accessToken: "new_access_token" }); + + const nativeFetch = jest.fn, [request: Request]>(); + let apiRequestCount = 0; + let refreshRequestCount = 0; + + nativeFetch.mockImplementation((request) => { + if (request.url.includes("identity")) { + refreshRequestCount++; + // Simulate slow token refresh to expose race condition + return new Promise((resolve) => + setTimeout( + () => + resolve({ + ok: true, + status: 200, + json: () => + Promise.resolve({ + access_token: "new_access_token", + token_type: "Bearer", + refresh_token: "new_refresh_token", + }), + } satisfies Partial as Response), + 100, + ), + ); + } + + apiRequestCount++; + const currentCall = apiRequestCount; + + // First two calls (Request A and B initial attempts): both return 401 + if (currentCall === 1 || currentCall === 2) { + return Promise.resolve({ + ok: false, + status: 401, + json: () => Promise.resolve({ message: "Unauthorized" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + } + + // Third and fourth calls (retries after refresh): both succeed + if (currentCall === 3 || currentCall === 4) { + expect(request.headers.get("Authorization")).toBe("Bearer new_access_token"); + return Promise.resolve({ + ok: true, + status: 200, + json: () => Promise.resolve({ data: `success-${currentCall}` }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + } + + throw new Error(`Unexpected API call #${currentCall}: ${request.method} ${request.url}`); + }); + + sut.nativeFetch = nativeFetch; + + // Make two concurrent requests + const [responseA, responseB] = await Promise.all([ + sut.send("GET", "/endpoint-a", null, testActiveUser, true, null, null), + sut.send("GET", "/endpoint-b", null, testActiveUser, true, null, null), + ]); + + // Both requests should succeed + expect(responseA).toMatchObject({ data: expect.stringContaining("success") }); + expect(responseB).toMatchObject({ data: expect.stringContaining("success") }); + + // Verify only ONE token refresh was made (they shared the refresh) + expect(refreshRequestCount).toBe(1); + + // Verify the total number of API requests: 2 initial + 2 retries = 4 + expect(apiRequestCount).toBe(4); + + // Verify setTokens was only called once + expect(tokenService.setTokens).toHaveBeenCalledTimes(1); + }); + }); + + describe("When 403 Forbidden response is received from API request", () => { + it("logs out the authenticated user", async () => { + environmentService.getEnvironment$.calledWith(testActiveUser).mockReturnValue( + of({ + getApiUrl: () => "https://example.com", + } satisfies Partial as Environment), + ); + + httpOperations.createRequest.mockImplementation((url, request) => { + return { + url: url, + cache: request.cache, + credentials: request.credentials, + method: request.method, + mode: request.mode, + signal: request.signal, + headers: new Headers(request.headers), + } satisfies Partial as unknown as Request; + }); + + tokenService.getAccessToken.calledWith(testActiveUser).mockResolvedValue("valid_token"); + tokenService.tokenNeedsRefresh.calledWith(testActiveUser).mockResolvedValue(false); + + const nativeFetch = jest.fn, [request: Request]>(); + + nativeFetch.mockImplementation((request) => { + return Promise.resolve({ + ok: false, + status: 403, + json: () => Promise.resolve({ message: "Forbidden" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + }); + + sut.nativeFetch = nativeFetch; + + await expect( + async () => await sut.send("GET", "/something", null, true, true, null, null), + ).rejects.toMatchObject({ message: "Forbidden" }); + + expect(logoutCallback).toHaveBeenCalledWith("invalidAccessToken"); + }); + + it("does not attempt to log out unauthenticated user", async () => { + environmentService.environment$ = of({ + getApiUrl: () => "https://example.com", + } satisfies Partial as Environment); + + httpOperations.createRequest.mockImplementation((url, request) => { + return { + url: url, + cache: request.cache, + credentials: request.credentials, + method: request.method, + mode: request.mode, + signal: request.signal, + headers: new Headers(request.headers), + } satisfies Partial as unknown as Request; + }); + + const nativeFetch = jest.fn, [request: Request]>(); + + nativeFetch.mockImplementation((request) => { + return Promise.resolve({ + ok: false, + status: 403, + json: () => Promise.resolve({ message: "Forbidden" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + }); + + sut.nativeFetch = nativeFetch; + + await expect( + async () => await sut.send("GET", "/something", null, false, true, null, null), + ).rejects.toMatchObject({ message: "Forbidden" }); + + expect(logoutCallback).not.toHaveBeenCalled(); + }); + }); }); diff --git a/libs/common/src/services/api.service.ts b/libs/common/src/services/api.service.ts index c60f6c5e907..10f349fbec7 100644 --- a/libs/common/src/services/api.service.ts +++ b/libs/common/src/services/api.service.ts @@ -74,7 +74,7 @@ import { BillingHistoryResponse } from "../billing/models/response/billing-histo import { PaymentResponse } from "../billing/models/response/payment.response"; import { PlanResponse } from "../billing/models/response/plan.response"; import { SubscriptionResponse } from "../billing/models/response/subscription.response"; -import { ClientType, DeviceType } from "../enums"; +import { ClientType, DeviceType, HttpStatusCode } from "../enums"; import { KeyConnectorUserKeyRequest } from "../key-management/key-connector/models/key-connector-user-key.request"; import { SetKeyConnectorKeyRequest } from "../key-management/key-connector/models/set-key-connector-key.request"; import { VaultTimeoutSettingsService } from "../key-management/vault-timeout"; @@ -1252,8 +1252,8 @@ export class ApiService implements ApiServiceAbstraction { }), ); - if (response.status !== 200) { - const error = await this.handleError(response, false, true); + if (response.status !== HttpStatusCode.Ok) { + const error = await this.handleApiRequestError(response, true); return Promise.reject(error); } @@ -1283,8 +1283,8 @@ export class ApiService implements ApiServiceAbstraction { }), ); - if (response.status !== 200) { - const error = await this.handleError(response, false, true); + if (response.status !== HttpStatusCode.Ok) { + const error = await this.handleApiRequestError(response, true); return Promise.reject(error); } } @@ -1301,14 +1301,12 @@ export class ApiService implements ApiServiceAbstraction { }), ); - if (response.status !== 200) { - const error = await this.handleError(response, false, true); + if (response.status !== HttpStatusCode.Ok) { + const error = await this.handleApiRequestError(response, true); return Promise.reject(error); } } - // Helpers - async getActiveBearerToken(userId: UserId): Promise { let accessToken = await this.tokenService.getAccessToken(userId); if (await this.tokenService.tokenNeedsRefresh(userId)) { @@ -1370,7 +1368,7 @@ export class ApiService implements ApiServiceAbstraction { const body = await response.json(); return new SsoPreValidateResponse(body); } else { - const error = await this.handleError(response, false, true); + const error = await this.handleApiRequestError(response, false); return Promise.reject(error); } } @@ -1525,7 +1523,7 @@ export class ApiService implements ApiServiceAbstraction { ); return refreshedTokens.accessToken; } else { - const error = await this.handleError(response, true, true); + const error = await this.handleTokenRefreshRequestError(response); return Promise.reject(error); } } @@ -1580,6 +1578,89 @@ export class ApiService implements ApiServiceAbstraction { apiUrl?: string | null, alterHeaders?: (headers: Headers) => void, ): Promise { + // We assume that if there is a UserId making the request, it is also an authenticated + // request and we will attempt to add an access token to the request. + const userIdMakingRequest = await this.getUserIdMakingRequest(authedOrUserId); + + const environment = await firstValueFrom( + userIdMakingRequest == null + ? this.environmentService.environment$ + : this.environmentService.getEnvironment$(userIdMakingRequest), + ); + apiUrl = Utils.isNullOrWhitespace(apiUrl) ? environment.getApiUrl() : apiUrl; + + const requestUrl = await this.buildSafeApiRequestUrl(apiUrl, path); + + let request = await this.buildRequest( + method, + userIdMakingRequest, + environment, + hasResponse, + body, + alterHeaders, + ); + + let response = await this.fetch(this.httpOperations.createRequest(requestUrl, request)); + + // First, check to see if we were making an authenticated request and received an Unauthorized (401) + // response. This could mean that we attempted to make a request with an expired access token. + // If so, attempt to refresh the token and try again. + if ( + hasResponse && + userIdMakingRequest != null && + response.status === HttpStatusCode.Unauthorized + ) { + this.logService.warning( + "Unauthorized response received for request to " + path + ". Attempting request again.", + ); + request = await this.buildRequest( + method, + userIdMakingRequest, + environment, + hasResponse, + body, + alterHeaders, + ); + response = await this.fetch(this.httpOperations.createRequest(requestUrl, request)); + } + + // At this point we are processing either the initial response or the response for the retry with the refreshed + // access token. + const responseType = response.headers.get("content-type"); + const responseIsJson = responseType != null && responseType.indexOf("application/json") !== -1; + const responseIsCsv = responseType != null && responseType.indexOf("text/csv") !== -1; + if (hasResponse && response.status === HttpStatusCode.Ok && responseIsJson) { + const responseJson = await response.json(); + return responseJson; + } else if (hasResponse && response.status === HttpStatusCode.Ok && responseIsCsv) { + return await response.text(); + } else if ( + response.status !== HttpStatusCode.Ok && + response.status !== HttpStatusCode.NoContent + ) { + const error = await this.handleApiRequestError(response, userIdMakingRequest != null); + return Promise.reject(error); + } + } + + private buildSafeApiRequestUrl(apiUrl: string, path: string): string { + const pathParts = path.split("?"); + + // Check for path traversal patterns from any URL. + const fullUrlPath = apiUrl + pathParts[0] + (pathParts.length > 1 ? `?${pathParts[1]}` : ""); + + const isInvalidUrl = Utils.invalidUrlPatterns(fullUrlPath); + if (isInvalidUrl) { + throw new Error("The request URL contains dangerous patterns."); + } + + const requestUrl = + apiUrl + Utils.normalizePath(pathParts[0]) + (pathParts.length > 1 ? `?${pathParts[1]}` : ""); + + return requestUrl; + } + + private async getUserIdMakingRequest(authedOrUserId: UserId | boolean): Promise { if (authedOrUserId == null) { throw new Error("A user id was given but it was null, cannot complete API request."); } @@ -1591,29 +1672,19 @@ export class ApiService implements ApiServiceAbstraction { } else if (typeof authedOrUserId === "string") { userId = authedOrUserId; } + return userId; + } - const env = await firstValueFrom( - userId == null - ? this.environmentService.environment$ - : this.environmentService.getEnvironment$(userId), - ); - apiUrl = Utils.isNullOrWhitespace(apiUrl) ? env.getApiUrl() : apiUrl; - - const pathParts = path.split("?"); - // Check for path traversal patterns from any URL. - const fullUrlPath = apiUrl + pathParts[0] + (pathParts.length > 1 ? `?${pathParts[1]}` : ""); - - const isInvalidUrl = Utils.invalidUrlPatterns(fullUrlPath); - if (isInvalidUrl) { - throw new Error("The request URL contains dangerous patterns."); - } - - // Prevent directory traversal from malicious paths - const requestUrl = - apiUrl + Utils.normalizePath(pathParts[0]) + (pathParts.length > 1 ? `?${pathParts[1]}` : ""); - + private async buildRequest( + method: "GET" | "POST" | "PUT" | "DELETE" | "PATCH", + userForAccessToken: UserId | null, + environment: Environment, + hasResponse: boolean, + body: string, + alterHeaders?: (headers: Headers) => void, + ): Promise { const [requestHeaders, requestBody] = await this.buildHeadersAndBody( - userId, + userForAccessToken, hasResponse, body, alterHeaders, @@ -1621,29 +1692,17 @@ export class ApiService implements ApiServiceAbstraction { const requestInit: RequestInit = { cache: "no-store", - credentials: await this.getCredentials(env), + credentials: await this.getCredentials(environment), method: method, }; requestInit.headers = requestHeaders; requestInit.body = requestBody; - const response = await this.fetch(this.httpOperations.createRequest(requestUrl, requestInit)); - const responseType = response.headers.get("content-type"); - const responseIsJson = responseType != null && responseType.indexOf("application/json") !== -1; - const responseIsCsv = responseType != null && responseType.indexOf("text/csv") !== -1; - if (hasResponse && response.status === 200 && responseIsJson) { - const responseJson = await response.json(); - return responseJson; - } else if (hasResponse && response.status === 200 && responseIsCsv) { - return await response.text(); - } else if (response.status !== 200 && response.status !== 204) { - const error = await this.handleError(response, false, userId != null); - return Promise.reject(error); - } + return requestInit; } private async buildHeadersAndBody( - userToAuthenticate: UserId | null, + userForAccessToken: UserId | null, hasResponse: boolean, body: any, alterHeaders: (headers: Headers) => void, @@ -1665,8 +1724,8 @@ export class ApiService implements ApiServiceAbstraction { if (alterHeaders != null) { alterHeaders(headers); } - if (userToAuthenticate != null) { - const authHeader = await this.getActiveBearerToken(userToAuthenticate); + if (userForAccessToken != null) { + const authHeader = await this.getActiveBearerToken(userForAccessToken); headers.set("Authorization", "Bearer " + authHeader); } else { // For unauthenticated requests, we need to tell the server what the device is for flag targeting, @@ -1692,32 +1751,59 @@ export class ApiService implements ApiServiceAbstraction { return [headers, requestBody]; } - private async handleError( + /** + * Handle an error response from a request to the Bitwarden API. + * If the request is made with an access token (aka the user is authenticated), + * and we receive a 401 or 403 response, we will log the user out, as this indicates + * that the access token used on the request is either expired or does not have the appropriate permissions. + * It is unlikely that it is expired, as we attempt to refresh the token on initial failure. + * @param response The response from the API request + * @param userIsAuthenticated A boolean indicating whether this is an authenticated request. + * @returns An ErrorResponse with a message based on the response status. + */ + private async handleApiRequestError( response: Response, - tokenError: boolean, - authed: boolean, + userIsAuthenticated: boolean, ): Promise { + if ( + userIsAuthenticated && + (response.status === HttpStatusCode.Unauthorized || + response.status === HttpStatusCode.Forbidden) + ) { + await this.logoutCallback("invalidAccessToken"); + } + + const responseJson = await this.getJsonResponse(response); + return new ErrorResponse(responseJson, response.status); + } + + /** + * Handle an error response when trying to refresh an access token. + * If the error indicates that the user's session has expired, it will log the user out. + * @param response The response from the token refresh request. + * @returns An ErrorResponse with a message based on the response status. + */ + private async handleTokenRefreshRequestError(response: Response): Promise { + const responseJson = await this.getJsonResponse(response); + + // IdentityServer will return an invalid_grant response if the refresh token has expired. + // This means that the user's session has expired, and they need to log out. + // We issue the logoutCallback() to log the user out through messaging. + if (response.status === HttpStatusCode.BadRequest && responseJson?.error === "invalid_grant") { + await this.logoutCallback("sessionExpired"); + } + + return new ErrorResponse(responseJson, response.status, true); + } + + private async getJsonResponse(response: Response): Promise { let responseJson: any = null; if (this.isJsonResponse(response)) { responseJson = await response.json(); } else if (this.isTextPlainResponse(response)) { responseJson = { Message: await response.text() }; } - - if (authed) { - if ( - response.status === 401 || - response.status === 403 || - (tokenError && - response.status === 400 && - responseJson != null && - responseJson.error === "invalid_grant") - ) { - await this.logoutCallback("invalidGrantError"); - } - } - - return new ErrorResponse(responseJson, response.status, tokenError); + return responseJson; } private qsStringify(params: any): string {