mirror of
https://github.com/eliasstepanik/core.git
synced 2026-01-22 09:18:27 +00:00
refactor: simplify clustered graph query and add stop conditions for AI responses
This commit is contained in:
parent
8836849310
commit
6a05ea4f37
@ -112,51 +112,31 @@ export const getNodeLinks = async (userId: string) => {
|
|||||||
export const getClusteredGraphData = async (userId: string) => {
|
export const getClusteredGraphData = async (userId: string) => {
|
||||||
const session = driver.session();
|
const session = driver.session();
|
||||||
try {
|
try {
|
||||||
// Get the simplified graph structure: Episode, Subject, Object with Predicate as edge
|
// Get Episode -> Entity graph, only showing entities connected to more than 1 episode
|
||||||
// Only include entities that are connected to more than 1 episode
|
|
||||||
const result = await session.run(
|
const result = await session.run(
|
||||||
`// Find entities connected to more than 1 episode
|
`// Find entities connected to more than 1 episode
|
||||||
MATCH (e:Episode)-[:HAS_PROVENANCE]->(s:Statement {userId: $userId})
|
MATCH (e:Episode{userId: $userId})-[:HAS_PROVENANCE]->(s:Statement {userId: $userId})-[r:HAS_SUBJECT|HAS_OBJECT|HAS_PREDICATE]->(entity:Entity)
|
||||||
MATCH (s)-[:HAS_SUBJECT|HAS_OBJECT|HAS_PREDICATE]->(ent:Entity)
|
WITH entity, count(DISTINCT e) as episodeCount
|
||||||
WITH ent, count(DISTINCT e) as episodeCount
|
|
||||||
WHERE episodeCount > 1
|
WHERE episodeCount > 1
|
||||||
WITH collect(ent.uuid) as validEntityUuids
|
WITH collect(entity.uuid) as validEntityUuids
|
||||||
|
|
||||||
// Get statements where all entities are in the valid set
|
// Build Episode -> Entity relationships for valid entities
|
||||||
MATCH (e:Episode)-[:HAS_PROVENANCE]->(s:Statement {userId: $userId})
|
MATCH (e:Episode{userId: $userId})-[r:HAS_PROVENANCE]->(s:Statement {userId: $userId})-[r:HAS_SUBJECT|HAS_OBJECT|HAS_PREDICATE]->(entity:Entity)
|
||||||
MATCH (s)-[:HAS_SUBJECT]->(subj:Entity)
|
WHERE entity.uuid IN validEntityUuids
|
||||||
WHERE subj.uuid IN validEntityUuids
|
WITH DISTINCT e, entity, type(r) as relType,
|
||||||
MATCH (s)-[:HAS_PREDICATE]->(pred:Entity)
|
CASE WHEN size(e.spaceIds) > 0 THEN e.spaceIds[0] ELSE null END as clusterId,
|
||||||
WHERE pred.uuid IN validEntityUuids
|
s.createdAt as createdAt
|
||||||
MATCH (s)-[:HAS_OBJECT]->(obj:Entity)
|
|
||||||
WHERE obj.uuid IN validEntityUuids
|
|
||||||
|
|
||||||
// Build relationships
|
|
||||||
WITH e, s, subj, pred, obj
|
|
||||||
UNWIND [
|
|
||||||
// Episode -> Subject
|
|
||||||
{source: e, sourceType: 'Episode', target: subj, targetType: 'Entity', predicate: null},
|
|
||||||
// Episode -> Object
|
|
||||||
{source: e, sourceType: 'Episode', target: obj, targetType: 'Entity', predicate: null},
|
|
||||||
// Subject -> Object (with Predicate as edge)
|
|
||||||
{source: subj, sourceType: 'Entity', target: obj, targetType: 'Entity', predicate: pred.name}
|
|
||||||
] AS rel
|
|
||||||
|
|
||||||
RETURN DISTINCT
|
RETURN DISTINCT
|
||||||
rel.source.uuid as sourceUuid,
|
e.uuid as sourceUuid,
|
||||||
rel.source.name as sourceName,
|
e.content as sourceContent,
|
||||||
rel.source.content as sourceContent,
|
'Episode' as sourceNodeType,
|
||||||
rel.sourceType as sourceNodeType,
|
entity.uuid as targetUuid,
|
||||||
rel.target.uuid as targetUuid,
|
entity.name as targetName,
|
||||||
rel.target.name as targetName,
|
'Entity' as targetNodeType,
|
||||||
rel.targetType as targetNodeType,
|
relType as edgeType,
|
||||||
rel.predicate as predicateLabel,
|
clusterId,
|
||||||
e.uuid as episodeUuid,
|
createdAt`,
|
||||||
e.content as episodeContent,
|
|
||||||
e.spaceIds as spaceIds,
|
|
||||||
s.uuid as statementUuid,
|
|
||||||
s.validAt as validAt,
|
|
||||||
s.createdAt as createdAt`,
|
|
||||||
{ userId },
|
{ userId },
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -165,72 +145,29 @@ export const getClusteredGraphData = async (userId: string) => {
|
|||||||
|
|
||||||
result.records.forEach((record) => {
|
result.records.forEach((record) => {
|
||||||
const sourceUuid = record.get("sourceUuid");
|
const sourceUuid = record.get("sourceUuid");
|
||||||
const sourceName = record.get("sourceName");
|
|
||||||
const sourceContent = record.get("sourceContent");
|
const sourceContent = record.get("sourceContent");
|
||||||
const sourceNodeType = record.get("sourceNodeType");
|
|
||||||
|
|
||||||
const targetUuid = record.get("targetUuid");
|
const targetUuid = record.get("targetUuid");
|
||||||
const targetName = record.get("targetName");
|
const targetName = record.get("targetName");
|
||||||
const targetNodeType = record.get("targetNodeType");
|
const edgeType = record.get("edgeType");
|
||||||
|
const clusterId = record.get("clusterId");
|
||||||
const predicateLabel = record.get("predicateLabel");
|
|
||||||
const episodeUuid = record.get("episodeUuid");
|
|
||||||
const clusterIds = record.get("spaceIds");
|
|
||||||
const clusterId = clusterIds ? clusterIds[0] : undefined;
|
|
||||||
const createdAt = record.get("createdAt");
|
const createdAt = record.get("createdAt");
|
||||||
|
|
||||||
// Create unique edge identifier to avoid duplicates
|
// Create unique edge identifier to avoid duplicates
|
||||||
// For Episode->Subject edges, use generic type; for Subject->Object use predicate
|
|
||||||
const edgeType = predicateLabel || "HAS_SUBJECT";
|
|
||||||
const edgeKey = `${sourceUuid}-${targetUuid}-${edgeType}`;
|
const edgeKey = `${sourceUuid}-${targetUuid}-${edgeType}`;
|
||||||
if (processedEdges.has(edgeKey)) return;
|
if (processedEdges.has(edgeKey)) return;
|
||||||
processedEdges.add(edgeKey);
|
processedEdges.add(edgeKey);
|
||||||
|
|
||||||
// Build node attributes based on type
|
|
||||||
const sourceAttributes =
|
|
||||||
sourceNodeType === "Episode"
|
|
||||||
? {
|
|
||||||
nodeType: "Episode",
|
|
||||||
content: sourceContent,
|
|
||||||
episodeUuid: sourceUuid,
|
|
||||||
clusterId,
|
|
||||||
}
|
|
||||||
: {
|
|
||||||
nodeType: "Entity",
|
|
||||||
name: sourceName,
|
|
||||||
clusterId,
|
|
||||||
};
|
|
||||||
|
|
||||||
const targetAttributes =
|
|
||||||
targetNodeType === "Episode"
|
|
||||||
? {
|
|
||||||
nodeType: "Episode",
|
|
||||||
content: sourceContent,
|
|
||||||
episodeUuid: targetUuid,
|
|
||||||
clusterId,
|
|
||||||
}
|
|
||||||
: {
|
|
||||||
nodeType: "Entity",
|
|
||||||
name: targetName,
|
|
||||||
clusterId,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Build display name
|
|
||||||
const sourceDisplayName =
|
|
||||||
sourceNodeType === "Episode"
|
|
||||||
? sourceContent || episodeUuid
|
|
||||||
: sourceName || sourceUuid;
|
|
||||||
const targetDisplayName =
|
|
||||||
targetNodeType === "Episode"
|
|
||||||
? sourceContent || episodeUuid
|
|
||||||
: targetName || targetUuid;
|
|
||||||
|
|
||||||
triplets.push({
|
triplets.push({
|
||||||
sourceNode: {
|
sourceNode: {
|
||||||
uuid: sourceUuid,
|
uuid: sourceUuid,
|
||||||
labels: [sourceNodeType],
|
labels: ["Episode"],
|
||||||
attributes: sourceAttributes,
|
attributes: {
|
||||||
name: sourceDisplayName,
|
nodeType: "Episode",
|
||||||
|
content: sourceContent,
|
||||||
|
episodeUuid: sourceUuid,
|
||||||
|
clusterId,
|
||||||
|
},
|
||||||
|
name: sourceContent || sourceUuid,
|
||||||
clusterId,
|
clusterId,
|
||||||
createdAt: createdAt || "",
|
createdAt: createdAt || "",
|
||||||
},
|
},
|
||||||
@ -243,10 +180,14 @@ export const getClusteredGraphData = async (userId: string) => {
|
|||||||
},
|
},
|
||||||
targetNode: {
|
targetNode: {
|
||||||
uuid: targetUuid,
|
uuid: targetUuid,
|
||||||
labels: [targetNodeType],
|
labels: ["Entity"],
|
||||||
attributes: targetAttributes,
|
attributes: {
|
||||||
|
nodeType: "Entity",
|
||||||
|
name: targetName,
|
||||||
|
clusterId,
|
||||||
|
},
|
||||||
|
name: targetName || targetUuid,
|
||||||
clusterId,
|
clusterId,
|
||||||
name: targetDisplayName,
|
|
||||||
createdAt: createdAt || "",
|
createdAt: createdAt || "",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|||||||
@ -1,6 +1,19 @@
|
|||||||
import { tool } from "ai";
|
import { StopCondition, tool } from "ai";
|
||||||
import z from "zod";
|
import z from "zod";
|
||||||
|
|
||||||
|
export const hasAnswer: StopCondition<any> = ({ steps }) => {
|
||||||
|
return (
|
||||||
|
steps.some((step) => step.text?.includes("</final_response>")) ?? false
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const hasQuestion: StopCondition<any> = ({ steps }) => {
|
||||||
|
return (
|
||||||
|
steps.some((step) => step.text?.includes("</question_response>")) ??
|
||||||
|
false
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
export const REACT_SYSTEM_PROMPT = `
|
export const REACT_SYSTEM_PROMPT = `
|
||||||
You are a helpful AI assistant with access to user memory. Your primary capabilities are:
|
You are a helpful AI assistant with access to user memory. Your primary capabilities are:
|
||||||
|
|
||||||
@ -128,18 +141,18 @@ PROGRESS UPDATES - During processing:
|
|||||||
- Avoid technical jargon
|
- Avoid technical jargon
|
||||||
|
|
||||||
QUESTIONS - When you need information:
|
QUESTIONS - When you need information:
|
||||||
<div>
|
<question_response>
|
||||||
<p>[Your question with HTML formatting]</p>
|
<p>[Your question with HTML formatting]</p>
|
||||||
</div>
|
</question_response>
|
||||||
|
|
||||||
- Ask questions only when you cannot find information through memory, or tools
|
- Ask questions only when you cannot find information through memory, or tools
|
||||||
- Be specific about what you need to know
|
- Be specific about what you need to know
|
||||||
- Provide context for why you're asking
|
- Provide context for why you're asking
|
||||||
|
|
||||||
FINAL ANSWERS - When completing tasks:
|
FINAL ANSWERS - When completing tasks:
|
||||||
<div>
|
<final_response>
|
||||||
<p>[Your answer with HTML formatting]</p>
|
<p>[Your answer with HTML formatting]</p>
|
||||||
</div>
|
</final_response>
|
||||||
|
|
||||||
CRITICAL:
|
CRITICAL:
|
||||||
- Use ONE format per turn
|
- Use ONE format per turn
|
||||||
|
|||||||
@ -5,6 +5,8 @@ import {
|
|||||||
type LanguageModel,
|
type LanguageModel,
|
||||||
experimental_createMCPClient as createMCPClient,
|
experimental_createMCPClient as createMCPClient,
|
||||||
generateId,
|
generateId,
|
||||||
|
stepCountIs,
|
||||||
|
StopCondition,
|
||||||
} from "ai";
|
} from "ai";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js";
|
import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js";
|
||||||
@ -19,7 +21,7 @@ import { getModel } from "~/lib/model.server";
|
|||||||
import { UserTypeEnum } from "@core/types";
|
import { UserTypeEnum } from "@core/types";
|
||||||
import { nanoid } from "nanoid";
|
import { nanoid } from "nanoid";
|
||||||
import { getOrCreatePersonalAccessToken } from "~/services/personalAccessToken.server";
|
import { getOrCreatePersonalAccessToken } from "~/services/personalAccessToken.server";
|
||||||
import { REACT_SYSTEM_PROMPT } from "~/lib/prompt.server";
|
import { hasAnswer, hasQuestion, REACT_SYSTEM_PROMPT } from "~/lib/prompt.server";
|
||||||
import { enqueueCreateConversationTitle } from "~/lib/queue-adapter.server";
|
import { enqueueCreateConversationTitle } from "~/lib/queue-adapter.server";
|
||||||
import { env } from "~/env.server";
|
import { env } from "~/env.server";
|
||||||
|
|
||||||
@ -84,7 +86,7 @@ const { loader, action } = createHybridActionApiRoute(
|
|||||||
await createConversationHistory(message, body.id, UserTypeEnum.User);
|
await createConversationHistory(message, body.id, UserTypeEnum.User);
|
||||||
}
|
}
|
||||||
|
|
||||||
const messages = conversationHistory.map((history) => {
|
const messages = conversationHistory.map((history: any) => {
|
||||||
return {
|
return {
|
||||||
parts: [{ text: history.message, type: "text" }],
|
parts: [{ text: history.message, type: "text" }],
|
||||||
role: "user",
|
role: "user",
|
||||||
@ -94,8 +96,6 @@ const { loader, action } = createHybridActionApiRoute(
|
|||||||
|
|
||||||
const tools = { ...(await mcpClient.tools()) };
|
const tools = { ...(await mcpClient.tools()) };
|
||||||
|
|
||||||
// console.log(tools);
|
|
||||||
|
|
||||||
const finalMessages = [
|
const finalMessages = [
|
||||||
...messages,
|
...messages,
|
||||||
{
|
{
|
||||||
@ -109,6 +109,8 @@ const { loader, action } = createHybridActionApiRoute(
|
|||||||
messages: finalMessages,
|
messages: finalMessages,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
const result = streamText({
|
const result = streamText({
|
||||||
model: getModel() as LanguageModel,
|
model: getModel() as LanguageModel,
|
||||||
messages: [
|
messages: [
|
||||||
@ -119,6 +121,7 @@ const { loader, action } = createHybridActionApiRoute(
|
|||||||
...convertToModelMessages(validatedMessages),
|
...convertToModelMessages(validatedMessages),
|
||||||
],
|
],
|
||||||
tools,
|
tools,
|
||||||
|
stopWhen: [stepCountIs(10), hasAnswer,hasQuestion],
|
||||||
});
|
});
|
||||||
|
|
||||||
result.consumeStream(); // no await
|
result.consumeStream(); // no await
|
||||||
|
|||||||
@ -11,15 +11,17 @@ import {
|
|||||||
import {
|
import {
|
||||||
convertToModelMessages,
|
convertToModelMessages,
|
||||||
type CoreMessage,
|
type CoreMessage,
|
||||||
|
generateId,
|
||||||
generateText,
|
generateText,
|
||||||
type LanguageModel,
|
type LanguageModel,
|
||||||
|
stepCountIs,
|
||||||
streamText,
|
streamText,
|
||||||
tool,
|
tool,
|
||||||
validateUIMessages,
|
validateUIMessages,
|
||||||
} from "ai";
|
} from "ai";
|
||||||
import axios from "axios";
|
import axios from "axios";
|
||||||
import { logger } from "~/services/logger.service";
|
import { logger } from "~/services/logger.service";
|
||||||
import { getReActPrompt } from "~/lib/prompt.server";
|
import { getReActPrompt, hasAnswer } from "~/lib/prompt.server";
|
||||||
import { getModel } from "~/lib/model.server";
|
import { getModel } from "~/lib/model.server";
|
||||||
|
|
||||||
const DeepSearchBodySchema = z.object({
|
const DeepSearchBodySchema = z.object({
|
||||||
@ -39,7 +41,7 @@ export function createSearchMemoryTool(token: string) {
|
|||||||
return tool({
|
return tool({
|
||||||
description:
|
description:
|
||||||
"Search the user's memory for relevant facts and episodes. Use this tool multiple times with different queries to gather comprehensive context.",
|
"Search the user's memory for relevant facts and episodes. Use this tool multiple times with different queries to gather comprehensive context.",
|
||||||
parameters: z.object({
|
inputSchema: z.object({
|
||||||
query: z
|
query: z
|
||||||
.string()
|
.string()
|
||||||
.describe(
|
.describe(
|
||||||
@ -50,21 +52,14 @@ export function createSearchMemoryTool(token: string) {
|
|||||||
try {
|
try {
|
||||||
const response = await axios.post(
|
const response = await axios.post(
|
||||||
`${process.env.API_BASE_URL || "https://core.heysol.ai"}/api/v1/search`,
|
`${process.env.API_BASE_URL || "https://core.heysol.ai"}/api/v1/search`,
|
||||||
{ query },
|
{ query, structured: false },
|
||||||
{
|
{
|
||||||
headers: {
|
headers: {
|
||||||
Authorization: `Bearer ${token}`,
|
Authorization: `Bearer ${token}`,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
return response.data;
|
||||||
const searchResult = response.data;
|
|
||||||
|
|
||||||
return {
|
|
||||||
facts: searchResult.facts || [],
|
|
||||||
episodes: searchResult.episodes || [],
|
|
||||||
summary: `Found ${searchResult.episodes?.length || 0} relevant memories`,
|
|
||||||
};
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error(`SearchMemory tool error: ${error}`);
|
logger.error(`SearchMemory tool error: ${error}`);
|
||||||
return {
|
return {
|
||||||
@ -115,14 +110,11 @@ const { action, loader } = createActionApiRoute(
|
|||||||
searchMemory: searchTool,
|
searchMemory: searchTool,
|
||||||
};
|
};
|
||||||
// Build initial messages with ReAct prompt
|
// Build initial messages with ReAct prompt
|
||||||
const initialMessages: CoreMessage[] = [
|
const initialMessages = [
|
||||||
{
|
|
||||||
role: "system",
|
|
||||||
content: getReActPrompt(body.metadata, body.intentOverride),
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
role: "user",
|
role: "user",
|
||||||
content: `CONTENT TO ANALYZE:\n${body.content}\n\nPlease search my memory for relevant context and synthesize what you find.`,
|
parts: [{ type: "text", text: `CONTENT TO ANALYZE:\n${body.content}\n\nPlease search my memory for relevant context and synthesize what you find.` }],
|
||||||
|
id: generateId(),
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
@ -134,7 +126,15 @@ const { action, loader } = createActionApiRoute(
|
|||||||
if (body.stream) {
|
if (body.stream) {
|
||||||
const result = streamText({
|
const result = streamText({
|
||||||
model: getModel() as LanguageModel,
|
model: getModel() as LanguageModel,
|
||||||
messages: convertToModelMessages(validatedMessages),
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: getReActPrompt(body.metadata, body.intentOverride),
|
||||||
|
},
|
||||||
|
...convertToModelMessages(validatedMessages),
|
||||||
|
],
|
||||||
|
tools,
|
||||||
|
stopWhen: [stepCountIs(10), hasAnswer],
|
||||||
});
|
});
|
||||||
|
|
||||||
return result.toUIMessageStreamResponse({
|
return result.toUIMessageStreamResponse({
|
||||||
@ -143,7 +143,15 @@ const { action, loader } = createActionApiRoute(
|
|||||||
} else {
|
} else {
|
||||||
const { text } = await generateText({
|
const { text } = await generateText({
|
||||||
model: getModel() as LanguageModel,
|
model: getModel() as LanguageModel,
|
||||||
messages: convertToModelMessages(validatedMessages),
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: getReActPrompt(body.metadata, body.intentOverride),
|
||||||
|
},
|
||||||
|
...convertToModelMessages(validatedMessages),
|
||||||
|
],
|
||||||
|
tools,
|
||||||
|
stopWhen: [stepCountIs(10), hasAnswer],
|
||||||
});
|
});
|
||||||
|
|
||||||
await deletePersonalAccessToken(pat?.id);
|
await deletePersonalAccessToken(pat?.id);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user