1
0
mirror of https://github.com/bitwarden/browser synced 2026-02-01 09:13:54 +00:00

Update to handle 401 to refresh token.

This commit is contained in:
Todd Martin
2025-11-15 11:31:31 -05:00
parent 9cd73b8738
commit 99c73075eb
2 changed files with 607 additions and 59 deletions

View File

@@ -447,4 +447,464 @@ describe("ApiService", () => {
).rejects.toThrow(InsecureUrlNotAllowedError);
expect(nativeFetch).not.toHaveBeenCalled();
});
it("retries request with refreshed token when initial request with access token returns 401", async () => {
environmentService.getEnvironment$.calledWith(testActiveUser).mockReturnValue(
of({
getApiUrl: () => "https://example.com",
getIdentityUrl: () => "https://identity.example.com",
} satisfies Partial<Environment> as Environment),
);
httpOperations.createRequest.mockImplementation((url, request) => {
return {
url: url,
cache: request.cache,
credentials: request.credentials,
method: request.method,
mode: request.mode,
signal: request.signal,
headers: new Headers(request.headers),
} satisfies Partial<Request> as unknown as Request;
});
tokenService.getAccessToken.calledWith(testActiveUser).mockResolvedValue("expired_token");
tokenService.tokenNeedsRefresh.calledWith(testActiveUser).mockResolvedValue(false);
tokenService.getRefreshToken
.calledWith(testActiveUser)
.mockResolvedValue("valid_refresh_token");
tokenService.decodeAccessToken
.calledWith(testActiveUser)
.mockResolvedValue({ client_id: "web" });
tokenService.decodeAccessToken
.calledWith("new_access_token")
.mockResolvedValue({ sub: testActiveUser });
vaultTimeoutSettingsService.getVaultTimeoutActionByUserId$
.calledWith(testActiveUser)
.mockReturnValue(of(VaultTimeoutAction.Lock));
vaultTimeoutSettingsService.getVaultTimeoutByUserId$
.calledWith(testActiveUser)
.mockReturnValue(of(VaultTimeoutStringType.Never));
tokenService.setTokens
.calledWith(
"new_access_token",
VaultTimeoutAction.Lock,
VaultTimeoutStringType.Never,
"new_refresh_token",
)
.mockResolvedValue({ accessToken: "refreshed_access_token" });
const nativeFetch = jest.fn<Promise<Response>, [request: Request]>();
let callCount = 0;
nativeFetch.mockImplementation((request) => {
callCount++;
// First call: initial request with expired token returns 401
if (callCount === 1) {
return Promise.resolve({
ok: false,
status: 401,
json: () => Promise.resolve({ message: "Unauthorized" }),
headers: new Headers({
"content-type": "application/json",
}),
} satisfies Partial<Response> as Response);
}
// Second call: token refresh request
if (callCount === 2 && request.url.includes("identity")) {
return Promise.resolve({
ok: true,
status: 200,
json: () =>
Promise.resolve({
access_token: "new_access_token",
token_type: "Bearer",
refresh_token: "new_refresh_token",
}),
} satisfies Partial<Response> as Response);
}
// Third call: retry with refreshed token succeeds
if (callCount === 3) {
expect(request.headers.get("Authorization")).toBe("Bearer refreshed_access_token");
return Promise.resolve({
ok: true,
status: 200,
json: () => Promise.resolve({ data: "success" }),
headers: new Headers({
"content-type": "application/json",
}),
} satisfies Partial<Response> as Response);
}
throw new Error("Unexpected call");
});
sut.nativeFetch = nativeFetch;
const response = await sut.send("GET", "/something", null, true, true, null, null);
expect(nativeFetch).toHaveBeenCalledTimes(3);
expect(response).toEqual({ data: "success" });
});
it("does not retry when request has no access token and returns 401", async () => {
environmentService.environment$ = of({
getApiUrl: () => "https://example.com",
} satisfies Partial<Environment> as Environment);
httpOperations.createRequest.mockImplementation((url, request) => {
return {
url: url,
cache: request.cache,
credentials: request.credentials,
method: request.method,
mode: request.mode,
signal: request.signal,
headers: new Headers(request.headers),
} satisfies Partial<Request> as unknown as Request;
});
const nativeFetch = jest.fn<Promise<Response>, [request: Request]>();
nativeFetch.mockImplementation((request) => {
return Promise.resolve({
ok: false,
status: 401,
json: () => Promise.resolve({ message: "Unauthorized" }),
headers: new Headers({
"content-type": "application/json",
}),
} satisfies Partial<Response> as Response);
});
sut.nativeFetch = nativeFetch;
await expect(
async () => await sut.send("GET", "/something", null, false, true, null, null),
).rejects.toMatchObject({ message: "Unauthorized" });
// Should only be called once (no retry)
expect(nativeFetch).toHaveBeenCalledTimes(1);
});
it("does not retry when request returns non-401 error", async () => {
environmentService.getEnvironment$.calledWith(testActiveUser).mockReturnValue(
of({
getApiUrl: () => "https://example.com",
getIdentityUrl: () => "https://identity.example.com",
} satisfies Partial<Environment> as Environment),
);
httpOperations.createRequest.mockImplementation((url, request) => {
return {
url: url,
cache: request.cache,
credentials: request.credentials,
method: request.method,
mode: request.mode,
signal: request.signal,
headers: new Headers(request.headers),
} satisfies Partial<Request> as unknown as Request;
});
tokenService.getAccessToken.calledWith(testActiveUser).mockResolvedValue("valid_token");
tokenService.tokenNeedsRefresh.calledWith(testActiveUser).mockResolvedValue(false);
const nativeFetch = jest.fn<Promise<Response>, [request: Request]>();
nativeFetch.mockImplementation((request) => {
return Promise.resolve({
ok: false,
status: 400,
json: () => Promise.resolve({ message: "Bad Request" }),
headers: new Headers({
"content-type": "application/json",
}),
} satisfies Partial<Response> as Response);
});
sut.nativeFetch = nativeFetch;
await expect(
async () => await sut.send("GET", "/something", null, true, true, null, null),
).rejects.toMatchObject({ message: "Bad Request" });
// Should only be called once (no retry for non-401 errors)
expect(nativeFetch).toHaveBeenCalledTimes(1);
});
it("does not retry when hasResponse is false", async () => {
environmentService.environment$ = of({
getApiUrl: () => "https://example.com",
} satisfies Partial<Environment> as Environment);
environmentService.getEnvironment$.calledWith(testActiveUser).mockReturnValue(
of({
getApiUrl: () => "https://example.com",
getIdentityUrl: () => "https://identity.example.com",
} satisfies Partial<Environment> as Environment),
);
httpOperations.createRequest.mockImplementation((url, request) => {
return {
url: url,
cache: request.cache,
credentials: request.credentials,
method: request.method,
mode: request.mode,
signal: request.signal,
headers: new Headers(request.headers),
} satisfies Partial<Request> as unknown as Request;
});
tokenService.getAccessToken.calledWith(testActiveUser).mockResolvedValue("expired_token");
tokenService.tokenNeedsRefresh.calledWith(testActiveUser).mockResolvedValue(false);
const nativeFetch = jest.fn<Promise<Response>, [request: Request]>();
nativeFetch.mockImplementation((request) => {
return Promise.resolve({
ok: false,
status: 401,
json: () => Promise.resolve({ message: "Unauthorized" }),
headers: new Headers({
"content-type": "application/json",
}),
} satisfies Partial<Response> as Response);
});
sut.nativeFetch = nativeFetch;
// When hasResponse is false, the method should throw even though no retry happens
await expect(
async () => await sut.send("POST", "/something", null, true, false, null, null),
).rejects.toMatchObject({ message: "Unauthorized" });
// Should only be called once (no retry when hasResponse is false)
expect(nativeFetch).toHaveBeenCalledTimes(1);
});
it("throws error when retry also returns 401", async () => {
environmentService.getEnvironment$.calledWith(testActiveUser).mockReturnValue(
of({
getApiUrl: () => "https://example.com",
getIdentityUrl: () => "https://identity.example.com",
} satisfies Partial<Environment> as Environment),
);
httpOperations.createRequest.mockImplementation((url, request) => {
return {
url: url,
cache: request.cache,
credentials: request.credentials,
method: request.method,
mode: request.mode,
signal: request.signal,
headers: new Headers(request.headers),
} satisfies Partial<Request> as unknown as Request;
});
tokenService.getAccessToken.calledWith(testActiveUser).mockResolvedValue("expired_token");
tokenService.tokenNeedsRefresh.calledWith(testActiveUser).mockResolvedValue(false);
tokenService.getRefreshToken
.calledWith(testActiveUser)
.mockResolvedValue("valid_refresh_token");
tokenService.decodeAccessToken
.calledWith(testActiveUser)
.mockResolvedValue({ client_id: "web" });
tokenService.decodeAccessToken
.calledWith("new_access_token")
.mockResolvedValue({ sub: testActiveUser });
vaultTimeoutSettingsService.getVaultTimeoutActionByUserId$
.calledWith(testActiveUser)
.mockReturnValue(of(VaultTimeoutAction.Lock));
vaultTimeoutSettingsService.getVaultTimeoutByUserId$
.calledWith(testActiveUser)
.mockReturnValue(of(VaultTimeoutStringType.Never));
tokenService.setTokens
.calledWith(
"new_access_token",
VaultTimeoutAction.Lock,
VaultTimeoutStringType.Never,
"new_refresh_token",
)
.mockResolvedValue({ accessToken: "refreshed_access_token" });
const nativeFetch = jest.fn<Promise<Response>, [request: Request]>();
let callCount = 0;
nativeFetch.mockImplementation((request) => {
callCount++;
// First call: initial request with expired token returns 401
if (callCount === 1) {
return Promise.resolve({
ok: false,
status: 401,
json: () => Promise.resolve({ message: "Unauthorized" }),
headers: new Headers({
"content-type": "application/json",
}),
} satisfies Partial<Response> as Response);
}
// Second call: token refresh request
if (callCount === 2 && request.url.includes("identity")) {
return Promise.resolve({
ok: true,
status: 200,
json: () =>
Promise.resolve({
access_token: "new_access_token",
token_type: "Bearer",
refresh_token: "new_refresh_token",
}),
} satisfies Partial<Response> as Response);
}
// Third call: retry with refreshed token still returns 401 (user no longer has permission)
if (callCount === 3) {
return Promise.resolve({
ok: false,
status: 401,
json: () => Promise.resolve({ message: "Still Unauthorized" }),
headers: new Headers({
"content-type": "application/json",
}),
} satisfies Partial<Response> as Response);
}
throw new Error("Unexpected call");
});
sut.nativeFetch = nativeFetch;
await expect(
async () => await sut.send("GET", "/something", null, true, true, null, null),
).rejects.toMatchObject({ message: "Still Unauthorized" });
expect(nativeFetch).toHaveBeenCalledTimes(3);
expect(logoutCallback).toHaveBeenCalledWith("sessionExpired");
});
it("retries with refreshed token for inactive user when 401 received", async () => {
tokenService.getAccessToken
.calledWith(testInactiveUser)
.mockResolvedValue("inactive_expired_token");
tokenService.tokenNeedsRefresh.calledWith(testInactiveUser).mockResolvedValue(false);
tokenService.getRefreshToken
.calledWith(testInactiveUser)
.mockResolvedValue("inactive_refresh_token");
tokenService.decodeAccessToken
.calledWith(testInactiveUser)
.mockResolvedValue({ client_id: "web" });
tokenService.decodeAccessToken
.calledWith("inactive_new_access_token")
.mockResolvedValue({ sub: testInactiveUser });
vaultTimeoutSettingsService.getVaultTimeoutActionByUserId$
.calledWith(testInactiveUser)
.mockReturnValue(of(VaultTimeoutAction.Lock));
vaultTimeoutSettingsService.getVaultTimeoutByUserId$
.calledWith(testInactiveUser)
.mockReturnValue(of(VaultTimeoutStringType.Never));
tokenService.setTokens
.calledWith(
"inactive_new_access_token",
VaultTimeoutAction.Lock,
VaultTimeoutStringType.Never,
"inactive_new_refresh_token",
)
.mockResolvedValue({ accessToken: "inactive_refreshed_access_token" });
environmentService.getEnvironment$.calledWith(testInactiveUser).mockReturnValue(
of({
getApiUrl: () => "https://inactive.example.com",
getIdentityUrl: () => "https://identity.inactive.example.com",
} satisfies Partial<Environment> as Environment),
);
httpOperations.createRequest.mockImplementation((url, request) => {
return {
url: url,
cache: request.cache,
credentials: request.credentials,
method: request.method,
mode: request.mode,
signal: request.signal,
headers: new Headers(request.headers),
} satisfies Partial<Request> as unknown as Request;
});
const nativeFetch = jest.fn<Promise<Response>, [request: Request]>();
let callCount = 0;
nativeFetch.mockImplementation((request) => {
callCount++;
if (callCount === 1) {
return Promise.resolve({
ok: false,
status: 401,
json: () => Promise.resolve({ message: "Unauthorized" }),
headers: new Headers({
"content-type": "application/json",
}),
} satisfies Partial<Response> as Response);
}
if (callCount === 2 && request.url.includes("identity")) {
return Promise.resolve({
ok: true,
status: 200,
json: () =>
Promise.resolve({
access_token: "inactive_new_access_token",
token_type: "Bearer",
refresh_token: "inactive_new_refresh_token",
}),
} satisfies Partial<Response> as Response);
}
if (callCount === 3) {
expect(request.headers.get("Authorization")).toBe("Bearer inactive_refreshed_access_token");
return Promise.resolve({
ok: true,
status: 200,
json: () => Promise.resolve({ data: "inactive user success" }),
headers: new Headers({
"content-type": "application/json",
}),
} satisfies Partial<Response> as Response);
}
throw new Error("Unexpected call");
});
sut.nativeFetch = nativeFetch;
const response = await sut.send("GET", "/something", null, testInactiveUser, true, null, null);
expect(nativeFetch).toHaveBeenCalledTimes(3);
expect(response).toEqual({ data: "inactive user success" });
});
});

