diff --git a/libs/common/src/services/api.service.spec.ts b/libs/common/src/services/api.service.spec.ts index 1fb8f86697f..d8a3d350da4 100644 --- a/libs/common/src/services/api.service.spec.ts +++ b/libs/common/src/services/api.service.spec.ts @@ -447,4 +447,464 @@ describe("ApiService", () => { ).rejects.toThrow(InsecureUrlNotAllowedError); expect(nativeFetch).not.toHaveBeenCalled(); }); + + it("retries request with refreshed token when initial request with access token returns 401", async () => { + environmentService.getEnvironment$.calledWith(testActiveUser).mockReturnValue( + of({ + getApiUrl: () => "https://example.com", + getIdentityUrl: () => "https://identity.example.com", + } satisfies Partial as Environment), + ); + + httpOperations.createRequest.mockImplementation((url, request) => { + return { + url: url, + cache: request.cache, + credentials: request.credentials, + method: request.method, + mode: request.mode, + signal: request.signal, + headers: new Headers(request.headers), + } satisfies Partial as unknown as Request; + }); + tokenService.getAccessToken.calledWith(testActiveUser).mockResolvedValue("expired_token"); + tokenService.tokenNeedsRefresh.calledWith(testActiveUser).mockResolvedValue(false); + + tokenService.getRefreshToken + .calledWith(testActiveUser) + .mockResolvedValue("valid_refresh_token"); + + tokenService.decodeAccessToken + .calledWith(testActiveUser) + .mockResolvedValue({ client_id: "web" }); + + tokenService.decodeAccessToken + .calledWith("new_access_token") + .mockResolvedValue({ sub: testActiveUser }); + + vaultTimeoutSettingsService.getVaultTimeoutActionByUserId$ + .calledWith(testActiveUser) + .mockReturnValue(of(VaultTimeoutAction.Lock)); + + vaultTimeoutSettingsService.getVaultTimeoutByUserId$ + .calledWith(testActiveUser) + .mockReturnValue(of(VaultTimeoutStringType.Never)); + + tokenService.setTokens + .calledWith( + "new_access_token", + VaultTimeoutAction.Lock, + VaultTimeoutStringType.Never, + "new_refresh_token", + ) + .mockResolvedValue({ accessToken: "refreshed_access_token" }); + + const nativeFetch = jest.fn, [request: Request]>(); + let callCount = 0; + + nativeFetch.mockImplementation((request) => { + callCount++; + + // First call: initial request with expired token returns 401 + if (callCount === 1) { + return Promise.resolve({ + ok: false, + status: 401, + json: () => Promise.resolve({ message: "Unauthorized" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + } + + // Second call: token refresh request + if (callCount === 2 && request.url.includes("identity")) { + return Promise.resolve({ + ok: true, + status: 200, + json: () => + Promise.resolve({ + access_token: "new_access_token", + token_type: "Bearer", + refresh_token: "new_refresh_token", + }), + } satisfies Partial as Response); + } + + // Third call: retry with refreshed token succeeds + if (callCount === 3) { + expect(request.headers.get("Authorization")).toBe("Bearer refreshed_access_token"); + return Promise.resolve({ + ok: true, + status: 200, + json: () => Promise.resolve({ data: "success" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + } + + throw new Error("Unexpected call"); + }); + + sut.nativeFetch = nativeFetch; + + const response = await sut.send("GET", "/something", null, true, true, null, null); + + expect(nativeFetch).toHaveBeenCalledTimes(3); + expect(response).toEqual({ data: "success" }); + }); + + it("does not retry when request has no access token and returns 401", async () => { + environmentService.environment$ = of({ + getApiUrl: () => "https://example.com", + } satisfies Partial as Environment); + + httpOperations.createRequest.mockImplementation((url, request) => { + return { + url: url, + cache: request.cache, + credentials: request.credentials, + method: request.method, + mode: request.mode, + signal: request.signal, + headers: new Headers(request.headers), + } satisfies Partial as unknown as Request; + }); + + const nativeFetch = jest.fn, [request: Request]>(); + + nativeFetch.mockImplementation((request) => { + return Promise.resolve({ + ok: false, + status: 401, + json: () => Promise.resolve({ message: "Unauthorized" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + }); + + sut.nativeFetch = nativeFetch; + + await expect( + async () => await sut.send("GET", "/something", null, false, true, null, null), + ).rejects.toMatchObject({ message: "Unauthorized" }); + + // Should only be called once (no retry) + expect(nativeFetch).toHaveBeenCalledTimes(1); + }); + + it("does not retry when request returns non-401 error", async () => { + environmentService.getEnvironment$.calledWith(testActiveUser).mockReturnValue( + of({ + getApiUrl: () => "https://example.com", + getIdentityUrl: () => "https://identity.example.com", + } satisfies Partial as Environment), + ); + + httpOperations.createRequest.mockImplementation((url, request) => { + return { + url: url, + cache: request.cache, + credentials: request.credentials, + method: request.method, + mode: request.mode, + signal: request.signal, + headers: new Headers(request.headers), + } satisfies Partial as unknown as Request; + }); + + tokenService.getAccessToken.calledWith(testActiveUser).mockResolvedValue("valid_token"); + tokenService.tokenNeedsRefresh.calledWith(testActiveUser).mockResolvedValue(false); + + const nativeFetch = jest.fn, [request: Request]>(); + + nativeFetch.mockImplementation((request) => { + return Promise.resolve({ + ok: false, + status: 400, + json: () => Promise.resolve({ message: "Bad Request" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + }); + + sut.nativeFetch = nativeFetch; + + await expect( + async () => await sut.send("GET", "/something", null, true, true, null, null), + ).rejects.toMatchObject({ message: "Bad Request" }); + + // Should only be called once (no retry for non-401 errors) + expect(nativeFetch).toHaveBeenCalledTimes(1); + }); + + it("does not retry when hasResponse is false", async () => { + environmentService.environment$ = of({ + getApiUrl: () => "https://example.com", + } satisfies Partial as Environment); + + environmentService.getEnvironment$.calledWith(testActiveUser).mockReturnValue( + of({ + getApiUrl: () => "https://example.com", + getIdentityUrl: () => "https://identity.example.com", + } satisfies Partial as Environment), + ); + + httpOperations.createRequest.mockImplementation((url, request) => { + return { + url: url, + cache: request.cache, + credentials: request.credentials, + method: request.method, + mode: request.mode, + signal: request.signal, + headers: new Headers(request.headers), + } satisfies Partial as unknown as Request; + }); + + tokenService.getAccessToken.calledWith(testActiveUser).mockResolvedValue("expired_token"); + tokenService.tokenNeedsRefresh.calledWith(testActiveUser).mockResolvedValue(false); + + const nativeFetch = jest.fn, [request: Request]>(); + + nativeFetch.mockImplementation((request) => { + return Promise.resolve({ + ok: false, + status: 401, + json: () => Promise.resolve({ message: "Unauthorized" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + }); + + sut.nativeFetch = nativeFetch; + + // When hasResponse is false, the method should throw even though no retry happens + await expect( + async () => await sut.send("POST", "/something", null, true, false, null, null), + ).rejects.toMatchObject({ message: "Unauthorized" }); + + // Should only be called once (no retry when hasResponse is false) + expect(nativeFetch).toHaveBeenCalledTimes(1); + }); + + it("throws error when retry also returns 401", async () => { + environmentService.getEnvironment$.calledWith(testActiveUser).mockReturnValue( + of({ + getApiUrl: () => "https://example.com", + getIdentityUrl: () => "https://identity.example.com", + } satisfies Partial as Environment), + ); + + httpOperations.createRequest.mockImplementation((url, request) => { + return { + url: url, + cache: request.cache, + credentials: request.credentials, + method: request.method, + mode: request.mode, + signal: request.signal, + headers: new Headers(request.headers), + } satisfies Partial as unknown as Request; + }); + + tokenService.getAccessToken.calledWith(testActiveUser).mockResolvedValue("expired_token"); + tokenService.tokenNeedsRefresh.calledWith(testActiveUser).mockResolvedValue(false); + + tokenService.getRefreshToken + .calledWith(testActiveUser) + .mockResolvedValue("valid_refresh_token"); + + tokenService.decodeAccessToken + .calledWith(testActiveUser) + .mockResolvedValue({ client_id: "web" }); + + tokenService.decodeAccessToken + .calledWith("new_access_token") + .mockResolvedValue({ sub: testActiveUser }); + + vaultTimeoutSettingsService.getVaultTimeoutActionByUserId$ + .calledWith(testActiveUser) + .mockReturnValue(of(VaultTimeoutAction.Lock)); + + vaultTimeoutSettingsService.getVaultTimeoutByUserId$ + .calledWith(testActiveUser) + .mockReturnValue(of(VaultTimeoutStringType.Never)); + + tokenService.setTokens + .calledWith( + "new_access_token", + VaultTimeoutAction.Lock, + VaultTimeoutStringType.Never, + "new_refresh_token", + ) + .mockResolvedValue({ accessToken: "refreshed_access_token" }); + + const nativeFetch = jest.fn, [request: Request]>(); + let callCount = 0; + + nativeFetch.mockImplementation((request) => { + callCount++; + + // First call: initial request with expired token returns 401 + if (callCount === 1) { + return Promise.resolve({ + ok: false, + status: 401, + json: () => Promise.resolve({ message: "Unauthorized" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + } + + // Second call: token refresh request + if (callCount === 2 && request.url.includes("identity")) { + return Promise.resolve({ + ok: true, + status: 200, + json: () => + Promise.resolve({ + access_token: "new_access_token", + token_type: "Bearer", + refresh_token: "new_refresh_token", + }), + } satisfies Partial as Response); + } + + // Third call: retry with refreshed token still returns 401 (user no longer has permission) + if (callCount === 3) { + return Promise.resolve({ + ok: false, + status: 401, + json: () => Promise.resolve({ message: "Still Unauthorized" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + } + + throw new Error("Unexpected call"); + }); + + sut.nativeFetch = nativeFetch; + + await expect( + async () => await sut.send("GET", "/something", null, true, true, null, null), + ).rejects.toMatchObject({ message: "Still Unauthorized" }); + + expect(nativeFetch).toHaveBeenCalledTimes(3); + expect(logoutCallback).toHaveBeenCalledWith("sessionExpired"); + }); + + it("retries with refreshed token for inactive user when 401 received", async () => { + tokenService.getAccessToken + .calledWith(testInactiveUser) + .mockResolvedValue("inactive_expired_token"); + tokenService.tokenNeedsRefresh.calledWith(testInactiveUser).mockResolvedValue(false); + + tokenService.getRefreshToken + .calledWith(testInactiveUser) + .mockResolvedValue("inactive_refresh_token"); + + tokenService.decodeAccessToken + .calledWith(testInactiveUser) + .mockResolvedValue({ client_id: "web" }); + + tokenService.decodeAccessToken + .calledWith("inactive_new_access_token") + .mockResolvedValue({ sub: testInactiveUser }); + + vaultTimeoutSettingsService.getVaultTimeoutActionByUserId$ + .calledWith(testInactiveUser) + .mockReturnValue(of(VaultTimeoutAction.Lock)); + + vaultTimeoutSettingsService.getVaultTimeoutByUserId$ + .calledWith(testInactiveUser) + .mockReturnValue(of(VaultTimeoutStringType.Never)); + + tokenService.setTokens + .calledWith( + "inactive_new_access_token", + VaultTimeoutAction.Lock, + VaultTimeoutStringType.Never, + "inactive_new_refresh_token", + ) + .mockResolvedValue({ accessToken: "inactive_refreshed_access_token" }); + + environmentService.getEnvironment$.calledWith(testInactiveUser).mockReturnValue( + of({ + getApiUrl: () => "https://inactive.example.com", + getIdentityUrl: () => "https://identity.inactive.example.com", + } satisfies Partial as Environment), + ); + + httpOperations.createRequest.mockImplementation((url, request) => { + return { + url: url, + cache: request.cache, + credentials: request.credentials, + method: request.method, + mode: request.mode, + signal: request.signal, + headers: new Headers(request.headers), + } satisfies Partial as unknown as Request; + }); + + const nativeFetch = jest.fn, [request: Request]>(); + let callCount = 0; + + nativeFetch.mockImplementation((request) => { + callCount++; + + if (callCount === 1) { + return Promise.resolve({ + ok: false, + status: 401, + json: () => Promise.resolve({ message: "Unauthorized" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + } + + if (callCount === 2 && request.url.includes("identity")) { + return Promise.resolve({ + ok: true, + status: 200, + json: () => + Promise.resolve({ + access_token: "inactive_new_access_token", + token_type: "Bearer", + refresh_token: "inactive_new_refresh_token", + }), + } satisfies Partial as Response); + } + + if (callCount === 3) { + expect(request.headers.get("Authorization")).toBe("Bearer inactive_refreshed_access_token"); + return Promise.resolve({ + ok: true, + status: 200, + json: () => Promise.resolve({ data: "inactive user success" }), + headers: new Headers({ + "content-type": "application/json", + }), + } satisfies Partial as Response); + } + + throw new Error("Unexpected call"); + }); + + sut.nativeFetch = nativeFetch; + + const response = await sut.send("GET", "/something", null, testInactiveUser, true, null, null); + + expect(nativeFetch).toHaveBeenCalledTimes(3); + expect(response).toEqual({ data: "inactive user success" }); + }); }); diff --git a/libs/common/src/services/api.service.ts b/libs/common/src/services/api.service.ts index 8314e44e75f..7b99a685bbf 100644 --- a/libs/common/src/services/api.service.ts +++ b/libs/common/src/services/api.service.ts @@ -73,7 +73,7 @@ import { BillingHistoryResponse } from "../billing/models/response/billing-histo import { PaymentResponse } from "../billing/models/response/payment.response"; import { PlanResponse } from "../billing/models/response/plan.response"; import { SubscriptionResponse } from "../billing/models/response/subscription.response"; -import { ClientType, DeviceType } from "../enums"; +import { ClientType, DeviceType, HttpStatusCode } from "../enums"; import { KeyConnectorUserKeyRequest } from "../key-management/key-connector/models/key-connector-user-key.request"; import { SetKeyConnectorKeyRequest } from "../key-management/key-connector/models/set-key-connector-key.request"; import { VaultTimeoutSettingsService } from "../key-management/vault-timeout"; @@ -1246,8 +1246,8 @@ export class ApiService implements ApiServiceAbstraction { }), ); - if (response.status !== 200) { - const error = await this.handleError(response, false, true); + if (response.status !== HttpStatusCode.Ok) { + const error = await this.handleApiRequestError(response, true); return Promise.reject(error); } @@ -1277,8 +1277,8 @@ export class ApiService implements ApiServiceAbstraction { }), ); - if (response.status !== 200) { - const error = await this.handleError(response, false, true); + if (response.status !== HttpStatusCode.Ok) { + const error = await this.handleApiRequestError(response, true); return Promise.reject(error); } } @@ -1295,15 +1295,22 @@ export class ApiService implements ApiServiceAbstraction { }), ); - if (response.status !== 200) { - const error = await this.handleError(response, false, true); + if (response.status !== HttpStatusCode.Ok) { + const error = await this.handleApiRequestError(response, true); return Promise.reject(error); } } // Helpers - async getActiveBearerToken(userId: UserId): Promise { + /** + * Retrieves the bearer access token for the user, or `null` if no token exists. + * If the access token is expired or within 5 minutes of expiration, attemps to refresh the token + * and persists the refresh token to state before returning it. + * @param userId The user for whom we're retrieving the access token + * @returns The access token, or `null` if none exists for the `userId` provided. + */ + async getActiveBearerToken(userId: UserId): Promise { let accessToken = await this.tokenService.getAccessToken(userId); if (await this.tokenService.tokenNeedsRefresh(userId)) { accessToken = await this.refreshToken(userId); @@ -1359,7 +1366,7 @@ export class ApiService implements ApiServiceAbstraction { const body = await response.json(); return new SsoPreValidateResponse(body); } else { - const error = await this.handleError(response, false, true); + const error = await this.handleApiRequestError(response, false); return Promise.reject(error); } } @@ -1514,7 +1521,7 @@ export class ApiService implements ApiServiceAbstraction { ); return refreshedTokens.accessToken; } else { - const error = await this.handleError(response, true, true); + const error = await this.handleTokenRefreshRequestError(response); return Promise.reject(error); } } @@ -1569,6 +1576,68 @@ export class ApiService implements ApiServiceAbstraction { apiUrl?: string | null, alterHeaders?: (headers: Headers) => void, ): Promise { + const userId = await this.getUserIdForRequest(authedOrUserId); + + const environment = await firstValueFrom( + userId == null + ? this.environmentService.environment$ + : this.environmentService.getEnvironment$(userId), + ); + apiUrl = Utils.isNullOrWhitespace(apiUrl) ? environment.getApiUrl() : apiUrl; + + // Prevent directory traversal from malicious paths + const pathParts = path.split("?"); + const requestUrl = + apiUrl + Utils.normalizePath(pathParts[0]) + (pathParts.length > 1 ? `?${pathParts[1]}` : ""); + + const accessToken = await this.getActiveBearerToken(userId); + + let request = await this.buildRequest( + method, + accessToken, + environment, + hasResponse, + body, + alterHeaders, + ); + let response = await this.fetch(this.httpOperations.createRequest(requestUrl, request)); + + // First, check to see if we were making an authenticated request and received an Unauthorized (401) + // response. This could mean that we attempted to make a request with an expired access token. + // If so, attempt to refresh the token and try again. + if (hasResponse && accessToken != null && response.status === HttpStatusCode.Unauthorized) { + const refreshedToken = await this.refreshAccessToken(userId); + request = await this.buildRequest( + method, + refreshedToken, + environment, + hasResponse, + body, + alterHeaders, + ); + response = await this.fetch(this.httpOperations.createRequest(requestUrl, request)); + } + + // At this point we are processing either the initial response or the response for the retry with the refreshed + // access token. + const responseType = response.headers.get("content-type"); + const responseIsJson = responseType != null && responseType.indexOf("application/json") !== -1; + const responseIsCsv = responseType != null && responseType.indexOf("text/csv") !== -1; + if (hasResponse && response.status === HttpStatusCode.Ok && responseIsJson) { + const responseJson = await response.json(); + return responseJson; + } else if (hasResponse && response.status === HttpStatusCode.Ok && responseIsCsv) { + return await response.text(); + } else if ( + response.status !== HttpStatusCode.Ok && + response.status !== HttpStatusCode.NoContent + ) { + const error = await this.handleApiRequestError(response, userId != null); + return Promise.reject(error); + } + } + + private async getUserIdForRequest(authedOrUserId: UserId | boolean): Promise { if (authedOrUserId == null) { throw new Error("A user id was given but it was null, cannot complete API request."); } @@ -1580,21 +1649,19 @@ export class ApiService implements ApiServiceAbstraction { } else if (typeof authedOrUserId === "string") { userId = authedOrUserId; } + return userId; + } - const env = await firstValueFrom( - userId == null - ? this.environmentService.environment$ - : this.environmentService.getEnvironment$(userId), - ); - apiUrl = Utils.isNullOrWhitespace(apiUrl) ? env.getApiUrl() : apiUrl; - - // Prevent directory traversal from malicious paths - const pathParts = path.split("?"); - const requestUrl = - apiUrl + Utils.normalizePath(pathParts[0]) + (pathParts.length > 1 ? `?${pathParts[1]}` : ""); - + private async buildRequest( + method: "GET" | "POST" | "PUT" | "DELETE" | "PATCH", + accessToken: string | null, + environment: Environment, + hasResponse: boolean, + body: string, + alterHeaders?: (headers: Headers) => void, + ): Promise { const [requestHeaders, requestBody] = await this.buildHeadersAndBody( - userId, + accessToken, hasResponse, body, alterHeaders, @@ -1602,29 +1669,17 @@ export class ApiService implements ApiServiceAbstraction { const requestInit: RequestInit = { cache: "no-store", - credentials: await this.getCredentials(env), + credentials: await this.getCredentials(environment), method: method, }; requestInit.headers = requestHeaders; requestInit.body = requestBody; - const response = await this.fetch(this.httpOperations.createRequest(requestUrl, requestInit)); - const responseType = response.headers.get("content-type"); - const responseIsJson = responseType != null && responseType.indexOf("application/json") !== -1; - const responseIsCsv = responseType != null && responseType.indexOf("text/csv") !== -1; - if (hasResponse && response.status === 200 && responseIsJson) { - const responseJson = await response.json(); - return responseJson; - } else if (hasResponse && response.status === 200 && responseIsCsv) { - return await response.text(); - } else if (response.status !== 200 && response.status !== 204) { - const error = await this.handleError(response, false, userId != null); - return Promise.reject(error); - } + return requestInit; } private async buildHeadersAndBody( - userToAuthenticate: UserId | null, + bearerAuthenticationToken: string | null, hasResponse: boolean, body: any, alterHeaders: (headers: Headers) => void, @@ -1646,9 +1701,8 @@ export class ApiService implements ApiServiceAbstraction { if (alterHeaders != null) { alterHeaders(headers); } - if (userToAuthenticate != null) { - const authHeader = await this.getActiveBearerToken(userToAuthenticate); - headers.set("Authorization", "Bearer " + authHeader); + if (bearerAuthenticationToken != null) { + headers.set("Authorization", "Bearer " + bearerAuthenticationToken); } else { // For unauthenticated requests, we need to tell the server what the device is for flag targeting, // since it won't be able to get it from the access token. @@ -1673,32 +1727,66 @@ export class ApiService implements ApiServiceAbstraction { return [headers, requestBody]; } - private async handleError( + /** + * Handle an error response from a request to the Bitwarden API. + * If the request is made with an access token (aka the user is authenticated), + * and we receive a 401 or 403 response, we will log the user out, as this indicates + * that the access token used on the request is either expired or does not have the appropriate permissions. + * It is unlikely that it is expired, as we attempt to refresh the token on initial failure. + * @param response The response from the API request + * @param userIsAuthenticated A boolean indicating whether this is an authenticated request. + * @returns An ErrorResponse with a message based on the response status. + */ + private async handleApiRequestError( response: Response, - tokenError: boolean, - authed: boolean, + userIsAuthenticated: boolean, ): Promise { + const responseJson = await this.getJsonResponse(response); + + if ( + userIsAuthenticated && + (response.status === HttpStatusCode.Unauthorized || + response.status === HttpStatusCode.Forbidden) + ) { + await this.logoutCallback("sessionExpired"); + } + + return new ErrorResponse(responseJson, response.status); + } + + /** + * Handle an error response when trying to refresh an access token. + * If the error indicates that the user's session has expired, it will log the user out. + * @param response The response from the token refresh request. + * @returns An ErrorResponse with a message based on the response status. + */ + private async handleTokenRefreshRequestError(response: Response): Promise { + const responseJson = await this.getJsonResponse(response); + + // IdentityServer will return an invalid_grant response if the refresh token has expired. + // This means that the user's session has expired, and they need to log out. + // We issue the logoutCallback() to log the user out through messaging. + if ( + response.status === HttpStatusCode.Unauthorized || + response.status === HttpStatusCode.Forbidden || + (response.status === HttpStatusCode.BadRequest && + responseJson != null && + responseJson.error === "invalid_grant") + ) { + await this.logoutCallback("sessionExpired"); + } + + return new ErrorResponse(responseJson, response.status, true); + } + + private async getJsonResponse(response: Response): Promise { let responseJson: any = null; if (this.isJsonResponse(response)) { responseJson = await response.json(); } else if (this.isTextPlainResponse(response)) { responseJson = { Message: await response.text() }; } - - if (authed) { - if ( - response.status === 401 || - response.status === 403 || - (tokenError && - response.status === 400 && - responseJson != null && - responseJson.error === "invalid_grant") - ) { - await this.logoutCallback("invalidGrantError"); - } - } - - return new ErrorResponse(responseJson, response.status, tokenError); + return responseJson; } private qsStringify(params: any): string {