mirror of
https://github.com/eliasstepanik/core.git
synced 2026-01-11 18:08:27 +00:00
feat: implement hybrid search with BM25, vector and BFS traversal
This commit is contained in:
parent
cf20da9ecd
commit
848153d57a
68
apps/webapp/app/lib/model.server.ts
Normal file
68
apps/webapp/app/lib/model.server.ts
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import { LLMMappings, LLMModelEnum } from "@recall/types";
|
||||||
|
import {
|
||||||
|
type CoreMessage,
|
||||||
|
type LanguageModelV1,
|
||||||
|
generateText,
|
||||||
|
streamText,
|
||||||
|
} from "ai";
|
||||||
|
import { openai } from "@ai-sdk/openai";
|
||||||
|
import { logger } from "~/services/logger.service";
|
||||||
|
|
||||||
|
export async function makeModelCall(
|
||||||
|
stream: boolean,
|
||||||
|
model: LLMModelEnum,
|
||||||
|
messages: CoreMessage[],
|
||||||
|
onFinish: (text: string, model: string) => void,
|
||||||
|
options?: any,
|
||||||
|
) {
|
||||||
|
let modelInstance;
|
||||||
|
let finalModel: string = "unknown";
|
||||||
|
|
||||||
|
switch (model) {
|
||||||
|
case LLMModelEnum.GPT35TURBO:
|
||||||
|
case LLMModelEnum.GPT4TURBO:
|
||||||
|
case LLMModelEnum.GPT4O:
|
||||||
|
case LLMModelEnum.GPT41:
|
||||||
|
case LLMModelEnum.GPT41MINI:
|
||||||
|
case LLMModelEnum.GPT41NANO:
|
||||||
|
finalModel = LLMMappings[model];
|
||||||
|
modelInstance = openai(finalModel, { ...options });
|
||||||
|
break;
|
||||||
|
|
||||||
|
case LLMModelEnum.CLAUDEOPUS:
|
||||||
|
case LLMModelEnum.CLAUDESONNET:
|
||||||
|
case LLMModelEnum.CLAUDEHAIKU:
|
||||||
|
finalModel = LLMMappings[model];
|
||||||
|
break;
|
||||||
|
|
||||||
|
case LLMModelEnum.GEMINI25FLASH:
|
||||||
|
case LLMModelEnum.GEMINI25PRO:
|
||||||
|
case LLMModelEnum.GEMINI20FLASH:
|
||||||
|
case LLMModelEnum.GEMINI20FLASHLITE:
|
||||||
|
finalModel = LLMMappings[model];
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
logger.warn(`Unsupported model type: ${model}`);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (stream) {
|
||||||
|
return await streamText({
|
||||||
|
model: modelInstance as LanguageModelV1,
|
||||||
|
messages,
|
||||||
|
onFinish: async ({ text }) => {
|
||||||
|
onFinish(text, finalModel);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const { text } = await generateText({
|
||||||
|
model: modelInstance as LanguageModelV1,
|
||||||
|
messages,
|
||||||
|
});
|
||||||
|
|
||||||
|
onFinish(text, finalModel);
|
||||||
|
|
||||||
|
return text;
|
||||||
|
}
|
||||||
@ -46,28 +46,66 @@ const runQuery = async (cypher: string, params = {}) => {
|
|||||||
// Initialize the database schema
|
// Initialize the database schema
|
||||||
const initializeSchema = async () => {
|
const initializeSchema = async () => {
|
||||||
try {
|
try {
|
||||||
// Run schema setup queries
|
// Create constraints for unique IDs
|
||||||
|
await runQuery(
|
||||||
|
"CREATE CONSTRAINT episode_uuid IF NOT EXISTS FOR (n:Episode) REQUIRE n.uuid IS UNIQUE",
|
||||||
|
);
|
||||||
|
await runQuery(
|
||||||
|
"CREATE CONSTRAINT entity_uuid IF NOT EXISTS FOR (n:Entity) REQUIRE n.uuid IS UNIQUE",
|
||||||
|
);
|
||||||
|
await runQuery(
|
||||||
|
"CREATE CONSTRAINT statement_uuid IF NOT EXISTS FOR (n:Statement) REQUIRE n.uuid IS UNIQUE",
|
||||||
|
);
|
||||||
|
|
||||||
|
// Create indexes for better query performance
|
||||||
|
await runQuery(
|
||||||
|
"CREATE INDEX episode_valid_at IF NOT EXISTS FOR (n:Episode) ON (n.validAt)",
|
||||||
|
);
|
||||||
|
await runQuery(
|
||||||
|
"CREATE INDEX statement_valid_at IF NOT EXISTS FOR (n:Statement) ON (n.validAt)",
|
||||||
|
);
|
||||||
|
await runQuery(
|
||||||
|
"CREATE INDEX statement_invalid_at IF NOT EXISTS FOR (n:Statement) ON (n.invalidAt)",
|
||||||
|
);
|
||||||
|
await runQuery(
|
||||||
|
"CREATE INDEX entity_name IF NOT EXISTS FOR (n:Entity) ON (n.name)",
|
||||||
|
);
|
||||||
|
|
||||||
|
// Create vector indexes for semantic search (if using Neo4j 5.0+)
|
||||||
await runQuery(`
|
await runQuery(`
|
||||||
// Create constraints for unique IDs
|
|
||||||
CREATE CONSTRAINT episode_uuid IF NOT EXISTS FOR (n:Episode) REQUIRE n.uuid IS UNIQUE;
|
|
||||||
CREATE CONSTRAINT entity_uuid IF NOT EXISTS FOR (n:Entity) REQUIRE n.uuid IS UNIQUE;
|
|
||||||
CREATE CONSTRAINT statement_uuid IF NOT EXISTS FOR (n:Statement) REQUIRE n.uuid IS UNIQUE;
|
|
||||||
|
|
||||||
// Create indexes for better query performance
|
|
||||||
CREATE INDEX episode_valid_at IF NOT EXISTS FOR (n:Episode) ON (n.validAt);
|
|
||||||
CREATE INDEX statement_valid_at IF NOT EXISTS FOR (n:Statement) ON (n.validAt);
|
|
||||||
CREATE INDEX statement_invalid_at IF NOT EXISTS FOR (n:Statement) ON (n.invalidAt);
|
|
||||||
CREATE INDEX entity_name IF NOT EXISTS FOR (n:Entity) ON (n.name);
|
|
||||||
|
|
||||||
// Create vector indexes for semantic search (if using Neo4j 5.0+)
|
|
||||||
CREATE VECTOR INDEX entity_embedding IF NOT EXISTS FOR (n:Entity) ON n.nameEmbedding
|
CREATE VECTOR INDEX entity_embedding IF NOT EXISTS FOR (n:Entity) ON n.nameEmbedding
|
||||||
OPTIONS {indexConfig: {dimensions: 1536, similarity: "cosine"}};
|
OPTIONS {indexConfig: {\`vector.dimensions\`: 1536, \`vector.similarity_function\`: 'cosine'}}
|
||||||
|
`);
|
||||||
|
|
||||||
|
await runQuery(`
|
||||||
CREATE VECTOR INDEX statement_embedding IF NOT EXISTS FOR (n:Statement) ON n.factEmbedding
|
CREATE VECTOR INDEX statement_embedding IF NOT EXISTS FOR (n:Statement) ON n.factEmbedding
|
||||||
OPTIONS {indexConfig: {dimensions: 1536, similarity: "cosine"}};
|
OPTIONS {indexConfig: {\`vector.dimensions\`: 1536, \`vector.similarity_function\`: 'cosine'}}
|
||||||
|
`);
|
||||||
|
|
||||||
|
await runQuery(`
|
||||||
CREATE VECTOR INDEX episode_embedding IF NOT EXISTS FOR (n:Episode) ON n.contentEmbedding
|
CREATE VECTOR INDEX episode_embedding IF NOT EXISTS FOR (n:Episode) ON n.contentEmbedding
|
||||||
OPTIONS {indexConfig: {dimensions: 1536, similarity: "cosine"}};
|
OPTIONS {indexConfig: {\`vector.dimensions\`: 1536, \`vector.similarity_function\`: 'cosine'}}
|
||||||
|
`);
|
||||||
|
|
||||||
|
// Create fulltext indexes for BM25 search
|
||||||
|
await runQuery(`
|
||||||
|
CREATE FULLTEXT INDEX statement_fact_index IF NOT EXISTS
|
||||||
|
FOR (n:Statement) ON EACH [n.fact]
|
||||||
|
OPTIONS {
|
||||||
|
indexConfig: {
|
||||||
|
\`fulltext.analyzer\`: 'english'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`);
|
||||||
|
|
||||||
|
await runQuery(`
|
||||||
|
CREATE FULLTEXT INDEX entity_name_index IF NOT EXISTS
|
||||||
|
FOR (n:Entity) ON EACH [n.name, n.description]
|
||||||
|
OPTIONS {
|
||||||
|
indexConfig: {
|
||||||
|
\`fulltext.analyzer\`: 'english'
|
||||||
|
}
|
||||||
|
}
|
||||||
`);
|
`);
|
||||||
|
|
||||||
logger.info("Neo4j schema initialized successfully");
|
logger.info("Neo4j schema initialized successfully");
|
||||||
@ -84,4 +122,6 @@ const closeDriver = async () => {
|
|||||||
logger.info("Neo4j driver closed");
|
logger.info("Neo4j driver closed");
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// await initializeSchema();
|
||||||
|
|
||||||
export { driver, verifyConnectivity, runQuery, initializeSchema, closeDriver };
|
export { driver, verifyConnectivity, runQuery, initializeSchema, closeDriver };
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import { EpisodeType } from "@recall/types";
|
import { EpisodeType } from "@recall/types";
|
||||||
import { json, LoaderFunctionArgs } from "@remix-run/node";
|
import { json } from "@remix-run/node";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createActionApiRoute } from "~/services/routeBuilders/apiBuilder.server";
|
import { createActionApiRoute } from "~/services/routeBuilders/apiBuilder.server";
|
||||||
import { getUserQueue } from "~/lib/ingest.queue";
|
import { getUserQueue } from "~/lib/ingest.queue";
|
||||||
|
|||||||
28
apps/webapp/app/routes/search.tsx
Normal file
28
apps/webapp/app/routes/search.tsx
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import { z } from "zod";
|
||||||
|
import { createActionApiRoute } from "~/services/routeBuilders/apiBuilder.server";
|
||||||
|
import { SearchService } from "~/services/search.server";
|
||||||
|
import { json } from "@remix-run/node";
|
||||||
|
|
||||||
|
export const SearchBodyRequest = z.object({
|
||||||
|
query: z.string(),
|
||||||
|
spaceId: z.string().optional(),
|
||||||
|
sessionId: z.string().optional(),
|
||||||
|
});
|
||||||
|
|
||||||
|
const searchService = new SearchService();
|
||||||
|
const { action, loader } = createActionApiRoute(
|
||||||
|
{
|
||||||
|
body: SearchBodyRequest,
|
||||||
|
allowJWT: true,
|
||||||
|
authorization: {
|
||||||
|
action: "search",
|
||||||
|
},
|
||||||
|
corsStrategy: "all",
|
||||||
|
},
|
||||||
|
async ({ body, authentication }) => {
|
||||||
|
const results = await searchService.search(body.query, authentication.userId);
|
||||||
|
return json(results);
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
export { action, loader };
|
||||||
@ -17,7 +17,7 @@ export async function saveTriple(triple: Triple): Promise<string> {
|
|||||||
MERGE (n:Statement {uuid: $uuid, userId: $userId})
|
MERGE (n:Statement {uuid: $uuid, userId: $userId})
|
||||||
ON CREATE SET
|
ON CREATE SET
|
||||||
n.fact = $fact,
|
n.fact = $fact,
|
||||||
n.embedding = $embedding,
|
n.factEmbedding = $factEmbedding,
|
||||||
n.createdAt = $createdAt,
|
n.createdAt = $createdAt,
|
||||||
n.validAt = $validAt,
|
n.validAt = $validAt,
|
||||||
n.invalidAt = $invalidAt,
|
n.invalidAt = $invalidAt,
|
||||||
@ -26,7 +26,7 @@ export async function saveTriple(triple: Triple): Promise<string> {
|
|||||||
n.space = $space
|
n.space = $space
|
||||||
ON MATCH SET
|
ON MATCH SET
|
||||||
n.fact = $fact,
|
n.fact = $fact,
|
||||||
n.embedding = $embedding,
|
n.factEmbedding = $factEmbedding,
|
||||||
n.validAt = $validAt,
|
n.validAt = $validAt,
|
||||||
n.invalidAt = $invalidAt,
|
n.invalidAt = $invalidAt,
|
||||||
n.attributesJson = $attributesJson,
|
n.attributesJson = $attributesJson,
|
||||||
@ -37,7 +37,7 @@ export async function saveTriple(triple: Triple): Promise<string> {
|
|||||||
const statementParams = {
|
const statementParams = {
|
||||||
uuid: triple.statement.uuid,
|
uuid: triple.statement.uuid,
|
||||||
fact: triple.statement.fact,
|
fact: triple.statement.fact,
|
||||||
embedding: triple.statement.factEmbedding,
|
factEmbedding: triple.statement.factEmbedding,
|
||||||
createdAt: triple.statement.createdAt.toISOString(),
|
createdAt: triple.statement.createdAt.toISOString(),
|
||||||
validAt: triple.statement.validAt.toISOString(),
|
validAt: triple.statement.validAt.toISOString(),
|
||||||
invalidAt: triple.statement.invalidAt
|
invalidAt: triple.statement.invalidAt
|
||||||
|
|||||||
@ -1,15 +1,8 @@
|
|||||||
import { openai } from "@ai-sdk/openai";
|
import { openai } from "@ai-sdk/openai";
|
||||||
import {
|
import { type CoreMessage, embed } from "ai";
|
||||||
type CoreMessage,
|
|
||||||
embed,
|
|
||||||
generateText,
|
|
||||||
type LanguageModelV1,
|
|
||||||
streamText,
|
|
||||||
} from "ai";
|
|
||||||
import {
|
import {
|
||||||
entityTypes,
|
entityTypes,
|
||||||
EpisodeType,
|
EpisodeType,
|
||||||
LLMMappings,
|
|
||||||
LLMModelEnum,
|
LLMModelEnum,
|
||||||
type AddEpisodeParams,
|
type AddEpisodeParams,
|
||||||
type EntityNode,
|
type EntityNode,
|
||||||
@ -33,6 +26,7 @@ import {
|
|||||||
invalidateStatements,
|
invalidateStatements,
|
||||||
saveTriple,
|
saveTriple,
|
||||||
} from "./graphModels/statement";
|
} from "./graphModels/statement";
|
||||||
|
import { makeModelCall } from "~/lib/model.server";
|
||||||
|
|
||||||
// Default number of previous episodes to retrieve for context
|
// Default number of previous episodes to retrieve for context
|
||||||
const DEFAULT_EPISODE_WINDOW = 5;
|
const DEFAULT_EPISODE_WINDOW = 5;
|
||||||
@ -155,7 +149,7 @@ export class KnowledgeGraphService {
|
|||||||
|
|
||||||
let responseText = "";
|
let responseText = "";
|
||||||
|
|
||||||
await this.makeModelCall(
|
await makeModelCall(
|
||||||
false,
|
false,
|
||||||
LLMModelEnum.GPT41,
|
LLMModelEnum.GPT41,
|
||||||
messages as CoreMessage[],
|
messages as CoreMessage[],
|
||||||
@ -217,7 +211,7 @@ export class KnowledgeGraphService {
|
|||||||
const messages = extractStatements(context);
|
const messages = extractStatements(context);
|
||||||
|
|
||||||
let responseText = "";
|
let responseText = "";
|
||||||
await this.makeModelCall(
|
await makeModelCall(
|
||||||
false,
|
false,
|
||||||
LLMModelEnum.GPT41,
|
LLMModelEnum.GPT41,
|
||||||
messages as CoreMessage[],
|
messages as CoreMessage[],
|
||||||
@ -393,7 +387,7 @@ export class KnowledgeGraphService {
|
|||||||
const messages = dedupeNodes(dedupeContext);
|
const messages = dedupeNodes(dedupeContext);
|
||||||
let responseText = "";
|
let responseText = "";
|
||||||
|
|
||||||
await this.makeModelCall(
|
await makeModelCall(
|
||||||
false,
|
false,
|
||||||
LLMModelEnum.GPT41,
|
LLMModelEnum.GPT41,
|
||||||
messages as CoreMessage[],
|
messages as CoreMessage[],
|
||||||
@ -583,7 +577,7 @@ export class KnowledgeGraphService {
|
|||||||
let responseText = "";
|
let responseText = "";
|
||||||
|
|
||||||
// Call the LLM to analyze all statements at once
|
// Call the LLM to analyze all statements at once
|
||||||
await this.makeModelCall(false, LLMModelEnum.GPT41, messages, (text) => {
|
await makeModelCall(false, LLMModelEnum.GPT41, messages, (text) => {
|
||||||
responseText = text;
|
responseText = text;
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -659,62 +653,4 @@ export class KnowledgeGraphService {
|
|||||||
|
|
||||||
return { resolvedStatements, invalidatedStatements };
|
return { resolvedStatements, invalidatedStatements };
|
||||||
}
|
}
|
||||||
|
|
||||||
private async makeModelCall(
|
|
||||||
stream: boolean,
|
|
||||||
model: LLMModelEnum,
|
|
||||||
messages: CoreMessage[],
|
|
||||||
onFinish: (text: string, model: string) => void,
|
|
||||||
) {
|
|
||||||
let modelInstance;
|
|
||||||
let finalModel: string = "unknown";
|
|
||||||
|
|
||||||
switch (model) {
|
|
||||||
case LLMModelEnum.GPT35TURBO:
|
|
||||||
case LLMModelEnum.GPT4TURBO:
|
|
||||||
case LLMModelEnum.GPT4O:
|
|
||||||
case LLMModelEnum.GPT41:
|
|
||||||
case LLMModelEnum.GPT41MINI:
|
|
||||||
case LLMModelEnum.GPT41NANO:
|
|
||||||
finalModel = LLMMappings[model];
|
|
||||||
modelInstance = openai(finalModel);
|
|
||||||
break;
|
|
||||||
|
|
||||||
case LLMModelEnum.CLAUDEOPUS:
|
|
||||||
case LLMModelEnum.CLAUDESONNET:
|
|
||||||
case LLMModelEnum.CLAUDEHAIKU:
|
|
||||||
finalModel = LLMMappings[model];
|
|
||||||
break;
|
|
||||||
|
|
||||||
case LLMModelEnum.GEMINI25FLASH:
|
|
||||||
case LLMModelEnum.GEMINI25PRO:
|
|
||||||
case LLMModelEnum.GEMINI20FLASH:
|
|
||||||
case LLMModelEnum.GEMINI20FLASHLITE:
|
|
||||||
finalModel = LLMMappings[model];
|
|
||||||
break;
|
|
||||||
|
|
||||||
default:
|
|
||||||
logger.warn(`Unsupported model type: ${model}`);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (stream) {
|
|
||||||
return await streamText({
|
|
||||||
model: modelInstance as LanguageModelV1,
|
|
||||||
messages,
|
|
||||||
onFinish: async ({ text }) => {
|
|
||||||
onFinish(text, finalModel);
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
const { text } = await generateText({
|
|
||||||
model: modelInstance as LanguageModelV1,
|
|
||||||
messages,
|
|
||||||
});
|
|
||||||
|
|
||||||
onFinish(text, finalModel);
|
|
||||||
|
|
||||||
return text;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,29 +1,19 @@
|
|||||||
import {
|
|
||||||
type EntityNode,
|
|
||||||
type KnowledgeGraphService,
|
|
||||||
type StatementNode,
|
|
||||||
} from "./knowledgeGraph.server";
|
|
||||||
import { openai } from "@ai-sdk/openai";
|
import { openai } from "@ai-sdk/openai";
|
||||||
|
import type { StatementNode } from "@recall/types";
|
||||||
import { embed } from "ai";
|
import { embed } from "ai";
|
||||||
import HelixDB from "helix-ts";
|
import { logger } from "./logger.service";
|
||||||
|
import { applyCrossEncoderReranking, applyWeightedRRF } from "./search/rerank";
|
||||||
// Initialize OpenAI for embeddings
|
import {
|
||||||
const openaiClient = openai("gpt-4.1-2025-04-14");
|
performBfsSearch,
|
||||||
|
performBM25Search,
|
||||||
// Initialize Helix client
|
performVectorSearch,
|
||||||
const helixClient = new HelixDB();
|
} from "./search/utils";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* SearchService provides methods to search the reified + temporal knowledge graph
|
* SearchService provides methods to search the reified + temporal knowledge graph
|
||||||
* using a hybrid approach combining BM25, vector similarity, and BFS traversal.
|
* using a hybrid approach combining BM25, vector similarity, and BFS traversal.
|
||||||
*/
|
*/
|
||||||
export class SearchService {
|
export class SearchService {
|
||||||
private knowledgeGraphService: KnowledgeGraphService;
|
|
||||||
|
|
||||||
constructor(knowledgeGraphService: KnowledgeGraphService) {
|
|
||||||
this.knowledgeGraphService = knowledgeGraphService;
|
|
||||||
}
|
|
||||||
|
|
||||||
async getEmbedding(text: string) {
|
async getEmbedding(text: string) {
|
||||||
const { embedding } = await embed({
|
const { embedding } = await embed({
|
||||||
model: openai.embedding("text-embedding-3-small"),
|
model: openai.embedding("text-embedding-3-small"),
|
||||||
@ -44,7 +34,7 @@ export class SearchService {
|
|||||||
query: string,
|
query: string,
|
||||||
userId: string,
|
userId: string,
|
||||||
options: SearchOptions = {},
|
options: SearchOptions = {},
|
||||||
): Promise<StatementNode[]> {
|
): Promise<string[]> {
|
||||||
// Default options
|
// Default options
|
||||||
const opts: Required<SearchOptions> = {
|
const opts: Required<SearchOptions> = {
|
||||||
limit: options.limit || 10,
|
limit: options.limit || 10,
|
||||||
@ -53,296 +43,144 @@ export class SearchService {
|
|||||||
includeInvalidated: options.includeInvalidated || false,
|
includeInvalidated: options.includeInvalidated || false,
|
||||||
entityTypes: options.entityTypes || [],
|
entityTypes: options.entityTypes || [],
|
||||||
predicateTypes: options.predicateTypes || [],
|
predicateTypes: options.predicateTypes || [],
|
||||||
|
scoreThreshold: options.scoreThreshold || 0.7,
|
||||||
|
minResults: options.minResults || 10,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const queryVector = await this.getEmbedding(query);
|
||||||
|
|
||||||
// 1. Run parallel search methods
|
// 1. Run parallel search methods
|
||||||
const [bm25Results, vectorResults, bfsResults] = await Promise.all([
|
const [bm25Results, vectorResults, bfsResults] = await Promise.all([
|
||||||
this.performBM25Search(query, userId, opts),
|
performBM25Search(query, userId, opts),
|
||||||
this.performVectorSearch(query, userId, opts),
|
performVectorSearch(queryVector, userId, opts),
|
||||||
this.performBfsSearch(query, userId, opts),
|
performBfsSearch(queryVector, userId, opts),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
// 2. Combine and deduplicate results
|
logger.info(
|
||||||
const combinedStatements = this.combineAndDeduplicate([
|
`Search results - BM25: ${bm25Results.length}, Vector: ${vectorResults.length}, BFS: ${bfsResults.length}`,
|
||||||
...bm25Results,
|
);
|
||||||
...vectorResults,
|
|
||||||
...bfsResults,
|
|
||||||
]);
|
|
||||||
|
|
||||||
// 3. Rerank the combined results
|
// 2. Apply reranking strategy
|
||||||
const rerankedStatements = await this.rerankStatements(
|
const rankedStatements = await this.rerankResults(
|
||||||
query,
|
query,
|
||||||
combinedStatements,
|
{ bm25: bm25Results, vector: vectorResults, bfs: bfsResults },
|
||||||
opts,
|
opts,
|
||||||
);
|
);
|
||||||
|
|
||||||
// 4. Return top results
|
// 3. Apply adaptive filtering based on score threshold and minimum count
|
||||||
return rerankedStatements.slice(0, opts.limit);
|
const filteredResults = this.applyAdaptiveFiltering(rankedStatements, opts);
|
||||||
|
|
||||||
|
// 3. Return top results
|
||||||
|
return filteredResults.map((statement) => statement.fact);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Perform BM25 keyword-based search on statements
|
* Apply adaptive filtering to ranked results
|
||||||
|
* Uses a minimum quality threshold to filter out low-quality results
|
||||||
*/
|
*/
|
||||||
private async performBM25Search(
|
private applyAdaptiveFiltering(
|
||||||
query: string,
|
results: StatementNode[],
|
||||||
userId: string,
|
|
||||||
options: Required<SearchOptions>,
|
options: Required<SearchOptions>,
|
||||||
): Promise<StatementNode[]> {
|
): StatementNode[] {
|
||||||
// TODO: Implement BM25 search using HelixDB or external search index
|
if (results.length === 0) return [];
|
||||||
// This is a placeholder implementation
|
|
||||||
try {
|
|
||||||
const results = await helixClient.query("searchStatementsByKeywords", {
|
|
||||||
query,
|
|
||||||
userId,
|
|
||||||
validAt: options.validAt.toISOString(),
|
|
||||||
includeInvalidated: options.includeInvalidated,
|
|
||||||
limit: options.limit * 2, // Fetch more for reranking
|
|
||||||
});
|
|
||||||
|
|
||||||
return results.statements || [];
|
// Extract scores from results
|
||||||
} catch (error) {
|
const scoredResults = results.map((result) => {
|
||||||
console.error("BM25 search error:", error);
|
// Find the score based on reranking strategy used
|
||||||
return [];
|
let score = 0;
|
||||||
}
|
if ((result as any).rrfScore !== undefined) {
|
||||||
}
|
score = (result as any).rrfScore;
|
||||||
|
} else if ((result as any).mmrScore !== undefined) {
|
||||||
/**
|
score = (result as any).mmrScore;
|
||||||
* Perform vector similarity search on statement embeddings
|
} else if ((result as any).crossEncoderScore !== undefined) {
|
||||||
*/
|
score = (result as any).crossEncoderScore;
|
||||||
private async performVectorSearch(
|
} else if ((result as any).finalScore !== undefined) {
|
||||||
query: string,
|
score = (result as any).finalScore;
|
||||||
userId: string,
|
|
||||||
options: Required<SearchOptions>,
|
|
||||||
): Promise<StatementNode[]> {
|
|
||||||
try {
|
|
||||||
// 1. Generate embedding for the query
|
|
||||||
const embedding = await this.generateEmbedding(query);
|
|
||||||
|
|
||||||
// 2. Search for similar statements
|
|
||||||
const results = await helixClient.query("searchStatementsByVector", {
|
|
||||||
embedding,
|
|
||||||
userId,
|
|
||||||
validAt: options.validAt.toISOString(),
|
|
||||||
includeInvalidated: options.includeInvalidated,
|
|
||||||
limit: options.limit * 2, // Fetch more for reranking
|
|
||||||
});
|
|
||||||
|
|
||||||
return results.statements || [];
|
|
||||||
} catch (error) {
|
|
||||||
console.error("Vector search error:", error);
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Perform BFS traversal starting from entities mentioned in the query
|
|
||||||
*/
|
|
||||||
private async performBfsSearch(
|
|
||||||
query: string,
|
|
||||||
userId: string,
|
|
||||||
options: Required<SearchOptions>,
|
|
||||||
): Promise<StatementNode[]> {
|
|
||||||
try {
|
|
||||||
// 1. Extract potential entities from query
|
|
||||||
const entities = await this.extractEntitiesFromQuery(query);
|
|
||||||
|
|
||||||
// 2. For each entity, perform BFS traversal
|
|
||||||
const allStatements: StatementNode[] = [];
|
|
||||||
|
|
||||||
for (const entity of entities) {
|
|
||||||
const statements = await this.bfsTraversal(
|
|
||||||
entity.uuid,
|
|
||||||
options.maxBfsDepth,
|
|
||||||
options.validAt,
|
|
||||||
userId,
|
|
||||||
options.includeInvalidated,
|
|
||||||
);
|
|
||||||
allStatements.push(...statements);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return allStatements;
|
return { result, score };
|
||||||
} catch (error) {
|
});
|
||||||
console.error("BFS search error:", error);
|
|
||||||
return [];
|
const hasScores = scoredResults.some((item) => item.score > 0);
|
||||||
|
// If no scores are available, return the original results
|
||||||
|
if (!hasScores) {
|
||||||
|
logger.info("No scores found in results, skipping adaptive filtering");
|
||||||
|
return options.limit > 0 ? results.slice(0, options.limit) : results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sort by score (descending)
|
||||||
|
scoredResults.sort((a, b) => b.score - a.score);
|
||||||
|
|
||||||
|
// Calculate statistics to identify low-quality results
|
||||||
|
const scores = scoredResults.map((item) => item.score);
|
||||||
|
const maxScore = Math.max(...scores);
|
||||||
|
const minScore = Math.min(...scores);
|
||||||
|
const scoreRange = maxScore - minScore;
|
||||||
|
|
||||||
|
// Define a minimum quality threshold as a fraction of the best score
|
||||||
|
// This is relative to the query's score distribution rather than an absolute value
|
||||||
|
const relativeThreshold = options.scoreThreshold || 0.3; // 30% of the best score by default
|
||||||
|
const absoluteMinimum = 0.1; // Absolute minimum threshold to prevent keeping very poor matches
|
||||||
|
|
||||||
|
// Calculate the actual threshold as a percentage of the distance from min to max score
|
||||||
|
const threshold = Math.max(
|
||||||
|
absoluteMinimum,
|
||||||
|
minScore + scoreRange * relativeThreshold,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Filter out low-quality results
|
||||||
|
const filteredResults = scoredResults
|
||||||
|
.filter((item) => item.score >= threshold)
|
||||||
|
.map((item) => item.result);
|
||||||
|
|
||||||
|
// Apply limit if specified
|
||||||
|
const limitedResults =
|
||||||
|
options.limit > 0
|
||||||
|
? filteredResults.slice(
|
||||||
|
0,
|
||||||
|
Math.min(filteredResults.length, options.limit),
|
||||||
|
)
|
||||||
|
: filteredResults;
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
`Quality filtering: ${limitedResults.length}/${results.length} results kept (threshold: ${threshold.toFixed(3)})`,
|
||||||
|
);
|
||||||
|
logger.info(
|
||||||
|
`Score range: min=${minScore.toFixed(3)}, max=${maxScore.toFixed(3)}, threshold=${threshold.toFixed(3)}`,
|
||||||
|
);
|
||||||
|
|
||||||
|
return limitedResults;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Perform BFS traversal starting from an entity
|
* Apply the selected reranking strategy to search results
|
||||||
*/
|
*/
|
||||||
private async bfsTraversal(
|
private async rerankResults(
|
||||||
startEntityId: string,
|
|
||||||
maxDepth: number,
|
|
||||||
validAt: Date,
|
|
||||||
userId: string,
|
|
||||||
includeInvalidated: boolean,
|
|
||||||
): Promise<StatementNode[]> {
|
|
||||||
// Track visited nodes to avoid cycles
|
|
||||||
const visited = new Set<string>();
|
|
||||||
// Track statements found during traversal
|
|
||||||
const statements: StatementNode[] = [];
|
|
||||||
// Queue for BFS traversal [nodeId, depth]
|
|
||||||
const queue: [string, number][] = [[startEntityId, 0]];
|
|
||||||
|
|
||||||
while (queue.length > 0) {
|
|
||||||
const [nodeId, depth] = queue.shift()!;
|
|
||||||
|
|
||||||
// Skip if already visited or max depth reached
|
|
||||||
if (visited.has(nodeId) || depth > maxDepth) continue;
|
|
||||||
visited.add(nodeId);
|
|
||||||
|
|
||||||
// Get statements where this entity is subject or object
|
|
||||||
const connectedStatements = await helixClient.query(
|
|
||||||
"getConnectedStatements",
|
|
||||||
{
|
|
||||||
entityId: nodeId,
|
|
||||||
userId,
|
|
||||||
validAt: validAt.toISOString(),
|
|
||||||
includeInvalidated,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
// Add statements to results
|
|
||||||
if (connectedStatements.statements) {
|
|
||||||
statements.push(...connectedStatements.statements);
|
|
||||||
|
|
||||||
// Add connected entities to queue
|
|
||||||
for (const statement of connectedStatements.statements) {
|
|
||||||
// Get subject and object entities
|
|
||||||
if (statement.subjectId && !visited.has(statement.subjectId)) {
|
|
||||||
queue.push([statement.subjectId, depth + 1]);
|
|
||||||
}
|
|
||||||
if (statement.objectId && !visited.has(statement.objectId)) {
|
|
||||||
queue.push([statement.objectId, depth + 1]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return statements;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Extract potential entities from a query using embeddings or LLM
|
|
||||||
*/
|
|
||||||
private async extractEntitiesFromQuery(query: string): Promise<EntityNode[]> {
|
|
||||||
// TODO: Implement more sophisticated entity extraction
|
|
||||||
// This is a placeholder implementation that uses simple vector search
|
|
||||||
try {
|
|
||||||
const embedding = await this.getEmbedding(query);
|
|
||||||
|
|
||||||
const results = await helixClient.query("searchEntitiesByVector", {
|
|
||||||
embedding,
|
|
||||||
limit: 3, // Start with top 3 entities
|
|
||||||
});
|
|
||||||
|
|
||||||
return results.entities || [];
|
|
||||||
} catch (error) {
|
|
||||||
console.error("Entity extraction error:", error);
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Combine and deduplicate statements from multiple sources
|
|
||||||
*/
|
|
||||||
private combineAndDeduplicate(statements: StatementNode[]): StatementNode[] {
|
|
||||||
const uniqueStatements = new Map<string, StatementNode>();
|
|
||||||
|
|
||||||
for (const statement of statements) {
|
|
||||||
if (!uniqueStatements.has(statement.uuid)) {
|
|
||||||
uniqueStatements.set(statement.uuid, statement);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Array.from(uniqueStatements.values());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Rerank statements based on relevance to the query
|
|
||||||
*/
|
|
||||||
private async rerankStatements(
|
|
||||||
query: string,
|
query: string,
|
||||||
statements: StatementNode[],
|
results: {
|
||||||
|
bm25: StatementNode[];
|
||||||
|
vector: StatementNode[];
|
||||||
|
bfs: StatementNode[];
|
||||||
|
},
|
||||||
options: Required<SearchOptions>,
|
options: Required<SearchOptions>,
|
||||||
): Promise<StatementNode[]> {
|
): Promise<StatementNode[]> {
|
||||||
// TODO: Implement more sophisticated reranking
|
// Count non-empty result sources
|
||||||
// This is a placeholder implementation using cosine similarity
|
const nonEmptySources = [
|
||||||
try {
|
results.bm25.length > 0,
|
||||||
// 1. Generate embedding for the query
|
results.vector.length > 0,
|
||||||
const queryEmbedding = await this.getEmbedding(query);
|
results.bfs.length > 0,
|
||||||
|
].filter(Boolean).length;
|
||||||
|
|
||||||
// 2. Generate or retrieve embeddings for statements
|
// If results are coming from only one source, use cross-encoder reranking
|
||||||
const statementEmbeddings = await Promise.all(
|
if (nonEmptySources <= 1) {
|
||||||
statements.map(async (statement) => {
|
logger.info(
|
||||||
// If statement has embedding, use it; otherwise generate
|
"Only one source has results, falling back to cross-encoder reranking",
|
||||||
if (statement.factEmbedding && statement.factEmbedding.length > 0) {
|
|
||||||
return { statement, embedding: statement.factEmbedding };
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate text representation of statement
|
|
||||||
const statementText = this.statementToText(statement);
|
|
||||||
const embedding = await this.getEmbedding(statementText);
|
|
||||||
|
|
||||||
return { statement, embedding };
|
|
||||||
}),
|
|
||||||
);
|
);
|
||||||
|
return applyCrossEncoderReranking(query, results);
|
||||||
// 3. Calculate cosine similarity
|
|
||||||
const scoredStatements = statementEmbeddings.map(
|
|
||||||
({ statement, embedding }) => {
|
|
||||||
const similarity = this.cosineSimilarity(queryEmbedding, embedding);
|
|
||||||
return { statement, score: similarity };
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
// 4. Sort by score (descending)
|
|
||||||
scoredStatements.sort((a, b) => b.score - a.score);
|
|
||||||
|
|
||||||
// 5. Return statements in order of relevance
|
|
||||||
return scoredStatements.map(({ statement }) => statement);
|
|
||||||
} catch (error) {
|
|
||||||
console.error("Reranking error:", error);
|
|
||||||
// Fallback: return original order
|
|
||||||
return statements;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert a statement to a text representation
|
|
||||||
*/
|
|
||||||
private statementToText(statement: StatementNode): string {
|
|
||||||
// TODO: Implement more sophisticated text representation
|
|
||||||
// This is a placeholder implementation
|
|
||||||
return `${statement.subjectName || "Unknown"} ${statement.predicateName || "has relation with"} ${statement.objectName || "Unknown"}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Calculate cosine similarity between two embeddings
|
|
||||||
*/
|
|
||||||
private cosineSimilarity(a: number[], b: number[]): number {
|
|
||||||
if (a.length !== b.length) {
|
|
||||||
throw new Error("Embeddings must have the same length");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let dotProduct = 0;
|
// Otherwise use weighted RRF for multiple sources
|
||||||
let normA = 0;
|
return applyWeightedRRF(results);
|
||||||
let normB = 0;
|
|
||||||
|
|
||||||
for (let i = 0; i < a.length; i++) {
|
|
||||||
dotProduct += a[i] * b[i];
|
|
||||||
normA += a[i] * a[i];
|
|
||||||
normB += b[i] * b[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
normA = Math.sqrt(normA);
|
|
||||||
normB = Math.sqrt(normB);
|
|
||||||
|
|
||||||
if (normA === 0 || normB === 0) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
return dotProduct / (normA * normB);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -356,23 +194,6 @@ export interface SearchOptions {
|
|||||||
includeInvalidated?: boolean;
|
includeInvalidated?: boolean;
|
||||||
entityTypes?: string[];
|
entityTypes?: string[];
|
||||||
predicateTypes?: string[];
|
predicateTypes?: string[];
|
||||||
}
|
scoreThreshold?: number;
|
||||||
|
minResults?: number;
|
||||||
/**
|
|
||||||
* Create a singleton instance of the search service
|
|
||||||
*/
|
|
||||||
let searchServiceInstance: SearchService | null = null;
|
|
||||||
|
|
||||||
export function getSearchService(
|
|
||||||
knowledgeGraphService?: KnowledgeGraphService,
|
|
||||||
): SearchService {
|
|
||||||
if (!searchServiceInstance) {
|
|
||||||
if (!knowledgeGraphService) {
|
|
||||||
throw new Error(
|
|
||||||
"KnowledgeGraphService must be provided when initializing SearchService",
|
|
||||||
);
|
|
||||||
}
|
|
||||||
searchServiceInstance = new SearchService(knowledgeGraphService);
|
|
||||||
}
|
|
||||||
return searchServiceInstance;
|
|
||||||
}
|
}
|
||||||
|
|||||||
118
apps/webapp/app/services/search/rerank.ts
Normal file
118
apps/webapp/app/services/search/rerank.ts
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
import { LLMModelEnum, type StatementNode } from "@recall/types";
|
||||||
|
import { combineAndDeduplicateStatements } from "./utils";
|
||||||
|
import { type CoreMessage } from "ai";
|
||||||
|
import { makeModelCall } from "~/lib/model.server";
|
||||||
|
import { logger } from "../logger.service";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Apply Weighted Reciprocal Rank Fusion to combine results
|
||||||
|
*/
|
||||||
|
export function applyWeightedRRF(results: {
|
||||||
|
bm25: StatementNode[];
|
||||||
|
vector: StatementNode[];
|
||||||
|
bfs: StatementNode[];
|
||||||
|
}): StatementNode[] {
|
||||||
|
// Determine weights based on query characteristics
|
||||||
|
const weights = {
|
||||||
|
bm25: 1.0,
|
||||||
|
vector: 0.8,
|
||||||
|
bfs: 0.5,
|
||||||
|
};
|
||||||
|
const k = 60; // RRF constant
|
||||||
|
|
||||||
|
// Map to store combined scores
|
||||||
|
const scores: Record<string, { score: number; statement: StatementNode }> =
|
||||||
|
{};
|
||||||
|
|
||||||
|
// Process BM25 results with their weight
|
||||||
|
results.bm25.forEach((statement, rank) => {
|
||||||
|
const uuid = statement.uuid;
|
||||||
|
scores[uuid] = scores[uuid] || { score: 0, statement };
|
||||||
|
scores[uuid].score += weights.bm25 * (1 / (rank + k));
|
||||||
|
});
|
||||||
|
|
||||||
|
// Process vector similarity results with their weight
|
||||||
|
results.vector.forEach((statement, rank) => {
|
||||||
|
const uuid = statement.uuid;
|
||||||
|
scores[uuid] = scores[uuid] || { score: 0, statement };
|
||||||
|
scores[uuid].score += weights.vector * (1 / (rank + k));
|
||||||
|
});
|
||||||
|
|
||||||
|
// Process BFS traversal results with their weight
|
||||||
|
results.bfs.forEach((statement, rank) => {
|
||||||
|
const uuid = statement.uuid;
|
||||||
|
scores[uuid] = scores[uuid] || { score: 0, statement };
|
||||||
|
scores[uuid].score += weights.bfs * (1 / (rank + k));
|
||||||
|
});
|
||||||
|
|
||||||
|
// Convert to array and sort by final score
|
||||||
|
const sortedResults = Object.values(scores)
|
||||||
|
.sort((a, b) => b.score - a.score)
|
||||||
|
.map((item) => {
|
||||||
|
// Add the RRF score to the statement for debugging
|
||||||
|
return {
|
||||||
|
...item.statement,
|
||||||
|
rrfScore: item.score,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
return sortedResults;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Apply Cross-Encoder reranking to results
|
||||||
|
* This is particularly useful when results come from a single source
|
||||||
|
*/
|
||||||
|
export async function applyCrossEncoderReranking(
|
||||||
|
query: string,
|
||||||
|
results: {
|
||||||
|
bm25: StatementNode[];
|
||||||
|
vector: StatementNode[];
|
||||||
|
bfs: StatementNode[];
|
||||||
|
},
|
||||||
|
): Promise<StatementNode[]> {
|
||||||
|
// Combine all results
|
||||||
|
const allResults = [...results.bm25, ...results.vector, ...results.bfs];
|
||||||
|
|
||||||
|
// Deduplicate by UUID
|
||||||
|
const uniqueResults = combineAndDeduplicateStatements(allResults);
|
||||||
|
|
||||||
|
if (uniqueResults.length === 0) return [];
|
||||||
|
|
||||||
|
logger.info(`Cross-encoder reranking ${uniqueResults.length} statements`);
|
||||||
|
|
||||||
|
const finalStatements: StatementNode[] = [];
|
||||||
|
|
||||||
|
await Promise.all(
|
||||||
|
uniqueResults.map(async (statement) => {
|
||||||
|
const messages: CoreMessage[] = [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: `You are an expert tasked with determining whether the statement is relevant to the query
|
||||||
|
Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: `<QUERY>${query}</QUERY>\n<STATEMENT>${statement.fact}</STATEMENT>`,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let responseText = "";
|
||||||
|
await makeModelCall(
|
||||||
|
false,
|
||||||
|
LLMModelEnum.GPT41NANO,
|
||||||
|
messages,
|
||||||
|
(text) => {
|
||||||
|
responseText = text;
|
||||||
|
},
|
||||||
|
{ temperature: 0, maxTokens: 1 },
|
||||||
|
);
|
||||||
|
|
||||||
|
if (responseText === "True") {
|
||||||
|
finalStatements.push(statement);
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
return finalStatements;
|
||||||
|
}
|
||||||
223
apps/webapp/app/services/search/utils.ts
Normal file
223
apps/webapp/app/services/search/utils.ts
Normal file
@ -0,0 +1,223 @@
|
|||||||
|
import type { EntityNode, StatementNode } from "@recall/types";
|
||||||
|
import type { SearchOptions } from "../search.server";
|
||||||
|
import type { Embedding } from "ai";
|
||||||
|
import { logger } from "../logger.service";
|
||||||
|
import { runQuery } from "~/lib/neo4j.server";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform BM25 keyword-based search on statements
|
||||||
|
*/
|
||||||
|
export async function performBM25Search(
|
||||||
|
query: string,
|
||||||
|
userId: string,
|
||||||
|
options: Required<SearchOptions>,
|
||||||
|
): Promise<StatementNode[]> {
|
||||||
|
try {
|
||||||
|
// Sanitize the query for Lucene syntax
|
||||||
|
const sanitizedQuery = sanitizeLuceneQuery(query);
|
||||||
|
|
||||||
|
// Use Neo4j's built-in fulltext search capabilities
|
||||||
|
const cypher = `
|
||||||
|
CALL db.index.fulltext.queryNodes("statement_fact_index", $query)
|
||||||
|
YIELD node AS s, score
|
||||||
|
WHERE
|
||||||
|
s.validAt <= $validAt
|
||||||
|
AND (s.invalidAt IS NULL OR s.invalidAt > $validAt)
|
||||||
|
AND (s.userId = $userId)
|
||||||
|
RETURN s, score
|
||||||
|
ORDER BY score DESC
|
||||||
|
`;
|
||||||
|
|
||||||
|
const params = {
|
||||||
|
query: sanitizedQuery,
|
||||||
|
userId,
|
||||||
|
validAt: options.validAt.toISOString(),
|
||||||
|
};
|
||||||
|
|
||||||
|
const records = await runQuery(cypher, params);
|
||||||
|
// return records.map((record) => record.get("s").properties as StatementNode);
|
||||||
|
return [];
|
||||||
|
} catch (error) {
|
||||||
|
logger.error("BM25 search error:", { error });
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sanitize a query string for Lucene syntax
|
||||||
|
*/
|
||||||
|
export function sanitizeLuceneQuery(query: string): string {
|
||||||
|
// Escape special characters: + - && || ! ( ) { } [ ] ^ " ~ * ? : \
|
||||||
|
let sanitized = query.replace(
|
||||||
|
/[+\-&|!(){}[\]^"~*?:\\]/g,
|
||||||
|
(match) => "\\" + match,
|
||||||
|
);
|
||||||
|
|
||||||
|
// If query is too long, truncate it
|
||||||
|
const MAX_QUERY_LENGTH = 32;
|
||||||
|
const words = sanitized.split(" ");
|
||||||
|
if (words.length > MAX_QUERY_LENGTH) {
|
||||||
|
sanitized = words.slice(0, MAX_QUERY_LENGTH).join(" ");
|
||||||
|
}
|
||||||
|
|
||||||
|
return sanitized;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform vector similarity search on statement embeddings
|
||||||
|
*/
|
||||||
|
export async function performVectorSearch(
|
||||||
|
query: Embedding,
|
||||||
|
userId: string,
|
||||||
|
options: Required<SearchOptions>,
|
||||||
|
): Promise<StatementNode[]> {
|
||||||
|
try {
|
||||||
|
// 1. Generate embedding for the query
|
||||||
|
// const embedding = await this.getEmbedding(query);
|
||||||
|
|
||||||
|
// 2. Search for similar statements using Neo4j vector search
|
||||||
|
const cypher = `
|
||||||
|
MATCH (s:Statement)
|
||||||
|
WHERE
|
||||||
|
s.validAt <= $validAt
|
||||||
|
AND (s.invalidAt IS NULL OR s.invalidAt > $validAt)
|
||||||
|
AND (s.userId = $userId OR s.isPublic = true)
|
||||||
|
WITH s, vector.similarity.cosine(s.factEmbedding, $embedding) AS score
|
||||||
|
WHERE score > 0.7
|
||||||
|
RETURN s, score
|
||||||
|
ORDER BY score DESC
|
||||||
|
`;
|
||||||
|
|
||||||
|
const params = {
|
||||||
|
embedding: query,
|
||||||
|
userId,
|
||||||
|
validAt: options.validAt.toISOString(),
|
||||||
|
};
|
||||||
|
|
||||||
|
const records = await runQuery(cypher, params);
|
||||||
|
// return records.map((record) => record.get("s").properties as StatementNode);
|
||||||
|
return [];
|
||||||
|
} catch (error) {
|
||||||
|
logger.error("Vector search error:", { error });
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform BFS traversal starting from entities mentioned in the query
|
||||||
|
*/
|
||||||
|
export async function performBfsSearch(
|
||||||
|
embedding: Embedding,
|
||||||
|
userId: string,
|
||||||
|
options: Required<SearchOptions>,
|
||||||
|
): Promise<StatementNode[]> {
|
||||||
|
try {
|
||||||
|
// 1. Extract potential entities from query
|
||||||
|
const entities = await extractEntitiesFromQuery(embedding);
|
||||||
|
|
||||||
|
// 2. For each entity, perform BFS traversal
|
||||||
|
const allStatements: StatementNode[] = [];
|
||||||
|
|
||||||
|
for (const entity of entities) {
|
||||||
|
const statements = await bfsTraversal(
|
||||||
|
entity.uuid,
|
||||||
|
options.maxBfsDepth,
|
||||||
|
options.validAt,
|
||||||
|
userId,
|
||||||
|
options.includeInvalidated,
|
||||||
|
);
|
||||||
|
allStatements.push(...statements);
|
||||||
|
}
|
||||||
|
|
||||||
|
return allStatements;
|
||||||
|
} catch (error) {
|
||||||
|
logger.error("BFS search error:", { error });
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform BFS traversal starting from an entity
|
||||||
|
*/
|
||||||
|
export async function bfsTraversal(
|
||||||
|
startEntityId: string,
|
||||||
|
maxDepth: number,
|
||||||
|
validAt: Date,
|
||||||
|
userId: string,
|
||||||
|
includeInvalidated: boolean,
|
||||||
|
): Promise<StatementNode[]> {
|
||||||
|
try {
|
||||||
|
// Use Neo4j's built-in path finding capabilities for efficient BFS
|
||||||
|
// This query implements BFS up to maxDepth and collects all statements along the way
|
||||||
|
const cypher = `
|
||||||
|
MATCH (e:Entity {uuid: $startEntityId})<-[:HAS_SUBJECT|HAS_OBJECT|HAS_PREDICATE]-(s:Statement)
|
||||||
|
WHERE
|
||||||
|
s.validAt <= $validAt
|
||||||
|
AND (s.invalidAt IS NULL OR s.invalidAt > $validAt)
|
||||||
|
AND (s.userId = $userId)
|
||||||
|
AND ($includeInvalidated OR s.invalidAt IS NULL)
|
||||||
|
RETURN s as statement
|
||||||
|
`;
|
||||||
|
|
||||||
|
const params = {
|
||||||
|
startEntityId,
|
||||||
|
maxDepth,
|
||||||
|
validAt: validAt.toISOString(),
|
||||||
|
userId,
|
||||||
|
includeInvalidated,
|
||||||
|
};
|
||||||
|
|
||||||
|
const records = await runQuery(cypher, params);
|
||||||
|
return records.map(
|
||||||
|
(record) => record.get("statement").properties as StatementNode,
|
||||||
|
);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error("BFS traversal error:", { error });
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extract potential entities from a query using embeddings or LLM
|
||||||
|
*/
|
||||||
|
export async function extractEntitiesFromQuery(
|
||||||
|
embedding: Embedding,
|
||||||
|
): Promise<EntityNode[]> {
|
||||||
|
try {
|
||||||
|
// Use vector similarity to find relevant entities
|
||||||
|
const cypher = `
|
||||||
|
// Match entities using vector similarity on name embeddings
|
||||||
|
MATCH (e:Entity)
|
||||||
|
WHERE e.nameEmbedding IS NOT NULL
|
||||||
|
WITH e, vector.similarity.cosine(e.nameEmbedding, $embedding) AS score
|
||||||
|
WHERE score > 0.7
|
||||||
|
RETURN e
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT 3
|
||||||
|
`;
|
||||||
|
|
||||||
|
const params = {
|
||||||
|
embedding,
|
||||||
|
};
|
||||||
|
|
||||||
|
const records = await runQuery(cypher, params);
|
||||||
|
|
||||||
|
return records.map((record) => record.get("e").properties as EntityNode);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error("Entity extraction error:", { error });
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Combine and deduplicate statements from different search methods
|
||||||
|
*/
|
||||||
|
export function combineAndDeduplicateStatements(
|
||||||
|
statements: StatementNode[],
|
||||||
|
): StatementNode[] {
|
||||||
|
return Array.from(
|
||||||
|
new Map(
|
||||||
|
statements.map((statement) => [statement.uuid, statement]),
|
||||||
|
).values(),
|
||||||
|
);
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user