diff --git a/libs/common/src/vault/abstractions/cipher-encryption.service.ts b/libs/common/src/vault/abstractions/cipher-encryption.service.ts index 35becd4b0e7..6057a91bae5 100644 --- a/libs/common/src/vault/abstractions/cipher-encryption.service.ts +++ b/libs/common/src/vault/abstractions/cipher-encryption.service.ts @@ -69,14 +69,19 @@ export abstract class CipherEncryptionService { */ abstract decryptManyLegacy(ciphers: Cipher[], userId: UserId): Promise; /** - * Decrypts many ciphers using the SDK for the given userId. + * Decrypts many ciphers using the SDK for the given userId, and returns a list of + * failures. * * @param ciphers The encrypted cipher objects * @param userId The user ID whose key will be used for decryption * - * @returns A promise that resolves to an array of decrypted cipher list views + * @returns A promise that resolves to a tuple containing an array of decrypted + * cipher list views, and an array of ciphers that failed to decrypt. */ - abstract decryptMany(ciphers: Cipher[], userId: UserId): Promise; + abstract decryptManyWithFailures( + ciphers: Cipher[], + userId: UserId, + ): Promise<[CipherListView[], Cipher[]]>; /** * Decrypts an attachment's content from a response object. * diff --git a/libs/common/src/vault/services/cipher.service.spec.ts b/libs/common/src/vault/services/cipher.service.spec.ts index 2088f50d1cc..e7437b24656 100644 --- a/libs/common/src/vault/services/cipher.service.spec.ts +++ b/libs/common/src/vault/services/cipher.service.spec.ts @@ -8,6 +8,7 @@ import { CipherResponse } from "@bitwarden/common/vault/models/response/cipher.r // eslint-disable-next-line no-restricted-imports import { CipherDecryptionKeys, KeyService } from "@bitwarden/key-management"; import { MessageSender } from "@bitwarden/messaging"; +import { CipherListView } from "@bitwarden/sdk-internal"; import { FakeAccountService, mockAccountServiceWith } from "../../../spec/fake-account-service"; import { FakeStateProvider } from "../../../spec/fake-state-provider"; @@ -117,6 +118,12 @@ describe("Cipher Service", () => { encryptService.encryptFileData.mockReturnValue(Promise.resolve(ENCRYPTED_BYTES)); encryptService.encryptString.mockReturnValue(Promise.resolve(new EncString(ENCRYPTED_TEXT))); + // Mock i18nService collator + i18nService.collator = { + compare: jest.fn().mockImplementation((a: string, b: string) => a.localeCompare(b)), + resolvedOptions: jest.fn().mockReturnValue({}), + } as any; + (window as any).bitwardenContainerService = new ContainerService(keyService, encryptService); cipherService = new CipherService( @@ -733,4 +740,80 @@ describe("Cipher Service", () => { ); }); }); + + describe("decryptCiphers", () => { + let mockCiphers: Cipher[]; + const cipher1_id = "11111111-1111-1111-1111-111111111111"; + const cipher2_id = "22222222-2222-2222-2222-222222222222"; + + beforeEach(() => { + const originalUserKey = new SymmetricCryptoKey(new Uint8Array(32)) as UserKey; + const orgKey = new SymmetricCryptoKey(new Uint8Array(32)) as OrgKey; + const keys = { + userKey: originalUserKey, + orgKeys: { [orgId]: orgKey }, + } as CipherDecryptionKeys; + keyService.cipherDecryptionKeys$.mockReturnValue(of(keys)); + + mockCiphers = [ + new Cipher({ ...cipherData, id: cipher1_id }), + new Cipher({ ...cipherData, id: cipher2_id }), + ]; + + //// Mock the SDK response + cipherEncryptionService.decryptManyWithFailures.mockResolvedValue([ + [{ id: mockCiphers[0].id, name: "Success 1" } as unknown as CipherListView], + [mockCiphers[1]], // Mock failed cipher + ]); + }); + + it("should use the SDK for decryption when SDK feature flag is enabled", async () => { + configService.getFeatureFlag + .calledWith(FeatureFlag.PM19941MigrateCipherDomainToSdk) + .mockResolvedValue(true); + + // Set up expected results + const expectedSuccessCipherViews = [ + { id: mockCiphers[0].id, name: "Success 1" } as unknown as CipherListView, + ]; + + const expectedFailedCipher = new CipherView(mockCiphers[1]); + expectedFailedCipher.name = "[error: cannot decrypt]"; + expectedFailedCipher.decryptionFailure = true; + const expectedFailedCipherViews = [expectedFailedCipher]; + + // Execute + const [successes, failures] = await (cipherService as any).decryptCiphers( + mockCiphers, + userId, + ); + + // Verify the SDK was used for decryption + expect(cipherEncryptionService.decryptManyWithFailures).toHaveBeenCalledWith( + mockCiphers, + userId, + ); + + expect(successes).toEqual(expectedSuccessCipherViews); + expect(failures).toEqual(expectedFailedCipherViews); + }); + + it("should use legacy decryption when SDK feature flag is disabled", async () => { + configService.getFeatureFlag + .calledWith(FeatureFlag.PM19941MigrateCipherDomainToSdk) + .mockResolvedValue(false); + + // Execute + const [successes, failures] = await (cipherService as any).decryptCiphers( + mockCiphers, + userId, + ); + + // Verify the SDK was not used for decryption + expect(cipherEncryptionService.decryptManyWithFailures).toHaveBeenCalledTimes(0); + + expect(successes).toHaveLength(2); + expect(failures).toHaveLength(0); + }); + }); }); diff --git a/libs/common/src/vault/services/cipher.service.ts b/libs/common/src/vault/services/cipher.service.ts index 2ad4274c235..7373bca2831 100644 --- a/libs/common/src/vault/services/cipher.service.ts +++ b/libs/common/src/vault/services/cipher.service.ts @@ -158,11 +158,9 @@ export class CipherService implements CipherServiceAbstraction { ), ), switchMap(async (ciphers) => { - // TODO: remove this once failed decrypted ciphers are handled in the SDK - await this.setFailedDecryptedCiphers([], userId); - return this.cipherEncryptionService - .decryptMany(ciphers, userId) - .then((ciphers) => ciphers.sort(this.getLocaleSortingFunction())); + const [decrypted, failures] = await this.decryptCiphersWithSdk(ciphers, userId); + await this.setFailedDecryptedCiphers(failures, userId); + return decrypted.sort(this.getLocaleSortingFunction()); }), ); }), @@ -489,14 +487,14 @@ export class CipherService implements CipherServiceAbstraction { ): Promise<[CipherView[], CipherView[]] | null> { if (await this.configService.getFeatureFlag(FeatureFlag.PM19941MigrateCipherDomainToSdk)) { const decryptStartTime = performance.now(); - const decrypted = await this.decryptCiphersWithSdk(ciphers, userId); + + const result = await this.decryptCiphersWithSdk(ciphers, userId); this.logService.measure(decryptStartTime, "Vault", "CipherService", "decrypt complete", [ ["Items", ciphers.length], ]); - // With SDK, failed ciphers are not returned - return [decrypted, []]; + return result; } const keys = await firstValueFrom(this.keyService.cipherDecryptionKeys$(userId)); @@ -2034,10 +2032,23 @@ export class CipherService implements CipherServiceAbstraction { * @returns The decrypted ciphers. * @private */ - private async decryptCiphersWithSdk(ciphers: Cipher[], userId: UserId): Promise { - const decryptedViews = await this.cipherEncryptionService.decryptManyLegacy(ciphers, userId); + private async decryptCiphersWithSdk( + ciphers: Cipher[], + userId: UserId, + ): Promise<[CipherView[], CipherView[]]> { + const [decrypted, failures] = await this.cipherEncryptionService.decryptManyWithFailures( + ciphers, + userId, + ); + const decryptedViews = await Promise.all(decrypted.map((c) => this.getFullCipherView(c))); + const failedViews = failures.map((c) => { + const cipher_view = new CipherView(c); + cipher_view.name = "[error: cannot decrypt]"; + cipher_view.decryptionFailure = true; + return cipher_view; + }); - return decryptedViews.sort(this.getLocaleSortingFunction()); + return [decryptedViews.sort(this.getLocaleSortingFunction()), failedViews]; } /** Fetches the full `CipherView` when a `CipherListView` is passed. */ diff --git a/libs/common/src/vault/services/default-cipher-encryption.service.spec.ts b/libs/common/src/vault/services/default-cipher-encryption.service.spec.ts index 12e5b0b4626..dee45d46a57 100644 --- a/libs/common/src/vault/services/default-cipher-encryption.service.spec.ts +++ b/libs/common/src/vault/services/default-cipher-encryption.service.spec.ts @@ -98,6 +98,7 @@ describe("DefaultCipherEncryptionService", () => { set_fido2_credentials: jest.fn(), decrypt: jest.fn(), decrypt_list: jest.fn(), + decrypt_list_with_failures: jest.fn(), decrypt_fido2_credentials: jest.fn(), move_to_organization: jest.fn(), }), @@ -514,36 +515,40 @@ describe("DefaultCipherEncryptionService", () => { }); }); - describe("decryptMany", () => { - it("should decrypt multiple ciphers to list views", async () => { - const ciphers = [new Cipher(cipherData), new Cipher(cipherData)]; - - const expectedListViews = [ - { id: "list1" as any, name: "List 1" } as CipherListView, - { id: "list2" as any, name: "List 2" } as CipherListView, + describe("decryptManyWithFailures", () => { + const cipher1_id = "11111111-1111-1111-1111-111111111111"; + const cipher2_id = "22222222-2222-2222-2222-222222222222"; + it("should decrypt multiple ciphers and return successes and failures", async () => { + const ciphers = [ + new Cipher({ ...cipherData, id: cipher1_id as CipherId }), + new Cipher({ ...cipherData, id: cipher2_id as CipherId }), ]; - mockSdkClient.vault().ciphers().decrypt_list.mockReturnValue(expectedListViews); + const successCipherList = { + id: cipher1_id, + name: "Decrypted Cipher 1", + } as unknown as CipherListView; + const failedCipher = { id: cipher2_id, name: "Failed Cipher" } as unknown as SdkCipher; - const result = await cipherEncryptionService.decryptMany(ciphers, userId); + const expectedFailedCiphers = [Cipher.fromSdkCipher(failedCipher)]; - expect(result).toEqual(expectedListViews); - expect(mockSdkClient.vault().ciphers().decrypt_list).toHaveBeenCalledWith( + const mockResult = { + successes: [successCipherList], + failures: [failedCipher], + }; + + mockSdkClient.vault().ciphers().decrypt_list_with_failures.mockReturnValue(mockResult); + + const result = await cipherEncryptionService.decryptManyWithFailures(ciphers, userId); + + expect(result).toEqual([[successCipherList], expectedFailedCiphers]); + expect(mockSdkClient.vault().ciphers().decrypt_list_with_failures).toHaveBeenCalledWith( expect.arrayContaining([ - expect.objectContaining({ id: cipherData.id }), - expect.objectContaining({ id: cipherData.id }), + expect.objectContaining({ id: cipher1_id }), + expect.objectContaining({ id: cipher2_id }), ]), ); - }); - - it("should throw EmptyError when SDK is not available", async () => { - sdkService.userClient$ = jest.fn().mockReturnValue(of(null)) as any; - - await expect(cipherEncryptionService.decryptMany([cipherObj], userId)).rejects.toThrow(); - - expect(logService.error).toHaveBeenCalledWith( - expect.stringContaining("Failed to decrypt cipher list"), - ); + expect(Cipher.fromSdkCipher).toHaveBeenCalledWith(failedCipher); }); }); diff --git a/libs/common/src/vault/services/default-cipher-encryption.service.ts b/libs/common/src/vault/services/default-cipher-encryption.service.ts index b7026fd4cfc..3f03e0f5e9e 100644 --- a/libs/common/src/vault/services/default-cipher-encryption.service.ts +++ b/libs/common/src/vault/services/default-cipher-encryption.service.ts @@ -6,6 +6,7 @@ import { CipherListView, BitwardenClient, CipherView as SdkCipherView, + DecryptCipherListResult, } from "@bitwarden/sdk-internal"; import { LogService } from "../../platform/abstractions/log.service"; @@ -218,7 +219,10 @@ export class DefaultCipherEncryptionService implements CipherEncryptionService { ); } - async decryptMany(ciphers: Cipher[], userId: UserId): Promise { + async decryptManyWithFailures( + ciphers: Cipher[], + userId: UserId, + ): Promise<[CipherListView[], Cipher[]]> { return firstValueFrom( this.sdkService.userClient$(userId).pipe( map((sdk) => { @@ -228,14 +232,17 @@ export class DefaultCipherEncryptionService implements CipherEncryptionService { using ref = sdk.take(); - return ref.value + const result: DecryptCipherListResult = ref.value .vault() .ciphers() - .decrypt_list(ciphers.map((cipher) => cipher.toSdkCipher())); - }), - catchError((error: unknown) => { - this.logService.error(`Failed to decrypt cipher list: ${error}`); - return EMPTY; + .decrypt_list_with_failures(ciphers.map((cipher) => cipher.toSdkCipher())); + + const decryptedCiphers = result.successes; + const failedCiphers: Cipher[] = result.failures + .map((cipher) => Cipher.fromSdkCipher(cipher)) + .filter((cipher): cipher is Cipher => cipher !== undefined); + + return [decryptedCiphers, failedCiphers]; }), ), );