mirror of
https://github.com/eliasstepanik/core.git
synced 2026-01-10 23:48:26 +00:00
feat: implement hybrid search with BM25, vector and BFS traversal
This commit is contained in:
parent
cf20da9ecd
commit
848153d57a
68
apps/webapp/app/lib/model.server.ts
Normal file
68
apps/webapp/app/lib/model.server.ts
Normal file
@ -0,0 +1,68 @@
|
||||
import { LLMMappings, LLMModelEnum } from "@recall/types";
|
||||
import {
|
||||
type CoreMessage,
|
||||
type LanguageModelV1,
|
||||
generateText,
|
||||
streamText,
|
||||
} from "ai";
|
||||
import { openai } from "@ai-sdk/openai";
|
||||
import { logger } from "~/services/logger.service";
|
||||
|
||||
export async function makeModelCall(
|
||||
stream: boolean,
|
||||
model: LLMModelEnum,
|
||||
messages: CoreMessage[],
|
||||
onFinish: (text: string, model: string) => void,
|
||||
options?: any,
|
||||
) {
|
||||
let modelInstance;
|
||||
let finalModel: string = "unknown";
|
||||
|
||||
switch (model) {
|
||||
case LLMModelEnum.GPT35TURBO:
|
||||
case LLMModelEnum.GPT4TURBO:
|
||||
case LLMModelEnum.GPT4O:
|
||||
case LLMModelEnum.GPT41:
|
||||
case LLMModelEnum.GPT41MINI:
|
||||
case LLMModelEnum.GPT41NANO:
|
||||
finalModel = LLMMappings[model];
|
||||
modelInstance = openai(finalModel, { ...options });
|
||||
break;
|
||||
|
||||
case LLMModelEnum.CLAUDEOPUS:
|
||||
case LLMModelEnum.CLAUDESONNET:
|
||||
case LLMModelEnum.CLAUDEHAIKU:
|
||||
finalModel = LLMMappings[model];
|
||||
break;
|
||||
|
||||
case LLMModelEnum.GEMINI25FLASH:
|
||||
case LLMModelEnum.GEMINI25PRO:
|
||||
case LLMModelEnum.GEMINI20FLASH:
|
||||
case LLMModelEnum.GEMINI20FLASHLITE:
|
||||
finalModel = LLMMappings[model];
|
||||
break;
|
||||
|
||||
default:
|
||||
logger.warn(`Unsupported model type: ${model}`);
|
||||
break;
|
||||
}
|
||||
|
||||
if (stream) {
|
||||
return await streamText({
|
||||
model: modelInstance as LanguageModelV1,
|
||||
messages,
|
||||
onFinish: async ({ text }) => {
|
||||
onFinish(text, finalModel);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const { text } = await generateText({
|
||||
model: modelInstance as LanguageModelV1,
|
||||
messages,
|
||||
});
|
||||
|
||||
onFinish(text, finalModel);
|
||||
|
||||
return text;
|
||||
}
|
||||
@ -46,28 +46,66 @@ const runQuery = async (cypher: string, params = {}) => {
|
||||
// Initialize the database schema
|
||||
const initializeSchema = async () => {
|
||||
try {
|
||||
// Run schema setup queries
|
||||
// Create constraints for unique IDs
|
||||
await runQuery(
|
||||
"CREATE CONSTRAINT episode_uuid IF NOT EXISTS FOR (n:Episode) REQUIRE n.uuid IS UNIQUE",
|
||||
);
|
||||
await runQuery(
|
||||
"CREATE CONSTRAINT entity_uuid IF NOT EXISTS FOR (n:Entity) REQUIRE n.uuid IS UNIQUE",
|
||||
);
|
||||
await runQuery(
|
||||
"CREATE CONSTRAINT statement_uuid IF NOT EXISTS FOR (n:Statement) REQUIRE n.uuid IS UNIQUE",
|
||||
);
|
||||
|
||||
// Create indexes for better query performance
|
||||
await runQuery(
|
||||
"CREATE INDEX episode_valid_at IF NOT EXISTS FOR (n:Episode) ON (n.validAt)",
|
||||
);
|
||||
await runQuery(
|
||||
"CREATE INDEX statement_valid_at IF NOT EXISTS FOR (n:Statement) ON (n.validAt)",
|
||||
);
|
||||
await runQuery(
|
||||
"CREATE INDEX statement_invalid_at IF NOT EXISTS FOR (n:Statement) ON (n.invalidAt)",
|
||||
);
|
||||
await runQuery(
|
||||
"CREATE INDEX entity_name IF NOT EXISTS FOR (n:Entity) ON (n.name)",
|
||||
);
|
||||
|
||||
// Create vector indexes for semantic search (if using Neo4j 5.0+)
|
||||
await runQuery(`
|
||||
// Create constraints for unique IDs
|
||||
CREATE CONSTRAINT episode_uuid IF NOT EXISTS FOR (n:Episode) REQUIRE n.uuid IS UNIQUE;
|
||||
CREATE CONSTRAINT entity_uuid IF NOT EXISTS FOR (n:Entity) REQUIRE n.uuid IS UNIQUE;
|
||||
CREATE CONSTRAINT statement_uuid IF NOT EXISTS FOR (n:Statement) REQUIRE n.uuid IS UNIQUE;
|
||||
|
||||
// Create indexes for better query performance
|
||||
CREATE INDEX episode_valid_at IF NOT EXISTS FOR (n:Episode) ON (n.validAt);
|
||||
CREATE INDEX statement_valid_at IF NOT EXISTS FOR (n:Statement) ON (n.validAt);
|
||||
CREATE INDEX statement_invalid_at IF NOT EXISTS FOR (n:Statement) ON (n.invalidAt);
|
||||
CREATE INDEX entity_name IF NOT EXISTS FOR (n:Entity) ON (n.name);
|
||||
|
||||
// Create vector indexes for semantic search (if using Neo4j 5.0+)
|
||||
CREATE VECTOR INDEX entity_embedding IF NOT EXISTS FOR (n:Entity) ON n.nameEmbedding
|
||||
OPTIONS {indexConfig: {dimensions: 1536, similarity: "cosine"}};
|
||||
|
||||
OPTIONS {indexConfig: {\`vector.dimensions\`: 1536, \`vector.similarity_function\`: 'cosine'}}
|
||||
`);
|
||||
|
||||
await runQuery(`
|
||||
CREATE VECTOR INDEX statement_embedding IF NOT EXISTS FOR (n:Statement) ON n.factEmbedding
|
||||
OPTIONS {indexConfig: {dimensions: 1536, similarity: "cosine"}};
|
||||
|
||||
OPTIONS {indexConfig: {\`vector.dimensions\`: 1536, \`vector.similarity_function\`: 'cosine'}}
|
||||
`);
|
||||
|
||||
await runQuery(`
|
||||
CREATE VECTOR INDEX episode_embedding IF NOT EXISTS FOR (n:Episode) ON n.contentEmbedding
|
||||
OPTIONS {indexConfig: {dimensions: 1536, similarity: "cosine"}};
|
||||
OPTIONS {indexConfig: {\`vector.dimensions\`: 1536, \`vector.similarity_function\`: 'cosine'}}
|
||||
`);
|
||||
|
||||
// Create fulltext indexes for BM25 search
|
||||
await runQuery(`
|
||||
CREATE FULLTEXT INDEX statement_fact_index IF NOT EXISTS
|
||||
FOR (n:Statement) ON EACH [n.fact]
|
||||
OPTIONS {
|
||||
indexConfig: {
|
||||
\`fulltext.analyzer\`: 'english'
|
||||
}
|
||||
}
|
||||
`);
|
||||
|
||||
await runQuery(`
|
||||
CREATE FULLTEXT INDEX entity_name_index IF NOT EXISTS
|
||||
FOR (n:Entity) ON EACH [n.name, n.description]
|
||||
OPTIONS {
|
||||
indexConfig: {
|
||||
\`fulltext.analyzer\`: 'english'
|
||||
}
|
||||
}
|
||||
`);
|
||||
|
||||
logger.info("Neo4j schema initialized successfully");
|
||||
@ -84,4 +122,6 @@ const closeDriver = async () => {
|
||||
logger.info("Neo4j driver closed");
|
||||
};
|
||||
|
||||
// await initializeSchema();
|
||||
|
||||
export { driver, verifyConnectivity, runQuery, initializeSchema, closeDriver };
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { EpisodeType } from "@recall/types";
|
||||
import { json, LoaderFunctionArgs } from "@remix-run/node";
|
||||
import { json } from "@remix-run/node";
|
||||
import { z } from "zod";
|
||||
import { createActionApiRoute } from "~/services/routeBuilders/apiBuilder.server";
|
||||
import { getUserQueue } from "~/lib/ingest.queue";
|
||||
|
||||
28
apps/webapp/app/routes/search.tsx
Normal file
28
apps/webapp/app/routes/search.tsx
Normal file
@ -0,0 +1,28 @@
|
||||
import { z } from "zod";
|
||||
import { createActionApiRoute } from "~/services/routeBuilders/apiBuilder.server";
|
||||
import { SearchService } from "~/services/search.server";
|
||||
import { json } from "@remix-run/node";
|
||||
|
||||
export const SearchBodyRequest = z.object({
|
||||
query: z.string(),
|
||||
spaceId: z.string().optional(),
|
||||
sessionId: z.string().optional(),
|
||||
});
|
||||
|
||||
const searchService = new SearchService();
|
||||
const { action, loader } = createActionApiRoute(
|
||||
{
|
||||
body: SearchBodyRequest,
|
||||
allowJWT: true,
|
||||
authorization: {
|
||||
action: "search",
|
||||
},
|
||||
corsStrategy: "all",
|
||||
},
|
||||
async ({ body, authentication }) => {
|
||||
const results = await searchService.search(body.query, authentication.userId);
|
||||
return json(results);
|
||||
},
|
||||
);
|
||||
|
||||
export { action, loader };
|
||||
@ -17,7 +17,7 @@ export async function saveTriple(triple: Triple): Promise<string> {
|
||||
MERGE (n:Statement {uuid: $uuid, userId: $userId})
|
||||
ON CREATE SET
|
||||
n.fact = $fact,
|
||||
n.embedding = $embedding,
|
||||
n.factEmbedding = $factEmbedding,
|
||||
n.createdAt = $createdAt,
|
||||
n.validAt = $validAt,
|
||||
n.invalidAt = $invalidAt,
|
||||
@ -26,7 +26,7 @@ export async function saveTriple(triple: Triple): Promise<string> {
|
||||
n.space = $space
|
||||
ON MATCH SET
|
||||
n.fact = $fact,
|
||||
n.embedding = $embedding,
|
||||
n.factEmbedding = $factEmbedding,
|
||||
n.validAt = $validAt,
|
||||
n.invalidAt = $invalidAt,
|
||||
n.attributesJson = $attributesJson,
|
||||
@ -37,7 +37,7 @@ export async function saveTriple(triple: Triple): Promise<string> {
|
||||
const statementParams = {
|
||||
uuid: triple.statement.uuid,
|
||||
fact: triple.statement.fact,
|
||||
embedding: triple.statement.factEmbedding,
|
||||
factEmbedding: triple.statement.factEmbedding,
|
||||
createdAt: triple.statement.createdAt.toISOString(),
|
||||
validAt: triple.statement.validAt.toISOString(),
|
||||
invalidAt: triple.statement.invalidAt
|
||||
|
||||
@ -1,15 +1,8 @@
|
||||
import { openai } from "@ai-sdk/openai";
|
||||
import {
|
||||
type CoreMessage,
|
||||
embed,
|
||||
generateText,
|
||||
type LanguageModelV1,
|
||||
streamText,
|
||||
} from "ai";
|
||||
import { type CoreMessage, embed } from "ai";
|
||||
import {
|
||||
entityTypes,
|
||||
EpisodeType,
|
||||
LLMMappings,
|
||||
LLMModelEnum,
|
||||
type AddEpisodeParams,
|
||||
type EntityNode,
|
||||
@ -33,6 +26,7 @@ import {
|
||||
invalidateStatements,
|
||||
saveTriple,
|
||||
} from "./graphModels/statement";
|
||||
import { makeModelCall } from "~/lib/model.server";
|
||||
|
||||
// Default number of previous episodes to retrieve for context
|
||||
const DEFAULT_EPISODE_WINDOW = 5;
|
||||
@ -155,7 +149,7 @@ export class KnowledgeGraphService {
|
||||
|
||||
let responseText = "";
|
||||
|
||||
await this.makeModelCall(
|
||||
await makeModelCall(
|
||||
false,
|
||||
LLMModelEnum.GPT41,
|
||||
messages as CoreMessage[],
|
||||
@ -217,7 +211,7 @@ export class KnowledgeGraphService {
|
||||
const messages = extractStatements(context);
|
||||
|
||||
let responseText = "";
|
||||
await this.makeModelCall(
|
||||
await makeModelCall(
|
||||
false,
|
||||
LLMModelEnum.GPT41,
|
||||
messages as CoreMessage[],
|
||||
@ -393,7 +387,7 @@ export class KnowledgeGraphService {
|
||||
const messages = dedupeNodes(dedupeContext);
|
||||
let responseText = "";
|
||||
|
||||
await this.makeModelCall(
|
||||
await makeModelCall(
|
||||
false,
|
||||
LLMModelEnum.GPT41,
|
||||
messages as CoreMessage[],
|
||||
@ -583,7 +577,7 @@ export class KnowledgeGraphService {
|
||||
let responseText = "";
|
||||
|
||||
// Call the LLM to analyze all statements at once
|
||||
await this.makeModelCall(false, LLMModelEnum.GPT41, messages, (text) => {
|
||||
await makeModelCall(false, LLMModelEnum.GPT41, messages, (text) => {
|
||||
responseText = text;
|
||||
});
|
||||
|
||||
@ -659,62 +653,4 @@ export class KnowledgeGraphService {
|
||||
|
||||
return { resolvedStatements, invalidatedStatements };
|
||||
}
|
||||
|
||||
private async makeModelCall(
|
||||
stream: boolean,
|
||||
model: LLMModelEnum,
|
||||
messages: CoreMessage[],
|
||||
onFinish: (text: string, model: string) => void,
|
||||
) {
|
||||
let modelInstance;
|
||||
let finalModel: string = "unknown";
|
||||
|
||||
switch (model) {
|
||||
case LLMModelEnum.GPT35TURBO:
|
||||
case LLMModelEnum.GPT4TURBO:
|
||||
case LLMModelEnum.GPT4O:
|
||||
case LLMModelEnum.GPT41:
|
||||
case LLMModelEnum.GPT41MINI:
|
||||
case LLMModelEnum.GPT41NANO:
|
||||
finalModel = LLMMappings[model];
|
||||
modelInstance = openai(finalModel);
|
||||
break;
|
||||
|
||||
case LLMModelEnum.CLAUDEOPUS:
|
||||
case LLMModelEnum.CLAUDESONNET:
|
||||
case LLMModelEnum.CLAUDEHAIKU:
|
||||
finalModel = LLMMappings[model];
|
||||
break;
|
||||
|
||||
case LLMModelEnum.GEMINI25FLASH:
|
||||
case LLMModelEnum.GEMINI25PRO:
|
||||
case LLMModelEnum.GEMINI20FLASH:
|
||||
case LLMModelEnum.GEMINI20FLASHLITE:
|
||||
finalModel = LLMMappings[model];
|
||||
break;
|
||||
|
||||
default:
|
||||
logger.warn(`Unsupported model type: ${model}`);
|
||||
break;
|
||||
}
|
||||
|
||||
if (stream) {
|
||||
return await streamText({
|
||||
model: modelInstance as LanguageModelV1,
|
||||
messages,
|
||||
onFinish: async ({ text }) => {
|
||||
onFinish(text, finalModel);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const { text } = await generateText({
|
||||
model: modelInstance as LanguageModelV1,
|
||||
messages,
|
||||
});
|
||||
|
||||
onFinish(text, finalModel);
|
||||
|
||||
return text;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,29 +1,19 @@
|
||||
import {
|
||||
type EntityNode,
|
||||
type KnowledgeGraphService,
|
||||
type StatementNode,
|
||||
} from "./knowledgeGraph.server";
|
||||
import { openai } from "@ai-sdk/openai";
|
||||
import type { StatementNode } from "@recall/types";
|
||||
import { embed } from "ai";
|
||||
import HelixDB from "helix-ts";
|
||||
|
||||
// Initialize OpenAI for embeddings
|
||||
const openaiClient = openai("gpt-4.1-2025-04-14");
|
||||
|
||||
// Initialize Helix client
|
||||
const helixClient = new HelixDB();
|
||||
import { logger } from "./logger.service";
|
||||
import { applyCrossEncoderReranking, applyWeightedRRF } from "./search/rerank";
|
||||
import {
|
||||
performBfsSearch,
|
||||
performBM25Search,
|
||||
performVectorSearch,
|
||||
} from "./search/utils";
|
||||
|
||||
/**
|
||||
* SearchService provides methods to search the reified + temporal knowledge graph
|
||||
* using a hybrid approach combining BM25, vector similarity, and BFS traversal.
|
||||
*/
|
||||
export class SearchService {
|
||||
private knowledgeGraphService: KnowledgeGraphService;
|
||||
|
||||
constructor(knowledgeGraphService: KnowledgeGraphService) {
|
||||
this.knowledgeGraphService = knowledgeGraphService;
|
||||
}
|
||||
|
||||
async getEmbedding(text: string) {
|
||||
const { embedding } = await embed({
|
||||
model: openai.embedding("text-embedding-3-small"),
|
||||
@ -44,7 +34,7 @@ export class SearchService {
|
||||
query: string,
|
||||
userId: string,
|
||||
options: SearchOptions = {},
|
||||
): Promise<StatementNode[]> {
|
||||
): Promise<string[]> {
|
||||
// Default options
|
||||
const opts: Required<SearchOptions> = {
|
||||
limit: options.limit || 10,
|
||||
@ -53,296 +43,144 @@ export class SearchService {
|
||||
includeInvalidated: options.includeInvalidated || false,
|
||||
entityTypes: options.entityTypes || [],
|
||||
predicateTypes: options.predicateTypes || [],
|
||||
scoreThreshold: options.scoreThreshold || 0.7,
|
||||
minResults: options.minResults || 10,
|
||||
};
|
||||
|
||||
const queryVector = await this.getEmbedding(query);
|
||||
|
||||
// 1. Run parallel search methods
|
||||
const [bm25Results, vectorResults, bfsResults] = await Promise.all([
|
||||
this.performBM25Search(query, userId, opts),
|
||||
this.performVectorSearch(query, userId, opts),
|
||||
this.performBfsSearch(query, userId, opts),
|
||||
performBM25Search(query, userId, opts),
|
||||
performVectorSearch(queryVector, userId, opts),
|
||||
performBfsSearch(queryVector, userId, opts),
|
||||
]);
|
||||
|
||||
// 2. Combine and deduplicate results
|
||||
const combinedStatements = this.combineAndDeduplicate([
|
||||
...bm25Results,
|
||||
...vectorResults,
|
||||
...bfsResults,
|
||||
]);
|
||||
logger.info(
|
||||
`Search results - BM25: ${bm25Results.length}, Vector: ${vectorResults.length}, BFS: ${bfsResults.length}`,
|
||||
);
|
||||
|
||||
// 3. Rerank the combined results
|
||||
const rerankedStatements = await this.rerankStatements(
|
||||
// 2. Apply reranking strategy
|
||||
const rankedStatements = await this.rerankResults(
|
||||
query,
|
||||
combinedStatements,
|
||||
{ bm25: bm25Results, vector: vectorResults, bfs: bfsResults },
|
||||
opts,
|
||||
);
|
||||
|
||||
// 4. Return top results
|
||||
return rerankedStatements.slice(0, opts.limit);
|
||||
// 3. Apply adaptive filtering based on score threshold and minimum count
|
||||
const filteredResults = this.applyAdaptiveFiltering(rankedStatements, opts);
|
||||
|
||||
// 3. Return top results
|
||||
return filteredResults.map((statement) => statement.fact);
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform BM25 keyword-based search on statements
|
||||
* Apply adaptive filtering to ranked results
|
||||
* Uses a minimum quality threshold to filter out low-quality results
|
||||
*/
|
||||
private async performBM25Search(
|
||||
query: string,
|
||||
userId: string,
|
||||
private applyAdaptiveFiltering(
|
||||
results: StatementNode[],
|
||||
options: Required<SearchOptions>,
|
||||
): Promise<StatementNode[]> {
|
||||
// TODO: Implement BM25 search using HelixDB or external search index
|
||||
// This is a placeholder implementation
|
||||
try {
|
||||
const results = await helixClient.query("searchStatementsByKeywords", {
|
||||
query,
|
||||
userId,
|
||||
validAt: options.validAt.toISOString(),
|
||||
includeInvalidated: options.includeInvalidated,
|
||||
limit: options.limit * 2, // Fetch more for reranking
|
||||
});
|
||||
): StatementNode[] {
|
||||
if (results.length === 0) return [];
|
||||
|
||||
return results.statements || [];
|
||||
} catch (error) {
|
||||
console.error("BM25 search error:", error);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform vector similarity search on statement embeddings
|
||||
*/
|
||||
private async performVectorSearch(
|
||||
query: string,
|
||||
userId: string,
|
||||
options: Required<SearchOptions>,
|
||||
): Promise<StatementNode[]> {
|
||||
try {
|
||||
// 1. Generate embedding for the query
|
||||
const embedding = await this.generateEmbedding(query);
|
||||
|
||||
// 2. Search for similar statements
|
||||
const results = await helixClient.query("searchStatementsByVector", {
|
||||
embedding,
|
||||
userId,
|
||||
validAt: options.validAt.toISOString(),
|
||||
includeInvalidated: options.includeInvalidated,
|
||||
limit: options.limit * 2, // Fetch more for reranking
|
||||
});
|
||||
|
||||
return results.statements || [];
|
||||
} catch (error) {
|
||||
console.error("Vector search error:", error);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform BFS traversal starting from entities mentioned in the query
|
||||
*/
|
||||
private async performBfsSearch(
|
||||
query: string,
|
||||
userId: string,
|
||||
options: Required<SearchOptions>,
|
||||
): Promise<StatementNode[]> {
|
||||
try {
|
||||
// 1. Extract potential entities from query
|
||||
const entities = await this.extractEntitiesFromQuery(query);
|
||||
|
||||
// 2. For each entity, perform BFS traversal
|
||||
const allStatements: StatementNode[] = [];
|
||||
|
||||
for (const entity of entities) {
|
||||
const statements = await this.bfsTraversal(
|
||||
entity.uuid,
|
||||
options.maxBfsDepth,
|
||||
options.validAt,
|
||||
userId,
|
||||
options.includeInvalidated,
|
||||
);
|
||||
allStatements.push(...statements);
|
||||
// Extract scores from results
|
||||
const scoredResults = results.map((result) => {
|
||||
// Find the score based on reranking strategy used
|
||||
let score = 0;
|
||||
if ((result as any).rrfScore !== undefined) {
|
||||
score = (result as any).rrfScore;
|
||||
} else if ((result as any).mmrScore !== undefined) {
|
||||
score = (result as any).mmrScore;
|
||||
} else if ((result as any).crossEncoderScore !== undefined) {
|
||||
score = (result as any).crossEncoderScore;
|
||||
} else if ((result as any).finalScore !== undefined) {
|
||||
score = (result as any).finalScore;
|
||||
}
|
||||
|
||||
return allStatements;
|
||||
} catch (error) {
|
||||
console.error("BFS search error:", error);
|
||||
return [];
|
||||
return { result, score };
|
||||
});
|
||||
|
||||
const hasScores = scoredResults.some((item) => item.score > 0);
|
||||
// If no scores are available, return the original results
|
||||
if (!hasScores) {
|
||||
logger.info("No scores found in results, skipping adaptive filtering");
|
||||
return options.limit > 0 ? results.slice(0, options.limit) : results;
|
||||
}
|
||||
|
||||
// Sort by score (descending)
|
||||
scoredResults.sort((a, b) => b.score - a.score);
|
||||
|
||||
// Calculate statistics to identify low-quality results
|
||||
const scores = scoredResults.map((item) => item.score);
|
||||
const maxScore = Math.max(...scores);
|
||||
const minScore = Math.min(...scores);
|
||||
const scoreRange = maxScore - minScore;
|
||||
|
||||
// Define a minimum quality threshold as a fraction of the best score
|
||||
// This is relative to the query's score distribution rather than an absolute value
|
||||
const relativeThreshold = options.scoreThreshold || 0.3; // 30% of the best score by default
|
||||
const absoluteMinimum = 0.1; // Absolute minimum threshold to prevent keeping very poor matches
|
||||
|
||||
// Calculate the actual threshold as a percentage of the distance from min to max score
|
||||
const threshold = Math.max(
|
||||
absoluteMinimum,
|
||||
minScore + scoreRange * relativeThreshold,
|
||||
);
|
||||
|
||||
// Filter out low-quality results
|
||||
const filteredResults = scoredResults
|
||||
.filter((item) => item.score >= threshold)
|
||||
.map((item) => item.result);
|
||||
|
||||
// Apply limit if specified
|
||||
const limitedResults =
|
||||
options.limit > 0
|
||||
? filteredResults.slice(
|
||||
0,
|
||||
Math.min(filteredResults.length, options.limit),
|
||||
)
|
||||
: filteredResults;
|
||||
|
||||
logger.info(
|
||||
`Quality filtering: ${limitedResults.length}/${results.length} results kept (threshold: ${threshold.toFixed(3)})`,
|
||||
);
|
||||
logger.info(
|
||||
`Score range: min=${minScore.toFixed(3)}, max=${maxScore.toFixed(3)}, threshold=${threshold.toFixed(3)}`,
|
||||
);
|
||||
|
||||
return limitedResults;
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform BFS traversal starting from an entity
|
||||
* Apply the selected reranking strategy to search results
|
||||
*/
|
||||
private async bfsTraversal(
|
||||
startEntityId: string,
|
||||
maxDepth: number,
|
||||
validAt: Date,
|
||||
userId: string,
|
||||
includeInvalidated: boolean,
|
||||
): Promise<StatementNode[]> {
|
||||
// Track visited nodes to avoid cycles
|
||||
const visited = new Set<string>();
|
||||
// Track statements found during traversal
|
||||
const statements: StatementNode[] = [];
|
||||
// Queue for BFS traversal [nodeId, depth]
|
||||
const queue: [string, number][] = [[startEntityId, 0]];
|
||||
|
||||
while (queue.length > 0) {
|
||||
const [nodeId, depth] = queue.shift()!;
|
||||
|
||||
// Skip if already visited or max depth reached
|
||||
if (visited.has(nodeId) || depth > maxDepth) continue;
|
||||
visited.add(nodeId);
|
||||
|
||||
// Get statements where this entity is subject or object
|
||||
const connectedStatements = await helixClient.query(
|
||||
"getConnectedStatements",
|
||||
{
|
||||
entityId: nodeId,
|
||||
userId,
|
||||
validAt: validAt.toISOString(),
|
||||
includeInvalidated,
|
||||
},
|
||||
);
|
||||
|
||||
// Add statements to results
|
||||
if (connectedStatements.statements) {
|
||||
statements.push(...connectedStatements.statements);
|
||||
|
||||
// Add connected entities to queue
|
||||
for (const statement of connectedStatements.statements) {
|
||||
// Get subject and object entities
|
||||
if (statement.subjectId && !visited.has(statement.subjectId)) {
|
||||
queue.push([statement.subjectId, depth + 1]);
|
||||
}
|
||||
if (statement.objectId && !visited.has(statement.objectId)) {
|
||||
queue.push([statement.objectId, depth + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return statements;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract potential entities from a query using embeddings or LLM
|
||||
*/
|
||||
private async extractEntitiesFromQuery(query: string): Promise<EntityNode[]> {
|
||||
// TODO: Implement more sophisticated entity extraction
|
||||
// This is a placeholder implementation that uses simple vector search
|
||||
try {
|
||||
const embedding = await this.getEmbedding(query);
|
||||
|
||||
const results = await helixClient.query("searchEntitiesByVector", {
|
||||
embedding,
|
||||
limit: 3, // Start with top 3 entities
|
||||
});
|
||||
|
||||
return results.entities || [];
|
||||
} catch (error) {
|
||||
console.error("Entity extraction error:", error);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Combine and deduplicate statements from multiple sources
|
||||
*/
|
||||
private combineAndDeduplicate(statements: StatementNode[]): StatementNode[] {
|
||||
const uniqueStatements = new Map<string, StatementNode>();
|
||||
|
||||
for (const statement of statements) {
|
||||
if (!uniqueStatements.has(statement.uuid)) {
|
||||
uniqueStatements.set(statement.uuid, statement);
|
||||
}
|
||||
}
|
||||
|
||||
return Array.from(uniqueStatements.values());
|
||||
}
|
||||
|
||||
/**
|
||||
* Rerank statements based on relevance to the query
|
||||
*/
|
||||
private async rerankStatements(
|
||||
private async rerankResults(
|
||||
query: string,
|
||||
statements: StatementNode[],
|
||||
results: {
|
||||
bm25: StatementNode[];
|
||||
vector: StatementNode[];
|
||||
bfs: StatementNode[];
|
||||
},
|
||||
options: Required<SearchOptions>,
|
||||
): Promise<StatementNode[]> {
|
||||
// TODO: Implement more sophisticated reranking
|
||||
// This is a placeholder implementation using cosine similarity
|
||||
try {
|
||||
// 1. Generate embedding for the query
|
||||
const queryEmbedding = await this.getEmbedding(query);
|
||||
// Count non-empty result sources
|
||||
const nonEmptySources = [
|
||||
results.bm25.length > 0,
|
||||
results.vector.length > 0,
|
||||
results.bfs.length > 0,
|
||||
].filter(Boolean).length;
|
||||
|
||||
// 2. Generate or retrieve embeddings for statements
|
||||
const statementEmbeddings = await Promise.all(
|
||||
statements.map(async (statement) => {
|
||||
// If statement has embedding, use it; otherwise generate
|
||||
if (statement.factEmbedding && statement.factEmbedding.length > 0) {
|
||||
return { statement, embedding: statement.factEmbedding };
|
||||
}
|
||||
|
||||
// Generate text representation of statement
|
||||
const statementText = this.statementToText(statement);
|
||||
const embedding = await this.getEmbedding(statementText);
|
||||
|
||||
return { statement, embedding };
|
||||
}),
|
||||
// If results are coming from only one source, use cross-encoder reranking
|
||||
if (nonEmptySources <= 1) {
|
||||
logger.info(
|
||||
"Only one source has results, falling back to cross-encoder reranking",
|
||||
);
|
||||
|
||||
// 3. Calculate cosine similarity
|
||||
const scoredStatements = statementEmbeddings.map(
|
||||
({ statement, embedding }) => {
|
||||
const similarity = this.cosineSimilarity(queryEmbedding, embedding);
|
||||
return { statement, score: similarity };
|
||||
},
|
||||
);
|
||||
|
||||
// 4. Sort by score (descending)
|
||||
scoredStatements.sort((a, b) => b.score - a.score);
|
||||
|
||||
// 5. Return statements in order of relevance
|
||||
return scoredStatements.map(({ statement }) => statement);
|
||||
} catch (error) {
|
||||
console.error("Reranking error:", error);
|
||||
// Fallback: return original order
|
||||
return statements;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a statement to a text representation
|
||||
*/
|
||||
private statementToText(statement: StatementNode): string {
|
||||
// TODO: Implement more sophisticated text representation
|
||||
// This is a placeholder implementation
|
||||
return `${statement.subjectName || "Unknown"} ${statement.predicateName || "has relation with"} ${statement.objectName || "Unknown"}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate cosine similarity between two embeddings
|
||||
*/
|
||||
private cosineSimilarity(a: number[], b: number[]): number {
|
||||
if (a.length !== b.length) {
|
||||
throw new Error("Embeddings must have the same length");
|
||||
return applyCrossEncoderReranking(query, results);
|
||||
}
|
||||
|
||||
let dotProduct = 0;
|
||||
let normA = 0;
|
||||
let normB = 0;
|
||||
|
||||
for (let i = 0; i < a.length; i++) {
|
||||
dotProduct += a[i] * b[i];
|
||||
normA += a[i] * a[i];
|
||||
normB += b[i] * b[i];
|
||||
}
|
||||
|
||||
normA = Math.sqrt(normA);
|
||||
normB = Math.sqrt(normB);
|
||||
|
||||
if (normA === 0 || normB === 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return dotProduct / (normA * normB);
|
||||
// Otherwise use weighted RRF for multiple sources
|
||||
return applyWeightedRRF(results);
|
||||
}
|
||||
}
|
||||
|
||||
@ -356,23 +194,6 @@ export interface SearchOptions {
|
||||
includeInvalidated?: boolean;
|
||||
entityTypes?: string[];
|
||||
predicateTypes?: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a singleton instance of the search service
|
||||
*/
|
||||
let searchServiceInstance: SearchService | null = null;
|
||||
|
||||
export function getSearchService(
|
||||
knowledgeGraphService?: KnowledgeGraphService,
|
||||
): SearchService {
|
||||
if (!searchServiceInstance) {
|
||||
if (!knowledgeGraphService) {
|
||||
throw new Error(
|
||||
"KnowledgeGraphService must be provided when initializing SearchService",
|
||||
);
|
||||
}
|
||||
searchServiceInstance = new SearchService(knowledgeGraphService);
|
||||
}
|
||||
return searchServiceInstance;
|
||||
scoreThreshold?: number;
|
||||
minResults?: number;
|
||||
}
|
||||
|
||||
118
apps/webapp/app/services/search/rerank.ts
Normal file
118
apps/webapp/app/services/search/rerank.ts
Normal file
@ -0,0 +1,118 @@
|
||||
import { LLMModelEnum, type StatementNode } from "@recall/types";
|
||||
import { combineAndDeduplicateStatements } from "./utils";
|
||||
import { type CoreMessage } from "ai";
|
||||
import { makeModelCall } from "~/lib/model.server";
|
||||
import { logger } from "../logger.service";
|
||||
|
||||
/**
|
||||
* Apply Weighted Reciprocal Rank Fusion to combine results
|
||||
*/
|
||||
export function applyWeightedRRF(results: {
|
||||
bm25: StatementNode[];
|
||||
vector: StatementNode[];
|
||||
bfs: StatementNode[];
|
||||
}): StatementNode[] {
|
||||
// Determine weights based on query characteristics
|
||||
const weights = {
|
||||
bm25: 1.0,
|
||||
vector: 0.8,
|
||||
bfs: 0.5,
|
||||
};
|
||||
const k = 60; // RRF constant
|
||||
|
||||
// Map to store combined scores
|
||||
const scores: Record<string, { score: number; statement: StatementNode }> =
|
||||
{};
|
||||
|
||||
// Process BM25 results with their weight
|
||||
results.bm25.forEach((statement, rank) => {
|
||||
const uuid = statement.uuid;
|
||||
scores[uuid] = scores[uuid] || { score: 0, statement };
|
||||
scores[uuid].score += weights.bm25 * (1 / (rank + k));
|
||||
});
|
||||
|
||||
// Process vector similarity results with their weight
|
||||
results.vector.forEach((statement, rank) => {
|
||||
const uuid = statement.uuid;
|
||||
scores[uuid] = scores[uuid] || { score: 0, statement };
|
||||
scores[uuid].score += weights.vector * (1 / (rank + k));
|
||||
});
|
||||
|
||||
// Process BFS traversal results with their weight
|
||||
results.bfs.forEach((statement, rank) => {
|
||||
const uuid = statement.uuid;
|
||||
scores[uuid] = scores[uuid] || { score: 0, statement };
|
||||
scores[uuid].score += weights.bfs * (1 / (rank + k));
|
||||
});
|
||||
|
||||
// Convert to array and sort by final score
|
||||
const sortedResults = Object.values(scores)
|
||||
.sort((a, b) => b.score - a.score)
|
||||
.map((item) => {
|
||||
// Add the RRF score to the statement for debugging
|
||||
return {
|
||||
...item.statement,
|
||||
rrfScore: item.score,
|
||||
};
|
||||
});
|
||||
|
||||
return sortedResults;
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply Cross-Encoder reranking to results
|
||||
* This is particularly useful when results come from a single source
|
||||
*/
|
||||
export async function applyCrossEncoderReranking(
|
||||
query: string,
|
||||
results: {
|
||||
bm25: StatementNode[];
|
||||
vector: StatementNode[];
|
||||
bfs: StatementNode[];
|
||||
},
|
||||
): Promise<StatementNode[]> {
|
||||
// Combine all results
|
||||
const allResults = [...results.bm25, ...results.vector, ...results.bfs];
|
||||
|
||||
// Deduplicate by UUID
|
||||
const uniqueResults = combineAndDeduplicateStatements(allResults);
|
||||
|
||||
if (uniqueResults.length === 0) return [];
|
||||
|
||||
logger.info(`Cross-encoder reranking ${uniqueResults.length} statements`);
|
||||
|
||||
const finalStatements: StatementNode[] = [];
|
||||
|
||||
await Promise.all(
|
||||
uniqueResults.map(async (statement) => {
|
||||
const messages: CoreMessage[] = [
|
||||
{
|
||||
role: "system",
|
||||
content: `You are an expert tasked with determining whether the statement is relevant to the query
|
||||
Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.`,
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: `<QUERY>${query}</QUERY>\n<STATEMENT>${statement.fact}</STATEMENT>`,
|
||||
},
|
||||
];
|
||||
|
||||
let responseText = "";
|
||||
await makeModelCall(
|
||||
false,
|
||||
LLMModelEnum.GPT41NANO,
|
||||
messages,
|
||||
(text) => {
|
||||
responseText = text;
|
||||
},
|
||||
{ temperature: 0, maxTokens: 1 },
|
||||
);
|
||||
|
||||
if (responseText === "True") {
|
||||
finalStatements.push(statement);
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
return finalStatements;
|
||||
}
|
||||
223
apps/webapp/app/services/search/utils.ts
Normal file
223
apps/webapp/app/services/search/utils.ts
Normal file
@ -0,0 +1,223 @@
|
||||
import type { EntityNode, StatementNode } from "@recall/types";
|
||||
import type { SearchOptions } from "../search.server";
|
||||
import type { Embedding } from "ai";
|
||||
import { logger } from "../logger.service";
|
||||
import { runQuery } from "~/lib/neo4j.server";
|
||||
|
||||
/**
|
||||
* Perform BM25 keyword-based search on statements
|
||||
*/
|
||||
export async function performBM25Search(
|
||||
query: string,
|
||||
userId: string,
|
||||
options: Required<SearchOptions>,
|
||||
): Promise<StatementNode[]> {
|
||||
try {
|
||||
// Sanitize the query for Lucene syntax
|
||||
const sanitizedQuery = sanitizeLuceneQuery(query);
|
||||
|
||||
// Use Neo4j's built-in fulltext search capabilities
|
||||
const cypher = `
|
||||
CALL db.index.fulltext.queryNodes("statement_fact_index", $query)
|
||||
YIELD node AS s, score
|
||||
WHERE
|
||||
s.validAt <= $validAt
|
||||
AND (s.invalidAt IS NULL OR s.invalidAt > $validAt)
|
||||
AND (s.userId = $userId)
|
||||
RETURN s, score
|
||||
ORDER BY score DESC
|
||||
`;
|
||||
|
||||
const params = {
|
||||
query: sanitizedQuery,
|
||||
userId,
|
||||
validAt: options.validAt.toISOString(),
|
||||
};
|
||||
|
||||
const records = await runQuery(cypher, params);
|
||||
// return records.map((record) => record.get("s").properties as StatementNode);
|
||||
return [];
|
||||
} catch (error) {
|
||||
logger.error("BM25 search error:", { error });
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitize a query string for Lucene syntax
|
||||
*/
|
||||
export function sanitizeLuceneQuery(query: string): string {
|
||||
// Escape special characters: + - && || ! ( ) { } [ ] ^ " ~ * ? : \
|
||||
let sanitized = query.replace(
|
||||
/[+\-&|!(){}[\]^"~*?:\\]/g,
|
||||
(match) => "\\" + match,
|
||||
);
|
||||
|
||||
// If query is too long, truncate it
|
||||
const MAX_QUERY_LENGTH = 32;
|
||||
const words = sanitized.split(" ");
|
||||
if (words.length > MAX_QUERY_LENGTH) {
|
||||
sanitized = words.slice(0, MAX_QUERY_LENGTH).join(" ");
|
||||
}
|
||||
|
||||
return sanitized;
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform vector similarity search on statement embeddings
|
||||
*/
|
||||
export async function performVectorSearch(
|
||||
query: Embedding,
|
||||
userId: string,
|
||||
options: Required<SearchOptions>,
|
||||
): Promise<StatementNode[]> {
|
||||
try {
|
||||
// 1. Generate embedding for the query
|
||||
// const embedding = await this.getEmbedding(query);
|
||||
|
||||
// 2. Search for similar statements using Neo4j vector search
|
||||
const cypher = `
|
||||
MATCH (s:Statement)
|
||||
WHERE
|
||||
s.validAt <= $validAt
|
||||
AND (s.invalidAt IS NULL OR s.invalidAt > $validAt)
|
||||
AND (s.userId = $userId OR s.isPublic = true)
|
||||
WITH s, vector.similarity.cosine(s.factEmbedding, $embedding) AS score
|
||||
WHERE score > 0.7
|
||||
RETURN s, score
|
||||
ORDER BY score DESC
|
||||
`;
|
||||
|
||||
const params = {
|
||||
embedding: query,
|
||||
userId,
|
||||
validAt: options.validAt.toISOString(),
|
||||
};
|
||||
|
||||
const records = await runQuery(cypher, params);
|
||||
// return records.map((record) => record.get("s").properties as StatementNode);
|
||||
return [];
|
||||
} catch (error) {
|
||||
logger.error("Vector search error:", { error });
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform BFS traversal starting from entities mentioned in the query
|
||||
*/
|
||||
export async function performBfsSearch(
|
||||
embedding: Embedding,
|
||||
userId: string,
|
||||
options: Required<SearchOptions>,
|
||||
): Promise<StatementNode[]> {
|
||||
try {
|
||||
// 1. Extract potential entities from query
|
||||
const entities = await extractEntitiesFromQuery(embedding);
|
||||
|
||||
// 2. For each entity, perform BFS traversal
|
||||
const allStatements: StatementNode[] = [];
|
||||
|
||||
for (const entity of entities) {
|
||||
const statements = await bfsTraversal(
|
||||
entity.uuid,
|
||||
options.maxBfsDepth,
|
||||
options.validAt,
|
||||
userId,
|
||||
options.includeInvalidated,
|
||||
);
|
||||
allStatements.push(...statements);
|
||||
}
|
||||
|
||||
return allStatements;
|
||||
} catch (error) {
|
||||
logger.error("BFS search error:", { error });
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform BFS traversal starting from an entity
|
||||
*/
|
||||
export async function bfsTraversal(
|
||||
startEntityId: string,
|
||||
maxDepth: number,
|
||||
validAt: Date,
|
||||
userId: string,
|
||||
includeInvalidated: boolean,
|
||||
): Promise<StatementNode[]> {
|
||||
try {
|
||||
// Use Neo4j's built-in path finding capabilities for efficient BFS
|
||||
// This query implements BFS up to maxDepth and collects all statements along the way
|
||||
const cypher = `
|
||||
MATCH (e:Entity {uuid: $startEntityId})<-[:HAS_SUBJECT|HAS_OBJECT|HAS_PREDICATE]-(s:Statement)
|
||||
WHERE
|
||||
s.validAt <= $validAt
|
||||
AND (s.invalidAt IS NULL OR s.invalidAt > $validAt)
|
||||
AND (s.userId = $userId)
|
||||
AND ($includeInvalidated OR s.invalidAt IS NULL)
|
||||
RETURN s as statement
|
||||
`;
|
||||
|
||||
const params = {
|
||||
startEntityId,
|
||||
maxDepth,
|
||||
validAt: validAt.toISOString(),
|
||||
userId,
|
||||
includeInvalidated,
|
||||
};
|
||||
|
||||
const records = await runQuery(cypher, params);
|
||||
return records.map(
|
||||
(record) => record.get("statement").properties as StatementNode,
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("BFS traversal error:", { error });
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract potential entities from a query using embeddings or LLM
|
||||
*/
|
||||
export async function extractEntitiesFromQuery(
|
||||
embedding: Embedding,
|
||||
): Promise<EntityNode[]> {
|
||||
try {
|
||||
// Use vector similarity to find relevant entities
|
||||
const cypher = `
|
||||
// Match entities using vector similarity on name embeddings
|
||||
MATCH (e:Entity)
|
||||
WHERE e.nameEmbedding IS NOT NULL
|
||||
WITH e, vector.similarity.cosine(e.nameEmbedding, $embedding) AS score
|
||||
WHERE score > 0.7
|
||||
RETURN e
|
||||
ORDER BY score DESC
|
||||
LIMIT 3
|
||||
`;
|
||||
|
||||
const params = {
|
||||
embedding,
|
||||
};
|
||||
|
||||
const records = await runQuery(cypher, params);
|
||||
|
||||
return records.map((record) => record.get("e").properties as EntityNode);
|
||||
} catch (error) {
|
||||
logger.error("Entity extraction error:", { error });
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Combine and deduplicate statements from different search methods
|
||||
*/
|
||||
export function combineAndDeduplicateStatements(
|
||||
statements: StatementNode[],
|
||||
): StatementNode[] {
|
||||
return Array.from(
|
||||
new Map(
|
||||
statements.map((statement) => [statement.uuid, statement]),
|
||||
).values(),
|
||||
);
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user