mirror of
https://github.com/eliasstepanik/core.git
synced 2026-01-11 10:08:27 +00:00
refactor: simplify clustered graph query and add stop conditions for AI responses
This commit is contained in:
parent
8836849310
commit
6a05ea4f37
@ -112,51 +112,31 @@ export const getNodeLinks = async (userId: string) => {
|
||||
export const getClusteredGraphData = async (userId: string) => {
|
||||
const session = driver.session();
|
||||
try {
|
||||
// Get the simplified graph structure: Episode, Subject, Object with Predicate as edge
|
||||
// Only include entities that are connected to more than 1 episode
|
||||
// Get Episode -> Entity graph, only showing entities connected to more than 1 episode
|
||||
const result = await session.run(
|
||||
`// Find entities connected to more than 1 episode
|
||||
MATCH (e:Episode)-[:HAS_PROVENANCE]->(s:Statement {userId: $userId})
|
||||
MATCH (s)-[:HAS_SUBJECT|HAS_OBJECT|HAS_PREDICATE]->(ent:Entity)
|
||||
WITH ent, count(DISTINCT e) as episodeCount
|
||||
MATCH (e:Episode{userId: $userId})-[:HAS_PROVENANCE]->(s:Statement {userId: $userId})-[r:HAS_SUBJECT|HAS_OBJECT|HAS_PREDICATE]->(entity:Entity)
|
||||
WITH entity, count(DISTINCT e) as episodeCount
|
||||
WHERE episodeCount > 1
|
||||
WITH collect(ent.uuid) as validEntityUuids
|
||||
WITH collect(entity.uuid) as validEntityUuids
|
||||
|
||||
// Get statements where all entities are in the valid set
|
||||
MATCH (e:Episode)-[:HAS_PROVENANCE]->(s:Statement {userId: $userId})
|
||||
MATCH (s)-[:HAS_SUBJECT]->(subj:Entity)
|
||||
WHERE subj.uuid IN validEntityUuids
|
||||
MATCH (s)-[:HAS_PREDICATE]->(pred:Entity)
|
||||
WHERE pred.uuid IN validEntityUuids
|
||||
MATCH (s)-[:HAS_OBJECT]->(obj:Entity)
|
||||
WHERE obj.uuid IN validEntityUuids
|
||||
|
||||
// Build relationships
|
||||
WITH e, s, subj, pred, obj
|
||||
UNWIND [
|
||||
// Episode -> Subject
|
||||
{source: e, sourceType: 'Episode', target: subj, targetType: 'Entity', predicate: null},
|
||||
// Episode -> Object
|
||||
{source: e, sourceType: 'Episode', target: obj, targetType: 'Entity', predicate: null},
|
||||
// Subject -> Object (with Predicate as edge)
|
||||
{source: subj, sourceType: 'Entity', target: obj, targetType: 'Entity', predicate: pred.name}
|
||||
] AS rel
|
||||
// Build Episode -> Entity relationships for valid entities
|
||||
MATCH (e:Episode{userId: $userId})-[r:HAS_PROVENANCE]->(s:Statement {userId: $userId})-[r:HAS_SUBJECT|HAS_OBJECT|HAS_PREDICATE]->(entity:Entity)
|
||||
WHERE entity.uuid IN validEntityUuids
|
||||
WITH DISTINCT e, entity, type(r) as relType,
|
||||
CASE WHEN size(e.spaceIds) > 0 THEN e.spaceIds[0] ELSE null END as clusterId,
|
||||
s.createdAt as createdAt
|
||||
|
||||
RETURN DISTINCT
|
||||
rel.source.uuid as sourceUuid,
|
||||
rel.source.name as sourceName,
|
||||
rel.source.content as sourceContent,
|
||||
rel.sourceType as sourceNodeType,
|
||||
rel.target.uuid as targetUuid,
|
||||
rel.target.name as targetName,
|
||||
rel.targetType as targetNodeType,
|
||||
rel.predicate as predicateLabel,
|
||||
e.uuid as episodeUuid,
|
||||
e.content as episodeContent,
|
||||
e.spaceIds as spaceIds,
|
||||
s.uuid as statementUuid,
|
||||
s.validAt as validAt,
|
||||
s.createdAt as createdAt`,
|
||||
e.uuid as sourceUuid,
|
||||
e.content as sourceContent,
|
||||
'Episode' as sourceNodeType,
|
||||
entity.uuid as targetUuid,
|
||||
entity.name as targetName,
|
||||
'Entity' as targetNodeType,
|
||||
relType as edgeType,
|
||||
clusterId,
|
||||
createdAt`,
|
||||
{ userId },
|
||||
);
|
||||
|
||||
@ -165,72 +145,29 @@ export const getClusteredGraphData = async (userId: string) => {
|
||||
|
||||
result.records.forEach((record) => {
|
||||
const sourceUuid = record.get("sourceUuid");
|
||||
const sourceName = record.get("sourceName");
|
||||
const sourceContent = record.get("sourceContent");
|
||||
const sourceNodeType = record.get("sourceNodeType");
|
||||
|
||||
const targetUuid = record.get("targetUuid");
|
||||
const targetName = record.get("targetName");
|
||||
const targetNodeType = record.get("targetNodeType");
|
||||
|
||||
const predicateLabel = record.get("predicateLabel");
|
||||
const episodeUuid = record.get("episodeUuid");
|
||||
const clusterIds = record.get("spaceIds");
|
||||
const clusterId = clusterIds ? clusterIds[0] : undefined;
|
||||
const edgeType = record.get("edgeType");
|
||||
const clusterId = record.get("clusterId");
|
||||
const createdAt = record.get("createdAt");
|
||||
|
||||
// Create unique edge identifier to avoid duplicates
|
||||
// For Episode->Subject edges, use generic type; for Subject->Object use predicate
|
||||
const edgeType = predicateLabel || "HAS_SUBJECT";
|
||||
const edgeKey = `${sourceUuid}-${targetUuid}-${edgeType}`;
|
||||
if (processedEdges.has(edgeKey)) return;
|
||||
processedEdges.add(edgeKey);
|
||||
|
||||
// Build node attributes based on type
|
||||
const sourceAttributes =
|
||||
sourceNodeType === "Episode"
|
||||
? {
|
||||
triplets.push({
|
||||
sourceNode: {
|
||||
uuid: sourceUuid,
|
||||
labels: ["Episode"],
|
||||
attributes: {
|
||||
nodeType: "Episode",
|
||||
content: sourceContent,
|
||||
episodeUuid: sourceUuid,
|
||||
clusterId,
|
||||
}
|
||||
: {
|
||||
nodeType: "Entity",
|
||||
name: sourceName,
|
||||
clusterId,
|
||||
};
|
||||
|
||||
const targetAttributes =
|
||||
targetNodeType === "Episode"
|
||||
? {
|
||||
nodeType: "Episode",
|
||||
content: sourceContent,
|
||||
episodeUuid: targetUuid,
|
||||
clusterId,
|
||||
}
|
||||
: {
|
||||
nodeType: "Entity",
|
||||
name: targetName,
|
||||
clusterId,
|
||||
};
|
||||
|
||||
// Build display name
|
||||
const sourceDisplayName =
|
||||
sourceNodeType === "Episode"
|
||||
? sourceContent || episodeUuid
|
||||
: sourceName || sourceUuid;
|
||||
const targetDisplayName =
|
||||
targetNodeType === "Episode"
|
||||
? sourceContent || episodeUuid
|
||||
: targetName || targetUuid;
|
||||
|
||||
triplets.push({
|
||||
sourceNode: {
|
||||
uuid: sourceUuid,
|
||||
labels: [sourceNodeType],
|
||||
attributes: sourceAttributes,
|
||||
name: sourceDisplayName,
|
||||
},
|
||||
name: sourceContent || sourceUuid,
|
||||
clusterId,
|
||||
createdAt: createdAt || "",
|
||||
},
|
||||
@ -243,10 +180,14 @@ export const getClusteredGraphData = async (userId: string) => {
|
||||
},
|
||||
targetNode: {
|
||||
uuid: targetUuid,
|
||||
labels: [targetNodeType],
|
||||
attributes: targetAttributes,
|
||||
labels: ["Entity"],
|
||||
attributes: {
|
||||
nodeType: "Entity",
|
||||
name: targetName,
|
||||
clusterId,
|
||||
},
|
||||
name: targetName || targetUuid,
|
||||
clusterId,
|
||||
name: targetDisplayName,
|
||||
createdAt: createdAt || "",
|
||||
},
|
||||
});
|
||||
|
||||
@ -1,6 +1,19 @@
|
||||
import { tool } from "ai";
|
||||
import { StopCondition, tool } from "ai";
|
||||
import z from "zod";
|
||||
|
||||
export const hasAnswer: StopCondition<any> = ({ steps }) => {
|
||||
return (
|
||||
steps.some((step) => step.text?.includes("</final_response>")) ?? false
|
||||
);
|
||||
};
|
||||
|
||||
export const hasQuestion: StopCondition<any> = ({ steps }) => {
|
||||
return (
|
||||
steps.some((step) => step.text?.includes("</question_response>")) ??
|
||||
false
|
||||
);
|
||||
};
|
||||
|
||||
export const REACT_SYSTEM_PROMPT = `
|
||||
You are a helpful AI assistant with access to user memory. Your primary capabilities are:
|
||||
|
||||
@ -128,18 +141,18 @@ PROGRESS UPDATES - During processing:
|
||||
- Avoid technical jargon
|
||||
|
||||
QUESTIONS - When you need information:
|
||||
<div>
|
||||
<question_response>
|
||||
<p>[Your question with HTML formatting]</p>
|
||||
</div>
|
||||
</question_response>
|
||||
|
||||
- Ask questions only when you cannot find information through memory, or tools
|
||||
- Be specific about what you need to know
|
||||
- Provide context for why you're asking
|
||||
|
||||
FINAL ANSWERS - When completing tasks:
|
||||
<div>
|
||||
<final_response>
|
||||
<p>[Your answer with HTML formatting]</p>
|
||||
</div>
|
||||
</final_response>
|
||||
|
||||
CRITICAL:
|
||||
- Use ONE format per turn
|
||||
|
||||
@ -5,6 +5,8 @@ import {
|
||||
type LanguageModel,
|
||||
experimental_createMCPClient as createMCPClient,
|
||||
generateId,
|
||||
stepCountIs,
|
||||
StopCondition,
|
||||
} from "ai";
|
||||
import { z } from "zod";
|
||||
import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js";
|
||||
@ -19,7 +21,7 @@ import { getModel } from "~/lib/model.server";
|
||||
import { UserTypeEnum } from "@core/types";
|
||||
import { nanoid } from "nanoid";
|
||||
import { getOrCreatePersonalAccessToken } from "~/services/personalAccessToken.server";
|
||||
import { REACT_SYSTEM_PROMPT } from "~/lib/prompt.server";
|
||||
import { hasAnswer, hasQuestion, REACT_SYSTEM_PROMPT } from "~/lib/prompt.server";
|
||||
import { enqueueCreateConversationTitle } from "~/lib/queue-adapter.server";
|
||||
import { env } from "~/env.server";
|
||||
|
||||
@ -84,7 +86,7 @@ const { loader, action } = createHybridActionApiRoute(
|
||||
await createConversationHistory(message, body.id, UserTypeEnum.User);
|
||||
}
|
||||
|
||||
const messages = conversationHistory.map((history) => {
|
||||
const messages = conversationHistory.map((history: any) => {
|
||||
return {
|
||||
parts: [{ text: history.message, type: "text" }],
|
||||
role: "user",
|
||||
@ -94,8 +96,6 @@ const { loader, action } = createHybridActionApiRoute(
|
||||
|
||||
const tools = { ...(await mcpClient.tools()) };
|
||||
|
||||
// console.log(tools);
|
||||
|
||||
const finalMessages = [
|
||||
...messages,
|
||||
{
|
||||
@ -109,6 +109,8 @@ const { loader, action } = createHybridActionApiRoute(
|
||||
messages: finalMessages,
|
||||
});
|
||||
|
||||
|
||||
|
||||
const result = streamText({
|
||||
model: getModel() as LanguageModel,
|
||||
messages: [
|
||||
@ -119,6 +121,7 @@ const { loader, action } = createHybridActionApiRoute(
|
||||
...convertToModelMessages(validatedMessages),
|
||||
],
|
||||
tools,
|
||||
stopWhen: [stepCountIs(10), hasAnswer,hasQuestion],
|
||||
});
|
||||
|
||||
result.consumeStream(); // no await
|
||||
|
||||
@ -11,15 +11,17 @@ import {
|
||||
import {
|
||||
convertToModelMessages,
|
||||
type CoreMessage,
|
||||
generateId,
|
||||
generateText,
|
||||
type LanguageModel,
|
||||
stepCountIs,
|
||||
streamText,
|
||||
tool,
|
||||
validateUIMessages,
|
||||
} from "ai";
|
||||
import axios from "axios";
|
||||
import { logger } from "~/services/logger.service";
|
||||
import { getReActPrompt } from "~/lib/prompt.server";
|
||||
import { getReActPrompt, hasAnswer } from "~/lib/prompt.server";
|
||||
import { getModel } from "~/lib/model.server";
|
||||
|
||||
const DeepSearchBodySchema = z.object({
|
||||
@ -39,7 +41,7 @@ export function createSearchMemoryTool(token: string) {
|
||||
return tool({
|
||||
description:
|
||||
"Search the user's memory for relevant facts and episodes. Use this tool multiple times with different queries to gather comprehensive context.",
|
||||
parameters: z.object({
|
||||
inputSchema: z.object({
|
||||
query: z
|
||||
.string()
|
||||
.describe(
|
||||
@ -50,21 +52,14 @@ export function createSearchMemoryTool(token: string) {
|
||||
try {
|
||||
const response = await axios.post(
|
||||
`${process.env.API_BASE_URL || "https://core.heysol.ai"}/api/v1/search`,
|
||||
{ query },
|
||||
{ query, structured: false },
|
||||
{
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
const searchResult = response.data;
|
||||
|
||||
return {
|
||||
facts: searchResult.facts || [],
|
||||
episodes: searchResult.episodes || [],
|
||||
summary: `Found ${searchResult.episodes?.length || 0} relevant memories`,
|
||||
};
|
||||
return response.data;
|
||||
} catch (error) {
|
||||
logger.error(`SearchMemory tool error: ${error}`);
|
||||
return {
|
||||
@ -115,14 +110,11 @@ const { action, loader } = createActionApiRoute(
|
||||
searchMemory: searchTool,
|
||||
};
|
||||
// Build initial messages with ReAct prompt
|
||||
const initialMessages: CoreMessage[] = [
|
||||
{
|
||||
role: "system",
|
||||
content: getReActPrompt(body.metadata, body.intentOverride),
|
||||
},
|
||||
const initialMessages = [
|
||||
{
|
||||
role: "user",
|
||||
content: `CONTENT TO ANALYZE:\n${body.content}\n\nPlease search my memory for relevant context and synthesize what you find.`,
|
||||
parts: [{ type: "text", text: `CONTENT TO ANALYZE:\n${body.content}\n\nPlease search my memory for relevant context and synthesize what you find.` }],
|
||||
id: generateId(),
|
||||
},
|
||||
];
|
||||
|
||||
@ -134,7 +126,15 @@ const { action, loader } = createActionApiRoute(
|
||||
if (body.stream) {
|
||||
const result = streamText({
|
||||
model: getModel() as LanguageModel,
|
||||
messages: convertToModelMessages(validatedMessages),
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: getReActPrompt(body.metadata, body.intentOverride),
|
||||
},
|
||||
...convertToModelMessages(validatedMessages),
|
||||
],
|
||||
tools,
|
||||
stopWhen: [stepCountIs(10), hasAnswer],
|
||||
});
|
||||
|
||||
return result.toUIMessageStreamResponse({
|
||||
@ -143,7 +143,15 @@ const { action, loader } = createActionApiRoute(
|
||||
} else {
|
||||
const { text } = await generateText({
|
||||
model: getModel() as LanguageModel,
|
||||
messages: convertToModelMessages(validatedMessages),
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: getReActPrompt(body.metadata, body.intentOverride),
|
||||
},
|
||||
...convertToModelMessages(validatedMessages),
|
||||
],
|
||||
tools,
|
||||
stopWhen: [stepCountIs(10), hasAnswer],
|
||||
});
|
||||
|
||||
await deletePersonalAccessToken(pat?.id);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user