Manoj 27f8740691
Fix: Semantic Search issue (#89)
* Fix: normalization prompt

* Fix: improve knowledge graph and better recall

* fix: add user context to search reranking

* fix: in search log the source

* fix: remove harcoded limit

---------

Co-authored-by: Harshith Mullapudi <harshithmullapudi@gmail.com>
2025-10-06 14:06:52 +05:30

428 lines
13 KiB
TypeScript

import type { EntityNode, StatementNode, EpisodicNode } from "@core/types";
import type { SearchOptions } from "../search.server";
import type { Embedding } from "ai";
import { logger } from "../logger.service";
import { runQuery } from "~/lib/neo4j.server";
import { getEmbedding } from "~/lib/model.server";
import { findSimilarEntities } from "../graphModels/entity";
/**
* Perform BM25 keyword-based search on statements
*/
export async function performBM25Search(
query: string,
userId: string,
options: Required<SearchOptions>,
): Promise<StatementNode[]> {
try {
// Sanitize the query for Lucene syntax
const sanitizedQuery = sanitizeLuceneQuery(query);
// Build the WHERE clause based on timeframe options
let timeframeCondition = `
AND s.validAt <= $validAt
${options.includeInvalidated ? '' : 'AND (s.invalidAt IS NULL OR s.invalidAt > $validAt)'}
`;
// If startTime is provided, add condition to filter by validAt >= startTime
if (options.startTime) {
timeframeCondition = `
AND s.validAt <= $validAt
${options.includeInvalidated ? '' : 'AND (s.invalidAt IS NULL OR s.invalidAt > $validAt)'}
AND s.validAt >= $startTime
`;
}
// Add space filtering if spaceIds are provided
let spaceCondition = "";
if (options.spaceIds.length > 0) {
spaceCondition = `
AND s.spaceIds IS NOT NULL AND ANY(spaceId IN $spaceIds WHERE spaceId IN s.spaceIds)
`;
}
// Use Neo4j's built-in fulltext search capabilities with provenance count
const cypher = `
CALL db.index.fulltext.queryNodes("statement_fact_index", $query)
YIELD node AS s, score
WHERE
(s.userId = $userId)
${timeframeCondition}
${spaceCondition}
OPTIONAL MATCH (episode:Episode)-[:HAS_PROVENANCE]->(s)
WITH s, score, count(episode) as provenanceCount
RETURN s, score, provenanceCount
ORDER BY score DESC
`;
const params = {
query: sanitizedQuery,
userId,
validAt: options.endTime.toISOString(),
...(options.startTime && { startTime: options.startTime.toISOString() }),
...(options.spaceIds.length > 0 && { spaceIds: options.spaceIds }),
};
const records = await runQuery(cypher, params);
return records.map((record) => {
const statement = record.get("s").properties as StatementNode;
const provenanceCountValue = record.get("provenanceCount");
statement.provenanceCount =
typeof provenanceCountValue === "bigint"
? Number(provenanceCountValue)
: (provenanceCountValue?.toNumber?.() ?? provenanceCountValue ?? 0);
return statement;
});
} catch (error) {
logger.error("BM25 search error:", { error });
return [];
}
}
/**
* Sanitize a query string for Lucene syntax
*/
export function sanitizeLuceneQuery(query: string): string {
// Escape special characters: + - && || ! ( ) { } [ ] ^ " ~ * ? : \ /
let sanitized = query.replace(
/[+\-&|!(){}[\]^"~*?:\\\/]/g,
(match) => "\\" + match,
);
// If query is too long, truncate it
const MAX_QUERY_LENGTH = 32;
const words = sanitized.split(" ");
if (words.length > MAX_QUERY_LENGTH) {
sanitized = words.slice(0, MAX_QUERY_LENGTH).join(" ");
}
return sanitized;
}
/**
* Perform vector similarity search on statement embeddings
*/
export async function performVectorSearch(
query: Embedding,
userId: string,
options: Required<SearchOptions>,
): Promise<StatementNode[]> {
try {
// Build the WHERE clause based on timeframe options
let timeframeCondition = `
AND s.validAt <= $validAt
${options.includeInvalidated ? '' : 'AND (s.invalidAt IS NULL OR s.invalidAt > $validAt)'}
`;
// If startTime is provided, add condition to filter by validAt >= startTime
if (options.startTime) {
timeframeCondition = `
AND s.validAt <= $validAt
${options.includeInvalidated ? '' : 'AND (s.invalidAt IS NULL OR s.invalidAt > $validAt)'}
AND s.validAt >= $startTime
`;
}
// Add space filtering if spaceIds are provided
let spaceCondition = "";
if (options.spaceIds.length > 0) {
spaceCondition = `
AND s.spaceIds IS NOT NULL AND ANY(spaceId IN $spaceIds WHERE spaceId IN s.spaceIds)
`;
}
const limit = options.limit || 100;
// 1. Search for similar statements using GDS cosine similarity with provenance count
const cypher = `
MATCH (s:Statement)
WHERE s.userId = $userId
${timeframeCondition}
${spaceCondition}
WITH s, gds.similarity.cosine(s.factEmbedding, $embedding) AS score
WHERE score >= 0.5
OPTIONAL MATCH (episode:Episode)-[:HAS_PROVENANCE]->(s)
WITH s, score, count(episode) as provenanceCount
RETURN s, score, provenanceCount
ORDER BY score DESC
LIMIT ${limit}
`;
const params = {
embedding: query,
userId,
validAt: options.endTime.toISOString(),
...(options.startTime && { startTime: options.startTime.toISOString() }),
...(options.spaceIds.length > 0 && { spaceIds: options.spaceIds }),
};
const records = await runQuery(cypher, params);
return records.map((record) => {
const statement = record.get("s").properties as StatementNode;
const provenanceCountValue = record.get("provenanceCount");
statement.provenanceCount =
typeof provenanceCountValue === "bigint"
? Number(provenanceCountValue)
: (provenanceCountValue?.toNumber?.() ?? provenanceCountValue ?? 0);
return statement;
});
} catch (error) {
logger.error("Vector search error:", { error });
return [];
}
}
/**
* Perform BFS traversal starting from entities mentioned in the query
* Uses guided search with semantic filtering to reduce noise
*/
export async function performBfsSearch(
query: string,
embedding: Embedding,
userId: string,
options: Required<SearchOptions>,
): Promise<StatementNode[]> {
try {
// 1. Extract potential entities from query using chunked embeddings
const entities = await extractEntitiesFromQuery(query, userId);
if (entities.length === 0) {
return [];
}
// 2. Perform guided BFS with semantic filtering
const statements = await bfsTraversal(
entities,
embedding,
options.maxBfsDepth || 3,
options.endTime,
userId,
options.includeInvalidated,
options.startTime,
);
// Return individual statements
return statements;
} catch (error) {
logger.error("BFS search error:", { error });
return [];
}
}
/**
* Iterative BFS traversal - explores up to 3 hops level-by-level using Neo4j cosine similarity
*/
async function bfsTraversal(
startEntities: EntityNode[],
queryEmbedding: Embedding,
maxDepth: number,
validAt: Date,
userId: string,
includeInvalidated: boolean,
startTime: Date | null,
): Promise<StatementNode[]> {
const RELEVANCE_THRESHOLD = 0.5;
const EXPLORATION_THRESHOLD = 0.3;
const allStatements = new Map<string, number>(); // uuid -> relevance
const visitedEntities = new Set<string>();
// Track entities per level for iterative BFS
let currentLevelEntities = startEntities.map(e => e.uuid);
// Timeframe condition for temporal filtering
let timeframeCondition = `
AND s.validAt <= $validAt
${includeInvalidated ? '' : 'AND (s.invalidAt IS NULL OR s.invalidAt > $validAt)'}
`;
if (startTime) {
timeframeCondition += ` AND s.validAt >= $startTime`;
}
// Process each depth level
for (let depth = 0; depth < maxDepth; depth++) {
if (currentLevelEntities.length === 0) break;
// Mark entities as visited at this depth
currentLevelEntities.forEach(id => visitedEntities.add(`${id}`));
// Get statements for current level entities with cosine similarity calculated in Neo4j
const cypher = `
MATCH (e:Entity{userId: $userId})-[:HAS_SUBJECT|HAS_OBJECT|HAS_PREDICATE]-(s:Statement{userId: $userId})
WHERE e.uuid IN $entityIds
${timeframeCondition}
WITH DISTINCT s // Deduplicate first
WITH s, gds.similarity.cosine(s.factEmbedding, $queryEmbedding) AS relevance
WHERE relevance >= $explorationThreshold
RETURN s.uuid AS uuid, relevance
ORDER BY relevance DESC
LIMIT 200 // Cap per BFS level to avoid explosion
`;
const records = await runQuery(cypher, {
entityIds: currentLevelEntities,
userId,
queryEmbedding,
explorationThreshold: EXPLORATION_THRESHOLD,
validAt: validAt.toISOString(),
...(startTime && { startTime: startTime.toISOString() }),
});
// Store statement relevance scores
const currentLevelStatementUuids: string[] = [];
for (const record of records) {
const uuid = record.get("uuid");
const relevance = record.get("relevance");
if (!allStatements.has(uuid)) {
allStatements.set(uuid, relevance);
currentLevelStatementUuids.push(uuid);
}
}
// Get connected entities for next level
if (depth < maxDepth - 1 && currentLevelStatementUuids.length > 0) {
const nextCypher = `
MATCH (s:Statement{userId: $userId})-[:HAS_SUBJECT|HAS_OBJECT|HAS_PREDICATE]->(e:Entity{userId: $userId})
WHERE s.uuid IN $statementUuids
RETURN DISTINCT e.uuid AS entityId
`;
const nextRecords = await runQuery(nextCypher, {
statementUuids: currentLevelStatementUuids,
userId
});
// Filter out already visited entities
currentLevelEntities = nextRecords
.map(r => r.get("entityId"))
.filter(id => !visitedEntities.has(`${id}`));
} else {
currentLevelEntities = [];
}
}
// Filter by relevance threshold and fetch full statements
const relevantUuids = Array.from(allStatements.entries())
.filter(([_, relevance]) => relevance >= RELEVANCE_THRESHOLD)
.sort((a, b) => b[1] - a[1])
.map(([uuid]) => uuid);
if (relevantUuids.length === 0) {
return [];
}
const fetchCypher = `
MATCH (s:Statement{userId: $userId})
WHERE s.uuid IN $uuids
RETURN s
`;
const fetchRecords = await runQuery(fetchCypher, { uuids: relevantUuids, userId });
const statements = fetchRecords.map(r => r.get("s").properties as StatementNode);
logger.info(
`BFS: explored ${allStatements.size} statements across ${maxDepth} hops, returning ${statements.length} (≥${RELEVANCE_THRESHOLD})`
);
return statements;
}
/**
* Generate query chunks (individual words and bigrams) for entity extraction
*/
function generateQueryChunks(query: string): string[] {
const words = query.toLowerCase()
.trim()
.split(/\s+/)
.filter(word => word.length > 0);
const chunks: string[] = [];
// Add individual words (for entities like "user")
chunks.push(...words);
// Add bigrams (for multi-word entities like "home address")
for (let i = 0; i < words.length - 1; i++) {
chunks.push(`${words[i]} ${words[i + 1]}`);
}
// Add full query as final chunk
chunks.push(query.toLowerCase().trim());
return chunks;
}
/**
* Extract potential entities from a query using chunked embeddings
* Chunks query into words/bigrams, embeds each chunk, finds entities for each
*/
export async function extractEntitiesFromQuery(
query: string,
userId: string,
): Promise<EntityNode[]> {
try {
// Generate chunks from query
const chunks = generateQueryChunks(query);
// Get embeddings for each chunk
const chunkEmbeddings = await Promise.all(
chunks.map(chunk => getEmbedding(chunk))
);
// Search for entities matching each chunk embedding
const allEntitySets = await Promise.all(
chunkEmbeddings.map(async (embedding) => {
return await findSimilarEntities({
queryEmbedding: embedding,
limit: 3,
threshold: 0.7,
userId,
});
})
);
// Flatten and deduplicate entities by ID
const allEntities = allEntitySets.flat();
const uniqueEntities = Array.from(
new Map(allEntities.map(e => [e.uuid, e])).values()
);
return uniqueEntities;
} catch (error) {
logger.error("Entity extraction error:", { error });
return [];
}
}
/**
* Combine and deduplicate statements from different search methods
*/
export function combineAndDeduplicateStatements(
statements: StatementNode[],
): StatementNode[] {
return Array.from(
new Map(
statements.map((statement) => [statement.uuid, statement]),
).values(),
);
}
export async function getEpisodesByStatements(
statements: StatementNode[],
): Promise<EpisodicNode[]> {
const cypher = `
MATCH (s:Statement)<-[:HAS_PROVENANCE]-(e:Episode)
WHERE s.uuid IN $statementUuids
RETURN distinct e
`;
const params = {
statementUuids: statements.map((s) => s.uuid),
};
const records = await runQuery(cypher, params);
return records.map((record) => record.get("e").properties as EpisodicNode);
}