core/apps/webapp/app/services/search.server.ts
2025-06-05 17:31:53 +05:30

379 lines
11 KiB
TypeScript

import {
type EntityNode,
type KnowledgeGraphService,
type StatementNode,
} from "./knowledgeGraph.server";
import { openai } from "@ai-sdk/openai";
import { embed } from "ai";
import HelixDB from "helix-ts";
// Initialize OpenAI for embeddings
const openaiClient = openai("gpt-4.1-2025-04-14");
// Initialize Helix client
const helixClient = new HelixDB();
/**
* SearchService provides methods to search the reified + temporal knowledge graph
* using a hybrid approach combining BM25, vector similarity, and BFS traversal.
*/
export class SearchService {
private knowledgeGraphService: KnowledgeGraphService;
constructor(knowledgeGraphService: KnowledgeGraphService) {
this.knowledgeGraphService = knowledgeGraphService;
}
async getEmbedding(text: string) {
const { embedding } = await embed({
model: openai.embedding("text-embedding-3-small"),
value: text,
});
return embedding;
}
/**
* Search the knowledge graph using a hybrid approach
* @param query The search query
* @param userId The user ID for personalization
* @param options Search options
* @returns Array of relevant statements
*/
public async search(
query: string,
userId: string,
options: SearchOptions = {},
): Promise<StatementNode[]> {
// Default options
const opts: Required<SearchOptions> = {
limit: options.limit || 10,
maxBfsDepth: options.maxBfsDepth || 4,
validAt: options.validAt || new Date(),
includeInvalidated: options.includeInvalidated || false,
entityTypes: options.entityTypes || [],
predicateTypes: options.predicateTypes || [],
};
// 1. Run parallel search methods
const [bm25Results, vectorResults, bfsResults] = await Promise.all([
this.performBM25Search(query, userId, opts),
this.performVectorSearch(query, userId, opts),
this.performBfsSearch(query, userId, opts),
]);
// 2. Combine and deduplicate results
const combinedStatements = this.combineAndDeduplicate([
...bm25Results,
...vectorResults,
...bfsResults,
]);
// 3. Rerank the combined results
const rerankedStatements = await this.rerankStatements(
query,
combinedStatements,
opts,
);
// 4. Return top results
return rerankedStatements.slice(0, opts.limit);
}
/**
* Perform BM25 keyword-based search on statements
*/
private async performBM25Search(
query: string,
userId: string,
options: Required<SearchOptions>,
): Promise<StatementNode[]> {
// TODO: Implement BM25 search using HelixDB or external search index
// This is a placeholder implementation
try {
const results = await helixClient.query("searchStatementsByKeywords", {
query,
userId,
validAt: options.validAt.toISOString(),
includeInvalidated: options.includeInvalidated,
limit: options.limit * 2, // Fetch more for reranking
});
return results.statements || [];
} catch (error) {
console.error("BM25 search error:", error);
return [];
}
}
/**
* Perform vector similarity search on statement embeddings
*/
private async performVectorSearch(
query: string,
userId: string,
options: Required<SearchOptions>,
): Promise<StatementNode[]> {
try {
// 1. Generate embedding for the query
const embedding = await this.generateEmbedding(query);
// 2. Search for similar statements
const results = await helixClient.query("searchStatementsByVector", {
embedding,
userId,
validAt: options.validAt.toISOString(),
includeInvalidated: options.includeInvalidated,
limit: options.limit * 2, // Fetch more for reranking
});
return results.statements || [];
} catch (error) {
console.error("Vector search error:", error);
return [];
}
}
/**
* Perform BFS traversal starting from entities mentioned in the query
*/
private async performBfsSearch(
query: string,
userId: string,
options: Required<SearchOptions>,
): Promise<StatementNode[]> {
try {
// 1. Extract potential entities from query
const entities = await this.extractEntitiesFromQuery(query);
// 2. For each entity, perform BFS traversal
const allStatements: StatementNode[] = [];
for (const entity of entities) {
const statements = await this.bfsTraversal(
entity.uuid,
options.maxBfsDepth,
options.validAt,
userId,
options.includeInvalidated,
);
allStatements.push(...statements);
}
return allStatements;
} catch (error) {
console.error("BFS search error:", error);
return [];
}
}
/**
* Perform BFS traversal starting from an entity
*/
private async bfsTraversal(
startEntityId: string,
maxDepth: number,
validAt: Date,
userId: string,
includeInvalidated: boolean,
): Promise<StatementNode[]> {
// Track visited nodes to avoid cycles
const visited = new Set<string>();
// Track statements found during traversal
const statements: StatementNode[] = [];
// Queue for BFS traversal [nodeId, depth]
const queue: [string, number][] = [[startEntityId, 0]];
while (queue.length > 0) {
const [nodeId, depth] = queue.shift()!;
// Skip if already visited or max depth reached
if (visited.has(nodeId) || depth > maxDepth) continue;
visited.add(nodeId);
// Get statements where this entity is subject or object
const connectedStatements = await helixClient.query(
"getConnectedStatements",
{
entityId: nodeId,
userId,
validAt: validAt.toISOString(),
includeInvalidated,
},
);
// Add statements to results
if (connectedStatements.statements) {
statements.push(...connectedStatements.statements);
// Add connected entities to queue
for (const statement of connectedStatements.statements) {
// Get subject and object entities
if (statement.subjectId && !visited.has(statement.subjectId)) {
queue.push([statement.subjectId, depth + 1]);
}
if (statement.objectId && !visited.has(statement.objectId)) {
queue.push([statement.objectId, depth + 1]);
}
}
}
}
return statements;
}
/**
* Extract potential entities from a query using embeddings or LLM
*/
private async extractEntitiesFromQuery(query: string): Promise<EntityNode[]> {
// TODO: Implement more sophisticated entity extraction
// This is a placeholder implementation that uses simple vector search
try {
const embedding = await this.getEmbedding(query);
const results = await helixClient.query("searchEntitiesByVector", {
embedding,
limit: 3, // Start with top 3 entities
});
return results.entities || [];
} catch (error) {
console.error("Entity extraction error:", error);
return [];
}
}
/**
* Combine and deduplicate statements from multiple sources
*/
private combineAndDeduplicate(statements: StatementNode[]): StatementNode[] {
const uniqueStatements = new Map<string, StatementNode>();
for (const statement of statements) {
if (!uniqueStatements.has(statement.uuid)) {
uniqueStatements.set(statement.uuid, statement);
}
}
return Array.from(uniqueStatements.values());
}
/**
* Rerank statements based on relevance to the query
*/
private async rerankStatements(
query: string,
statements: StatementNode[],
options: Required<SearchOptions>,
): Promise<StatementNode[]> {
// TODO: Implement more sophisticated reranking
// This is a placeholder implementation using cosine similarity
try {
// 1. Generate embedding for the query
const queryEmbedding = await this.getEmbedding(query);
// 2. Generate or retrieve embeddings for statements
const statementEmbeddings = await Promise.all(
statements.map(async (statement) => {
// If statement has embedding, use it; otherwise generate
if (statement.factEmbedding && statement.factEmbedding.length > 0) {
return { statement, embedding: statement.factEmbedding };
}
// Generate text representation of statement
const statementText = this.statementToText(statement);
const embedding = await this.getEmbedding(statementText);
return { statement, embedding };
}),
);
// 3. Calculate cosine similarity
const scoredStatements = statementEmbeddings.map(
({ statement, embedding }) => {
const similarity = this.cosineSimilarity(queryEmbedding, embedding);
return { statement, score: similarity };
},
);
// 4. Sort by score (descending)
scoredStatements.sort((a, b) => b.score - a.score);
// 5. Return statements in order of relevance
return scoredStatements.map(({ statement }) => statement);
} catch (error) {
console.error("Reranking error:", error);
// Fallback: return original order
return statements;
}
}
/**
* Convert a statement to a text representation
*/
private statementToText(statement: StatementNode): string {
// TODO: Implement more sophisticated text representation
// This is a placeholder implementation
return `${statement.subjectName || "Unknown"} ${statement.predicateName || "has relation with"} ${statement.objectName || "Unknown"}`;
}
/**
* Calculate cosine similarity between two embeddings
*/
private cosineSimilarity(a: number[], b: number[]): number {
if (a.length !== b.length) {
throw new Error("Embeddings must have the same length");
}
let dotProduct = 0;
let normA = 0;
let normB = 0;
for (let i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
normA = Math.sqrt(normA);
normB = Math.sqrt(normB);
if (normA === 0 || normB === 0) {
return 0;
}
return dotProduct / (normA * normB);
}
}
/**
* Search options interface
*/
export interface SearchOptions {
limit?: number;
maxBfsDepth?: number;
validAt?: Date;
includeInvalidated?: boolean;
entityTypes?: string[];
predicateTypes?: string[];
}
/**
* Create a singleton instance of the search service
*/
let searchServiceInstance: SearchService | null = null;
export function getSearchService(
knowledgeGraphService?: KnowledgeGraphService,
): SearchService {
if (!searchServiceInstance) {
if (!knowledgeGraphService) {
throw new Error(
"KnowledgeGraphService must be provided when initializing SearchService",
);
}
searchServiceInstance = new SearchService(knowledgeGraphService);
}
return searchServiceInstance;
}