refactor: simplify clustered graph query and add stop conditions for AI responses

This commit is contained in:
Manoj 2025-10-26 12:20:33 +05:30
parent 8836849310
commit 6a05ea4f37
4 changed files with 88 additions and 123 deletions

View File

@ -112,51 +112,31 @@ export const getNodeLinks = async (userId: string) => {
export const getClusteredGraphData = async (userId: string) => {
const session = driver.session();
try {
// Get the simplified graph structure: Episode, Subject, Object with Predicate as edge
// Only include entities that are connected to more than 1 episode
// Get Episode -> Entity graph, only showing entities connected to more than 1 episode
const result = await session.run(
`// Find entities connected to more than 1 episode
MATCH (e:Episode)-[:HAS_PROVENANCE]->(s:Statement {userId: $userId})
MATCH (s)-[:HAS_SUBJECT|HAS_OBJECT|HAS_PREDICATE]->(ent:Entity)
WITH ent, count(DISTINCT e) as episodeCount
MATCH (e:Episode{userId: $userId})-[:HAS_PROVENANCE]->(s:Statement {userId: $userId})-[r:HAS_SUBJECT|HAS_OBJECT|HAS_PREDICATE]->(entity:Entity)
WITH entity, count(DISTINCT e) as episodeCount
WHERE episodeCount > 1
WITH collect(ent.uuid) as validEntityUuids
WITH collect(entity.uuid) as validEntityUuids
// Get statements where all entities are in the valid set
MATCH (e:Episode)-[:HAS_PROVENANCE]->(s:Statement {userId: $userId})
MATCH (s)-[:HAS_SUBJECT]->(subj:Entity)
WHERE subj.uuid IN validEntityUuids
MATCH (s)-[:HAS_PREDICATE]->(pred:Entity)
WHERE pred.uuid IN validEntityUuids
MATCH (s)-[:HAS_OBJECT]->(obj:Entity)
WHERE obj.uuid IN validEntityUuids
// Build relationships
WITH e, s, subj, pred, obj
UNWIND [
// Episode -> Subject
{source: e, sourceType: 'Episode', target: subj, targetType: 'Entity', predicate: null},
// Episode -> Object
{source: e, sourceType: 'Episode', target: obj, targetType: 'Entity', predicate: null},
// Subject -> Object (with Predicate as edge)
{source: subj, sourceType: 'Entity', target: obj, targetType: 'Entity', predicate: pred.name}
] AS rel
// Build Episode -> Entity relationships for valid entities
MATCH (e:Episode{userId: $userId})-[r:HAS_PROVENANCE]->(s:Statement {userId: $userId})-[r:HAS_SUBJECT|HAS_OBJECT|HAS_PREDICATE]->(entity:Entity)
WHERE entity.uuid IN validEntityUuids
WITH DISTINCT e, entity, type(r) as relType,
CASE WHEN size(e.spaceIds) > 0 THEN e.spaceIds[0] ELSE null END as clusterId,
s.createdAt as createdAt
RETURN DISTINCT
rel.source.uuid as sourceUuid,
rel.source.name as sourceName,
rel.source.content as sourceContent,
rel.sourceType as sourceNodeType,
rel.target.uuid as targetUuid,
rel.target.name as targetName,
rel.targetType as targetNodeType,
rel.predicate as predicateLabel,
e.uuid as episodeUuid,
e.content as episodeContent,
e.spaceIds as spaceIds,
s.uuid as statementUuid,
s.validAt as validAt,
s.createdAt as createdAt`,
e.uuid as sourceUuid,
e.content as sourceContent,
'Episode' as sourceNodeType,
entity.uuid as targetUuid,
entity.name as targetName,
'Entity' as targetNodeType,
relType as edgeType,
clusterId,
createdAt`,
{ userId },
);
@ -165,72 +145,29 @@ export const getClusteredGraphData = async (userId: string) => {
result.records.forEach((record) => {
const sourceUuid = record.get("sourceUuid");
const sourceName = record.get("sourceName");
const sourceContent = record.get("sourceContent");
const sourceNodeType = record.get("sourceNodeType");
const targetUuid = record.get("targetUuid");
const targetName = record.get("targetName");
const targetNodeType = record.get("targetNodeType");
const predicateLabel = record.get("predicateLabel");
const episodeUuid = record.get("episodeUuid");
const clusterIds = record.get("spaceIds");
const clusterId = clusterIds ? clusterIds[0] : undefined;
const edgeType = record.get("edgeType");
const clusterId = record.get("clusterId");
const createdAt = record.get("createdAt");
// Create unique edge identifier to avoid duplicates
// For Episode->Subject edges, use generic type; for Subject->Object use predicate
const edgeType = predicateLabel || "HAS_SUBJECT";
const edgeKey = `${sourceUuid}-${targetUuid}-${edgeType}`;
if (processedEdges.has(edgeKey)) return;
processedEdges.add(edgeKey);
// Build node attributes based on type
const sourceAttributes =
sourceNodeType === "Episode"
? {
triplets.push({
sourceNode: {
uuid: sourceUuid,
labels: ["Episode"],
attributes: {
nodeType: "Episode",
content: sourceContent,
episodeUuid: sourceUuid,
clusterId,
}
: {
nodeType: "Entity",
name: sourceName,
clusterId,
};
const targetAttributes =
targetNodeType === "Episode"
? {
nodeType: "Episode",
content: sourceContent,
episodeUuid: targetUuid,
clusterId,
}
: {
nodeType: "Entity",
name: targetName,
clusterId,
};
// Build display name
const sourceDisplayName =
sourceNodeType === "Episode"
? sourceContent || episodeUuid
: sourceName || sourceUuid;
const targetDisplayName =
targetNodeType === "Episode"
? sourceContent || episodeUuid
: targetName || targetUuid;
triplets.push({
sourceNode: {
uuid: sourceUuid,
labels: [sourceNodeType],
attributes: sourceAttributes,
name: sourceDisplayName,
},
name: sourceContent || sourceUuid,
clusterId,
createdAt: createdAt || "",
},
@ -243,10 +180,14 @@ export const getClusteredGraphData = async (userId: string) => {
},
targetNode: {
uuid: targetUuid,
labels: [targetNodeType],
attributes: targetAttributes,
labels: ["Entity"],
attributes: {
nodeType: "Entity",
name: targetName,
clusterId,
},
name: targetName || targetUuid,
clusterId,
name: targetDisplayName,
createdAt: createdAt || "",
},
});

View File

@ -1,6 +1,19 @@
import { tool } from "ai";
import { StopCondition, tool } from "ai";
import z from "zod";
export const hasAnswer: StopCondition<any> = ({ steps }) => {
return (
steps.some((step) => step.text?.includes("</final_response>")) ?? false
);
};
export const hasQuestion: StopCondition<any> = ({ steps }) => {
return (
steps.some((step) => step.text?.includes("</question_response>")) ??
false
);
};
export const REACT_SYSTEM_PROMPT = `
You are a helpful AI assistant with access to user memory. Your primary capabilities are:
@ -128,18 +141,18 @@ PROGRESS UPDATES - During processing:
- Avoid technical jargon
QUESTIONS - When you need information:
<div>
<question_response>
<p>[Your question with HTML formatting]</p>
</div>
</question_response>
- Ask questions only when you cannot find information through memory, or tools
- Be specific about what you need to know
- Provide context for why you're asking
FINAL ANSWERS - When completing tasks:
<div>
<final_response>
<p>[Your answer with HTML formatting]</p>
</div>
</final_response>
CRITICAL:
- Use ONE format per turn

View File

@ -5,6 +5,8 @@ import {
type LanguageModel,
experimental_createMCPClient as createMCPClient,
generateId,
stepCountIs,
StopCondition,
} from "ai";
import { z } from "zod";
import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js";
@ -19,7 +21,7 @@ import { getModel } from "~/lib/model.server";
import { UserTypeEnum } from "@core/types";
import { nanoid } from "nanoid";
import { getOrCreatePersonalAccessToken } from "~/services/personalAccessToken.server";
import { REACT_SYSTEM_PROMPT } from "~/lib/prompt.server";
import { hasAnswer, hasQuestion, REACT_SYSTEM_PROMPT } from "~/lib/prompt.server";
import { enqueueCreateConversationTitle } from "~/lib/queue-adapter.server";
import { env } from "~/env.server";
@ -84,7 +86,7 @@ const { loader, action } = createHybridActionApiRoute(
await createConversationHistory(message, body.id, UserTypeEnum.User);
}
const messages = conversationHistory.map((history) => {
const messages = conversationHistory.map((history: any) => {
return {
parts: [{ text: history.message, type: "text" }],
role: "user",
@ -94,8 +96,6 @@ const { loader, action } = createHybridActionApiRoute(
const tools = { ...(await mcpClient.tools()) };
// console.log(tools);
const finalMessages = [
...messages,
{
@ -109,6 +109,8 @@ const { loader, action } = createHybridActionApiRoute(
messages: finalMessages,
});
const result = streamText({
model: getModel() as LanguageModel,
messages: [
@ -119,6 +121,7 @@ const { loader, action } = createHybridActionApiRoute(
...convertToModelMessages(validatedMessages),
],
tools,
stopWhen: [stepCountIs(10), hasAnswer,hasQuestion],
});
result.consumeStream(); // no await

View File

@ -11,15 +11,17 @@ import {
import {
convertToModelMessages,
type CoreMessage,
generateId,
generateText,
type LanguageModel,
stepCountIs,
streamText,
tool,
validateUIMessages,
} from "ai";
import axios from "axios";
import { logger } from "~/services/logger.service";
import { getReActPrompt } from "~/lib/prompt.server";
import { getReActPrompt, hasAnswer } from "~/lib/prompt.server";
import { getModel } from "~/lib/model.server";
const DeepSearchBodySchema = z.object({
@ -39,7 +41,7 @@ export function createSearchMemoryTool(token: string) {
return tool({
description:
"Search the user's memory for relevant facts and episodes. Use this tool multiple times with different queries to gather comprehensive context.",
parameters: z.object({
inputSchema: z.object({
query: z
.string()
.describe(
@ -50,21 +52,14 @@ export function createSearchMemoryTool(token: string) {
try {
const response = await axios.post(
`${process.env.API_BASE_URL || "https://core.heysol.ai"}/api/v1/search`,
{ query },
{ query, structured: false },
{
headers: {
Authorization: `Bearer ${token}`,
},
},
);
const searchResult = response.data;
return {
facts: searchResult.facts || [],
episodes: searchResult.episodes || [],
summary: `Found ${searchResult.episodes?.length || 0} relevant memories`,
};
return response.data;
} catch (error) {
logger.error(`SearchMemory tool error: ${error}`);
return {
@ -115,14 +110,11 @@ const { action, loader } = createActionApiRoute(
searchMemory: searchTool,
};
// Build initial messages with ReAct prompt
const initialMessages: CoreMessage[] = [
{
role: "system",
content: getReActPrompt(body.metadata, body.intentOverride),
},
const initialMessages = [
{
role: "user",
content: `CONTENT TO ANALYZE:\n${body.content}\n\nPlease search my memory for relevant context and synthesize what you find.`,
parts: [{ type: "text", text: `CONTENT TO ANALYZE:\n${body.content}\n\nPlease search my memory for relevant context and synthesize what you find.` }],
id: generateId(),
},
];
@ -134,7 +126,15 @@ const { action, loader } = createActionApiRoute(
if (body.stream) {
const result = streamText({
model: getModel() as LanguageModel,
messages: convertToModelMessages(validatedMessages),
messages: [
{
role: "system",
content: getReActPrompt(body.metadata, body.intentOverride),
},
...convertToModelMessages(validatedMessages),
],
tools,
stopWhen: [stepCountIs(10), hasAnswer],
});
return result.toUIMessageStreamResponse({
@ -143,7 +143,15 @@ const { action, loader } = createActionApiRoute(
} else {
const { text } = await generateText({
model: getModel() as LanguageModel,
messages: convertToModelMessages(validatedMessages),
messages: [
{
role: "system",
content: getReActPrompt(body.metadata, body.intentOverride),
},
...convertToModelMessages(validatedMessages),
],
tools,
stopWhen: [stepCountIs(10), hasAnswer],
});
await deletePersonalAccessToken(pat?.id);