refactor: simplify clustered graph query and add stop conditions for AI responses

This commit is contained in:
Manoj 2025-10-26 12:20:33 +05:30
parent 8836849310
commit 6a05ea4f37
4 changed files with 88 additions and 123 deletions

View File

@ -112,51 +112,31 @@ export const getNodeLinks = async (userId: string) => {
export const getClusteredGraphData = async (userId: string) => { export const getClusteredGraphData = async (userId: string) => {
const session = driver.session(); const session = driver.session();
try { try {
// Get the simplified graph structure: Episode, Subject, Object with Predicate as edge // Get Episode -> Entity graph, only showing entities connected to more than 1 episode
// Only include entities that are connected to more than 1 episode
const result = await session.run( const result = await session.run(
`// Find entities connected to more than 1 episode `// Find entities connected to more than 1 episode
MATCH (e:Episode)-[:HAS_PROVENANCE]->(s:Statement {userId: $userId}) MATCH (e:Episode{userId: $userId})-[:HAS_PROVENANCE]->(s:Statement {userId: $userId})-[r:HAS_SUBJECT|HAS_OBJECT|HAS_PREDICATE]->(entity:Entity)
MATCH (s)-[:HAS_SUBJECT|HAS_OBJECT|HAS_PREDICATE]->(ent:Entity) WITH entity, count(DISTINCT e) as episodeCount
WITH ent, count(DISTINCT e) as episodeCount
WHERE episodeCount > 1 WHERE episodeCount > 1
WITH collect(ent.uuid) as validEntityUuids WITH collect(entity.uuid) as validEntityUuids
// Get statements where all entities are in the valid set // Build Episode -> Entity relationships for valid entities
MATCH (e:Episode)-[:HAS_PROVENANCE]->(s:Statement {userId: $userId}) MATCH (e:Episode{userId: $userId})-[r:HAS_PROVENANCE]->(s:Statement {userId: $userId})-[r:HAS_SUBJECT|HAS_OBJECT|HAS_PREDICATE]->(entity:Entity)
MATCH (s)-[:HAS_SUBJECT]->(subj:Entity) WHERE entity.uuid IN validEntityUuids
WHERE subj.uuid IN validEntityUuids WITH DISTINCT e, entity, type(r) as relType,
MATCH (s)-[:HAS_PREDICATE]->(pred:Entity) CASE WHEN size(e.spaceIds) > 0 THEN e.spaceIds[0] ELSE null END as clusterId,
WHERE pred.uuid IN validEntityUuids s.createdAt as createdAt
MATCH (s)-[:HAS_OBJECT]->(obj:Entity)
WHERE obj.uuid IN validEntityUuids
// Build relationships
WITH e, s, subj, pred, obj
UNWIND [
// Episode -> Subject
{source: e, sourceType: 'Episode', target: subj, targetType: 'Entity', predicate: null},
// Episode -> Object
{source: e, sourceType: 'Episode', target: obj, targetType: 'Entity', predicate: null},
// Subject -> Object (with Predicate as edge)
{source: subj, sourceType: 'Entity', target: obj, targetType: 'Entity', predicate: pred.name}
] AS rel
RETURN DISTINCT RETURN DISTINCT
rel.source.uuid as sourceUuid, e.uuid as sourceUuid,
rel.source.name as sourceName, e.content as sourceContent,
rel.source.content as sourceContent, 'Episode' as sourceNodeType,
rel.sourceType as sourceNodeType, entity.uuid as targetUuid,
rel.target.uuid as targetUuid, entity.name as targetName,
rel.target.name as targetName, 'Entity' as targetNodeType,
rel.targetType as targetNodeType, relType as edgeType,
rel.predicate as predicateLabel, clusterId,
e.uuid as episodeUuid, createdAt`,
e.content as episodeContent,
e.spaceIds as spaceIds,
s.uuid as statementUuid,
s.validAt as validAt,
s.createdAt as createdAt`,
{ userId }, { userId },
); );
@ -165,72 +145,29 @@ export const getClusteredGraphData = async (userId: string) => {
result.records.forEach((record) => { result.records.forEach((record) => {
const sourceUuid = record.get("sourceUuid"); const sourceUuid = record.get("sourceUuid");
const sourceName = record.get("sourceName");
const sourceContent = record.get("sourceContent"); const sourceContent = record.get("sourceContent");
const sourceNodeType = record.get("sourceNodeType");
const targetUuid = record.get("targetUuid"); const targetUuid = record.get("targetUuid");
const targetName = record.get("targetName"); const targetName = record.get("targetName");
const targetNodeType = record.get("targetNodeType"); const edgeType = record.get("edgeType");
const clusterId = record.get("clusterId");
const predicateLabel = record.get("predicateLabel");
const episodeUuid = record.get("episodeUuid");
const clusterIds = record.get("spaceIds");
const clusterId = clusterIds ? clusterIds[0] : undefined;
const createdAt = record.get("createdAt"); const createdAt = record.get("createdAt");
// Create unique edge identifier to avoid duplicates // Create unique edge identifier to avoid duplicates
// For Episode->Subject edges, use generic type; for Subject->Object use predicate
const edgeType = predicateLabel || "HAS_SUBJECT";
const edgeKey = `${sourceUuid}-${targetUuid}-${edgeType}`; const edgeKey = `${sourceUuid}-${targetUuid}-${edgeType}`;
if (processedEdges.has(edgeKey)) return; if (processedEdges.has(edgeKey)) return;
processedEdges.add(edgeKey); processedEdges.add(edgeKey);
// Build node attributes based on type
const sourceAttributes =
sourceNodeType === "Episode"
? {
nodeType: "Episode",
content: sourceContent,
episodeUuid: sourceUuid,
clusterId,
}
: {
nodeType: "Entity",
name: sourceName,
clusterId,
};
const targetAttributes =
targetNodeType === "Episode"
? {
nodeType: "Episode",
content: sourceContent,
episodeUuid: targetUuid,
clusterId,
}
: {
nodeType: "Entity",
name: targetName,
clusterId,
};
// Build display name
const sourceDisplayName =
sourceNodeType === "Episode"
? sourceContent || episodeUuid
: sourceName || sourceUuid;
const targetDisplayName =
targetNodeType === "Episode"
? sourceContent || episodeUuid
: targetName || targetUuid;
triplets.push({ triplets.push({
sourceNode: { sourceNode: {
uuid: sourceUuid, uuid: sourceUuid,
labels: [sourceNodeType], labels: ["Episode"],
attributes: sourceAttributes, attributes: {
name: sourceDisplayName, nodeType: "Episode",
content: sourceContent,
episodeUuid: sourceUuid,
clusterId,
},
name: sourceContent || sourceUuid,
clusterId, clusterId,
createdAt: createdAt || "", createdAt: createdAt || "",
}, },
@ -243,10 +180,14 @@ export const getClusteredGraphData = async (userId: string) => {
}, },
targetNode: { targetNode: {
uuid: targetUuid, uuid: targetUuid,
labels: [targetNodeType], labels: ["Entity"],
attributes: targetAttributes, attributes: {
nodeType: "Entity",
name: targetName,
clusterId,
},
name: targetName || targetUuid,
clusterId, clusterId,
name: targetDisplayName,
createdAt: createdAt || "", createdAt: createdAt || "",
}, },
}); });

View File

@ -1,6 +1,19 @@
import { tool } from "ai"; import { StopCondition, tool } from "ai";
import z from "zod"; import z from "zod";
export const hasAnswer: StopCondition<any> = ({ steps }) => {
return (
steps.some((step) => step.text?.includes("</final_response>")) ?? false
);
};
export const hasQuestion: StopCondition<any> = ({ steps }) => {
return (
steps.some((step) => step.text?.includes("</question_response>")) ??
false
);
};
export const REACT_SYSTEM_PROMPT = ` export const REACT_SYSTEM_PROMPT = `
You are a helpful AI assistant with access to user memory. Your primary capabilities are: You are a helpful AI assistant with access to user memory. Your primary capabilities are:
@ -128,18 +141,18 @@ PROGRESS UPDATES - During processing:
- Avoid technical jargon - Avoid technical jargon
QUESTIONS - When you need information: QUESTIONS - When you need information:
<div> <question_response>
<p>[Your question with HTML formatting]</p> <p>[Your question with HTML formatting]</p>
</div> </question_response>
- Ask questions only when you cannot find information through memory, or tools - Ask questions only when you cannot find information through memory, or tools
- Be specific about what you need to know - Be specific about what you need to know
- Provide context for why you're asking - Provide context for why you're asking
FINAL ANSWERS - When completing tasks: FINAL ANSWERS - When completing tasks:
<div> <final_response>
<p>[Your answer with HTML formatting]</p> <p>[Your answer with HTML formatting]</p>
</div> </final_response>
CRITICAL: CRITICAL:
- Use ONE format per turn - Use ONE format per turn

View File

@ -5,6 +5,8 @@ import {
type LanguageModel, type LanguageModel,
experimental_createMCPClient as createMCPClient, experimental_createMCPClient as createMCPClient,
generateId, generateId,
stepCountIs,
StopCondition,
} from "ai"; } from "ai";
import { z } from "zod"; import { z } from "zod";
import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js";
@ -19,7 +21,7 @@ import { getModel } from "~/lib/model.server";
import { UserTypeEnum } from "@core/types"; import { UserTypeEnum } from "@core/types";
import { nanoid } from "nanoid"; import { nanoid } from "nanoid";
import { getOrCreatePersonalAccessToken } from "~/services/personalAccessToken.server"; import { getOrCreatePersonalAccessToken } from "~/services/personalAccessToken.server";
import { REACT_SYSTEM_PROMPT } from "~/lib/prompt.server"; import { hasAnswer, hasQuestion, REACT_SYSTEM_PROMPT } from "~/lib/prompt.server";
import { enqueueCreateConversationTitle } from "~/lib/queue-adapter.server"; import { enqueueCreateConversationTitle } from "~/lib/queue-adapter.server";
import { env } from "~/env.server"; import { env } from "~/env.server";
@ -84,7 +86,7 @@ const { loader, action } = createHybridActionApiRoute(
await createConversationHistory(message, body.id, UserTypeEnum.User); await createConversationHistory(message, body.id, UserTypeEnum.User);
} }
const messages = conversationHistory.map((history) => { const messages = conversationHistory.map((history: any) => {
return { return {
parts: [{ text: history.message, type: "text" }], parts: [{ text: history.message, type: "text" }],
role: "user", role: "user",
@ -94,8 +96,6 @@ const { loader, action } = createHybridActionApiRoute(
const tools = { ...(await mcpClient.tools()) }; const tools = { ...(await mcpClient.tools()) };
// console.log(tools);
const finalMessages = [ const finalMessages = [
...messages, ...messages,
{ {
@ -109,6 +109,8 @@ const { loader, action } = createHybridActionApiRoute(
messages: finalMessages, messages: finalMessages,
}); });
const result = streamText({ const result = streamText({
model: getModel() as LanguageModel, model: getModel() as LanguageModel,
messages: [ messages: [
@ -119,6 +121,7 @@ const { loader, action } = createHybridActionApiRoute(
...convertToModelMessages(validatedMessages), ...convertToModelMessages(validatedMessages),
], ],
tools, tools,
stopWhen: [stepCountIs(10), hasAnswer,hasQuestion],
}); });
result.consumeStream(); // no await result.consumeStream(); // no await

View File

@ -11,15 +11,17 @@ import {
import { import {
convertToModelMessages, convertToModelMessages,
type CoreMessage, type CoreMessage,
generateId,
generateText, generateText,
type LanguageModel, type LanguageModel,
stepCountIs,
streamText, streamText,
tool, tool,
validateUIMessages, validateUIMessages,
} from "ai"; } from "ai";
import axios from "axios"; import axios from "axios";
import { logger } from "~/services/logger.service"; import { logger } from "~/services/logger.service";
import { getReActPrompt } from "~/lib/prompt.server"; import { getReActPrompt, hasAnswer } from "~/lib/prompt.server";
import { getModel } from "~/lib/model.server"; import { getModel } from "~/lib/model.server";
const DeepSearchBodySchema = z.object({ const DeepSearchBodySchema = z.object({
@ -39,7 +41,7 @@ export function createSearchMemoryTool(token: string) {
return tool({ return tool({
description: description:
"Search the user's memory for relevant facts and episodes. Use this tool multiple times with different queries to gather comprehensive context.", "Search the user's memory for relevant facts and episodes. Use this tool multiple times with different queries to gather comprehensive context.",
parameters: z.object({ inputSchema: z.object({
query: z query: z
.string() .string()
.describe( .describe(
@ -50,21 +52,14 @@ export function createSearchMemoryTool(token: string) {
try { try {
const response = await axios.post( const response = await axios.post(
`${process.env.API_BASE_URL || "https://core.heysol.ai"}/api/v1/search`, `${process.env.API_BASE_URL || "https://core.heysol.ai"}/api/v1/search`,
{ query }, { query, structured: false },
{ {
headers: { headers: {
Authorization: `Bearer ${token}`, Authorization: `Bearer ${token}`,
}, },
}, },
); );
return response.data;
const searchResult = response.data;
return {
facts: searchResult.facts || [],
episodes: searchResult.episodes || [],
summary: `Found ${searchResult.episodes?.length || 0} relevant memories`,
};
} catch (error) { } catch (error) {
logger.error(`SearchMemory tool error: ${error}`); logger.error(`SearchMemory tool error: ${error}`);
return { return {
@ -115,14 +110,11 @@ const { action, loader } = createActionApiRoute(
searchMemory: searchTool, searchMemory: searchTool,
}; };
// Build initial messages with ReAct prompt // Build initial messages with ReAct prompt
const initialMessages: CoreMessage[] = [ const initialMessages = [
{
role: "system",
content: getReActPrompt(body.metadata, body.intentOverride),
},
{ {
role: "user", role: "user",
content: `CONTENT TO ANALYZE:\n${body.content}\n\nPlease search my memory for relevant context and synthesize what you find.`, parts: [{ type: "text", text: `CONTENT TO ANALYZE:\n${body.content}\n\nPlease search my memory for relevant context and synthesize what you find.` }],
id: generateId(),
}, },
]; ];
@ -134,7 +126,15 @@ const { action, loader } = createActionApiRoute(
if (body.stream) { if (body.stream) {
const result = streamText({ const result = streamText({
model: getModel() as LanguageModel, model: getModel() as LanguageModel,
messages: convertToModelMessages(validatedMessages), messages: [
{
role: "system",
content: getReActPrompt(body.metadata, body.intentOverride),
},
...convertToModelMessages(validatedMessages),
],
tools,
stopWhen: [stepCountIs(10), hasAnswer],
}); });
return result.toUIMessageStreamResponse({ return result.toUIMessageStreamResponse({
@ -143,7 +143,15 @@ const { action, loader } = createActionApiRoute(
} else { } else {
const { text } = await generateText({ const { text } = await generateText({
model: getModel() as LanguageModel, model: getModel() as LanguageModel,
messages: convertToModelMessages(validatedMessages), messages: [
{
role: "system",
content: getReActPrompt(body.metadata, body.intentOverride),
},
...convertToModelMessages(validatedMessages),
],
tools,
stopWhen: [stepCountIs(10), hasAnswer],
}); });
await deletePersonalAccessToken(pat?.id); await deletePersonalAccessToken(pat?.id);