feat: implement hybrid search with BM25, vector and BFS traversal

This commit is contained in:
Manoj K 2025-06-07 10:10:43 +05:30
parent cf20da9ecd
commit 848153d57a
9 changed files with 623 additions and 389 deletions

View File

@ -0,0 +1,68 @@
import { LLMMappings, LLMModelEnum } from "@recall/types";
import {
type CoreMessage,
type LanguageModelV1,
generateText,
streamText,
} from "ai";
import { openai } from "@ai-sdk/openai";
import { logger } from "~/services/logger.service";
export async function makeModelCall(
stream: boolean,
model: LLMModelEnum,
messages: CoreMessage[],
onFinish: (text: string, model: string) => void,
options?: any,
) {
let modelInstance;
let finalModel: string = "unknown";
switch (model) {
case LLMModelEnum.GPT35TURBO:
case LLMModelEnum.GPT4TURBO:
case LLMModelEnum.GPT4O:
case LLMModelEnum.GPT41:
case LLMModelEnum.GPT41MINI:
case LLMModelEnum.GPT41NANO:
finalModel = LLMMappings[model];
modelInstance = openai(finalModel, { ...options });
break;
case LLMModelEnum.CLAUDEOPUS:
case LLMModelEnum.CLAUDESONNET:
case LLMModelEnum.CLAUDEHAIKU:
finalModel = LLMMappings[model];
break;
case LLMModelEnum.GEMINI25FLASH:
case LLMModelEnum.GEMINI25PRO:
case LLMModelEnum.GEMINI20FLASH:
case LLMModelEnum.GEMINI20FLASHLITE:
finalModel = LLMMappings[model];
break;
default:
logger.warn(`Unsupported model type: ${model}`);
break;
}
if (stream) {
return await streamText({
model: modelInstance as LanguageModelV1,
messages,
onFinish: async ({ text }) => {
onFinish(text, finalModel);
},
});
}
const { text } = await generateText({
model: modelInstance as LanguageModelV1,
messages,
});
onFinish(text, finalModel);
return text;
}

View File

