mirror of
https://github.com/eliasstepanik/core.git
synced 2026-01-10 08:48:29 +00:00
493 lines
14 KiB
TypeScript
493 lines
14 KiB
TypeScript
/* eslint-disable @typescript-eslint/no-explicit-any */
|
|
|
|
import { ActionStatusEnum } from "@core/types";
|
|
import { logger } from "@trigger.dev/sdk/v3";
|
|
import {
|
|
type CoreMessage,
|
|
type DataContent,
|
|
jsonSchema,
|
|
tool,
|
|
type ToolSet,
|
|
} from "ai";
|
|
import axios from "axios";
|
|
import Handlebars from "handlebars";
|
|
|
|
import { REACT_SYSTEM_PROMPT, REACT_USER_PROMPT } from "./prompt";
|
|
import { generate, processTag } from "./stream-utils";
|
|
import { type AgentMessage, AgentMessageType, Message } from "./types";
|
|
import { type MCP } from "../utils/mcp";
|
|
import {
|
|
type ExecutionState,
|
|
type HistoryStep,
|
|
type Resource,
|
|
type TotalCost,
|
|
} from "../utils/types";
|
|
import { flattenObject } from "../utils/utils";
|
|
|
|
interface LLMOutputInterface {
|
|
response: AsyncGenerator<
|
|
| string
|
|
| {
|
|
type: string;
|
|
toolName: string;
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
args?: any;
|
|
toolCallId?: string;
|
|
message?: string;
|
|
},
|
|
any,
|
|
any
|
|
>;
|
|
}
|
|
|
|
const progressUpdateTool = tool({
|
|
description:
|
|
"Send a progress update to the user about what has been discovered or will be done next in a crisp and user friendly way no technical terms",
|
|
parameters: jsonSchema({
|
|
type: "object",
|
|
properties: {
|
|
message: {
|
|
type: "string",
|
|
description: "The progress update message to send to the user",
|
|
},
|
|
},
|
|
required: ["message"],
|
|
additionalProperties: false,
|
|
}),
|
|
});
|
|
|
|
const internalTools = ["core--progress_update"];
|
|
|
|
async function addResources(messages: CoreMessage[], resources: Resource[]) {
|
|
const resourcePromises = resources.map(async (resource) => {
|
|
// Remove everything before "/api" in the publicURL
|
|
if (resource.publicURL) {
|
|
const apiIndex = resource.publicURL.indexOf("/api");
|
|
if (apiIndex !== -1) {
|
|
resource.publicURL = resource.publicURL.substring(apiIndex);
|
|
}
|
|
}
|
|
const response = await axios.get(resource.publicURL, {
|
|
responseType: "arraybuffer",
|
|
});
|
|
|
|
if (resource.fileType.startsWith("image/")) {
|
|
return {
|
|
type: "image",
|
|
image: response.data as DataContent,
|
|
};
|
|
}
|
|
|
|
return {
|
|
type: "file",
|
|
data: response.data as DataContent,
|
|
|
|
mimeType: resource.fileType,
|
|
};
|
|
});
|
|
|
|
const content = await Promise.all(resourcePromises);
|
|
|
|
return [...messages, { role: "user", content } as CoreMessage];
|
|
}
|
|
|
|
function toolToMessage(history: HistoryStep[], messages: CoreMessage[]) {
|
|
for (let i = 0; i < history.length; i++) {
|
|
const step = history[i];
|
|
|
|
// Add assistant message with tool calls
|
|
if (step.observation && step.skillId) {
|
|
messages.push({
|
|
role: "assistant",
|
|
content: [
|
|
{
|
|
type: "tool-call",
|
|
toolCallId: step.skillId,
|
|
toolName: step.skill ?? "",
|
|
args:
|
|
typeof step.skillInput === "string"
|
|
? JSON.parse(step.skillInput)
|
|
: step.skillInput,
|
|
},
|
|
],
|
|
});
|
|
|
|
messages.push({
|
|
role: "tool",
|
|
content: [
|
|
{
|
|
type: "tool-result",
|
|
toolName: step.skill,
|
|
toolCallId: step.skillId,
|
|
result: step.observation,
|
|
isError: step.isError,
|
|
},
|
|
],
|
|
} as any);
|
|
}
|
|
// Handle format correction steps (observation exists but no skillId)
|
|
else if (step.observation && !step.skillId) {
|
|
// Add as a system message for format correction
|
|
messages.push({
|
|
role: "system",
|
|
content: step.observation,
|
|
});
|
|
}
|
|
}
|
|
|
|
return messages;
|
|
}
|
|
|
|
async function makeNextCall(
|
|
executionState: ExecutionState,
|
|
TOOLS: ToolSet,
|
|
totalCost: TotalCost,
|
|
guardLoop: number,
|
|
): Promise<LLMOutputInterface> {
|
|
const { context, history, previousHistory } = executionState;
|
|
|
|
const promptInfo = {
|
|
USER_MESSAGE: executionState.query,
|
|
CONTEXT: context,
|
|
USER_MEMORY: executionState.userMemoryContext,
|
|
};
|
|
|
|
let messages: CoreMessage[] = [];
|
|
|
|
const systemTemplateHandler = Handlebars.compile(REACT_SYSTEM_PROMPT);
|
|
let systemPrompt = systemTemplateHandler(promptInfo);
|
|
|
|
const userTemplateHandler = Handlebars.compile(REACT_USER_PROMPT);
|
|
const userPrompt = userTemplateHandler(promptInfo);
|
|
|
|
// Always start with a system message (this does use tokens but keeps the instructions clear)
|
|
messages.push({ role: "system", content: systemPrompt });
|
|
|
|
// For subsequent queries, include only final responses from previous exchanges if available
|
|
if (previousHistory && previousHistory.length > 0) {
|
|
messages = [...messages, ...previousHistory];
|
|
}
|
|
|
|
// Add the current user query (much simpler than the full prompt)
|
|
messages.push({ role: "user", content: userPrompt });
|
|
|
|
// Include any steps from the current interaction
|
|
if (history.length > 0) {
|
|
messages = toolToMessage(history, messages);
|
|
}
|
|
|
|
if (executionState.resources && executionState.resources.length > 0) {
|
|
messages = await addResources(messages, executionState.resources);
|
|
}
|
|
|
|
// Get the next action from the LLM
|
|
const response = generate(
|
|
messages,
|
|
guardLoop > 0 && guardLoop % 3 === 0,
|
|
(event) => {
|
|
const usage = event.usage;
|
|
totalCost.inputTokens += usage.promptTokens;
|
|
totalCost.outputTokens += usage.completionTokens;
|
|
},
|
|
TOOLS,
|
|
);
|
|
|
|
return { response };
|
|
}
|
|
|
|
export async function* run(
|
|
message: string,
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
context: Record<string, any>,
|
|
previousHistory: CoreMessage[],
|
|
mcp: MCP,
|
|
stepHistory: HistoryStep[],
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
): AsyncGenerator<AgentMessage, any, any> {
|
|
let guardLoop = 0;
|
|
|
|
let tools = {
|
|
...(await mcp.allTools()),
|
|
"core--progress_update": progressUpdateTool,
|
|
};
|
|
|
|
logger.info("Tools have been formed");
|
|
|
|
let contextText = "";
|
|
let resources = [];
|
|
if (context) {
|
|
// Extract resources and remove from context
|
|
resources = context.resources || [];
|
|
delete context.resources;
|
|
|
|
// Process remaining context
|
|
contextText = flattenObject(context).join("\n");
|
|
}
|
|
|
|
const executionState: ExecutionState = {
|
|
query: message,
|
|
context: contextText,
|
|
resources,
|
|
previousHistory,
|
|
history: stepHistory, // Track the full ReAct history
|
|
completed: false,
|
|
};
|
|
|
|
const totalCost: TotalCost = { inputTokens: 0, outputTokens: 0, cost: 0 };
|
|
|
|
try {
|
|
while (!executionState.completed && guardLoop < 50) {
|
|
logger.info(`Starting the loop: ${guardLoop}`);
|
|
|
|
const { response: llmResponse } = await makeNextCall(
|
|
executionState,
|
|
tools,
|
|
totalCost,
|
|
guardLoop,
|
|
);
|
|
|
|
let toolCallInfo;
|
|
|
|
const messageState = {
|
|
inTag: false,
|
|
message: "",
|
|
messageEnded: false,
|
|
lastSent: "",
|
|
};
|
|
|
|
const questionState = {
|
|
inTag: false,
|
|
message: "",
|
|
messageEnded: false,
|
|
lastSent: "",
|
|
};
|
|
|
|
let totalMessage = "";
|
|
const toolCalls = [];
|
|
|
|
// LLM thought response
|
|
for await (const chunk of llmResponse) {
|
|
if (typeof chunk === "object" && chunk.type === "tool-call") {
|
|
toolCallInfo = chunk;
|
|
toolCalls.push(chunk);
|
|
}
|
|
|
|
totalMessage += chunk;
|
|
|
|
if (!messageState.messageEnded) {
|
|
yield* processTag(
|
|
messageState,
|
|
totalMessage,
|
|
chunk as string,
|
|
"<final_response>",
|
|
"</final_response>",
|
|
{
|
|
start: AgentMessageType.MESSAGE_START,
|
|
chunk: AgentMessageType.MESSAGE_CHUNK,
|
|
end: AgentMessageType.MESSAGE_END,
|
|
},
|
|
);
|
|
}
|
|
|
|
if (!questionState.messageEnded) {
|
|
yield* processTag(
|
|
questionState,
|
|
totalMessage,
|
|
chunk as string,
|
|
"<question_response>",
|
|
"</question_response>",
|
|
{
|
|
start: AgentMessageType.MESSAGE_START,
|
|
chunk: AgentMessageType.MESSAGE_CHUNK,
|
|
end: AgentMessageType.MESSAGE_END,
|
|
},
|
|
);
|
|
}
|
|
}
|
|
|
|
logger.info(`Cost for thought: ${JSON.stringify(totalCost)}`);
|
|
|
|
// Replace the error-handling block with this self-correcting implementation
|
|
if (
|
|
!totalMessage.includes("final_response") &&
|
|
!totalMessage.includes("question_response") &&
|
|
!toolCallInfo
|
|
) {
|
|
// Log the issue for debugging
|
|
logger.info(
|
|
`Invalid response format detected. Attempting to get proper format.`,
|
|
);
|
|
|
|
// Extract the raw content from the invalid response
|
|
const rawContent = totalMessage
|
|
.replace(/(<[^>]*>|<\/[^>]*>)/g, "")
|
|
.trim();
|
|
|
|
// Create a correction step
|
|
const stepRecord: HistoryStep = {
|
|
thought: "",
|
|
skill: "",
|
|
skillId: "",
|
|
userMessage: "Core agent error, retrying \n",
|
|
isQuestion: false,
|
|
isFinal: false,
|
|
tokenCount: totalCost,
|
|
skillInput: "",
|
|
observation: `Your last response was not in a valid format. You must respond with EXACTLY ONE of the required formats: either a tool call, <question_response> tags, or <final_response> tags. Please reformat your previous response using the correct format:\n\n${rawContent}`,
|
|
};
|
|
|
|
yield Message("", AgentMessageType.MESSAGE_START);
|
|
yield Message(
|
|
stepRecord.userMessage as string,
|
|
AgentMessageType.MESSAGE_CHUNK,
|
|
);
|
|
yield Message("", AgentMessageType.MESSAGE_END);
|
|
|
|
// Add this step to the history
|
|
yield Message(JSON.stringify(stepRecord), AgentMessageType.STEP);
|
|
executionState.history.push(stepRecord);
|
|
|
|
// Log that we're continuing the loop with a correction request
|
|
logger.info(`Added format correction request to history.`);
|
|
|
|
// Don't mark as completed - let the loop continue
|
|
guardLoop++; // Still increment to prevent infinite loops
|
|
continue;
|
|
}
|
|
|
|
// Record this step in history
|
|
const stepRecord: HistoryStep = {
|
|
thought: "",
|
|
skill: "",
|
|
skillId: "",
|
|
userMessage: "",
|
|
isQuestion: false,
|
|
isFinal: false,
|
|
tokenCount: totalCost,
|
|
skillInput: "",
|
|
};
|
|
|
|
if (totalMessage && totalMessage.includes("final_response")) {
|
|
executionState.completed = true;
|
|
stepRecord.isFinal = true;
|
|
stepRecord.userMessage = messageState.message;
|
|
stepRecord.finalTokenCount = totalCost;
|
|
stepRecord.skillStatus = ActionStatusEnum.SUCCESS;
|
|
yield Message(JSON.stringify(stepRecord), AgentMessageType.STEP);
|
|
executionState.history.push(stepRecord);
|
|
break;
|
|
}
|
|
|
|
if (totalMessage && totalMessage.includes("question_response")) {
|
|
executionState.completed = true;
|
|
stepRecord.isQuestion = true;
|
|
stepRecord.userMessage = questionState.message;
|
|
stepRecord.finalTokenCount = totalCost;
|
|
stepRecord.skillStatus = ActionStatusEnum.QUESTION;
|
|
yield Message(JSON.stringify(stepRecord), AgentMessageType.STEP);
|
|
executionState.history.push(stepRecord);
|
|
break;
|
|
}
|
|
|
|
if (toolCalls && toolCalls.length > 0) {
|
|
// Run all tool calls in parallel
|
|
for (const toolCallInfo of toolCalls) {
|
|
const skillName = toolCallInfo.toolName;
|
|
const skillId = toolCallInfo.toolCallId;
|
|
const skillInput = toolCallInfo.args;
|
|
|
|
const toolName = skillName.split("--")[1];
|
|
const agent = skillName.split("--")[0];
|
|
|
|
const stepRecord: HistoryStep = {
|
|
agent,
|
|
thought: "",
|
|
skill: skillName,
|
|
skillId,
|
|
userMessage: "",
|
|
isQuestion: false,
|
|
isFinal: false,
|
|
tokenCount: totalCost,
|
|
skillInput: JSON.stringify(skillInput),
|
|
};
|
|
|
|
if (!internalTools.includes(skillName)) {
|
|
const skillMessageToSend = `\n<skill id="${skillId}" name="${toolName}" agent="${agent}"></skill>\n`;
|
|
|
|
stepRecord.userMessage += skillMessageToSend;
|
|
|
|
yield Message("", AgentMessageType.MESSAGE_START);
|
|
yield Message(skillMessageToSend, AgentMessageType.MESSAGE_CHUNK);
|
|
yield Message("", AgentMessageType.MESSAGE_END);
|
|
}
|
|
|
|
let result;
|
|
try {
|
|
// Log skill execution details
|
|
logger.info(`Executing skill: ${skillName}`);
|
|
logger.info(`Input parameters: ${JSON.stringify(skillInput)}`);
|
|
|
|
if (!internalTools.includes(toolName)) {
|
|
yield Message(
|
|
JSON.stringify({ skillId, status: "start" }),
|
|
AgentMessageType.SKILL_START,
|
|
);
|
|
}
|
|
|
|
// Handle CORE agent tools
|
|
if (agent === "core") {
|
|
if (toolName === "progress_update") {
|
|
yield Message("", AgentMessageType.MESSAGE_START);
|
|
yield Message(
|
|
skillInput.message,
|
|
AgentMessageType.MESSAGE_CHUNK,
|
|
);
|
|
stepRecord.userMessage += skillInput.message;
|
|
yield Message("", AgentMessageType.MESSAGE_END);
|
|
result = "Progress update sent successfully";
|
|
}
|
|
}
|
|
// Handle other MCP tools
|
|
else {
|
|
result = await mcp.callTool(skillName, skillInput);
|
|
|
|
yield Message(
|
|
JSON.stringify({ result, skillId }),
|
|
AgentMessageType.SKILL_CHUNK,
|
|
);
|
|
}
|
|
|
|
yield Message(
|
|
JSON.stringify({ skillId, status: "end" }),
|
|
AgentMessageType.SKILL_END,
|
|
);
|
|
|
|
stepRecord.skillOutput =
|
|
typeof result === "object"
|
|
? JSON.stringify(result, null, 2)
|
|
: result;
|
|
stepRecord.observation = stepRecord.skillOutput;
|
|
} catch (e) {
|
|
console.log(e);
|
|
logger.error(e as string);
|
|
stepRecord.skillInput = skillInput;
|
|
stepRecord.observation = JSON.stringify(e);
|
|
stepRecord.isError = true;
|
|
}
|
|
|
|
logger.info(`Skill step: ${JSON.stringify(stepRecord)}`);
|
|
|
|
yield Message(JSON.stringify(stepRecord), AgentMessageType.STEP);
|
|
executionState.history.push(stepRecord);
|
|
}
|
|
}
|
|
guardLoop++;
|
|
}
|
|
yield Message("Stream ended", AgentMessageType.STREAM_END);
|
|
} catch (e) {
|
|
logger.error(e as string);
|
|
yield Message((e as Error).message, AgentMessageType.ERROR);
|
|
yield Message("Stream ended", AgentMessageType.STREAM_END);
|
|
}
|
|
}
|