import type { EpisodicNode, StatementNode } from "@core/types"; import { logger } from "./logger.service"; import { applyLLMReranking } from "./search/rerank"; import { getEpisodesByStatements, performBfsSearch, performBM25Search, performVectorSearch, } from "./search/utils"; import { getEmbedding } from "~/lib/model.server"; import { prisma } from "~/db.server"; import { runQuery } from "~/lib/neo4j.server"; /** * SearchService provides methods to search the reified + temporal knowledge graph * using a hybrid approach combining BM25, vector similarity, and BFS traversal. */ export class SearchService { async getEmbedding(text: string) { return getEmbedding(text); } /** * Search the knowledge graph using a hybrid approach * @param query The search query * @param userId The user ID for personalization * @param options Search options * @returns Markdown formatted context (default) or structured JSON (if structured: true) */ public async search( query: string, userId: string, options: SearchOptions = {}, source?: string, ): Promise { const startTime = Date.now(); // Default options const opts: Required = { limit: options.limit || 100, maxBfsDepth: options.maxBfsDepth || 4, validAt: options.validAt || new Date(), startTime: options.startTime || null, endTime: options.endTime || new Date(), includeInvalidated: options.includeInvalidated || true, entityTypes: options.entityTypes || [], predicateTypes: options.predicateTypes || [], scoreThreshold: options.scoreThreshold || 0.7, minResults: options.minResults || 10, spaceIds: options.spaceIds || [], adaptiveFiltering: options.adaptiveFiltering || false, structured: options.structured || false, }; const queryVector = await this.getEmbedding(query); // 1. Run parallel search methods const [bm25Results, vectorResults, bfsResults] = await Promise.all([ performBM25Search(query, userId, opts), performVectorSearch(queryVector, userId, opts), performBfsSearch(query, queryVector, userId, opts), ]); logger.info( `Search results - BM25: ${bm25Results.length}, Vector: ${vectorResults.length}, BFS: ${bfsResults.length}`, ); // 2. Apply reranking strategy const rankedStatements = await this.rerankResults( query, userId, { bm25: bm25Results, vector: vectorResults, bfs: bfsResults }, opts, ); // 3. Apply adaptive filtering based on score threshold and minimum count const filteredResults = this.applyAdaptiveFiltering(rankedStatements, opts); // 3. Return top results const episodes = await getEpisodesByStatements( filteredResults.map((item) => item.statement), ); // Log recall asynchronously (don't await to avoid blocking response) const responseTime = Date.now() - startTime; this.logRecallAsync( query, userId, filteredResults.map((item) => item.statement), opts, responseTime, source, ).catch((error) => { logger.error("Failed to log recall event:", error); }); this.updateRecallCount( userId, episodes, filteredResults.map((item) => item.statement), ); // Replace session episodes with compacts automatically const unifiedEpisodes = await this.replaceWithCompacts(episodes, userId); const factsData = filteredResults.map((statement) => ({ fact: statement.statement.fact, validAt: statement.statement.validAt, invalidAt: statement.statement.invalidAt || null, relevantScore: statement.score, })); // Return markdown by default, structured JSON if requested if (opts.structured) { return { episodes: unifiedEpisodes, facts: factsData, }; } // Return markdown formatted context return this.formatAsMarkdown(unifiedEpisodes, factsData); } /** * Apply adaptive filtering to ranked results * Uses a minimum quality threshold to filter out low-quality results */ private applyAdaptiveFiltering( results: StatementNode[], options: Required, ): { statement: StatementNode; score: number }[] { if (results.length === 0) return []; let isRRF = false; // Extract scores from results const scoredResults = results.map((result) => { // Find the score based on reranking strategy used let score = 0; if ((result as any).rrfScore !== undefined) { score = (result as any).rrfScore; isRRF = true; } else if ((result as any).mmrScore !== undefined) { score = (result as any).mmrScore; } else if ((result as any).crossEncoderScore !== undefined) { score = (result as any).crossEncoderScore; } else if ((result as any).finalScore !== undefined) { score = (result as any).finalScore; } else if ((result as any).multifactorScore !== undefined) { score = (result as any).multifactorScore; } else if ((result as any).combinedScore !== undefined) { score = (result as any).combinedScore; } else if ((result as any).mmrScore !== undefined) { score = (result as any).mmrScore; } else if ((result as any).cohereScore !== undefined) { score = (result as any).cohereScore; } return { statement: result, score }; }); if (!options.adaptiveFiltering || results.length <= 5) { return scoredResults; } const hasScores = scoredResults.some((item) => item.score > 0); // If no scores are available, return the original results if (!hasScores) { logger.info("No scores found in results, skipping adaptive filtering"); return options.limit > 0 ? results .slice(0, options.limit) .map((item) => ({ statement: item, score: 0 })) : results.map((item) => ({ statement: item, score: 0 })); } // Sort by score (descending) scoredResults.sort((a, b) => b.score - a.score); // Calculate statistics to identify low-quality results const scores = scoredResults.map((item) => item.score); const maxScore = Math.max(...scores); const minScore = Math.min(...scores); const scoreRange = maxScore - minScore; let threshold = 0; if (isRRF || scoreRange < 0.01) { // For RRF scores, use a more lenient adaptive approach // Calculate median score and use a dynamic threshold based on score distribution const sortedScores = [...scores].sort((a, b) => b - a); const medianIndex = Math.floor(sortedScores.length / 2); const medianScore = sortedScores[medianIndex]; // Use the smaller of: 20% of max score or 50% of median score // This is more lenient for broad queries while still filtering noise const maxBasedThreshold = maxScore * 0.2; const medianBasedThreshold = medianScore * 0.5; threshold = Math.min(maxBasedThreshold, medianBasedThreshold); // Ensure we keep at least minResults if available const minResultsCount = Math.min( options.minResults, scoredResults.length, ); if (scoredResults.length >= minResultsCount) { const minResultsThreshold = scoredResults[minResultsCount - 1].score; threshold = Math.min(threshold, minResultsThreshold); } } else { // For normal score distributions, use the relative threshold approach const relativeThreshold = options.scoreThreshold || 0.3; const absoluteMinimum = 0.1; threshold = Math.max( absoluteMinimum, minScore + scoreRange * relativeThreshold, ); } // Filter out low-quality results const filteredResults = scoredResults .filter((item) => item.score >= threshold) .map((item) => ({ statement: item.statement, score: item.score })); // Apply limit if specified const limitedResults = options.limit > 0 ? filteredResults.slice( 0, Math.min(filteredResults.length, options.limit), ) : filteredResults; logger.info( `Quality filtering: ${limitedResults.length}/${results.length} results kept (threshold: ${threshold.toFixed(3)})`, ); logger.info( `Score range: min=${minScore.toFixed(3)}, max=${maxScore.toFixed(3)}, threshold=${threshold.toFixed(3)}`, ); return limitedResults; } /** * Apply the selected reranking strategy to search results */ private async rerankResults( query: string, userId: string, results: { bm25: StatementNode[]; vector: StatementNode[]; bfs: StatementNode[]; }, options: Required, ): Promise { // Fetch user profile for context const user = await prisma.user.findUnique({ where: { id: userId }, select: { name: true, id: true }, }); const userContext = user ? { name: user.name ?? undefined, userId: user.id } : undefined; return applyLLMReranking(query, results, options.limit, userContext); } private async logRecallAsync( query: string, userId: string, results: StatementNode[], options: Required, responseTime: number, source?: string, ): Promise { try { // Determine target type based on results let targetType = "mixed_results"; if (results.length === 1) { targetType = "statement"; } else if (results.length === 0) { targetType = "no_results"; } // Calculate average similarity score if available let averageSimilarityScore: number | null = null; const scoresWithValues = results .map((result) => { // Try to extract score from various possible score fields const score = (result as any).rrfScore || (result as any).mmrScore || (result as any).crossEncoderScore || (result as any).finalScore || (result as any).score; return score && typeof score === "number" ? score : null; }) .filter((score): score is number => score !== null); if (scoresWithValues.length > 0) { averageSimilarityScore = scoresWithValues.reduce((sum, score) => sum + score, 0) / scoresWithValues.length; } await prisma.recallLog.create({ data: { accessType: "search", query, targetType, searchMethod: "hybrid", // BM25 + Vector + BFS minSimilarity: options.scoreThreshold, maxResults: options.limit, resultCount: results.length, similarityScore: averageSimilarityScore, context: JSON.stringify({ entityTypes: options.entityTypes, predicateTypes: options.predicateTypes, maxBfsDepth: options.maxBfsDepth, includeInvalidated: options.includeInvalidated, validAt: options.validAt.toISOString(), startTime: options.startTime?.toISOString() || null, endTime: options.endTime.toISOString(), }), source: source ?? "search_api", responseTimeMs: responseTime, userId, }, }); logger.debug( `Logged recall event for user ${userId}: ${results.length} results in ${responseTime}ms`, ); } catch (error) { logger.error("Error creating recall log entry:", { error }); // Don't throw - we don't want logging failures to affect the search response } } private async updateRecallCount( userId: string, episodes: EpisodicNode[], statements: StatementNode[], ) { const episodeIds = episodes.map((episode) => episode.uuid); const statementIds = statements.map((statement) => statement.uuid); const cypher = ` MATCH (e:Episode) WHERE e.uuid IN $episodeUuids and e.userId = $userId SET e.recallCount = coalesce(e.recallCount, 0) + 1 `; await runQuery(cypher, { episodeUuids: episodeIds, userId }); const cypher2 = ` MATCH (s:Statement) WHERE s.uuid IN $statementUuids and s.userId = $userId SET s.recallCount = coalesce(s.recallCount, 0) + 1 `; await runQuery(cypher2, { statementUuids: statementIds, userId }); } /** * Format search results as markdown for agent consumption */ private formatAsMarkdown( episodes: Array<{ content: string; createdAt: Date; spaceIds: string[]; isCompact?: boolean; }>, facts: Array<{ fact: string; validAt: Date; invalidAt: Date | null; relevantScore: number; }>, ): string { const sections: string[] = []; // Add episodes/compacts section if (episodes.length > 0) { sections.push("## Recalled Relevant Context\n"); episodes.forEach((episode, index) => { const date = episode.createdAt.toLocaleString("en-US", { month: "short", day: "numeric", year: "numeric", hour: "2-digit", minute: "2-digit", }); if (episode.isCompact) { sections.push(`### 📦 Session Compact`); sections.push(`**Created**: ${date}\n`); sections.push(episode.content); sections.push(""); // Empty line } else { sections.push(`### Episode ${index + 1}`); sections.push(`**Created**: ${date}`); if (episode.spaceIds.length > 0) { sections.push(`**Spaces**: ${episode.spaceIds.join(", ")}`); } sections.push(""); // Empty line before content sections.push(episode.content); sections.push(""); // Empty line after } }); } // Add facts section if (facts.length > 0) { sections.push("## Key Facts\n"); facts.forEach((fact) => { const validDate = fact.validAt.toLocaleString("en-US", { month: "short", day: "numeric", year: "numeric", }); const invalidInfo = fact.invalidAt ? ` → Invalidated ${fact.invalidAt.toLocaleString("en-US", { month: "short", day: "numeric", year: "numeric" })}` : ""; sections.push(`- ${fact.fact}`); sections.push(` *Valid from ${validDate}${invalidInfo}*`); }); sections.push(""); // Empty line after facts } // Handle empty results if (episodes.length === 0 && facts.length === 0) { sections.push("*No relevant memories found.*\n"); } return sections.join("\n"); } /** * Replace session episodes with their compacted sessions * Returns unified array with both regular episodes and compacts */ private async replaceWithCompacts( episodes: EpisodicNode[], userId: string, ): Promise> { // Group episodes by sessionId const sessionEpisodes = new Map(); const nonSessionEpisodes: EpisodicNode[] = []; for (const episode of episodes) { // Skip episodes with documentId (these are document chunks, not session episodes) if (episode.metadata?.documentUuid) { nonSessionEpisodes.push(episode); continue; } // Episodes with sessionId - group them if (episode.sessionId) { if (!sessionEpisodes.has(episode.sessionId)) { sessionEpisodes.set(episode.sessionId, []); } sessionEpisodes.get(episode.sessionId)!.push(episode); } else { // No sessionId - keep as regular episode nonSessionEpisodes.push(episode); } } // Build unified result array const result: Array<{ content: string; createdAt: Date; spaceIds: string[]; isCompact?: boolean; }> = []; // Add non-session episodes first for (const episode of nonSessionEpisodes) { result.push({ content: episode.originalContent, createdAt: episode.createdAt, spaceIds: episode.spaceIds || [], }); } // Check each session for compacts const { getCompactedSessionBySessionId } = await import( "~/services/graphModels/compactedSession" ); const sessionIds = Array.from(sessionEpisodes.keys()); for (const sessionId of sessionIds) { const sessionEps = sessionEpisodes.get(sessionId)!; const compact = await getCompactedSessionBySessionId(sessionId, userId); if (compact) { // Compact exists - add compact as episode, skip original episodes result.push({ content: compact.summary, createdAt: compact.startTime, // Use session start time spaceIds: [], // Compacts don't have spaceIds directly isCompact: true, }); logger.info(`Replaced ${sessionEps.length} episodes with compact`, { sessionId, episodeCount: sessionEps.length, }); } else { // No compact - add original episodes for (const episode of sessionEps) { result.push({ content: episode.originalContent, createdAt: episode.createdAt, spaceIds: episode.spaceIds || [], }); } } } return result; } } /** * Search options interface */ export interface SearchOptions { limit?: number; maxBfsDepth?: number; validAt?: Date; startTime?: Date | null; endTime?: Date; includeInvalidated?: boolean; entityTypes?: string[]; predicateTypes?: string[]; scoreThreshold?: number; minResults?: number; spaceIds?: string[]; // Filter results by specific spaces adaptiveFiltering?: boolean; structured?: boolean; // Return structured JSON instead of markdown (default: false) }