diff --git a/.changeset/late-plums-start.md b/.changeset/late-plums-start.md new file mode 100644 index 000000000..7d075e481 --- /dev/null +++ b/.changeset/late-plums-start.md @@ -0,0 +1,5 @@ +--- +"@browserbasehq/stagehand": patch +--- + +Add DeepseekAIClient implementation and integrate with LLMProvider diff --git a/packages/core/lib/v3/llm/DeepseekAIClient.ts b/packages/core/lib/v3/llm/DeepseekAIClient.ts new file mode 100644 index 000000000..426d100e4 --- /dev/null +++ b/packages/core/lib/v3/llm/DeepseekAIClient.ts @@ -0,0 +1,307 @@ +import OpenAI, { ClientOptions } from "openai"; +import { + ChatCompletionAssistantMessageParam, + ChatCompletionContentPartImage, + ChatCompletionContentPartText, + ChatCompletionCreateParamsNonStreaming, + ChatCompletionMessageParam, + ChatCompletionSystemMessageParam, + ChatCompletionUserMessageParam, +} from "openai/resources/chat"; +import zodToJsonSchema from "zod-to-json-schema"; +import { LogLine } from "../types/public/logs"; +import { AvailableModel } from "../types/public/model"; +import { validateZodSchema } from "../../utils"; +import { + ChatCompletionOptions, + ChatMessage, + CreateChatCompletionOptions, + LLMClient, + LLMResponse, +} from "./LLMClient"; +import { + CreateChatCompletionResponseError, + ZodSchemaValidationError, +} from "../types/public/sdkErrors"; + +export class DeepseekAIClient extends LLMClient { + public type = "deepseek" as const; + private client: OpenAI; + public clientOptions: ClientOptions; + + constructor({ + modelName, + clientOptions, + }: { + logger: (message: LogLine) => void; + modelName: AvailableModel; + clientOptions?: ClientOptions; + }) { + super(modelName); + this.clientOptions = clientOptions; + this.client = new OpenAI({ + ...clientOptions, + baseURL: "https://api.deepseek.com/v1", + }); + this.modelName = modelName; + } + + async createChatCompletion({ + options, + logger, + retries = 3, + }: CreateChatCompletionOptions): Promise { + const { requestId, ...optionsWithoutImageAndRequestId } = options; + + logger({ + category: "deepseek", + message: "creating chat completion", + level: 2, + auxiliary: { + options: { + value: JSON.stringify({ + ...optionsWithoutImageAndRequestId, + requestId, + }), + type: "object", + }, + modelName: { + value: this.modelName, + type: "string", + }, + }, + }); + + if (options.image) { + const screenshotMessage: ChatMessage = { + role: "user", + content: [ + { + type: "image_url", + image_url: { + url: `data:image/jpeg;base64,${options.image.buffer.toString( + "base64", + )}`, + }, + }, + ...(options.image.description + ? [{ type: "text", text: options.image.description }] + : []), + ], + }; + + options.messages.push(screenshotMessage); + } + + let responseFormat: { type: "json_object" } | undefined = undefined; + if (options.response_model) { + try { + const parsedSchema = JSON.stringify( + zodToJsonSchema(options.response_model.schema), + ); + options.messages.push({ + role: "user", + content: `Respond in this zod schema format:\n${parsedSchema}\n + You must respond in JSON format. respond WITH JSON. Do not include any other text, formatting or markdown in your output. Do not include \`\`\` or \`\`\`json in your response. Only the JSON object itself.`, + }); + responseFormat = { type: "json_object" }; + } catch (error) { + logger({ + category: "deepseek", + message: "Failed to parse response model schema", + level: 0, + }); + + if (retries > 0) { + return this.createChatCompletion({ + options: options as ChatCompletionOptions, + logger, + retries: retries - 1, + }); + } + + throw error; + } + } + + /* eslint-disable */ + const { response_model, ...deepseekOptions } = { + ...optionsWithoutImageAndRequestId, + model: this.modelName, + }; + /* eslint-enable */ + + logger({ + category: "deepseek", + message: "creating chat completion", + level: 2, + auxiliary: { + deepseekOptions: { + value: JSON.stringify(deepseekOptions), + type: "object", + }, + }, + }); + + const formattedMessages: ChatCompletionMessageParam[] = + options.messages.map((message) => { + if (Array.isArray(message.content)) { + const contentParts = message.content.map((content) => { + if ("image_url" in content) { + const imageContent: ChatCompletionContentPartImage = { + image_url: { + url: content.image_url.url, + }, + type: "image_url", + }; + return imageContent; + } else { + const textContent: ChatCompletionContentPartText = { + text: content.text, + type: "text", + }; + return textContent; + } + }); + + if (message.role === "system") { + const formattedMessage: ChatCompletionSystemMessageParam = { + ...message, + role: "system", + content: contentParts + .map((c) => (c.type === "text" ? c.text : "")) + .join("\n"), + }; + return formattedMessage; + } else if (message.role === "user") { + const formattedMessage: ChatCompletionUserMessageParam = { + ...message, + role: "user", + content: contentParts, + }; + return formattedMessage; + } else { + const formattedMessage: ChatCompletionAssistantMessageParam = { + ...message, + role: "assistant", + content: contentParts + .map((c) => (c.type === "text" ? c.text : "")) + .join("\n"), + }; + return formattedMessage; + } + } + + const formattedMessage: ChatCompletionUserMessageParam = { + role: "user", + content: message.content, + }; + + return formattedMessage; + }); + + const modelNameToUse = this.modelName.startsWith("deepseek/") + ? this.modelName.split("/")[1] + : this.modelName; + + const body: ChatCompletionCreateParamsNonStreaming = { + ...deepseekOptions, + model: modelNameToUse, + messages: formattedMessages, + response_format: responseFormat, + stream: false, + tools: options.tools?.map((tool) => ({ + function: { + name: tool.name, + description: tool.description, + parameters: tool.parameters, + }, + type: "function", + })), + }; + + const response = await this.client.chat.completions.create(body); + + logger({ + category: "deepseek", + message: "response", + level: 2, + auxiliary: { + response: { + value: JSON.stringify(response), + type: "object", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + if (options.response_model) { + const extractedData = response.choices[0]?.message.content; + + if (extractedData === null) { + const errorMessage = "Response content is null."; + logger({ + category: "deepseek", + message: errorMessage, + level: 0, + }); + if (retries > 0) { + return this.createChatCompletion({ + options: options as ChatCompletionOptions, + logger, + retries: retries - 1, + }); + } + throw new CreateChatCompletionResponseError(errorMessage); + } + + const parsedData = JSON.parse(extractedData); + + try { + validateZodSchema(options.response_model.schema, parsedData); + } catch (e) { + logger({ + category: "deepseek", + message: "Response failed Zod schema validation", + level: 0, + }); + if (retries > 0) { + return this.createChatCompletion({ + options: options as ChatCompletionOptions, + logger, + retries: retries - 1, + }); + } + + if (e instanceof ZodSchemaValidationError) { + logger({ + category: "deepseek", + message: `Error during Deepseek chat completion: ${e.message}`, + level: 0, + auxiliary: { + errorDetails: { + value: `Message: ${e.message}${ + e.stack ? "\nStack: " + e.stack : "" + }`, + type: "string", + }, + requestId: { value: requestId, type: "string" }, + }, + }); + throw new CreateChatCompletionResponseError(e.message); + } + throw e; + } + + return { + data: parsedData, + usage: response.usage, + } as T; + } + + return response as T; + } +} diff --git a/packages/core/lib/v3/llm/LLMProvider.ts b/packages/core/lib/v3/llm/LLMProvider.ts index 7c16f2118..38ac21d05 100644 --- a/packages/core/lib/v3/llm/LLMProvider.ts +++ b/packages/core/lib/v3/llm/LLMProvider.ts @@ -16,6 +16,7 @@ import { GoogleClient } from "./GoogleClient"; import { GroqClient } from "./GroqClient"; import { LLMClient } from "./LLMClient"; import { OpenAIClient } from "./OpenAIClient"; +import { DeepseekAIClient } from "./DeepseekAIClient"; import { openai, createOpenAI } from "@ai-sdk/openai"; import { anthropic, createAnthropic } from "@ai-sdk/anthropic"; import { google, createGoogleGenerativeAI } from "@ai-sdk/google"; @@ -91,6 +92,8 @@ const modelToProviderMap: { [key in AvailableModel]: ModelProvider } = { "gemini-2.0-flash": "google", "gemini-2.5-flash-preview-04-17": "google", "gemini-2.5-pro-preview-03-25": "google", + "deepseek/deepseek-chat": "deepseek", + "deepseek/deepseek-reasoner": "deepseek", }; export function getAISDKLanguageModel( @@ -129,8 +132,9 @@ export function getAISDKLanguageModel( export class LLMProvider { private logger: (message: LogLine) => void; - - constructor(logger: (message: LogLine) => void) { + private manual?: boolean = false; + constructor(logger: (message: LogLine) => void, manual?: boolean) { + this.manual = manual; this.logger = logger; } @@ -138,7 +142,7 @@ export class LLMProvider { modelName: AvailableModel, clientOptions?: ClientOptions, ): LLMClient { - if (modelName.includes("/")) { + if (modelName.includes("/") && !this.manual) { const firstSlashIndex = modelName.indexOf("/"); const subProvider = modelName.substring(0, firstSlashIndex); const subModelName = modelName.substring(firstSlashIndex + 1); @@ -192,6 +196,12 @@ export class LLMProvider { modelName: availableModel, clientOptions, }); + case "deepseek": + return new DeepseekAIClient({ + logger: this.logger, + modelName: availableModel, + clientOptions, + }); default: throw new UnsupportedModelProviderError([ ...new Set(Object.values(modelToProviderMap)), @@ -199,8 +209,11 @@ export class LLMProvider { } } - static getModelProvider(modelName: AvailableModel): ModelProvider { - if (modelName.includes("/")) { + static getModelProvider( + modelName: AvailableModel, + manual: boolean = false, + ): ModelProvider { + if (modelName.includes("/") && !manual) { const firstSlashIndex = modelName.indexOf("/"); const subProvider = modelName.substring(0, firstSlashIndex); if (AISDKProviders[subProvider]) { diff --git a/packages/core/lib/v3/types/public/model.ts b/packages/core/lib/v3/types/public/model.ts index ea8aa57da..777fd5ba6 100644 --- a/packages/core/lib/v3/types/public/model.ts +++ b/packages/core/lib/v3/types/public/model.ts @@ -56,15 +56,18 @@ export type AvailableModel = | "gemini-2.0-flash" | "gemini-2.5-flash-preview-04-17" | "gemini-2.5-pro-preview-03-25" + | "deepseek/deepseek-chat" + | "deepseek/deepseek-reasoner" | string; export type ModelProvider = | "openai" | "anthropic" - | "cerebras" - | "groq" | "google" - | "aisdk"; + | "groq" + | "cerebras" + | "aisdk" + | "deepseek"; export type ClientOptions = OpenAIClientOptions | AnthropicClientOptions; diff --git a/packages/core/lib/v3/types/public/options.ts b/packages/core/lib/v3/types/public/options.ts index 3e5f45e43..2a56100e3 100644 --- a/packages/core/lib/v3/types/public/options.ts +++ b/packages/core/lib/v3/types/public/options.ts @@ -86,4 +86,5 @@ export interface V3Options { cacheDir?: string; domSettleTimeout?: number; disableAPI?: boolean; + manual?: boolean; } diff --git a/packages/core/lib/v3/v3.ts b/packages/core/lib/v3/v3.ts index b7457f000..7d92714ee 100644 --- a/packages/core/lib/v3/v3.ts +++ b/packages/core/lib/v3/v3.ts @@ -136,7 +136,7 @@ export class V3 { } private _onCdpClosed = (why: string) => { // Single place to react to the transport closing - this._immediateShutdown(`CDP transport closed: ${why}`).catch(() => {}); + this._immediateShutdown(`CDP transport closed: ${why}`).catch(() => { }); }; public readonly experimental: boolean = false; public readonly logInferenceToFile: boolean = false; @@ -173,6 +173,7 @@ export class V3 { constructor(opts: V3Options) { V3._installProcessGuards(); + this.externalLogger = opts.logger; this.verbose = opts.verbose ?? 1; this.instanceId = @@ -213,7 +214,7 @@ export class V3 { this.modelName = modelName; this.experimental = opts.experimental ?? false; this.logInferenceToFile = opts.logInferenceToFile ?? false; - this.llmProvider = new LLMProvider(this.logger); + this.llmProvider = new LLMProvider(this.logger, this.opts?.manual ?? false); this.domSettleTimeoutMs = opts.domSettleTimeout; this.disableAPI = opts.disableAPI ?? false; const baseClientOptions: ClientOptions = clientOptions @@ -611,7 +612,7 @@ export class V3 { kind: "LOCAL", // no LaunchedChrome when attaching externally; create a stub kill chrome: { - kill: async () => {}, + kill: async () => { }, } as unknown as import("chrome-launcher").LaunchedChrome, ws: lbo.cdpUrl, }; @@ -840,7 +841,7 @@ export class V3 { downloadPath: lbo.downloadsPath, eventsEnabled: true, }) - .catch(() => {}); + .catch(() => { }); } } catch { // best-effort only @@ -1418,7 +1419,7 @@ export class V3 { if (!isCua) { throw new Error( "To use the computer use agent, please provide a CUA model in the agent constructor or stagehand config. Try one of our supported CUA models: " + - AVAILABLE_CUA_MODELS.join(", "), + AVAILABLE_CUA_MODELS.join(", "), ); } diff --git a/packages/core/lib/v3Evaluator.ts b/packages/core/lib/v3Evaluator.ts index 149e09347..b79ef0319 100644 --- a/packages/core/lib/v3Evaluator.ts +++ b/packages/core/lib/v3Evaluator.ts @@ -29,16 +29,19 @@ const BatchEvaluationSchema = z.array(EvaluationSchema); export class V3Evaluator { private v3: V3; + private manual?: boolean; private modelName: AvailableModel; private modelClientOptions: ClientOptions | { apiKey: string }; private silentLogger: (message: LogLine) => void = () => {}; constructor( v3: V3, + manual?: boolean, modelName?: AvailableModel, modelClientOptions?: ClientOptions, ) { this.v3 = v3; + this.manual = manual; this.modelName = modelName || ("google/gemini-2.5-flash" as AvailableModel); this.modelClientOptions = modelClientOptions || { apiKey: @@ -50,7 +53,7 @@ export class V3Evaluator { private getClient(): LLMClient { // Prefer a dedicated provider so we can override model per-evaluation - const provider = new LLMProvider(this.v3.logger); + const provider = new LLMProvider(this.v3.logger, this.manual); return provider.getClient(this.modelName, this.modelClientOptions); }