From 6d42dd0f3e3572a41ecef52d2d97708fb8897b3c Mon Sep 17 00:00:00 2001 From: gbubemismith Date: Tue, 8 Apr 2025 14:45:49 -0400 Subject: [PATCH] added new function to be used for decrypting ciphers --- .../src/vault/services/cipher.service.spec.ts | 31 ++++ .../src/vault/services/cipher.service.ts | 164 ++++++++++-------- 2 files changed, 126 insertions(+), 69 deletions(-) diff --git a/libs/common/src/vault/services/cipher.service.spec.ts b/libs/common/src/vault/services/cipher.service.spec.ts index dd7faea8e8a..2781bb389c8 100644 --- a/libs/common/src/vault/services/cipher.service.spec.ts +++ b/libs/common/src/vault/services/cipher.service.spec.ts @@ -15,6 +15,7 @@ import { EncryptService } from "../../key-management/crypto/abstractions/encrypt import { UriMatchStrategy } from "../../models/domain/domain-service"; import { ConfigService } from "../../platform/abstractions/config/config.service"; import { I18nService } from "../../platform/abstractions/i18n.service"; +import { SdkService } from "../../platform/abstractions/sdk/sdk.service"; import { StateService } from "../../platform/abstractions/state.service"; import { Utils } from "../../platform/misc/utils"; import { EncArrayBuffer } from "../../platform/models/domain/enc-array-buffer"; @@ -23,6 +24,7 @@ import { SymmetricCryptoKey } from "../../platform/models/domain/symmetric-crypt import { ContainerService } from "../../platform/services/container.service"; import { CipherId, UserId } from "../../types/guid"; import { CipherKey, OrgKey, UserKey } from "../../types/key"; +import { CipherEncryptionService } from "../abstractions/cipher-encryption.service"; import { CipherFileUploadService } from "../abstractions/file-upload/cipher-file-upload.service"; import { FieldType } from "../enums"; import { CipherRepromptType } from "../enums/cipher-reprompt-type"; @@ -122,6 +124,8 @@ describe("Cipher Service", () => { const configService = mock(); accountService = mockAccountServiceWith(mockUserId); const stateProvider = new FakeStateProvider(accountService); + const sdkService = mock(); + const cipherEncryptionService = mock(); const userId = "TestUserId" as UserId; @@ -148,6 +152,8 @@ describe("Cipher Service", () => { configService, stateProvider, accountService, + sdkService, + cipherEncryptionService, ); cipherObj = new Cipher(cipherData); @@ -470,4 +476,29 @@ describe("Cipher Service", () => { ).rejects.toThrow("Cannot rotate ciphers when decryption failures are present"); }); }); + + describe("decryptCipherWithSdkOrLegacy", () => { + it("should call decrypt method of CipherEncryptionService when feature flag is true", async () => { + configService.getFeatureFlag.mockResolvedValue(true); + cipherEncryptionService.decrypt.mockResolvedValue(new CipherView(cipherObj)); + + const result = await cipherService.decryptCipherWithSdkOrLegacy(cipherObj, userId); + + expect(result).toEqual(new CipherView(cipherObj)); + expect(cipherEncryptionService.decrypt).toHaveBeenCalledWith(cipherObj, userId); + }); + + it("should call legacy decrypt when feature flag is false", async () => { + const mockUserKey = new SymmetricCryptoKey(new Uint8Array(32)) as UserKey; + configService.getFeatureFlag.mockResolvedValue(false); + cipherService.getKeyForCipherKeyDecryption = jest.fn().mockResolvedValue(mockUserKey); + encryptService.decryptToBytes.mockResolvedValue(new Uint8Array(32)); + jest.spyOn(cipherObj, "decrypt").mockResolvedValue(new CipherView(cipherObj)); + + const result = await cipherService.decryptCipherWithSdkOrLegacy(cipherObj, userId); + + expect(result).toEqual(new CipherView(cipherObj)); + expect(cipherObj.decrypt).toHaveBeenCalledWith(mockUserKey); + }); + }); }); diff --git a/libs/common/src/vault/services/cipher.service.ts b/libs/common/src/vault/services/cipher.service.ts index 655ffa48f52..7e6fc19dd1b 100644 --- a/libs/common/src/vault/services/cipher.service.ts +++ b/libs/common/src/vault/services/cipher.service.ts @@ -15,7 +15,6 @@ import { import { SemVer } from "semver"; import { KeyService } from "@bitwarden/key-management"; -import { CipherView as SdkCipherView } from "@bitwarden/sdk-internal"; import { ApiService } from "../../abstractions/api.service"; import { SearchService } from "../../abstractions/search.service"; @@ -41,6 +40,7 @@ import { SymmetricCryptoKey } from "../../platform/models/domain/symmetric-crypt import { StateProvider } from "../../platform/state"; import { CipherId, CollectionId, OrganizationId, UserId } from "../../types/guid"; import { OrgKey, UserKey } from "../../types/key"; +import { CipherEncryptionService } from "../abstractions/cipher-encryption.service"; import { CipherService as CipherServiceAbstraction } from "../abstractions/cipher.service"; import { CipherFileUploadService } from "../abstractions/file-upload/cipher-file-upload.service"; import { FieldType } from "../enums"; @@ -113,6 +113,7 @@ export class CipherService implements CipherServiceAbstraction { private stateProvider: StateProvider, private accountService: AccountService, private sdkService: SdkService, + private cipherEncryptionService: CipherEncryptionService, ) {} localData$(userId: UserId): Observable> { @@ -158,23 +159,6 @@ export class CipherService implements CipherServiceAbstraction { ); } - /** - * {@link CipherServiceAbstraction.decrypt$} - */ - decrypt$(userId: UserId, cipher: Cipher): Observable { - return this.sdkService.userClient$(userId).pipe( - map((sdk) => { - if (!sdk) { - throw new Error("SDK is undefined"); - } - - using ref = sdk.take(); - - return ref.value.vault().ciphers().decrypt(cipher.toSdkCipher()); - }), - ); - } - async setDecryptedCipherCache(value: CipherView[], userId: UserId) { // Sometimes we might prematurely decrypt the vault and that will result in no ciphers // if we cache it then we may accidentally return it when it's not right, we'd rather try decryption again. @@ -448,55 +432,70 @@ export class CipherService implements CipherServiceAbstraction { ciphers: Cipher[], userId: UserId, ): Promise<[CipherView[], CipherView[]]> { - const keys = await firstValueFrom(this.keyService.cipherDecryptionKeys$(userId, true)); - - if (keys == null || (keys.userKey == null && Object.keys(keys.orgKeys).length === 0)) { - // return early if there are no keys to decrypt with - return [[], []]; - } - - // Group ciphers by orgId or under 'null' for the user's ciphers - const grouped = ciphers.reduce( - (agg, c) => { - agg[c.organizationId] ??= []; - agg[c.organizationId].push(c); - return agg; - }, - {} as Record, - ); - - const allCipherViews = ( - await Promise.all( - Object.entries(grouped).map(async ([orgId, groupedCiphers]) => { - if (await this.configService.getFeatureFlag(FeatureFlag.PM4154_BulkEncryptionService)) { - return await this.bulkEncryptService.decryptItems( - groupedCiphers, - keys.orgKeys[orgId as OrganizationId] ?? keys.userKey, - ); - } else { - return await this.encryptService.decryptItems( - groupedCiphers, - keys.orgKeys[orgId as OrganizationId] ?? keys.userKey, - ); - } - }), + if (await this.configService.getFeatureFlag(FeatureFlag.PM19941MigrateCipherDomainToSdk)) { + return this.decryptCiphersWithSdk(ciphers, userId); + } else { + const keys = await firstValueFrom(this.keyService.cipherDecryptionKeys$(userId, true)); + if (keys == null || (keys.userKey == null && Object.keys(keys.orgKeys).length === 0)) { + // return early if there are no keys to decrypt with + return [[], []]; + } + // Group ciphers by orgId or under 'null' for the user's ciphers + const grouped = ciphers.reduce( + (agg, c) => { + agg[c.organizationId] ??= []; + agg[c.organizationId].push(c); + return agg; + }, + {} as Record, + ); + const allCipherViews = ( + await Promise.all( + Object.entries(grouped).map(async ([orgId, groupedCiphers]) => { + if (await this.configService.getFeatureFlag(FeatureFlag.PM4154_BulkEncryptionService)) { + return await this.bulkEncryptService.decryptItems( + groupedCiphers, + keys.orgKeys[orgId as OrganizationId] ?? keys.userKey, + ); + } else { + return await this.encryptService.decryptItems( + groupedCiphers, + keys.orgKeys[orgId as OrganizationId] ?? keys.userKey, + ); + } + }), + ) ) - ) - .flat() - .sort(this.getLocaleSortingFunction()); + .flat() + .sort(this.getLocaleSortingFunction()); + // Split ciphers into two arrays, one for successfully decrypted ciphers and one for ciphers that failed to decrypt + return allCipherViews.reduce( + (acc, c) => { + if (c.decryptionFailure) { + acc[1].push(c); + } else { + acc[0].push(c); + } + return acc; + }, + [[], []] as [CipherView[], CipherView[]], + ); + } + } - // Split ciphers into two arrays, one for successfully decrypted ciphers and one for ciphers that failed to decrypt - return allCipherViews.reduce( - (acc, c) => { - if (c.decryptionFailure) { - acc[1].push(c); - } else { - acc[0].push(c); - } - return acc; - }, - [[], []] as [CipherView[], CipherView[]], - ); + /** + * Decrypts a cipher using either the SDK or the legacy method based on the feature flag. + * @param cipher The cipher to decrypt. + * @param userId The user ID to use for decryption. + * @returns A promise that resolves to the decrypted cipher view. + */ + async decryptCipherWithSdkOrLegacy(cipher: Cipher, userId: UserId): Promise { + if (await this.configService.getFeatureFlag(FeatureFlag.PM19941MigrateCipherDomainToSdk)) { + return await this.cipherEncryptionService.decrypt(cipher, userId); + } else { + const encKey = await this.getKeyForCipherKeyDecryption(cipher, userId); + return await cipher.decrypt(encKey); + } } private async reindexCiphers(userId: UserId) { @@ -910,7 +909,7 @@ export class CipherService implements CipherServiceAbstraction { //then we rollback to using the user key as the main key of encryption of the item //in order to keep item and it's attachments with the same encryption level if (cipher.key != null && !cipherKeyEncryptionEnabled) { - const model = await cipher.decrypt(await this.getKeyForCipherKeyDecryption(cipher, userId)); + const model = await this.decryptCipherWithSdkOrLegacy(cipher, userId); cipher = await this.encrypt(model, userId); await this.updateWithServer(cipher); } @@ -1441,9 +1440,7 @@ export class CipherService implements CipherServiceAbstraction { originalCipher: Cipher, userId: UserId, ): Promise { - const existingCipher = await originalCipher.decrypt( - await this.getKeyForCipherKeyDecryption(originalCipher, userId), - ); + const existingCipher = await this.decryptCipherWithSdkOrLegacy(originalCipher, userId); model.passwordHistory = existingCipher.passwordHistory || []; if (model.type === CipherType.Login && existingCipher.type === CipherType.Login) { if ( @@ -1856,4 +1853,33 @@ export class CipherService implements CipherServiceAbstraction { ); return featureEnabled && meetsServerVersion; } + + /** + * Decrypts the provided ciphers using the SDK. + * @param ciphers The ciphers to decrypt. + * @param userId The user ID to use for decryption. + * @returns A tuple containing the successful and failed decrypted ciphers. + * @private + */ + private async decryptCiphersWithSdk( + ciphers: Cipher[], + userId: UserId, + ): Promise<[CipherView[], CipherView[]]> { + const decryptedViews = await Promise.all( + ciphers.map((cipher) => this.cipherEncryptionService.decrypt(cipher, userId)), + ); + + const successful: CipherView[] = []; + const failed: CipherView[] = []; + + decryptedViews.forEach((view) => { + if (view.decryptionFailure) { + failed.push(view); + } else { + successful.push(view); + } + }); + + return [successful.sort(this.getLocaleSortingFunction()), failed]; + } }