@ -46,28 +46,66 @@ const runQuery = async (cypher: string, params = {}) => {
// Initialize the database schema // Initialize the database schema
const initializeSchema = async () => { const initializeSchema = async () => {
try { try {
// Run schema setup queries // Create constraints for unique IDs
await runQuery(
"CREATE CONSTRAINT episode_uuid IF NOT EXISTS FOR (n:Episode) REQUIRE n.uuid IS UNIQUE",
);
await runQuery(
"CREATE CONSTRAINT entity_uuid IF NOT EXISTS FOR (n:Entity) REQUIRE n.uuid IS UNIQUE",
);
await runQuery(
"CREATE CONSTRAINT statement_uuid IF NOT EXISTS FOR (n:Statement) REQUIRE n.uuid IS UNIQUE",
);
// Create indexes for better query performance
await runQuery(
"CREATE INDEX episode_valid_at IF NOT EXISTS FOR (n:Episode) ON (n.validAt)",
);
await runQuery(
"CREATE INDEX statement_valid_at IF NOT EXISTS FOR (n:Statement) ON (n.validAt)",
);
await runQuery(
"CREATE INDEX statement_invalid_at IF NOT EXISTS FOR (n:Statement) ON (n.invalidAt)",
);
await runQuery(
"CREATE INDEX entity_name IF NOT EXISTS FOR (n:Entity) ON (n.name)",
);
// Create vector indexes for semantic search (if using Neo4j 5.0+)
await runQuery(` await runQuery(`
// Create constraints for unique IDs
CREATE CONSTRAINT episode_uuid IF NOT EXISTS FOR (n:Episode) REQUIRE n.uuid IS UNIQUE;
CREATE CONSTRAINT entity_uuid IF NOT EXISTS FOR (n:Entity) REQUIRE n.uuid IS UNIQUE;
CREATE CONSTRAINT statement_uuid IF NOT EXISTS FOR (n:Statement) REQUIRE n.uuid IS UNIQUE;
// Create indexes for better query performance
CREATE INDEX episode_valid_at IF NOT EXISTS FOR (n:Episode) ON (n.validAt);
CREATE INDEX statement_valid_at IF NOT EXISTS FOR (n:Statement) ON (n.validAt);
CREATE INDEX statement_invalid_at IF NOT EXISTS FOR (n:Statement) ON (n.invalidAt);
CREATE INDEX entity_name IF NOT EXISTS FOR (n:Entity) ON (n.name);
// Create vector indexes for semantic search (if using Neo4j 5.0+)
CREATE VECTOR INDEX entity_embedding IF NOT EXISTS FOR (n:Entity) ON n.nameEmbedding CREATE VECTOR INDEX entity_embedding IF NOT EXISTS FOR (n:Entity) ON n.nameEmbedding
OPTIONS {indexConfig: {dimensions: 1536, similarity: "cosine"}}; OPTIONS {indexConfig: {\`vector.dimensions\`: 1536, \`vector.similarity_function\`: 'cosine'}}
`);
await runQuery(`
CREATE VECTOR INDEX statement_embedding IF NOT EXISTS FOR (n:Statement) ON n.factEmbedding CREATE VECTOR INDEX statement_embedding IF NOT EXISTS FOR (n:Statement) ON n.factEmbedding
OPTIONS {indexConfig: {dimensions: 1536, similarity: "cosine"}}; OPTIONS {indexConfig: {\`vector.dimensions\`: 1536, \`vector.similarity_function\`: 'cosine'}}
`);
await runQuery(`
CREATE VECTOR INDEX episode_embedding IF NOT EXISTS FOR (n:Episode) ON n.contentEmbedding CREATE VECTOR INDEX episode_embedding IF NOT EXISTS FOR (n:Episode) ON n.contentEmbedding
OPTIONS {indexConfig: {dimensions: 1536, similarity: "cosine"}}; OPTIONS {indexConfig: {\`vector.dimensions\`: 1536, \`vector.similarity_function\`: 'cosine'}}
`);
// Create fulltext indexes for BM25 search
await runQuery(`
CREATE FULLTEXT INDEX statement_fact_index IF NOT EXISTS
FOR (n:Statement) ON EACH [n.fact]
OPTIONS {
indexConfig: {
\`fulltext.analyzer\`: 'english'
}
}
`);
await runQuery(`
CREATE FULLTEXT INDEX entity_name_index IF NOT EXISTS
FOR (n:Entity) ON EACH [n.name, n.description]
OPTIONS {
indexConfig: {
\`fulltext.analyzer\`: 'english'
}
}
`); `);
logger.info("Neo4j schema initialized successfully"); logger.info("Neo4j schema initialized successfully");
@ -84,4 +122,6 @@ const closeDriver = async () => {
logger.info("Neo4j driver closed"); logger.info("Neo4j driver closed");
}; };
// await initializeSchema();
export { driver, verifyConnectivity, runQuery, initializeSchema, closeDriver }; export { driver, verifyConnectivity, runQuery, initializeSchema, closeDriver };

View File

@ -1,5 +1,5 @@
import { EpisodeType } from "@recall/types"; import { EpisodeType } from "@recall/types";
import { json, LoaderFunctionArgs } from "@remix-run/node"; import { json } from "@remix-run/node";
import { z } from "zod"; import { z } from "zod";
import { createActionApiRoute } from "~/services/routeBuilders/apiBuilder.server"; import { createActionApiRoute } from "~/services/routeBuilders/apiBuilder.server";
import { getUserQueue } from "~/lib/ingest.queue"; import { getUserQueue } from "~/lib/ingest.queue";

View File

@ -0,0 +1,28 @@
import { z } from "zod";
import { createActionApiRoute } from "~/services/routeBuilders/apiBuilder.server";
import { SearchService } from "~/services/search.server";
import { json } from "@remix-run/node";
export const SearchBodyRequest = z.object({
query: z.string(),
spaceId: z.string().optional(),
sessionId: z.string().optional(),
});
const searchService = new SearchService();
const { action, loader } = createActionApiRoute(
{
body: SearchBodyRequest,
allowJWT: true,
authorization: {
action: "search",
},
corsStrategy: "all",
},
async ({ body, authentication }) => {
const results = await searchService.search(body.query, authentication.userId);
return json(results);
},
);
export { action, loader };

View File

@ -17,7 +17,7 @@ export async function saveTriple(triple: Triple): Promise<string> {
MERGE (n:Statement {uuid: $uuid, userId: $userId}) MERGE (n:Statement {uuid: $uuid, userId: $userId})
ON CREATE SET ON CREATE SET
n.fact = $fact, n.fact = $fact,
n.embedding = $embedding, n.factEmbedding = $factEmbedding,
n.createdAt = $createdAt, n.createdAt = $createdAt,
n.validAt = $validAt, n.validAt = $validAt,
n.invalidAt = $invalidAt, n.invalidAt = $invalidAt,
@ -26,7 +26,7 @@ export async function saveTriple(triple: Triple): Promise<string> {
n.space = $space n.space = $space
ON MATCH SET ON MATCH SET
n.fact = $fact, n.fact = $fact,
n.embedding = $embedding, n.factEmbedding = $factEmbedding,
n.validAt = $validAt, n.validAt = $validAt,
n.invalidAt = $invalidAt, n.invalidAt = $invalidAt,
n.attributesJson = $attributesJson, n.attributesJson = $attributesJson,
@ -37,7 +37,7 @@ export async function saveTriple(triple: Triple): Promise<string> {
const statementParams = { const statementParams = {
uuid: triple.statement.uuid, uuid: triple.statement.uuid,
fact: triple.statement.fact, fact: triple.statement.fact,
embedding: triple.statement.factEmbedding, factEmbedding: triple.statement.factEmbedding,
createdAt: triple.statement.createdAt.toISOString(), createdAt: triple.statement.createdAt.toISOString(),
validAt: triple.statement.validAt.toISOString(), validAt: triple.statement.validAt.toISOString(),
invalidAt: triple.statement.invalidAt invalidAt: triple.statement.invalidAt

View File

@ -1,15 +1,8 @@
import { openai } from "@ai-sdk/openai"; import { openai } from "@ai-sdk/openai";
import { import { type CoreMessage, embed } from "ai";
type CoreMessage,
embed,
generateText,
type LanguageModelV1,
streamText,
} from "ai";
import { import {
entityTypes, entityTypes,
EpisodeType, EpisodeType,
LLMMappings,
LLMModelEnum, LLMModelEnum,
type AddEpisodeParams, type AddEpisodeParams,
type EntityNode, type EntityNode,
@ -33,6 +26,7 @@ import {
invalidateStatements, invalidateStatements,
saveTriple, saveTriple,
} from "./graphModels/statement"; } from "./graphModels/statement";
import { makeModelCall } from "~/lib/model.server";
// Default number of previous episodes to retrieve for context // Default number of previous episodes to retrieve for context
const DEFAULT_EPISODE_WINDOW = 5; const DEFAULT_EPISODE_WINDOW = 5;
@ -155,7 +149,7 @@ export class KnowledgeGraphService {
let responseText = ""; let responseText = "";
await this.makeModelCall( await makeModelCall(
false, false,
LLMModelEnum.GPT41, LLMModelEnum.GPT41,
messages as CoreMessage[], messages as CoreMessage[],
@ -217,7 +211,7 @@ export class KnowledgeGraphService {
const messages = extractStatements(context); const messages = extractStatements(context);
let responseText = ""; let responseText = "";
await this.makeModelCall( await makeModelCall(
false, false,
LLMModelEnum.GPT41, LLMModelEnum.GPT41,
messages as CoreMessage[], messages as CoreMessage[],
@ -393,7 +387,7 @@ export class KnowledgeGraphService {
const messages = dedupeNodes(dedupeContext); const messages = dedupeNodes(dedupeContext);
let responseText = ""; let responseText = "";
await this.makeModelCall( await makeModelCall(
false, false,
LLMModelEnum.GPT41, LLMModelEnum.GPT41,
messages as CoreMessage[], messages as CoreMessage[],
@ -583,7 +577,7 @@ export class KnowledgeGraphService {
let responseText = ""; let responseText = "";
// Call the LLM to analyze all statements at once // Call the LLM to analyze all statements at once
await this.makeModelCall(false, LLMModelEnum.GPT41, messages, (text) => { await makeModelCall(false, LLMModelEnum.GPT41, messages, (text) => {
responseText = text; responseText = text;
}); });
@ -659,62 +653,4 @@ export class KnowledgeGraphService {
return { resolvedStatements, invalidatedStatements }; return { resolvedStatements, invalidatedStatements };
} }
private async makeModelCall(
stream: boolean,
model: LLMModelEnum,
messages: CoreMessage[],
onFinish: (text: string, model: string) => void,
) {
let modelInstance;
let finalModel: string = "unknown";
switch (model) {
case LLMModelEnum.GPT35TURBO:
case LLMModelEnum.GPT4TURBO:
case LLMModelEnum.GPT4O:
case LLMModelEnum.GPT41:
case LLMModelEnum.GPT41MINI:
case LLMModelEnum.GPT41NANO:
finalModel = LLMMappings[model];
modelInstance = openai(finalModel);
break;
case LLMModelEnum.CLAUDEOPUS:
case LLMModelEnum.CLAUDESONNET:
case LLMModelEnum.CLAUDEHAIKU:
finalModel = LLMMappings[model];
break;
case LLMModelEnum.GEMINI25FLASH:
case LLMModelEnum.GEMINI25PRO:
case LLMModelEnum.GEMINI20FLASH:
case LLMModelEnum.GEMINI20FLASHLITE:
finalModel = LLMMappings[model];
break;
default:
logger.warn(`Unsupported model type: ${model}`);
break;
}
if (stream) {
return await streamText({
model: modelInstance as LanguageModelV1,
messages,
onFinish: async ({ text }) => {
onFinish(text, finalModel);
},
});
}
const { text } = await generateText({
model: modelInstance as LanguageModelV1,
messages,
});
onFinish(text, finalModel);
return text;
}
} }

View File

@ -1,29 +1,19 @@
import {
type EntityNode,
type KnowledgeGraphService,
type StatementNode,
} from "./knowledgeGraph.server";
import { openai } from "@ai-sdk/openai"; import { openai } from "@ai-sdk/openai";
import type { StatementNode } from "@recall/types";
import { embed } from "ai"; import { embed } from "ai";
import HelixDB from "helix-ts"; import { logger } from "./logger.service";
import { applyCrossEncoderReranking, applyWeightedRRF } from "./search/rerank";
// Initialize OpenAI for embeddings import {
const openaiClient = openai("gpt-4.1-2025-04-14"); performBfsSearch,
performBM25Search,
// Initialize Helix client performVectorSearch,
const helixClient = new HelixDB(); } from "./search/utils";
/** /**
* SearchService provides methods to search the reified + temporal knowledge graph * SearchService provides methods to search the reified + temporal knowledge graph
* using a hybrid approach combining BM25, vector similarity, and BFS traversal. * using a hybrid approach combining BM25, vector similarity, and BFS traversal.
*/ */
export class SearchService { export class SearchService {
private knowledgeGraphService: KnowledgeGraphService;
constructor(knowledgeGraphService: KnowledgeGraphService) {
this.knowledgeGraphService = knowledgeGraphService;
}
async getEmbedding(text: string) { async getEmbedding(text: string) {
const { embedding } = await embed({ const { embedding } = await embed({
model: openai.embedding("text-embedding-3-small"), model: openai.embedding("text-embedding-3-small"),
@ -44,7 +34,7 @@ export class SearchService {
query: string, query: string,
userId: string, userId: string,
options: SearchOptions = {}, options: SearchOptions = {},
): Promise<StatementNode[]> { ): Promise<string[]> {
// Default options // Default options
const opts: Required<SearchOptions> = { const opts: Required<SearchOptions> = {
limit: options.limit || 10, limit: options.limit || 10,
@ -53,296 +43,144 @@ export class SearchService {
includeInvalidated: options.includeInvalidated || false, includeInvalidated: options.includeInvalidated || false,
entityTypes: options.entityTypes || [], entityTypes: options.entityTypes || [],
predicateTypes: options.predicateTypes || [], predicateTypes: options.predicateTypes || [],
scoreThreshold: options.scoreThreshold || 0.7,
minResults: options.minResults || 10,
}; };
const queryVector = await this.getEmbedding(query);
// 1. Run parallel search methods // 1. Run parallel search methods
const [bm25Results, vectorResults, bfsResults] = await Promise.all([ const [bm25Results, vectorResults, bfsResults] = await Promise.all([
this.performBM25Search(query, userId, opts), performBM25Search(query, userId, opts),
this.performVectorSearch(query, userId, opts), performVectorSearch(queryVector, userId, opts),
this.performBfsSearch(query, userId, opts), performBfsSearch(queryVector, userId, opts),
]); ]);
// 2. Combine and deduplicate results logger.info(
const combinedStatements = this.combineAndDeduplicate([ `Search results - BM25: ${bm25Results.length}, Vector: ${vectorResults.length}, BFS: ${bfsResults.length}`,
...bm25Results, );
...vectorResults,
...bfsResults,
]);
// 3. Rerank the combined results // 2. Apply reranking strategy
const rerankedStatements = await this.rerankStatements( const rankedStatements = await this.rerankResults(
query, query,
combinedStatements, { bm25: bm25Results, vector: vectorResults, bfs: bfsResults },
opts, opts,
); );
// 4. Return top results // 3. Apply adaptive filtering based on score threshold and minimum count
return rerankedStatements.slice(0, opts.limit); const filteredResults = this.applyAdaptiveFiltering(rankedStatements, opts);
// 3. Return top results
return filteredResults.map((statement) => statement.fact);
} }
/** /**
* Perform BM25 keyword-based search on statements * Apply adaptive filtering to ranked results
* Uses a minimum quality threshold to filter out low-quality results
*/ */
private async performBM25Search( private applyAdaptiveFiltering(
query: string, results: StatementNode[],
userId: string,
options: Required<SearchOptions>, options: Required<SearchOptions>,
): Promise<StatementNode[]> { ): StatementNode[] {
// TODO: Implement BM25 search using HelixDB or external search index if (results.length === 0) return [];
// This is a placeholder implementation
try {
const results = await helixClient.query("searchStatementsByKeywords", {
query,
userId,
validAt: options.validAt.toISOString(),
includeInvalidated: options.includeInvalidated,
limit: options.limit * 2, // Fetch more for reranking
});
return results.statements || []; // Extract scores from results
} catch (error) { const scoredResults = results.map((result) => {
console.error("BM25 search error:", error); // Find the score based on reranking strategy used
return []; let score = 0;
} if ((result as any).rrfScore !== undefined) {
} score = (result as any).rrfScore;
} else if ((result as any).mmrScore !== undefined) {
/** score = (result as any).mmrScore;
* Perform vector similarity search on statement embeddings } else if ((result as any).crossEncoderScore !== undefined) {
*/ score = (result as any).crossEncoderScore;
private async performVectorSearch( } else if ((result as any).finalScore !== undefined) {
query: string, score = (result as any).finalScore;
userId: string,
options: Required<SearchOptions>,
): Promise<StatementNode[]> {
try {
// 1. Generate embedding for the query
const embedding = await this.generateEmbedding(query);
// 2. Search for similar statements
const results = await helixClient.query("searchStatementsByVector", {
embedding,
userId,
validAt: options.validAt.toISOString(),
includeInvalidated: options.includeInvalidated,
limit: options.limit * 2, // Fetch more for reranking
});
return results.statements || [];
} catch (error) {
console.error("Vector search error:", error);
return [];
}
}
/**
* Perform BFS traversal starting from entities mentioned in the query
*/
private async performBfsSearch(
query: string,
userId: string,
options: Required<SearchOptions>,
): Promise<StatementNode[]> {
try {
// 1. Extract potential entities from query
const entities = await this.extractEntitiesFromQuery(query);
// 2. For each entity, perform BFS traversal
const allStatements: StatementNode[] = [];
for (const entity of entities) {
const statements = await this.bfsTraversal(
entity.uuid,
options.maxBfsDepth,
options.validAt,
userId,
options.includeInvalidated,
);
allStatements.push(...statements);
} }
return allStatements; return { result, score };
} catch (error) { });
console.error("BFS search error:", error);
return []; const hasScores = scoredResults.some((item) => item.score > 0);
// If no scores are available, return the original results
if (!hasScores) {
logger.info("No scores found in results, skipping adaptive filtering");
return options.limit > 0 ? results.slice(0, options.limit) : results;
} }
// Sort by score (descending)
scoredResults.sort((a, b) => b.score - a.score);
// Calculate statistics to identify low-quality results
const scores = scoredResults.map((item) => item.score);
const maxScore = Math.max(...scores);
const minScore = Math.min(...scores);
const scoreRange = maxScore - minScore;
// Define a minimum quality threshold as a fraction of the best score
// This is relative to the query's score distribution rather than an absolute value
const relativeThreshold = options.scoreThreshold || 0.3; // 30% of the best score by default
const absoluteMinimum = 0.1; // Absolute minimum threshold to prevent keeping very poor matches
// Calculate the actual threshold as a percentage of the distance from min to max score
const threshold = Math.max(
absoluteMinimum,
minScore + scoreRange * relativeThreshold,
);
// Filter out low-quality results
const filteredResults = scoredResults
.filter((item) => item.score >= threshold)
.map((item) => item.result);
// Apply limit if specified
const limitedResults =
options.limit > 0
? filteredResults.slice(
0,
Math.min(filteredResults.length, options.limit),
)
: filteredResults;
logger.info(
`Quality filtering: ${limitedResults.length}/${results.length} results kept (threshold: ${threshold.toFixed(3)})`,
);
logger.info(
`Score range: min=${minScore.toFixed(3)}, max=${maxScore.toFixed(3)}, threshold=${threshold.toFixed(3)}`,
);
return limitedResults;
} }
/** /**
* Perform BFS traversal starting from an entity * Apply the selected reranking strategy to search results
*/ */
private async bfsTraversal( private async rerankResults(
startEntityId: string,
maxDepth: number,
validAt: Date,
userId: string,
includeInvalidated: boolean,
): Promise<StatementNode[]> {
// Track visited nodes to avoid cycles
const visited = new Set<string>();
// Track statements found during traversal
const statements: StatementNode[] = [];
// Queue for BFS traversal [nodeId, depth]
const queue: [string, number][] = [[startEntityId, 0]];
while (queue.length > 0) {
const [nodeId, depth] = queue.shift()!;
// Skip if already visited or max depth reached
if (visited.has(nodeId) || depth > maxDepth) continue;
visited.add(nodeId);
// Get statements where this entity is subject or object
const connectedStatements = await helixClient.query(
"getConnectedStatements",
{
entityId: nodeId,
userId,
validAt: validAt.toISOString(),
includeInvalidated,
},
);
// Add statements to results
if (connectedStatements.statements) {
statements.push(...connectedStatements.statements);
// Add connected entities to queue
for (const statement of connectedStatements.statements) {
// Get subject and object entities
if (statement.subjectId && !visited.has(statement.subjectId)) {
queue.push([statement.subjectId, depth + 1]);
}
if (statement.objectId && !visited.has(statement.objectId)) {
queue.push([statement.objectId, depth + 1]);
}
}
}
}
return statements;
}
/**
* Extract potential entities from a query using embeddings or LLM
*/
private async extractEntitiesFromQuery(query: string): Promise<EntityNode[]> {
// TODO: Implement more sophisticated entity extraction
// This is a placeholder implementation that uses simple vector search
try {
const embedding = await this.getEmbedding(query);
const results = await helixClient.query("searchEntitiesByVector", {
embedding,
limit: 3, // Start with top 3 entities
});
return results.entities || [];
} catch (error) {
console.error("Entity extraction error:", error);
return [];
}
}
/**
* Combine and deduplicate statements from multiple sources
*/
private combineAndDeduplicate(statements: StatementNode[]): StatementNode[] {
const uniqueStatements = new Map<string, StatementNode>();
for (const statement of statements) {
if (!uniqueStatements.has(statement.uuid)) {
uniqueStatements.set(statement.uuid, statement);
}
}
return Array.from(uniqueStatements.values());
}
/**
* Rerank statements based on relevance to the query
*/
private async rerankStatements(
query: string, query: string,
statements: StatementNode[], results: {
bm25: StatementNode[];
vector: StatementNode[];
bfs: StatementNode[];
},
options: Required<SearchOptions>, options: Required<SearchOptions>,
): Promise<StatementNode[]> { ): Promise<StatementNode[]> {
// TODO: Implement more sophisticated reranking // Count non-empty result sources
// This is a placeholder implementation using cosine similarity const nonEmptySources = [
try { results.bm25.length > 0,
// 1. Generate embedding for the query results.vector.length > 0,
const queryEmbedding = await this.getEmbedding(query); results.bfs.length > 0,
].filter(Boolean).length;
// 2. Generate or retrieve embeddings for statements // If results are coming from only one source, use cross-encoder reranking
const statementEmbeddings = await Promise.all( if (nonEmptySources <= 1) {
statements.map(async (statement) => { logger.info(
// If statement has embedding, use it; otherwise generate "Only one source has results, falling back to cross-encoder reranking",
if (statement.factEmbedding && statement.factEmbedding.length > 0) {
return { statement, embedding: statement.factEmbedding };
}
// Generate text representation of statement
const statementText = this.statementToText(statement);
const embedding = await this.getEmbedding(statementText);
return { statement, embedding };
}),
); );
return applyCrossEncoderReranking(query, results);
// 3. Calculate cosine similarity
const scoredStatements = statementEmbeddings.map(
({ statement, embedding }) => {
const similarity = this.cosineSimilarity(queryEmbedding, embedding);
return { statement, score: similarity };
},
);
// 4. Sort by score (descending)
scoredStatements.sort((a, b) => b.score - a.score);
// 5. Return statements in order of relevance
return scoredStatements.map(({ statement }) => statement);
} catch (error) {
console.error("Reranking error:", error);
// Fallback: return original order
return statements;
}
}
/**
* Convert a statement to a text representation
*/
private statementToText(statement: StatementNode): string {
// TODO: Implement more sophisticated text representation
// This is a placeholder implementation
return `${statement.subjectName || "Unknown"} ${statement.predicateName || "has relation with"} ${statement.objectName || "Unknown"}`;
}
/**
* Calculate cosine similarity between two embeddings
*/
private cosineSimilarity(a: number[], b: number[]): number {
if (a.length !== b.length) {
throw new Error("Embeddings must have the same length");
} }
let dotProduct = 0; // Otherwise use weighted RRF for multiple sources
let normA = 0; return applyWeightedRRF(results);
let normB = 0;
for (let i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
normA = Math.sqrt(normA);
normB = Math.sqrt(normB);
if (normA === 0 || normB === 0) {
return 0;
}
return dotProduct / (normA * normB);
} }
} }
@ -356,23 +194,6 @@ export interface SearchOptions {
includeInvalidated?: boolean; includeInvalidated?: boolean;
entityTypes?: string[]; entityTypes?: string[];
predicateTypes?: string[]; predicateTypes?: string[];
} scoreThreshold?: number;
minResults?: number;
/**
* Create a singleton instance of the search service
*/
let searchServiceInstance: SearchService | null = null;
export function getSearchService(
knowledgeGraphService?: KnowledgeGraphService,
): SearchService {
if (!searchServiceInstance) {
if (!knowledgeGraphService) {
throw new Error(
"KnowledgeGraphService must be provided when initializing SearchService",
);
}
searchServiceInstance = new SearchService(knowledgeGraphService);
}
return searchServiceInstance;
} }

View File

@ -0,0 +1,118 @@
import { LLMModelEnum, type StatementNode } from "@recall/types";
import { combineAndDeduplicateStatements } from "./utils";
import { type CoreMessage } from "ai";
import { makeModelCall } from "~/lib/model.server";
import { logger } from "../logger.service";
/**
* Apply Weighted Reciprocal Rank Fusion to combine results
*/
export function applyWeightedRRF(results: {
bm25: StatementNode[];
vector: StatementNode[];
bfs: StatementNode[];
}): StatementNode[] {
// Determine weights based on query characteristics
const weights = {
bm25: 1.0,
vector: 0.8,
bfs: 0.5,
};
const k = 60; // RRF constant
// Map to store combined scores
const scores: Record<string, { score: number; statement: StatementNode }> =
{};
// Process BM25 results with their weight
results.bm25.forEach((statement, rank) => {
const uuid = statement.uuid;
scores[uuid] = scores[uuid] || { score: 0, statement };
scores[uuid].score += weights.bm25 * (1 / (rank + k));
});
// Process vector similarity results with their weight
results.vector.forEach((statement, rank) => {
const uuid = statement.uuid;
scores[uuid] = scores[uuid] || { score: 0, statement };
scores[uuid].score += weights.vector * (1 / (rank + k));
});
// Process BFS traversal results with their weight
results.bfs.forEach((statement, rank) => {
const uuid = statement.uuid;
scores[uuid] = scores[uuid] || { score: 0, statement };
scores[uuid].score += weights.bfs * (1 / (rank + k));
});
// Convert to array and sort by final score
const sortedResults = Object.values(scores)
.sort((a, b) => b.score - a.score)
.map((item) => {
// Add the RRF score to the statement for debugging
return {
...item.statement,
rrfScore: item.score,
};
});
return sortedResults;
}
/**
* Apply Cross-Encoder reranking to results
* This is particularly useful when results come from a single source
*/
export async function applyCrossEncoderReranking(
query: string,
results: {
bm25: StatementNode[];
vector: StatementNode[];
bfs: StatementNode[];
},
): Promise<StatementNode[]> {
// Combine all results
const allResults = [...results.bm25, ...results.vector, ...results.bfs];
// Deduplicate by UUID
const uniqueResults = combineAndDeduplicateStatements(allResults);
if (uniqueResults.length === 0) return [];
logger.info(`Cross-encoder reranking ${uniqueResults.length} statements`);
const finalStatements: StatementNode[] = [];
await Promise.all(
uniqueResults.map(async (statement) => {
const messages: CoreMessage[] = [
{
role: "system",
content: `You are an expert tasked with determining whether the statement is relevant to the query
Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.`,
},
{
role: "user",
content: `<QUERY>${query}</QUERY>\n<STATEMENT>${statement.fact}</STATEMENT>`,
},
];
let responseText = "";
await makeModelCall(
false,
LLMModelEnum.GPT41NANO,
messages,
(text) => {
responseText = text;
},
{ temperature: 0, maxTokens: 1 },
);
if (responseText === "True") {
finalStatements.push(statement);
}
}),
);
return finalStatements;
}

View File

@ -0,0 +1,223 @@
import type { EntityNode, StatementNode } from "@recall/types";
import type { SearchOptions } from "../search.server";
import type { Embedding } from "ai";
import { logger } from "../logger.service";
import { runQuery } from "~/lib/neo4j.server";
/**
* Perform BM25 keyword-based search on statements
*/
export async function performBM25Search(
query: string,
userId: string,
options: Required<SearchOptions>,
): Promise<StatementNode[]> {
try {
// Sanitize the query for Lucene syntax
const sanitizedQuery = sanitizeLuceneQuery(query);
// Use Neo4j's built-in fulltext search capabilities
const cypher = `
CALL db.index.fulltext.queryNodes("statement_fact_index", $query)
YIELD node AS s, score
WHERE
s.validAt <= $validAt
AND (s.invalidAt IS NULL OR s.invalidAt > $validAt)
AND (s.userId = $userId)
RETURN s, score
ORDER BY score DESC
`;
const params = {
query: sanitizedQuery,
userId,
validAt: options.validAt.toISOString(),
};
const records = await runQuery(cypher, params);
// return records.map((record) => record.get("s").properties as StatementNode);
return [];
} catch (error) {
logger.error("BM25 search error:", { error });
return [];
}
}
/**
* Sanitize a query string for Lucene syntax
*/
export function sanitizeLuceneQuery(query: string): string {
// Escape special characters: + - && || ! ( ) { } [ ] ^ " ~ * ? : \
let sanitized = query.replace(
/[+\-&|!(){}[\]^"~*?:\\]/g,
(match) => "\\" + match,
);
// If query is too long, truncate it
const MAX_QUERY_LENGTH = 32;
const words = sanitized.split(" ");
if (words.length > MAX_QUERY_LENGTH) {
sanitized = words.slice(0, MAX_QUERY_LENGTH).join(" ");
}
return sanitized;
}
/**
* Perform vector similarity search on statement embeddings
*/
export async function performVectorSearch(
query: Embedding,
userId: string,
options: Required<SearchOptions>,
): Promise<StatementNode[]> {
try {
// 1. Generate embedding for the query
// const embedding = await this.getEmbedding(query);
// 2. Search for similar statements using Neo4j vector search
const cypher = `
MATCH (s:Statement)
WHERE
s.validAt <= $validAt
AND (s.invalidAt IS NULL OR s.invalidAt > $validAt)
AND (s.userId = $userId OR s.isPublic = true)
WITH s, vector.similarity.cosine(s.factEmbedding, $embedding) AS score
WHERE score > 0.7
RETURN s, score
ORDER BY score DESC
`;
const params = {
embedding: query,
userId,
validAt: options.validAt.toISOString(),
};
const records = await runQuery(cypher, params);
// return records.map((record) => record.get("s").properties as StatementNode);
return [];
} catch (error) {
logger.error("Vector search error:", { error });
return [];
}
}
/**
* Perform BFS traversal starting from entities mentioned in the query
*/
export async function performBfsSearch(
embedding: Embedding,
userId: string,
options: Required<SearchOptions>,
): Promise<StatementNode[]> {
try {
// 1. Extract potential entities from query
const entities = await extractEntitiesFromQuery(embedding);
// 2. For each entity, perform BFS traversal
const allStatements: StatementNode[] = [];
for (const entity of entities) {
const statements = await bfsTraversal(
entity.uuid,
options.maxBfsDepth,
options.validAt,
userId,
options.includeInvalidated,
);
allStatements.push(...statements);
}
return allStatements;
} catch (error) {
logger.error("BFS search error:", { error });
return [];
}
}
/**
* Perform BFS traversal starting from an entity
*/
export async function bfsTraversal(
startEntityId: string,
maxDepth: number,
validAt: Date,
userId: string,
includeInvalidated: boolean,
): Promise<StatementNode[]> {
try {
// Use Neo4j's built-in path finding capabilities for efficient BFS
// This query implements BFS up to maxDepth and collects all statements along the way
const cypher = `
MATCH (e:Entity {uuid: $startEntityId})<-[:HAS_SUBJECT|HAS_OBJECT|HAS_PREDICATE]-(s:Statement)
WHERE
s.validAt <= $validAt
AND (s.invalidAt IS NULL OR s.invalidAt > $validAt)
AND (s.userId = $userId)
AND ($includeInvalidated OR s.invalidAt IS NULL)
RETURN s as statement
`;
const params = {
startEntityId,
maxDepth,
validAt: validAt.toISOString(),
userId,
includeInvalidated,
};
const records = await runQuery(cypher, params);
return records.map(
(record) => record.get("statement").properties as StatementNode,
);
} catch (error) {
logger.error("BFS traversal error:", { error });
return [];
}
}
/**
* Extract potential entities from a query using embeddings or LLM
*/
export async function extractEntitiesFromQuery(
embedding: Embedding,
): Promise<EntityNode[]> {
try {
// Use vector similarity to find relevant entities
const cypher = `
// Match entities using vector similarity on name embeddings
MATCH (e:Entity)
WHERE e.nameEmbedding IS NOT NULL
WITH e, vector.similarity.cosine(e.nameEmbedding, $embedding) AS score
WHERE score > 0.7
RETURN e
ORDER BY score DESC
LIMIT 3
`;
const params = {
embedding,
};
const records = await runQuery(cypher, params);
return records.map((record) => record.get("e").properties as EntityNode);
} catch (error) {
logger.error("Entity extraction error:", { error });
return [];
}
}
/**
* Combine and deduplicate statements from different search methods
*/
export function combineAndDeduplicateStatements(
statements: StatementNode[],
): StatementNode[] {
return Array.from(
new Map(
statements.map((statement) => [statement.uuid, statement]),
).values(),
);
}