import { openai } from "@ai-sdk/openai"; import type { StatementNode } from "@core/types"; import { embed } from "ai"; import { logger } from "./logger.service"; import { applyCrossEncoderReranking, applyWeightedRRF } from "./search/rerank"; import { performBfsSearch, performBM25Search, performVectorSearch, } from "./search/utils"; /** * SearchService provides methods to search the reified + temporal knowledge graph * using a hybrid approach combining BM25, vector similarity, and BFS traversal. */ export class SearchService { async getEmbedding(text: string) { const { embedding } = await embed({ model: openai.embedding("text-embedding-3-small"), value: text, }); return embedding; } /** * Search the knowledge graph using a hybrid approach * @param query The search query * @param userId The user ID for personalization * @param options Search options * @returns Array of relevant statements */ public async search( query: string, userId: string, options: SearchOptions = {}, ): Promise { // Default options const opts: Required = { limit: options.limit || 10, maxBfsDepth: options.maxBfsDepth || 4, validAt: options.validAt || new Date(), includeInvalidated: options.includeInvalidated || false, entityTypes: options.entityTypes || [], predicateTypes: options.predicateTypes || [], scoreThreshold: options.scoreThreshold || 0.7, minResults: options.minResults || 10, }; const queryVector = await this.getEmbedding(query); // 1. Run parallel search methods const [bm25Results, vectorResults, bfsResults] = await Promise.all([ performBM25Search(query, userId, opts), performVectorSearch(queryVector, userId, opts), performBfsSearch(queryVector, userId, opts), ]); logger.info( `Search results - BM25: ${bm25Results.length}, Vector: ${vectorResults.length}, BFS: ${bfsResults.length}`, ); // 2. Apply reranking strategy const rankedStatements = await this.rerankResults( query, { bm25: bm25Results, vector: vectorResults, bfs: bfsResults }, opts, ); // 3. Apply adaptive filtering based on score threshold and minimum count const filteredResults = this.applyAdaptiveFiltering(rankedStatements, opts); // 3. Return top results return filteredResults.map((statement) => statement.fact); } /** * Apply adaptive filtering to ranked results * Uses a minimum quality threshold to filter out low-quality results */ private applyAdaptiveFiltering( results: StatementNode[], options: Required, ): StatementNode[] { if (results.length === 0) return []; // Extract scores from results const scoredResults = results.map((result) => { // Find the score based on reranking strategy used let score = 0; if ((result as any).rrfScore !== undefined) { score = (result as any).rrfScore; } else if ((result as any).mmrScore !== undefined) { score = (result as any).mmrScore; } else if ((result as any).crossEncoderScore !== undefined) { score = (result as any).crossEncoderScore; } else if ((result as any).finalScore !== undefined) { score = (result as any).finalScore; } return { result, score }; }); const hasScores = scoredResults.some((item) => item.score > 0); // If no scores are available, return the original results if (!hasScores) { logger.info("No scores found in results, skipping adaptive filtering"); return options.limit > 0 ? results.slice(0, options.limit) : results; } // Sort by score (descending) scoredResults.sort((a, b) => b.score - a.score); // Calculate statistics to identify low-quality results const scores = scoredResults.map((item) => item.score); const maxScore = Math.max(...scores); const minScore = Math.min(...scores); const scoreRange = maxScore - minScore; // Define a minimum quality threshold as a fraction of the best score // This is relative to the query's score distribution rather than an absolute value const relativeThreshold = options.scoreThreshold || 0.3; // 30% of the best score by default const absoluteMinimum = 0.1; // Absolute minimum threshold to prevent keeping very poor matches // Calculate the actual threshold as a percentage of the distance from min to max score const threshold = Math.max( absoluteMinimum, minScore + scoreRange * relativeThreshold, ); // Filter out low-quality results const filteredResults = scoredResults .filter((item) => item.score >= threshold) .map((item) => item.result); // Apply limit if specified const limitedResults = options.limit > 0 ? filteredResults.slice( 0, Math.min(filteredResults.length, options.limit), ) : filteredResults; logger.info( `Quality filtering: ${limitedResults.length}/${results.length} results kept (threshold: ${threshold.toFixed(3)})`, ); logger.info( `Score range: min=${minScore.toFixed(3)}, max=${maxScore.toFixed(3)}, threshold=${threshold.toFixed(3)}`, ); return limitedResults; } /** * Apply the selected reranking strategy to search results */ private async rerankResults( query: string, results: { bm25: StatementNode[]; vector: StatementNode[]; bfs: StatementNode[]; }, options: Required, ): Promise { // Count non-empty result sources const nonEmptySources = [ results.bm25.length > 0, results.vector.length > 0, results.bfs.length > 0, ].filter(Boolean).length; // If results are coming from only one source, use cross-encoder reranking if (nonEmptySources <= 1) { logger.info( "Only one source has results, falling back to cross-encoder reranking", ); return applyCrossEncoderReranking(query, results); } // Otherwise use weighted RRF for multiple sources return applyWeightedRRF(results); } } /** * Search options interface */ export interface SearchOptions { limit?: number; maxBfsDepth?: number; validAt?: Date; includeInvalidated?: boolean; entityTypes?: string[]; predicateTypes?: string[]; scoreThreshold?: number; minResults?: number; }