diff --git a/packages/firestore/src/platform/browser/webchannel_connection.ts b/packages/firestore/src/platform/browser/webchannel_connection.ts index 56f57aa9595..32673669b4d 100644 --- a/packages/firestore/src/platform/browser/webchannel_connection.ts +++ b/packages/firestore/src/platform/browser/webchannel_connection.ts @@ -16,7 +16,7 @@ */ import { - createWebChannelTransport, + createWebChannelTransport as internalCreateWebChannelTransport, ErrorCode, EventType, WebChannel, @@ -27,7 +27,8 @@ import { EventTarget, StatEvent, Event, - Stat + Stat, + WebChannelTransport } from '@firebase/webchannel-wrapper/webchannel-blob'; import { Token } from '../../api/credentials'; @@ -181,7 +182,7 @@ export class WebChannelConnection extends RestConnection { rpcName, '/channel' ]; - const webchannelTransport = createWebChannelTransport(); + const webchannelTransport = this.createWebChannelTransport(); const requestStats = getStatEventTarget(); const request: WebChannelOptions = { // Required for backend stickiness, routing behavior is based on this @@ -460,4 +461,29 @@ export class WebChannelConnection extends RestConnection { instance => instance === webChannel ); } + + /** + * Modifies the headers for a request, adding the api key if present, + * and then calling super.modifyHeadersForRequest + */ + protected modifyHeadersForRequest( + headers: StringMap, + authToken: Token | null, + appCheckToken: Token | null + ): void { + super.modifyHeadersForRequest(headers, authToken, appCheckToken); + + // For web channel streams, we want to send the api key in the headers. + if (this.databaseInfo.apiKey) { + headers['x-goog-api-key'] = this.databaseInfo.apiKey; + } + } + + /** + * Wrapped for mocking. + * @protected + */ + protected createWebChannelTransport(): WebChannelTransport { + return internalCreateWebChannelTransport(); + } } diff --git a/packages/firestore/src/remote/rest_connection.ts b/packages/firestore/src/remote/rest_connection.ts index d9446a733e5..83b0c572f9a 100644 --- a/packages/firestore/src/remote/rest_connection.ts +++ b/packages/firestore/src/remote/rest_connection.ts @@ -64,7 +64,6 @@ export abstract class RestConnection implements Connection { protected readonly baseUrl: string; private readonly databasePath: string; private readonly requestParams: string; - private readonly apiKey: string | undefined; get shouldResourcePathBeIncludedInRequest(): boolean { // Both `invokeRPC()` and `invokeStreamingRPC()` use their `path` arguments to determine @@ -72,7 +71,7 @@ export abstract class RestConnection implements Connection { return false; } - constructor(private readonly databaseInfo: DatabaseInfo) { + constructor(protected readonly databaseInfo: DatabaseInfo) { this.databaseId = databaseInfo.databaseId; const proto = databaseInfo.ssl ? 'https' : 'http'; const projectId = encodeURIComponent(this.databaseId.projectId); @@ -83,7 +82,6 @@ export abstract class RestConnection implements Connection { this.databaseId.database === DEFAULT_DATABASE_NAME ? `project_id=${projectId}` : `project_id=${projectId}&database_id=${databaseId}`; - this.apiKey = databaseInfo.apiKey; } invokeRPC( @@ -203,8 +201,8 @@ export abstract class RestConnection implements Connection { 'Unknown REST mapping for: ' + rpcName ); let url = `${this.baseUrl}/${RPC_URL_VERSION}/${path}:${urlRpcName}`; - if (this.apiKey) { - url = `${url}?key=${encodeURIComponent(this.apiKey)}`; + if (this.databaseInfo.apiKey) { + url = `${url}?key=${encodeURIComponent(this.databaseInfo.apiKey)}`; } return url; } diff --git a/packages/firestore/test/unit/remote/web_channel_connection.browser.test.ts b/packages/firestore/test/unit/remote/web_channel_connection.browser.test.ts new file mode 100644 index 00000000000..73f6ddf03d8 --- /dev/null +++ b/packages/firestore/test/unit/remote/web_channel_connection.browser.test.ts @@ -0,0 +1,71 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { expect } from 'chai'; + +import { DatabaseId, DatabaseInfo } from '../../../src/core/database_info'; +import { WebChannelConnection } from '../../../src/platform/browser/webchannel_connection'; +import { + WebChannelOptions, + WebChannelTransport +} from '@firebase/webchannel-wrapper'; + +export class TestWebChannelConnection extends WebChannelConnection { + public transport: { lastOptions?: WebChannelOptions } & WebChannelTransport = + { + lastOptions: undefined, + createWebChannel(url: string, options: WebChannelOptions): never { + this.lastOptions = options; + + // Throw here so we don't have to mock out any more of Web Channel + throw new Error('Not implemented for test'); + } + }; + protected createWebChannelTransport(): WebChannelTransport { + return this.transport; + } +} + +describe('WebChannelConnection', () => { + const testDatabaseInfo = new DatabaseInfo( + new DatabaseId('testproject'), + 'test-app-id', + 'persistenceKey', + 'example.com', + /*ssl=*/ false, + /*forceLongPolling=*/ false, + /*autoDetectLongPolling=*/ false, + /*longPollingOptions=*/ {}, + /*useFetchStreams=*/ false, + /*isUsingEmulator=*/ false, + 'wc-connection-test-api-key' + ); + + it('Passes the API Key from DatabaseInfo to makeHeaders for openStream', async () => { + const connection = new TestWebChannelConnection(testDatabaseInfo); + + expect(() => connection.openStream('mockRpc', null, null)).to.throw( + 'Not implemented for test' + ); + + const headers = connection.transport.lastOptions + ?.initMessageHeaders as unknown as { [key: string]: string }; + expect(headers['x-goog-api-key']).to.deep.equal( + 'wc-connection-test-api-key' + ); + }); +});