View File

@@ -73,7 +73,7 @@ import { BillingHistoryResponse } from "../billing/models/response/billing-histo
import { PaymentResponse } from "../billing/models/response/payment.response";
import { PlanResponse } from "../billing/models/response/plan.response";
import { SubscriptionResponse } from "../billing/models/response/subscription.response";
import { ClientType, DeviceType } from "../enums";
import { ClientType, DeviceType, HttpStatusCode } from "../enums";
import { KeyConnectorUserKeyRequest } from "../key-management/key-connector/models/key-connector-user-key.request";
import { SetKeyConnectorKeyRequest } from "../key-management/key-connector/models/set-key-connector-key.request";
import { VaultTimeoutSettingsService } from "../key-management/vault-timeout";
@@ -1246,8 +1246,8 @@ export class ApiService implements ApiServiceAbstraction {
}),
);
if (response.status !== 200) {
const error = await this.handleError(response, false, true);
if (response.status !== HttpStatusCode.Ok) {
const error = await this.handleApiRequestError(response, true);
return Promise.reject(error);
}
@@ -1277,8 +1277,8 @@ export class ApiService implements ApiServiceAbstraction {
}),
);
if (response.status !== 200) {
const error = await this.handleError(response, false, true);
if (response.status !== HttpStatusCode.Ok) {
const error = await this.handleApiRequestError(response, true);
return Promise.reject(error);
}
}
@@ -1295,15 +1295,22 @@ export class ApiService implements ApiServiceAbstraction {
}),
);
if (response.status !== 200) {
const error = await this.handleError(response, false, true);
if (response.status !== HttpStatusCode.Ok) {
const error = await this.handleApiRequestError(response, true);
return Promise.reject(error);
}
}
// Helpers
async getActiveBearerToken(userId: UserId): Promise<string> {
/**
* Retrieves the bearer access token for the user, or `null` if no token exists.
* If the access token is expired or within 5 minutes of expiration, attemps to refresh the token
* and persists the refresh token to state before returning it.
* @param userId The user for whom we're retrieving the access token
* @returns The access token, or `null` if none exists for the `userId` provided.
*/
async getActiveBearerToken(userId: UserId): Promise<string | null> {
let accessToken = await this.tokenService.getAccessToken(userId);
if (await this.tokenService.tokenNeedsRefresh(userId)) {
accessToken = await this.refreshToken(userId);
@@ -1359,7 +1366,7 @@ export class ApiService implements ApiServiceAbstraction {
const body = await response.json();
return new SsoPreValidateResponse(body);
} else {
const error = await this.handleError(response, false, true);
const error = await this.handleApiRequestError(response, false);
return Promise.reject(error);
}
}
@@ -1514,7 +1521,7 @@ export class ApiService implements ApiServiceAbstraction {
);
return refreshedTokens.accessToken;
} else {
const error = await this.handleError(response, true, true);
const error = await this.handleTokenRefreshRequestError(response);
return Promise.reject(error);
}
}
@@ -1569,6 +1576,68 @@ export class ApiService implements ApiServiceAbstraction {
apiUrl?: string | null,
alterHeaders?: (headers: Headers) => void,
): Promise<any> {
const userId = await this.getUserIdForRequest(authedOrUserId);
const environment = await firstValueFrom(
userId == null
? this.environmentService.environment$
: this.environmentService.getEnvironment$(userId),
);
apiUrl = Utils.isNullOrWhitespace(apiUrl) ? environment.getApiUrl() : apiUrl;
// Prevent directory traversal from malicious paths
const pathParts = path.split("?");
const requestUrl =
apiUrl + Utils.normalizePath(pathParts[0]) + (pathParts.length > 1 ? `?${pathParts[1]}` : "");
const accessToken = await this.getActiveBearerToken(userId);
let request = await this.buildRequest(
method,
accessToken,
environment,
hasResponse,
body,
alterHeaders,
);
let response = await this.fetch(this.httpOperations.createRequest(requestUrl, request));
// First, check to see if we were making an authenticated request and received an Unauthorized (401)
// response. This could mean that we attempted to make a request with an expired access token.
// If so, attempt to refresh the token and try again.
if (hasResponse && accessToken != null && response.status === HttpStatusCode.Unauthorized) {
const refreshedToken = await this.refreshAccessToken(userId);
request = await this.buildRequest(
method,
refreshedToken,
environment,
hasResponse,
body,
alterHeaders,
);
response = await this.fetch(this.httpOperations.createRequest(requestUrl, request));
}
// At this point we are processing either the initial response or the response for the retry with the refreshed
// access token.
const responseType = response.headers.get("content-type");
const responseIsJson = responseType != null && responseType.indexOf("application/json") !== -1;
const responseIsCsv = responseType != null && responseType.indexOf("text/csv") !== -1;
if (hasResponse && response.status === HttpStatusCode.Ok && responseIsJson) {
const responseJson = await response.json();
return responseJson;
} else if (hasResponse && response.status === HttpStatusCode.Ok && responseIsCsv) {
return await response.text();
} else if (
response.status !== HttpStatusCode.Ok &&
response.status !== HttpStatusCode.NoContent
) {
const error = await this.handleApiRequestError(response, userId != null);
return Promise.reject(error);
}
}
private async getUserIdForRequest(authedOrUserId: UserId | boolean): Promise<UserId> {
if (authedOrUserId == null) {
throw new Error("A user id was given but it was null, cannot complete API request.");
}
@@ -1580,21 +1649,19 @@ export class ApiService implements ApiServiceAbstraction {
} else if (typeof authedOrUserId === "string") {
userId = authedOrUserId;
}
return userId;
}
const env = await firstValueFrom(
userId == null
? this.environmentService.environment$
: this.environmentService.getEnvironment$(userId),
);
apiUrl = Utils.isNullOrWhitespace(apiUrl) ? env.getApiUrl() : apiUrl;
// Prevent directory traversal from malicious paths
const pathParts = path.split("?");
const requestUrl =
apiUrl + Utils.normalizePath(pathParts[0]) + (pathParts.length > 1 ? `?${pathParts[1]}` : "");
private async buildRequest(
method: "GET" | "POST" | "PUT" | "DELETE" | "PATCH",
accessToken: string | null,
environment: Environment,
hasResponse: boolean,
body: string,
alterHeaders?: (headers: Headers) => void,
): Promise<RequestInit> {
const [requestHeaders, requestBody] = await this.buildHeadersAndBody(
userId,
accessToken,
hasResponse,
body,
alterHeaders,
@@ -1602,29 +1669,17 @@ export class ApiService implements ApiServiceAbstraction {
const requestInit: RequestInit = {
cache: "no-store",
credentials: await this.getCredentials(env),
credentials: await this.getCredentials(environment),
method: method,
};
requestInit.headers = requestHeaders;
requestInit.body = requestBody;
const response = await this.fetch(this.httpOperations.createRequest(requestUrl, requestInit));
const responseType = response.headers.get("content-type");
const responseIsJson = responseType != null && responseType.indexOf("application/json") !== -1;
const responseIsCsv = responseType != null && responseType.indexOf("text/csv") !== -1;
if (hasResponse && response.status === 200 && responseIsJson) {
const responseJson = await response.json();
return responseJson;
} else if (hasResponse && response.status === 200 && responseIsCsv) {
return await response.text();
} else if (response.status !== 200 && response.status !== 204) {
const error = await this.handleError(response, false, userId != null);
return Promise.reject(error);
}
return requestInit;
}
private async buildHeadersAndBody(
userToAuthenticate: UserId | null,
bearerAuthenticationToken: string | null,
hasResponse: boolean,
body: any,
alterHeaders: (headers: Headers) => void,
@@ -1646,9 +1701,8 @@ export class ApiService implements ApiServiceAbstraction {
if (alterHeaders != null) {
alterHeaders(headers);
}
if (userToAuthenticate != null) {
const authHeader = await this.getActiveBearerToken(userToAuthenticate);
headers.set("Authorization", "Bearer " + authHeader);
if (bearerAuthenticationToken != null) {
headers.set("Authorization", "Bearer " + bearerAuthenticationToken);
} else {
// For unauthenticated requests, we need to tell the server what the device is for flag targeting,
// since it won't be able to get it from the access token.
@@ -1673,32 +1727,66 @@ export class ApiService implements ApiServiceAbstraction {
return [headers, requestBody];
}
private async handleError(
/**
* Handle an error response from a request to the Bitwarden API.
* If the request is made with an access token (aka the user is authenticated),
* and we receive a 401 or 403 response, we will log the user out, as this indicates
* that the access token used on the request is either expired or does not have the appropriate permissions.
* It is unlikely that it is expired, as we attempt to refresh the token on initial failure.
* @param response The response from the API request
* @param userIsAuthenticated A boolean indicating whether this is an authenticated request.
* @returns An ErrorResponse with a message based on the response status.
*/
private async handleApiRequestError(
response: Response,
tokenError: boolean,
authed: boolean,
userIsAuthenticated: boolean,
): Promise<ErrorResponse> {
const responseJson = await this.getJsonResponse(response);
if (
userIsAuthenticated &&
(response.status === HttpStatusCode.Unauthorized ||
response.status === HttpStatusCode.Forbidden)
) {
await this.logoutCallback("sessionExpired");
}
return new ErrorResponse(responseJson, response.status);
}
/**
* Handle an error response when trying to refresh an access token.
* If the error indicates that the user's session has expired, it will log the user out.
* @param response The response from the token refresh request.
* @returns An ErrorResponse with a message based on the response status.
*/
private async handleTokenRefreshRequestError(response: Response): Promise<ErrorResponse> {
const responseJson = await this.getJsonResponse(response);
// IdentityServer will return an invalid_grant response if the refresh token has expired.
// This means that the user's session has expired, and they need to log out.
// We issue the logoutCallback() to log the user out through messaging.
if (
response.status === HttpStatusCode.Unauthorized ||
response.status === HttpStatusCode.Forbidden ||
(response.status === HttpStatusCode.BadRequest &&
responseJson != null &&
responseJson.error === "invalid_grant")
) {
await this.logoutCallback("sessionExpired");
}
return new ErrorResponse(responseJson, response.status, true);
}
private async getJsonResponse(response: Response): Promise<any> {
let responseJson: any = null;
if (this.isJsonResponse(response)) {
responseJson = await response.json();
} else if (this.isTextPlainResponse(response)) {
responseJson = { Message: await response.text() };
}
if (authed) {
if (
response.status === 401 ||
response.status === 403 ||
(tokenError &&
response.status === 400 &&
responseJson != null &&
responseJson.error === "invalid_grant")
) {
await this.logoutCallback("invalidGrantError");
}
}
return new ErrorResponse(responseJson, response.status, tokenError);
return responseJson;
}
private qsStringify(params: any): string {