diff --git a/apps/webapp/app/trigger/deep-search/deep-search-utils.ts b/apps/webapp/app/trigger/deep-search/deep-search-utils.ts index 5c9df33..8d36bfc 100644 --- a/apps/webapp/app/trigger/deep-search/deep-search-utils.ts +++ b/apps/webapp/app/trigger/deep-search/deep-search-utils.ts @@ -1,6 +1,7 @@ import { type CoreMessage } from "ai"; import { logger } from "@trigger.dev/sdk/v3"; -import { generate, processTag } from "../chat/stream-utils"; +import { generate } from "./stream-utils"; +import { processTag } from "../chat/stream-utils"; import { type AgentMessage, AgentMessageType, Message } from "../chat/types"; import { TotalCost } from "../utils/types"; @@ -36,7 +37,7 @@ export async function* run( logger.info(`ReAct loop iteration ${guardLoop}, searches: ${searchCount}`); // Call LLM with current message history - const response = generate(messages, false, (event)=>{const usage = event.usage; + const response = generate(messages, (event)=>{const usage = event.usage; totalCost.inputTokens += usage.promptTokens; totalCost.outputTokens += usage.completionTokens; }, tools); @@ -95,34 +96,31 @@ export async function* run( } } - // Execute tool calls sequentially + // Execute tool calls in parallel for better performance if (toolCalls.length > 0) { + // Notify about all searches starting for (const toolCall of toolCalls) { - // Add assistant message with tool call - messages.push({ - role: "assistant", - content: [ - { - type: "tool-call", - toolCallId: toolCall.toolCallId, - toolName: toolCall.toolName, - args: toolCall.args, - }, - ], - }); - - // Execute the search tool logger.info(`Executing search: ${JSON.stringify(toolCall.args)}`); - - // Notify about search starting yield Message("", AgentMessageType.SKILL_START); yield Message( `\nSearching memory: "${toolCall.args.query}"...\n`, AgentMessageType.SKILL_CHUNK ); yield Message("", AgentMessageType.SKILL_END); + } - const result = await searchTool.execute(toolCall.args); + // Execute all searches in parallel + const searchPromises = toolCalls.map((toolCall) => + searchTool.execute(toolCall.args).then((result: any) => ({ + toolCall, + result, + })) + ); + + const searchResults = await Promise.all(searchPromises); + + // Process results and add to message history + for (const { toolCall, result } of searchResults) { searchCount++; // Deduplicate episodes - track unique IDs @@ -141,6 +139,18 @@ export async function* run( const episodesInThisSearch = result.episodes?.length || 0; totalEpisodesFound = seenEpisodeIds.size; // Use unique count + messages.push({ + role: "assistant", + content: [ + { + type: "tool-call", + toolCallId: toolCall.toolCallId, + toolName: toolCall.toolName, + args: toolCall.args, + }, + ], + }); + // Add tool result to message history messages.push({ role: "tool", diff --git a/apps/webapp/app/trigger/deep-search/stream-utils.ts b/apps/webapp/app/trigger/deep-search/stream-utils.ts new file mode 100644 index 0000000..11910c6 --- /dev/null +++ b/apps/webapp/app/trigger/deep-search/stream-utils.ts @@ -0,0 +1,68 @@ +import { openai } from "@ai-sdk/openai"; +import { logger } from "@trigger.dev/sdk/v3"; +import { + type CoreMessage, + type LanguageModelV1, + streamText, + type ToolSet, +} from "ai"; + +/** + * Generate LLM responses with tool calling support + * Simplified version for deep-search use case - NO maxSteps for manual ReAct control + */ +export async function* generate( + messages: CoreMessage[], + onFinish?: (event: any) => void, + tools?: ToolSet, + model?: string, +): AsyncGenerator< + | string + | { + type: string; + toolName: string; + args?: any; + toolCallId?: string; + } +> { + const modelToUse = model || process.env.MODEL || "gpt-4.1-2025-04-14"; + const modelInstance = openai(modelToUse) as LanguageModelV1; + + logger.info(`Starting LLM generation with model: ${modelToUse}`); + + try { + const { textStream, fullStream } = streamText({ + model: modelInstance, + messages, + temperature: 1, + tools, + // NO maxSteps - we handle tool execution manually in the ReAct loop + toolCallStreaming: true, + onFinish, + }); + + // Yield text chunks + for await (const chunk of textStream) { + yield chunk; + } + + // Yield tool calls + for await (const fullChunk of fullStream) { + if (fullChunk.type === "tool-call") { + yield { + type: "tool-call", + toolName: fullChunk.toolName, + toolCallId: fullChunk.toolCallId, + args: fullChunk.args, + }; + } + + if (fullChunk.type === "error") { + logger.error(`LLM error: ${JSON.stringify(fullChunk)}`); + } + } + } catch (error) { + logger.error(`LLM generation error: ${error}`); + throw error; + } +}