From 6a05ea4f378434d38de0d7bb62d48c83712a818b Mon Sep 17 00:00:00 2001 From: Manoj Date: Sun, 26 Oct 2025 12:20:33 +0530 Subject: [PATCH] refactor: simplify clustered graph query and add stop conditions for AI responses --- apps/webapp/app/lib/neo4j.server.ts | 131 +++++------------- apps/webapp/app/lib/prompt.server.ts | 23 ++- .../app/routes/api.v1.conversation._index.tsx | 11 +- apps/webapp/app/routes/api.v1.deep-search.tsx | 46 +++--- 4 files changed, 88 insertions(+), 123 deletions(-) diff --git a/apps/webapp/app/lib/neo4j.server.ts b/apps/webapp/app/lib/neo4j.server.ts index e924a95..94fdf31 100644 --- a/apps/webapp/app/lib/neo4j.server.ts +++ b/apps/webapp/app/lib/neo4j.server.ts @@ -112,51 +112,31 @@ export const getNodeLinks = async (userId: string) => { export const getClusteredGraphData = async (userId: string) => { const session = driver.session(); try { - // Get the simplified graph structure: Episode, Subject, Object with Predicate as edge - // Only include entities that are connected to more than 1 episode + // Get Episode -> Entity graph, only showing entities connected to more than 1 episode const result = await session.run( `// Find entities connected to more than 1 episode - MATCH (e:Episode)-[:HAS_PROVENANCE]->(s:Statement {userId: $userId}) - MATCH (s)-[:HAS_SUBJECT|HAS_OBJECT|HAS_PREDICATE]->(ent:Entity) - WITH ent, count(DISTINCT e) as episodeCount + MATCH (e:Episode{userId: $userId})-[:HAS_PROVENANCE]->(s:Statement {userId: $userId})-[r:HAS_SUBJECT|HAS_OBJECT|HAS_PREDICATE]->(entity:Entity) + WITH entity, count(DISTINCT e) as episodeCount WHERE episodeCount > 1 - WITH collect(ent.uuid) as validEntityUuids + WITH collect(entity.uuid) as validEntityUuids - // Get statements where all entities are in the valid set - MATCH (e:Episode)-[:HAS_PROVENANCE]->(s:Statement {userId: $userId}) - MATCH (s)-[:HAS_SUBJECT]->(subj:Entity) - WHERE subj.uuid IN validEntityUuids - MATCH (s)-[:HAS_PREDICATE]->(pred:Entity) - WHERE pred.uuid IN validEntityUuids - MATCH (s)-[:HAS_OBJECT]->(obj:Entity) - WHERE obj.uuid IN validEntityUuids - - // Build relationships - WITH e, s, subj, pred, obj - UNWIND [ - // Episode -> Subject - {source: e, sourceType: 'Episode', target: subj, targetType: 'Entity', predicate: null}, - // Episode -> Object - {source: e, sourceType: 'Episode', target: obj, targetType: 'Entity', predicate: null}, - // Subject -> Object (with Predicate as edge) - {source: subj, sourceType: 'Entity', target: obj, targetType: 'Entity', predicate: pred.name} - ] AS rel + // Build Episode -> Entity relationships for valid entities + MATCH (e:Episode{userId: $userId})-[r:HAS_PROVENANCE]->(s:Statement {userId: $userId})-[r:HAS_SUBJECT|HAS_OBJECT|HAS_PREDICATE]->(entity:Entity) + WHERE entity.uuid IN validEntityUuids + WITH DISTINCT e, entity, type(r) as relType, + CASE WHEN size(e.spaceIds) > 0 THEN e.spaceIds[0] ELSE null END as clusterId, + s.createdAt as createdAt RETURN DISTINCT - rel.source.uuid as sourceUuid, - rel.source.name as sourceName, - rel.source.content as sourceContent, - rel.sourceType as sourceNodeType, - rel.target.uuid as targetUuid, - rel.target.name as targetName, - rel.targetType as targetNodeType, - rel.predicate as predicateLabel, - e.uuid as episodeUuid, - e.content as episodeContent, - e.spaceIds as spaceIds, - s.uuid as statementUuid, - s.validAt as validAt, - s.createdAt as createdAt`, + e.uuid as sourceUuid, + e.content as sourceContent, + 'Episode' as sourceNodeType, + entity.uuid as targetUuid, + entity.name as targetName, + 'Entity' as targetNodeType, + relType as edgeType, + clusterId, + createdAt`, { userId }, ); @@ -165,72 +145,29 @@ export const getClusteredGraphData = async (userId: string) => { result.records.forEach((record) => { const sourceUuid = record.get("sourceUuid"); - const sourceName = record.get("sourceName"); const sourceContent = record.get("sourceContent"); - const sourceNodeType = record.get("sourceNodeType"); - const targetUuid = record.get("targetUuid"); const targetName = record.get("targetName"); - const targetNodeType = record.get("targetNodeType"); - - const predicateLabel = record.get("predicateLabel"); - const episodeUuid = record.get("episodeUuid"); - const clusterIds = record.get("spaceIds"); - const clusterId = clusterIds ? clusterIds[0] : undefined; + const edgeType = record.get("edgeType"); + const clusterId = record.get("clusterId"); const createdAt = record.get("createdAt"); // Create unique edge identifier to avoid duplicates - // For Episode->Subject edges, use generic type; for Subject->Object use predicate - const edgeType = predicateLabel || "HAS_SUBJECT"; const edgeKey = `${sourceUuid}-${targetUuid}-${edgeType}`; if (processedEdges.has(edgeKey)) return; processedEdges.add(edgeKey); - // Build node attributes based on type - const sourceAttributes = - sourceNodeType === "Episode" - ? { - nodeType: "Episode", - content: sourceContent, - episodeUuid: sourceUuid, - clusterId, - } - : { - nodeType: "Entity", - name: sourceName, - clusterId, - }; - - const targetAttributes = - targetNodeType === "Episode" - ? { - nodeType: "Episode", - content: sourceContent, - episodeUuid: targetUuid, - clusterId, - } - : { - nodeType: "Entity", - name: targetName, - clusterId, - }; - - // Build display name - const sourceDisplayName = - sourceNodeType === "Episode" - ? sourceContent || episodeUuid - : sourceName || sourceUuid; - const targetDisplayName = - targetNodeType === "Episode" - ? sourceContent || episodeUuid - : targetName || targetUuid; - triplets.push({ sourceNode: { uuid: sourceUuid, - labels: [sourceNodeType], - attributes: sourceAttributes, - name: sourceDisplayName, + labels: ["Episode"], + attributes: { + nodeType: "Episode", + content: sourceContent, + episodeUuid: sourceUuid, + clusterId, + }, + name: sourceContent || sourceUuid, clusterId, createdAt: createdAt || "", }, @@ -243,10 +180,14 @@ export const getClusteredGraphData = async (userId: string) => { }, targetNode: { uuid: targetUuid, - labels: [targetNodeType], - attributes: targetAttributes, + labels: ["Entity"], + attributes: { + nodeType: "Entity", + name: targetName, + clusterId, + }, + name: targetName || targetUuid, clusterId, - name: targetDisplayName, createdAt: createdAt || "", }, }); diff --git a/apps/webapp/app/lib/prompt.server.ts b/apps/webapp/app/lib/prompt.server.ts index 09e1dc1..dbc1d12 100644 --- a/apps/webapp/app/lib/prompt.server.ts +++ b/apps/webapp/app/lib/prompt.server.ts @@ -1,6 +1,19 @@ -import { tool } from "ai"; +import { StopCondition, tool } from "ai"; import z from "zod"; +export const hasAnswer: StopCondition = ({ steps }) => { + return ( + steps.some((step) => step.text?.includes("")) ?? false + ); +}; + +export const hasQuestion: StopCondition = ({ steps }) => { + return ( + steps.some((step) => step.text?.includes("")) ?? + false + ); +}; + export const REACT_SYSTEM_PROMPT = ` You are a helpful AI assistant with access to user memory. Your primary capabilities are: @@ -128,18 +141,18 @@ PROGRESS UPDATES - During processing: - Avoid technical jargon QUESTIONS - When you need information: -
+

[Your question with HTML formatting]

-
+ - Ask questions only when you cannot find information through memory, or tools - Be specific about what you need to know - Provide context for why you're asking FINAL ANSWERS - When completing tasks: -
+

[Your answer with HTML formatting]

-
+ CRITICAL: - Use ONE format per turn diff --git a/apps/webapp/app/routes/api.v1.conversation._index.tsx b/apps/webapp/app/routes/api.v1.conversation._index.tsx index 87c825e..5a316fc 100644 --- a/apps/webapp/app/routes/api.v1.conversation._index.tsx +++ b/apps/webapp/app/routes/api.v1.conversation._index.tsx @@ -5,6 +5,8 @@ import { type LanguageModel, experimental_createMCPClient as createMCPClient, generateId, + stepCountIs, + StopCondition, } from "ai"; import { z } from "zod"; import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; @@ -19,7 +21,7 @@ import { getModel } from "~/lib/model.server"; import { UserTypeEnum } from "@core/types"; import { nanoid } from "nanoid"; import { getOrCreatePersonalAccessToken } from "~/services/personalAccessToken.server"; -import { REACT_SYSTEM_PROMPT } from "~/lib/prompt.server"; +import { hasAnswer, hasQuestion, REACT_SYSTEM_PROMPT } from "~/lib/prompt.server"; import { enqueueCreateConversationTitle } from "~/lib/queue-adapter.server"; import { env } from "~/env.server"; @@ -84,7 +86,7 @@ const { loader, action } = createHybridActionApiRoute( await createConversationHistory(message, body.id, UserTypeEnum.User); } - const messages = conversationHistory.map((history) => { + const messages = conversationHistory.map((history: any) => { return { parts: [{ text: history.message, type: "text" }], role: "user", @@ -94,8 +96,6 @@ const { loader, action } = createHybridActionApiRoute( const tools = { ...(await mcpClient.tools()) }; - // console.log(tools); - const finalMessages = [ ...messages, { @@ -109,6 +109,8 @@ const { loader, action } = createHybridActionApiRoute( messages: finalMessages, }); + + const result = streamText({ model: getModel() as LanguageModel, messages: [ @@ -119,6 +121,7 @@ const { loader, action } = createHybridActionApiRoute( ...convertToModelMessages(validatedMessages), ], tools, + stopWhen: [stepCountIs(10), hasAnswer,hasQuestion], }); result.consumeStream(); // no await diff --git a/apps/webapp/app/routes/api.v1.deep-search.tsx b/apps/webapp/app/routes/api.v1.deep-search.tsx index 5f04659..a7bfe1e 100644 --- a/apps/webapp/app/routes/api.v1.deep-search.tsx +++ b/apps/webapp/app/routes/api.v1.deep-search.tsx @@ -11,15 +11,17 @@ import { import { convertToModelMessages, type CoreMessage, + generateId, generateText, type LanguageModel, + stepCountIs, streamText, tool, validateUIMessages, } from "ai"; import axios from "axios"; import { logger } from "~/services/logger.service"; -import { getReActPrompt } from "~/lib/prompt.server"; +import { getReActPrompt, hasAnswer } from "~/lib/prompt.server"; import { getModel } from "~/lib/model.server"; const DeepSearchBodySchema = z.object({ @@ -39,7 +41,7 @@ export function createSearchMemoryTool(token: string) { return tool({ description: "Search the user's memory for relevant facts and episodes. Use this tool multiple times with different queries to gather comprehensive context.", - parameters: z.object({ + inputSchema: z.object({ query: z .string() .describe( @@ -50,21 +52,14 @@ export function createSearchMemoryTool(token: string) { try { const response = await axios.post( `${process.env.API_BASE_URL || "https://core.heysol.ai"}/api/v1/search`, - { query }, + { query, structured: false }, { headers: { Authorization: `Bearer ${token}`, }, }, ); - - const searchResult = response.data; - - return { - facts: searchResult.facts || [], - episodes: searchResult.episodes || [], - summary: `Found ${searchResult.episodes?.length || 0} relevant memories`, - }; + return response.data; } catch (error) { logger.error(`SearchMemory tool error: ${error}`); return { @@ -115,14 +110,11 @@ const { action, loader } = createActionApiRoute( searchMemory: searchTool, }; // Build initial messages with ReAct prompt - const initialMessages: CoreMessage[] = [ - { - role: "system", - content: getReActPrompt(body.metadata, body.intentOverride), - }, + const initialMessages = [ { role: "user", - content: `CONTENT TO ANALYZE:\n${body.content}\n\nPlease search my memory for relevant context and synthesize what you find.`, + parts: [{ type: "text", text: `CONTENT TO ANALYZE:\n${body.content}\n\nPlease search my memory for relevant context and synthesize what you find.` }], + id: generateId(), }, ]; @@ -134,7 +126,15 @@ const { action, loader } = createActionApiRoute( if (body.stream) { const result = streamText({ model: getModel() as LanguageModel, - messages: convertToModelMessages(validatedMessages), + messages: [ + { + role: "system", + content: getReActPrompt(body.metadata, body.intentOverride), + }, + ...convertToModelMessages(validatedMessages), + ], + tools, + stopWhen: [stepCountIs(10), hasAnswer], }); return result.toUIMessageStreamResponse({ @@ -143,7 +143,15 @@ const { action, loader } = createActionApiRoute( } else { const { text } = await generateText({ model: getModel() as LanguageModel, - messages: convertToModelMessages(validatedMessages), + messages: [ + { + role: "system", + content: getReActPrompt(body.metadata, body.intentOverride), + }, + ...convertToModelMessages(validatedMessages), + ], + tools, + stopWhen: [stepCountIs(10), hasAnswer], }); await deletePersonalAccessToken(pat?.id);