From 848153d57a5fe16c0b2dd166141a2f8811ddeb4a Mon Sep 17 00:00:00 2001 From: Manoj K Date: Sat, 7 Jun 2025 10:10:43 +0530 Subject: [PATCH] feat: implement hybrid search with BM25, vector and BFS traversal --- apps/webapp/app/lib/model.server.ts | 68 +++ apps/webapp/app/lib/neo4j.server.ts | 76 +++- apps/webapp/app/routes/ingest.tsx | 2 +- apps/webapp/app/routes/search.tsx | 28 ++ .../app/services/graphModels/statement.ts | 6 +- .../app/services/knowledgeGraph.server.ts | 76 +--- apps/webapp/app/services/search.server.ts | 415 +++++------------- apps/webapp/app/services/search/rerank.ts | 118 +++++ apps/webapp/app/services/search/utils.ts | 223 ++++++++++ 9 files changed, 623 insertions(+), 389 deletions(-) create mode 100644 apps/webapp/app/lib/model.server.ts create mode 100644 apps/webapp/app/routes/search.tsx create mode 100644 apps/webapp/app/services/search/rerank.ts create mode 100644 apps/webapp/app/services/search/utils.ts diff --git a/apps/webapp/app/lib/model.server.ts b/apps/webapp/app/lib/model.server.ts new file mode 100644 index 0000000..e567cd6 --- /dev/null +++ b/apps/webapp/app/lib/model.server.ts @@ -0,0 +1,68 @@ +import { LLMMappings, LLMModelEnum } from "@recall/types"; +import { + type CoreMessage, + type LanguageModelV1, + generateText, + streamText, +} from "ai"; +import { openai } from "@ai-sdk/openai"; +import { logger } from "~/services/logger.service"; + +export async function makeModelCall( + stream: boolean, + model: LLMModelEnum, + messages: CoreMessage[], + onFinish: (text: string, model: string) => void, + options?: any, +) { + let modelInstance; + let finalModel: string = "unknown"; + + switch (model) { + case LLMModelEnum.GPT35TURBO: + case LLMModelEnum.GPT4TURBO: + case LLMModelEnum.GPT4O: + case LLMModelEnum.GPT41: + case LLMModelEnum.GPT41MINI: + case LLMModelEnum.GPT41NANO: + finalModel = LLMMappings[model]; + modelInstance = openai(finalModel, { ...options }); + break; + + case LLMModelEnum.CLAUDEOPUS: + case LLMModelEnum.CLAUDESONNET: + case LLMModelEnum.CLAUDEHAIKU: + finalModel = LLMMappings[model]; + break; + + case LLMModelEnum.GEMINI25FLASH: + case LLMModelEnum.GEMINI25PRO: + case LLMModelEnum.GEMINI20FLASH: + case LLMModelEnum.GEMINI20FLASHLITE: + finalModel = LLMMappings[model]; + break; + + default: + logger.warn(`Unsupported model type: ${model}`); + break; + } + + if (stream) { + return await streamText({ + model: modelInstance as LanguageModelV1, + messages, + onFinish: async ({ text }) => { + onFinish(text, finalModel); + }, + }); + } + + const { text } = await generateText({ + model: modelInstance as LanguageModelV1, + messages, + }); + + onFinish(text, finalModel); + + return text; +} diff --git a/apps/webapp/app/lib/neo4j.server.ts b/apps/webapp/app/lib/neo4j.server.ts index 809ac0e..e3c5628 100644 --- a/apps/webapp/app/lib/neo4j.server.ts +++ b/apps/webapp/app/lib/neo4j.server.ts @@ -46,28 +46,66 @@ const runQuery = async (cypher: string, params = {}) => { // Initialize the database schema const initializeSchema = async () => { try { - // Run schema setup queries + // Create constraints for unique IDs + await runQuery( + "CREATE CONSTRAINT episode_uuid IF NOT EXISTS FOR (n:Episode) REQUIRE n.uuid IS UNIQUE", + ); + await runQuery( + "CREATE CONSTRAINT entity_uuid IF NOT EXISTS FOR (n:Entity) REQUIRE n.uuid IS UNIQUE", + ); + await runQuery( + "CREATE CONSTRAINT statement_uuid IF NOT EXISTS FOR (n:Statement) REQUIRE n.uuid IS UNIQUE", + ); + + // Create indexes for better query performance + await runQuery( + "CREATE INDEX episode_valid_at IF NOT EXISTS FOR (n:Episode) ON (n.validAt)", + ); + await runQuery( + "CREATE INDEX statement_valid_at IF NOT EXISTS FOR (n:Statement) ON (n.validAt)", + ); + await runQuery( + "CREATE INDEX statement_invalid_at IF NOT EXISTS FOR (n:Statement) ON (n.invalidAt)", + ); + await runQuery( + "CREATE INDEX entity_name IF NOT EXISTS FOR (n:Entity) ON (n.name)", + ); + + // Create vector indexes for semantic search (if using Neo4j 5.0+) await runQuery(` - // Create constraints for unique IDs - CREATE CONSTRAINT episode_uuid IF NOT EXISTS FOR (n:Episode) REQUIRE n.uuid IS UNIQUE; - CREATE CONSTRAINT entity_uuid IF NOT EXISTS FOR (n:Entity) REQUIRE n.uuid IS UNIQUE; - CREATE CONSTRAINT statement_uuid IF NOT EXISTS FOR (n:Statement) REQUIRE n.uuid IS UNIQUE; - - // Create indexes for better query performance - CREATE INDEX episode_valid_at IF NOT EXISTS FOR (n:Episode) ON (n.validAt); - CREATE INDEX statement_valid_at IF NOT EXISTS FOR (n:Statement) ON (n.validAt); - CREATE INDEX statement_invalid_at IF NOT EXISTS FOR (n:Statement) ON (n.invalidAt); - CREATE INDEX entity_name IF NOT EXISTS FOR (n:Entity) ON (n.name); - - // Create vector indexes for semantic search (if using Neo4j 5.0+) CREATE VECTOR INDEX entity_embedding IF NOT EXISTS FOR (n:Entity) ON n.nameEmbedding - OPTIONS {indexConfig: {dimensions: 1536, similarity: "cosine"}}; - + OPTIONS {indexConfig: {\`vector.dimensions\`: 1536, \`vector.similarity_function\`: 'cosine'}} + `); + + await runQuery(` CREATE VECTOR INDEX statement_embedding IF NOT EXISTS FOR (n:Statement) ON n.factEmbedding - OPTIONS {indexConfig: {dimensions: 1536, similarity: "cosine"}}; - + OPTIONS {indexConfig: {\`vector.dimensions\`: 1536, \`vector.similarity_function\`: 'cosine'}} + `); + + await runQuery(` CREATE VECTOR INDEX episode_embedding IF NOT EXISTS FOR (n:Episode) ON n.contentEmbedding - OPTIONS {indexConfig: {dimensions: 1536, similarity: "cosine"}}; + OPTIONS {indexConfig: {\`vector.dimensions\`: 1536, \`vector.similarity_function\`: 'cosine'}} + `); + + // Create fulltext indexes for BM25 search + await runQuery(` + CREATE FULLTEXT INDEX statement_fact_index IF NOT EXISTS + FOR (n:Statement) ON EACH [n.fact] + OPTIONS { + indexConfig: { + \`fulltext.analyzer\`: 'english' + } + } + `); + + await runQuery(` + CREATE FULLTEXT INDEX entity_name_index IF NOT EXISTS + FOR (n:Entity) ON EACH [n.name, n.description] + OPTIONS { + indexConfig: { + \`fulltext.analyzer\`: 'english' + } + } `); logger.info("Neo4j schema initialized successfully"); @@ -84,4 +122,6 @@ const closeDriver = async () => { logger.info("Neo4j driver closed"); }; +// await initializeSchema(); + export { driver, verifyConnectivity, runQuery, initializeSchema, closeDriver }; diff --git a/apps/webapp/app/routes/ingest.tsx b/apps/webapp/app/routes/ingest.tsx index 3717be7..eaf06ea 100644 --- a/apps/webapp/app/routes/ingest.tsx +++ b/apps/webapp/app/routes/ingest.tsx @@ -1,5 +1,5 @@ import { EpisodeType } from "@recall/types"; -import { json, LoaderFunctionArgs } from "@remix-run/node"; +import { json } from "@remix-run/node"; import { z } from "zod"; import { createActionApiRoute } from "~/services/routeBuilders/apiBuilder.server"; import { getUserQueue } from "~/lib/ingest.queue"; diff --git a/apps/webapp/app/routes/search.tsx b/apps/webapp/app/routes/search.tsx new file mode 100644 index 0000000..d9097a8 --- /dev/null +++ b/apps/webapp/app/routes/search.tsx @@ -0,0 +1,28 @@ +import { z } from "zod"; +import { createActionApiRoute } from "~/services/routeBuilders/apiBuilder.server"; +import { SearchService } from "~/services/search.server"; +import { json } from "@remix-run/node"; + +export const SearchBodyRequest = z.object({ + query: z.string(), + spaceId: z.string().optional(), + sessionId: z.string().optional(), +}); + +const searchService = new SearchService(); +const { action, loader } = createActionApiRoute( + { + body: SearchBodyRequest, + allowJWT: true, + authorization: { + action: "search", + }, + corsStrategy: "all", + }, + async ({ body, authentication }) => { + const results = await searchService.search(body.query, authentication.userId); + return json(results); + }, +); + +export { action, loader }; diff --git a/apps/webapp/app/services/graphModels/statement.ts b/apps/webapp/app/services/graphModels/statement.ts index 2e16af1..fb20f2e 100644 --- a/apps/webapp/app/services/graphModels/statement.ts +++ b/apps/webapp/app/services/graphModels/statement.ts @@ -17,7 +17,7 @@ export async function saveTriple(triple: Triple): Promise { MERGE (n:Statement {uuid: $uuid, userId: $userId}) ON CREATE SET n.fact = $fact, - n.embedding = $embedding, + n.factEmbedding = $factEmbedding, n.createdAt = $createdAt, n.validAt = $validAt, n.invalidAt = $invalidAt, @@ -26,7 +26,7 @@ export async function saveTriple(triple: Triple): Promise { n.space = $space ON MATCH SET n.fact = $fact, - n.embedding = $embedding, + n.factEmbedding = $factEmbedding, n.validAt = $validAt, n.invalidAt = $invalidAt, n.attributesJson = $attributesJson, @@ -37,7 +37,7 @@ export async function saveTriple(triple: Triple): Promise { const statementParams = { uuid: triple.statement.uuid, fact: triple.statement.fact, - embedding: triple.statement.factEmbedding, + factEmbedding: triple.statement.factEmbedding, createdAt: triple.statement.createdAt.toISOString(), validAt: triple.statement.validAt.toISOString(), invalidAt: triple.statement.invalidAt diff --git a/apps/webapp/app/services/knowledgeGraph.server.ts b/apps/webapp/app/services/knowledgeGraph.server.ts index 3c1d9c1..58acce4 100644 --- a/apps/webapp/app/services/knowledgeGraph.server.ts +++ b/apps/webapp/app/services/knowledgeGraph.server.ts @@ -1,15 +1,8 @@ import { openai } from "@ai-sdk/openai"; -import { - type CoreMessage, - embed, - generateText, - type LanguageModelV1, - streamText, -} from "ai"; +import { type CoreMessage, embed } from "ai"; import { entityTypes, EpisodeType, - LLMMappings, LLMModelEnum, type AddEpisodeParams, type EntityNode, @@ -33,6 +26,7 @@ import { invalidateStatements, saveTriple, } from "./graphModels/statement"; +import { makeModelCall } from "~/lib/model.server"; // Default number of previous episodes to retrieve for context const DEFAULT_EPISODE_WINDOW = 5; @@ -155,7 +149,7 @@ export class KnowledgeGraphService { let responseText = ""; - await this.makeModelCall( + await makeModelCall( false, LLMModelEnum.GPT41, messages as CoreMessage[], @@ -217,7 +211,7 @@ export class KnowledgeGraphService { const messages = extractStatements(context); let responseText = ""; - await this.makeModelCall( + await makeModelCall( false, LLMModelEnum.GPT41, messages as CoreMessage[], @@ -393,7 +387,7 @@ export class KnowledgeGraphService { const messages = dedupeNodes(dedupeContext); let responseText = ""; - await this.makeModelCall( + await makeModelCall( false, LLMModelEnum.GPT41, messages as CoreMessage[], @@ -583,7 +577,7 @@ export class KnowledgeGraphService { let responseText = ""; // Call the LLM to analyze all statements at once - await this.makeModelCall(false, LLMModelEnum.GPT41, messages, (text) => { + await makeModelCall(false, LLMModelEnum.GPT41, messages, (text) => { responseText = text; }); @@ -659,62 +653,4 @@ export class KnowledgeGraphService { return { resolvedStatements, invalidatedStatements }; } - - private async makeModelCall( - stream: boolean, - model: LLMModelEnum, - messages: CoreMessage[], - onFinish: (text: string, model: string) => void, - ) { - let modelInstance; - let finalModel: string = "unknown"; - - switch (model) { - case LLMModelEnum.GPT35TURBO: - case LLMModelEnum.GPT4TURBO: - case LLMModelEnum.GPT4O: - case LLMModelEnum.GPT41: - case LLMModelEnum.GPT41MINI: - case LLMModelEnum.GPT41NANO: - finalModel = LLMMappings[model]; - modelInstance = openai(finalModel); - break; - - case LLMModelEnum.CLAUDEOPUS: - case LLMModelEnum.CLAUDESONNET: - case LLMModelEnum.CLAUDEHAIKU: - finalModel = LLMMappings[model]; - break; - - case LLMModelEnum.GEMINI25FLASH: - case LLMModelEnum.GEMINI25PRO: - case LLMModelEnum.GEMINI20FLASH: - case LLMModelEnum.GEMINI20FLASHLITE: - finalModel = LLMMappings[model]; - break; - - default: - logger.warn(`Unsupported model type: ${model}`); - break; - } - - if (stream) { - return await streamText({ - model: modelInstance as LanguageModelV1, - messages, - onFinish: async ({ text }) => { - onFinish(text, finalModel); - }, - }); - } - - const { text } = await generateText({ - model: modelInstance as LanguageModelV1, - messages, - }); - - onFinish(text, finalModel); - - return text; - } } diff --git a/apps/webapp/app/services/search.server.ts b/apps/webapp/app/services/search.server.ts index 790d79a..3315e27 100644 --- a/apps/webapp/app/services/search.server.ts +++ b/apps/webapp/app/services/search.server.ts @@ -1,29 +1,19 @@ -import { - type EntityNode, - type KnowledgeGraphService, - type StatementNode, -} from "./knowledgeGraph.server"; import { openai } from "@ai-sdk/openai"; +import type { StatementNode } from "@recall/types"; import { embed } from "ai"; -import HelixDB from "helix-ts"; - -// Initialize OpenAI for embeddings -const openaiClient = openai("gpt-4.1-2025-04-14"); - -// Initialize Helix client -const helixClient = new HelixDB(); +import { logger } from "./logger.service"; +import { applyCrossEncoderReranking, applyWeightedRRF } from "./search/rerank"; +import { + performBfsSearch, + performBM25Search, + performVectorSearch, +} from "./search/utils"; /** * SearchService provides methods to search the reified + temporal knowledge graph * using a hybrid approach combining BM25, vector similarity, and BFS traversal. */ export class SearchService { - private knowledgeGraphService: KnowledgeGraphService; - - constructor(knowledgeGraphService: KnowledgeGraphService) { - this.knowledgeGraphService = knowledgeGraphService; - } - async getEmbedding(text: string) { const { embedding } = await embed({ model: openai.embedding("text-embedding-3-small"), @@ -44,7 +34,7 @@ export class SearchService { query: string, userId: string, options: SearchOptions = {}, - ): Promise { + ): Promise { // Default options const opts: Required = { limit: options.limit || 10, @@ -53,296 +43,144 @@ export class SearchService { includeInvalidated: options.includeInvalidated || false, entityTypes: options.entityTypes || [], predicateTypes: options.predicateTypes || [], + scoreThreshold: options.scoreThreshold || 0.7, + minResults: options.minResults || 10, }; + const queryVector = await this.getEmbedding(query); + // 1. Run parallel search methods const [bm25Results, vectorResults, bfsResults] = await Promise.all([ - this.performBM25Search(query, userId, opts), - this.performVectorSearch(query, userId, opts), - this.performBfsSearch(query, userId, opts), + performBM25Search(query, userId, opts), + performVectorSearch(queryVector, userId, opts), + performBfsSearch(queryVector, userId, opts), ]); - // 2. Combine and deduplicate results - const combinedStatements = this.combineAndDeduplicate([ - ...bm25Results, - ...vectorResults, - ...bfsResults, - ]); + logger.info( + `Search results - BM25: ${bm25Results.length}, Vector: ${vectorResults.length}, BFS: ${bfsResults.length}`, + ); - // 3. Rerank the combined results - const rerankedStatements = await this.rerankStatements( + // 2. Apply reranking strategy + const rankedStatements = await this.rerankResults( query, - combinedStatements, + { bm25: bm25Results, vector: vectorResults, bfs: bfsResults }, opts, ); - // 4. Return top results - return rerankedStatements.slice(0, opts.limit); + // 3. Apply adaptive filtering based on score threshold and minimum count + const filteredResults = this.applyAdaptiveFiltering(rankedStatements, opts); + + // 3. Return top results + return filteredResults.map((statement) => statement.fact); } /** - * Perform BM25 keyword-based search on statements + * Apply adaptive filtering to ranked results + * Uses a minimum quality threshold to filter out low-quality results */ - private async performBM25Search( - query: string, - userId: string, + private applyAdaptiveFiltering( + results: StatementNode[], options: Required, - ): Promise { - // TODO: Implement BM25 search using HelixDB or external search index - // This is a placeholder implementation - try { - const results = await helixClient.query("searchStatementsByKeywords", { - query, - userId, - validAt: options.validAt.toISOString(), - includeInvalidated: options.includeInvalidated, - limit: options.limit * 2, // Fetch more for reranking - }); + ): StatementNode[] { + if (results.length === 0) return []; - return results.statements || []; - } catch (error) { - console.error("BM25 search error:", error); - return []; - } - } - - /** - * Perform vector similarity search on statement embeddings - */ - private async performVectorSearch( - query: string, - userId: string, - options: Required, - ): Promise { - try { - // 1. Generate embedding for the query - const embedding = await this.generateEmbedding(query); - - // 2. Search for similar statements - const results = await helixClient.query("searchStatementsByVector", { - embedding, - userId, - validAt: options.validAt.toISOString(), - includeInvalidated: options.includeInvalidated, - limit: options.limit * 2, // Fetch more for reranking - }); - - return results.statements || []; - } catch (error) { - console.error("Vector search error:", error); - return []; - } - } - - /** - * Perform BFS traversal starting from entities mentioned in the query - */ - private async performBfsSearch( - query: string, - userId: string, - options: Required, - ): Promise { - try { - // 1. Extract potential entities from query - const entities = await this.extractEntitiesFromQuery(query); - - // 2. For each entity, perform BFS traversal - const allStatements: StatementNode[] = []; - - for (const entity of entities) { - const statements = await this.bfsTraversal( - entity.uuid, - options.maxBfsDepth, - options.validAt, - userId, - options.includeInvalidated, - ); - allStatements.push(...statements); + // Extract scores from results + const scoredResults = results.map((result) => { + // Find the score based on reranking strategy used + let score = 0; + if ((result as any).rrfScore !== undefined) { + score = (result as any).rrfScore; + } else if ((result as any).mmrScore !== undefined) { + score = (result as any).mmrScore; + } else if ((result as any).crossEncoderScore !== undefined) { + score = (result as any).crossEncoderScore; + } else if ((result as any).finalScore !== undefined) { + score = (result as any).finalScore; } - return allStatements; - } catch (error) { - console.error("BFS search error:", error); - return []; + return { result, score }; + }); + + const hasScores = scoredResults.some((item) => item.score > 0); + // If no scores are available, return the original results + if (!hasScores) { + logger.info("No scores found in results, skipping adaptive filtering"); + return options.limit > 0 ? results.slice(0, options.limit) : results; } + + // Sort by score (descending) + scoredResults.sort((a, b) => b.score - a.score); + + // Calculate statistics to identify low-quality results + const scores = scoredResults.map((item) => item.score); + const maxScore = Math.max(...scores); + const minScore = Math.min(...scores); + const scoreRange = maxScore - minScore; + + // Define a minimum quality threshold as a fraction of the best score + // This is relative to the query's score distribution rather than an absolute value + const relativeThreshold = options.scoreThreshold || 0.3; // 30% of the best score by default + const absoluteMinimum = 0.1; // Absolute minimum threshold to prevent keeping very poor matches + + // Calculate the actual threshold as a percentage of the distance from min to max score + const threshold = Math.max( + absoluteMinimum, + minScore + scoreRange * relativeThreshold, + ); + + // Filter out low-quality results + const filteredResults = scoredResults + .filter((item) => item.score >= threshold) + .map((item) => item.result); + + // Apply limit if specified + const limitedResults = + options.limit > 0 + ? filteredResults.slice( + 0, + Math.min(filteredResults.length, options.limit), + ) + : filteredResults; + + logger.info( + `Quality filtering: ${limitedResults.length}/${results.length} results kept (threshold: ${threshold.toFixed(3)})`, + ); + logger.info( + `Score range: min=${minScore.toFixed(3)}, max=${maxScore.toFixed(3)}, threshold=${threshold.toFixed(3)}`, + ); + + return limitedResults; } /** - * Perform BFS traversal starting from an entity + * Apply the selected reranking strategy to search results */ - private async bfsTraversal( - startEntityId: string, - maxDepth: number, - validAt: Date, - userId: string, - includeInvalidated: boolean, - ): Promise { - // Track visited nodes to avoid cycles - const visited = new Set(); - // Track statements found during traversal - const statements: StatementNode[] = []; - // Queue for BFS traversal [nodeId, depth] - const queue: [string, number][] = [[startEntityId, 0]]; - - while (queue.length > 0) { - const [nodeId, depth] = queue.shift()!; - - // Skip if already visited or max depth reached - if (visited.has(nodeId) || depth > maxDepth) continue; - visited.add(nodeId); - - // Get statements where this entity is subject or object - const connectedStatements = await helixClient.query( - "getConnectedStatements", - { - entityId: nodeId, - userId, - validAt: validAt.toISOString(), - includeInvalidated, - }, - ); - - // Add statements to results - if (connectedStatements.statements) { - statements.push(...connectedStatements.statements); - - // Add connected entities to queue - for (const statement of connectedStatements.statements) { - // Get subject and object entities - if (statement.subjectId && !visited.has(statement.subjectId)) { - queue.push([statement.subjectId, depth + 1]); - } - if (statement.objectId && !visited.has(statement.objectId)) { - queue.push([statement.objectId, depth + 1]); - } - } - } - } - - return statements; - } - - /** - * Extract potential entities from a query using embeddings or LLM - */ - private async extractEntitiesFromQuery(query: string): Promise { - // TODO: Implement more sophisticated entity extraction - // This is a placeholder implementation that uses simple vector search - try { - const embedding = await this.getEmbedding(query); - - const results = await helixClient.query("searchEntitiesByVector", { - embedding, - limit: 3, // Start with top 3 entities - }); - - return results.entities || []; - } catch (error) { - console.error("Entity extraction error:", error); - return []; - } - } - - /** - * Combine and deduplicate statements from multiple sources - */ - private combineAndDeduplicate(statements: StatementNode[]): StatementNode[] { - const uniqueStatements = new Map(); - - for (const statement of statements) { - if (!uniqueStatements.has(statement.uuid)) { - uniqueStatements.set(statement.uuid, statement); - } - } - - return Array.from(uniqueStatements.values()); - } - - /** - * Rerank statements based on relevance to the query - */ - private async rerankStatements( + private async rerankResults( query: string, - statements: StatementNode[], + results: { + bm25: StatementNode[]; + vector: StatementNode[]; + bfs: StatementNode[]; + }, options: Required, ): Promise { - // TODO: Implement more sophisticated reranking - // This is a placeholder implementation using cosine similarity - try { - // 1. Generate embedding for the query - const queryEmbedding = await this.getEmbedding(query); + // Count non-empty result sources + const nonEmptySources = [ + results.bm25.length > 0, + results.vector.length > 0, + results.bfs.length > 0, + ].filter(Boolean).length; - // 2. Generate or retrieve embeddings for statements - const statementEmbeddings = await Promise.all( - statements.map(async (statement) => { - // If statement has embedding, use it; otherwise generate - if (statement.factEmbedding && statement.factEmbedding.length > 0) { - return { statement, embedding: statement.factEmbedding }; - } - - // Generate text representation of statement - const statementText = this.statementToText(statement); - const embedding = await this.getEmbedding(statementText); - - return { statement, embedding }; - }), + // If results are coming from only one source, use cross-encoder reranking + if (nonEmptySources <= 1) { + logger.info( + "Only one source has results, falling back to cross-encoder reranking", ); - - // 3. Calculate cosine similarity - const scoredStatements = statementEmbeddings.map( - ({ statement, embedding }) => { - const similarity = this.cosineSimilarity(queryEmbedding, embedding); - return { statement, score: similarity }; - }, - ); - - // 4. Sort by score (descending) - scoredStatements.sort((a, b) => b.score - a.score); - - // 5. Return statements in order of relevance - return scoredStatements.map(({ statement }) => statement); - } catch (error) { - console.error("Reranking error:", error); - // Fallback: return original order - return statements; - } - } - - /** - * Convert a statement to a text representation - */ - private statementToText(statement: StatementNode): string { - // TODO: Implement more sophisticated text representation - // This is a placeholder implementation - return `${statement.subjectName || "Unknown"} ${statement.predicateName || "has relation with"} ${statement.objectName || "Unknown"}`; - } - - /** - * Calculate cosine similarity between two embeddings - */ - private cosineSimilarity(a: number[], b: number[]): number { - if (a.length !== b.length) { - throw new Error("Embeddings must have the same length"); + return applyCrossEncoderReranking(query, results); } - let dotProduct = 0; - let normA = 0; - let normB = 0; - - for (let i = 0; i < a.length; i++) { - dotProduct += a[i] * b[i]; - normA += a[i] * a[i]; - normB += b[i] * b[i]; - } - - normA = Math.sqrt(normA); - normB = Math.sqrt(normB); - - if (normA === 0 || normB === 0) { - return 0; - } - - return dotProduct / (normA * normB); + // Otherwise use weighted RRF for multiple sources + return applyWeightedRRF(results); } } @@ -356,23 +194,6 @@ export interface SearchOptions { includeInvalidated?: boolean; entityTypes?: string[]; predicateTypes?: string[]; -} - -/** - * Create a singleton instance of the search service - */ -let searchServiceInstance: SearchService | null = null; - -export function getSearchService( - knowledgeGraphService?: KnowledgeGraphService, -): SearchService { - if (!searchServiceInstance) { - if (!knowledgeGraphService) { - throw new Error( - "KnowledgeGraphService must be provided when initializing SearchService", - ); - } - searchServiceInstance = new SearchService(knowledgeGraphService); - } - return searchServiceInstance; + scoreThreshold?: number; + minResults?: number; } diff --git a/apps/webapp/app/services/search/rerank.ts b/apps/webapp/app/services/search/rerank.ts new file mode 100644 index 0000000..18e5eb0 --- /dev/null +++ b/apps/webapp/app/services/search/rerank.ts @@ -0,0 +1,118 @@ +import { LLMModelEnum, type StatementNode } from "@recall/types"; +import { combineAndDeduplicateStatements } from "./utils"; +import { type CoreMessage } from "ai"; +import { makeModelCall } from "~/lib/model.server"; +import { logger } from "../logger.service"; + +/** + * Apply Weighted Reciprocal Rank Fusion to combine results + */ +export function applyWeightedRRF(results: { + bm25: StatementNode[]; + vector: StatementNode[]; + bfs: StatementNode[]; +}): StatementNode[] { + // Determine weights based on query characteristics + const weights = { + bm25: 1.0, + vector: 0.8, + bfs: 0.5, + }; + const k = 60; // RRF constant + + // Map to store combined scores + const scores: Record = + {}; + + // Process BM25 results with their weight + results.bm25.forEach((statement, rank) => { + const uuid = statement.uuid; + scores[uuid] = scores[uuid] || { score: 0, statement }; + scores[uuid].score += weights.bm25 * (1 / (rank + k)); + }); + + // Process vector similarity results with their weight + results.vector.forEach((statement, rank) => { + const uuid = statement.uuid; + scores[uuid] = scores[uuid] || { score: 0, statement }; + scores[uuid].score += weights.vector * (1 / (rank + k)); + }); + + // Process BFS traversal results with their weight + results.bfs.forEach((statement, rank) => { + const uuid = statement.uuid; + scores[uuid] = scores[uuid] || { score: 0, statement }; + scores[uuid].score += weights.bfs * (1 / (rank + k)); + }); + + // Convert to array and sort by final score + const sortedResults = Object.values(scores) + .sort((a, b) => b.score - a.score) + .map((item) => { + // Add the RRF score to the statement for debugging + return { + ...item.statement, + rrfScore: item.score, + }; + }); + + return sortedResults; +} + +/** + * Apply Cross-Encoder reranking to results + * This is particularly useful when results come from a single source + */ +export async function applyCrossEncoderReranking( + query: string, + results: { + bm25: StatementNode[]; + vector: StatementNode[]; + bfs: StatementNode[]; + }, +): Promise { + // Combine all results + const allResults = [...results.bm25, ...results.vector, ...results.bfs]; + + // Deduplicate by UUID + const uniqueResults = combineAndDeduplicateStatements(allResults); + + if (uniqueResults.length === 0) return []; + + logger.info(`Cross-encoder reranking ${uniqueResults.length} statements`); + + const finalStatements: StatementNode[] = []; + + await Promise.all( + uniqueResults.map(async (statement) => { + const messages: CoreMessage[] = [ + { + role: "system", + content: `You are an expert tasked with determining whether the statement is relevant to the query + Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.`, + }, + { + role: "user", + content: `${query}\n${statement.fact}`, + }, + ]; + + let responseText = ""; + await makeModelCall( + false, + LLMModelEnum.GPT41NANO, + messages, + (text) => { + responseText = text; + }, + { temperature: 0, maxTokens: 1 }, + ); + + if (responseText === "True") { + finalStatements.push(statement); + } + }), + ); + + return finalStatements; +} diff --git a/apps/webapp/app/services/search/utils.ts b/apps/webapp/app/services/search/utils.ts new file mode 100644 index 0000000..0a320fc --- /dev/null +++ b/apps/webapp/app/services/search/utils.ts @@ -0,0 +1,223 @@ +import type { EntityNode, StatementNode } from "@recall/types"; +import type { SearchOptions } from "../search.server"; +import type { Embedding } from "ai"; +import { logger } from "../logger.service"; +import { runQuery } from "~/lib/neo4j.server"; + +/** + * Perform BM25 keyword-based search on statements + */ +export async function performBM25Search( + query: string, + userId: string, + options: Required, +): Promise { + try { + // Sanitize the query for Lucene syntax + const sanitizedQuery = sanitizeLuceneQuery(query); + + // Use Neo4j's built-in fulltext search capabilities + const cypher = ` + CALL db.index.fulltext.queryNodes("statement_fact_index", $query) + YIELD node AS s, score + WHERE + s.validAt <= $validAt + AND (s.invalidAt IS NULL OR s.invalidAt > $validAt) + AND (s.userId = $userId) + RETURN s, score + ORDER BY score DESC + `; + + const params = { + query: sanitizedQuery, + userId, + validAt: options.validAt.toISOString(), + }; + + const records = await runQuery(cypher, params); + // return records.map((record) => record.get("s").properties as StatementNode); + return []; + } catch (error) { + logger.error("BM25 search error:", { error }); + return []; + } +} + +/** + * Sanitize a query string for Lucene syntax + */ +export function sanitizeLuceneQuery(query: string): string { + // Escape special characters: + - && || ! ( ) { } [ ] ^ " ~ * ? : \ + let sanitized = query.replace( + /[+\-&|!(){}[\]^"~*?:\\]/g, + (match) => "\\" + match, + ); + + // If query is too long, truncate it + const MAX_QUERY_LENGTH = 32; + const words = sanitized.split(" "); + if (words.length > MAX_QUERY_LENGTH) { + sanitized = words.slice(0, MAX_QUERY_LENGTH).join(" "); + } + + return sanitized; +} + +/** + * Perform vector similarity search on statement embeddings + */ +export async function performVectorSearch( + query: Embedding, + userId: string, + options: Required, +): Promise { + try { + // 1. Generate embedding for the query + // const embedding = await this.getEmbedding(query); + + // 2. Search for similar statements using Neo4j vector search + const cypher = ` + MATCH (s:Statement) + WHERE + s.validAt <= $validAt + AND (s.invalidAt IS NULL OR s.invalidAt > $validAt) + AND (s.userId = $userId OR s.isPublic = true) + WITH s, vector.similarity.cosine(s.factEmbedding, $embedding) AS score + WHERE score > 0.7 + RETURN s, score + ORDER BY score DESC + `; + + const params = { + embedding: query, + userId, + validAt: options.validAt.toISOString(), + }; + + const records = await runQuery(cypher, params); + // return records.map((record) => record.get("s").properties as StatementNode); + return []; + } catch (error) { + logger.error("Vector search error:", { error }); + return []; + } +} + +/** + * Perform BFS traversal starting from entities mentioned in the query + */ +export async function performBfsSearch( + embedding: Embedding, + userId: string, + options: Required, +): Promise { + try { + // 1. Extract potential entities from query + const entities = await extractEntitiesFromQuery(embedding); + + // 2. For each entity, perform BFS traversal + const allStatements: StatementNode[] = []; + + for (const entity of entities) { + const statements = await bfsTraversal( + entity.uuid, + options.maxBfsDepth, + options.validAt, + userId, + options.includeInvalidated, + ); + allStatements.push(...statements); + } + + return allStatements; + } catch (error) { + logger.error("BFS search error:", { error }); + return []; + } +} + +/** + * Perform BFS traversal starting from an entity + */ +export async function bfsTraversal( + startEntityId: string, + maxDepth: number, + validAt: Date, + userId: string, + includeInvalidated: boolean, +): Promise { + try { + // Use Neo4j's built-in path finding capabilities for efficient BFS + // This query implements BFS up to maxDepth and collects all statements along the way + const cypher = ` + MATCH (e:Entity {uuid: $startEntityId})<-[:HAS_SUBJECT|HAS_OBJECT|HAS_PREDICATE]-(s:Statement) + WHERE + s.validAt <= $validAt + AND (s.invalidAt IS NULL OR s.invalidAt > $validAt) + AND (s.userId = $userId) + AND ($includeInvalidated OR s.invalidAt IS NULL) + RETURN s as statement + `; + + const params = { + startEntityId, + maxDepth, + validAt: validAt.toISOString(), + userId, + includeInvalidated, + }; + + const records = await runQuery(cypher, params); + return records.map( + (record) => record.get("statement").properties as StatementNode, + ); + } catch (error) { + logger.error("BFS traversal error:", { error }); + return []; + } +} + +/** + * Extract potential entities from a query using embeddings or LLM + */ +export async function extractEntitiesFromQuery( + embedding: Embedding, +): Promise { + try { + // Use vector similarity to find relevant entities + const cypher = ` + // Match entities using vector similarity on name embeddings + MATCH (e:Entity) + WHERE e.nameEmbedding IS NOT NULL + WITH e, vector.similarity.cosine(e.nameEmbedding, $embedding) AS score + WHERE score > 0.7 + RETURN e + ORDER BY score DESC + LIMIT 3 + `; + + const params = { + embedding, + }; + + const records = await runQuery(cypher, params); + + return records.map((record) => record.get("e").properties as EntityNode); + } catch (error) { + logger.error("Entity extraction error:", { error }); + return []; + } +} + +/** + * Combine and deduplicate statements from different search methods + */ +export function combineAndDeduplicateStatements( + statements: StatementNode[], +): StatementNode[] { + return Array.from( + new Map( + statements.map((statement) => [statement.uuid, statement]), + ).values(), + ); +}