import { type StatementNode } from "@core/types"; import { combineAndDeduplicateStatements } from "./utils"; import { type CoreMessage } from "ai"; import { makeModelCall } from "~/lib/model.server"; import { logger } from "../logger.service"; import { CohereClientV2 } from "cohere-ai"; import { env } from "~/env.server"; // Utility function to safely convert BigInt values to Number function safeNumber(value: any): number { if (typeof value === "bigint") { return Number(value); } if (typeof value === "number") { return value; } if (typeof value === "string") { const parsed = parseFloat(value); return isNaN(parsed) ? 0 : parsed; } return 0; } /** * Apply Weighted Reciprocal Rank Fusion to combine results */ export function applyWeightedRRF(results: { bm25: StatementNode[]; vector: StatementNode[]; bfs: StatementNode[]; }): StatementNode[] { // Determine weights based on query characteristics const weights = { bm25: 1.0, vector: 0.8, bfs: 0.5, }; const k = 60; // RRF constant // Map to store combined scores const scores: Record = {}; // Process BM25 results with their weight results.bm25.forEach((statement, rank) => { const uuid = statement.uuid; scores[uuid] = scores[uuid] || { score: 0, statement }; scores[uuid].score += weights.bm25 * (1 / (rank + k)); }); // Process vector similarity results with their weight results.vector.forEach((statement, rank) => { const uuid = statement.uuid; scores[uuid] = scores[uuid] || { score: 0, statement }; scores[uuid].score += weights.vector * (1 / (rank + k)); }); // Process BFS traversal results with their weight results.bfs.forEach((statement, rank) => { const uuid = statement.uuid; scores[uuid] = scores[uuid] || { score: 0, statement }; scores[uuid].score += weights.bfs * (1 / (rank + k)); }); // Convert to array and sort by final score const sortedResults = Object.values(scores) .sort((a, b) => b.score - a.score) .map((item) => { // Add the RRF score to the statement for debugging return { ...item.statement, rrfScore: item.score, }; }); return sortedResults; } /** * Apply MMR (Maximal Marginal Relevance) reranking to reduce redundancy while maintaining relevance * MMR balances relevance and diversity to prevent redundant fact statements in results */ export function applyMMRReranking( statements: StatementNode[], lambda: number = 0.7, // Balance between relevance (1.0) and diversity (0.0) maxResults: number = 50, ): StatementNode[] { if (statements.length === 0) return []; // Extract relevance scores and embeddings const candidates = statements.map((statement) => { let relevanceScore = 0; // Use existing scores from MultiFactorReranking or other sources if ((statement as any).multifactorScore !== undefined) { relevanceScore = safeNumber((statement as any).multifactorScore); } else if ((statement as any).rrfScore !== undefined) { relevanceScore = safeNumber((statement as any).rrfScore); } else if ((statement as any).crossEncoderScore !== undefined) { relevanceScore = safeNumber((statement as any).crossEncoderScore); } else if ((statement as any).finalScore !== undefined) { relevanceScore = safeNumber((statement as any).finalScore); } return { statement, relevanceScore, embedding: statement.factEmbedding || [], selected: false, }; }); // Sort by relevance score (descending) candidates.sort((a, b) => b.relevanceScore - a.relevanceScore); const selectedCandidates: typeof candidates = []; const remainingCandidates = [...candidates]; // Pre-filter candidates with no embeddings for faster processing const candidatesWithEmbeddings = remainingCandidates.filter( (c) => c.embedding.length > 0, ); const candidatesWithoutEmbeddings = remainingCandidates.filter( (c) => c.embedding.length === 0, ); // MMR Selection Algorithm with optimizations while ( selectedCandidates.length < maxResults && remainingCandidates.length > 0 ) { let bestCandidate = null; let bestScore = -Infinity; let bestIndex = -1; // Early termination: if we have enough high-relevance items, stop diversity checking const relevanceThreshold = selectedCandidates.length > 0 ? selectedCandidates[selectedCandidates.length - 1].relevanceScore * 0.5 : 0; for (let i = 0; i < remainingCandidates.length; i++) { const candidate = remainingCandidates[i]; // Skip similarity calculation for very low relevance items if ( candidate.relevanceScore < relevanceThreshold && selectedCandidates.length > 3 ) { continue; } let maxSimilarityToSelected = 0; // Only calculate similarity if candidate has embedding and we have selected items if (selectedCandidates.length > 0 && candidate.embedding.length > 0) { // Optimization: only check similarity with most recent selected items (last 5) const recentSelected = selectedCandidates.slice( -Math.min(5, selectedCandidates.length), ); for (const selected of recentSelected) { if (selected.embedding.length > 0) { const similarity = cosineSimilarity( candidate.embedding, selected.embedding, ); maxSimilarityToSelected = Math.max( maxSimilarityToSelected, similarity, ); // Early exit: if similarity is very high, no need to check more if (similarity > 0.95) break; } } } // MMR Score: λ * relevance - (1-λ) * max_similarity_to_selected const mmrScore = lambda * candidate.relevanceScore - (1 - lambda) * maxSimilarityToSelected; if (mmrScore > bestScore) { bestScore = mmrScore; bestCandidate = candidate; bestIndex = i; } } if (bestCandidate && bestIndex !== -1) { selectedCandidates.push(bestCandidate); remainingCandidates.splice(bestIndex, 1); } else { // No more candidates to select break; } } // Return selected statements with MMR scores return selectedCandidates.map((item, index) => ({ ...item.statement, mmrScore: item.relevanceScore, // Keep original relevance score mmrRank: index + 1, })); } /** * Calculate cosine similarity between two vectors */ function cosineSimilarity(a: number[], b: number[]): number { if (a.length !== b.length || a.length === 0) return 0; let dotProduct = 0; let normA = 0; let normB = 0; for (let i = 0; i < a.length; i++) { dotProduct += a[i] * b[i]; normA += a[i] * a[i]; normB += b[i] * b[i]; } if (normA === 0 || normB === 0) return 0; return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); } /** * Apply Cross-Encoder reranking to results * This is particularly useful when results come from a single source */ export async function applyCrossEncoderReranking( query: string, results: { bm25: StatementNode[]; vector: StatementNode[]; bfs: StatementNode[]; }, ): Promise { // Combine all results const allResults = [...results.bm25, ...results.vector, ...results.bfs]; // Deduplicate by UUID const uniqueResults = combineAndDeduplicateStatements(allResults); if (uniqueResults.length === 0) return []; logger.info(`Cross-encoder reranking ${uniqueResults.length} statements`); const finalStatements: StatementNode[] = []; await Promise.all( uniqueResults.map(async (statement) => { const messages: CoreMessage[] = [ { role: "system", content: `You are an expert tasked with determining whether the statement is relevant to the query Respond with "True" if STATEMENT is relevant to QUERY and "False" otherwise.`, }, { role: "user", content: `${query}\n${statement.fact}`, }, ]; let responseText = ""; await makeModelCall( false, messages, (text) => { responseText = text; }, { temperature: 0, maxTokens: 1 }, ); if (responseText === "True") { finalStatements.push(statement); } }), ); return finalStatements; } /** * Apply combined MultiFactorReranking + MMR for optimal relevance and diversity * First applies MultiFactorReranking for authority/popularity/temporal scoring, * then applies MMR to reduce redundancy while maintaining relevance */ export function applyMultiFactorMMRReranking( results: { bm25: StatementNode[]; vector: StatementNode[]; bfs: StatementNode[]; }, options?: { lambda?: number; // MMR balance parameter (default: 0.7) maxResults?: number; // Maximum results to return (default: 50) }, ): StatementNode[] { const { lambda = 0.7, maxResults = 50 } = options || {}; // Step 1: Apply MultiFactorReranking to get relevance/authority/popularity scores const multiFactorResults = applyMultiFactorReranking(results); // Step 2: Apply MMR to reduce redundancy while maintaining relevance const mmrResults = applyMMRReranking(multiFactorResults, lambda, maxResults); // Add combined score for debugging return mmrResults.map((statement) => ({ ...statement, combinedScore: safeNumber((statement as any).mmrScore), // MMR preserves MultiFactorScore rerankerUsed: "multifactor+mmr", })); } /** * Apply Multi-Factor Reranking combining semantic, structural, temporal, and provenance signals */ export function applyMultiFactorReranking(results: { bm25: StatementNode[]; vector: StatementNode[]; bfs: StatementNode[]; }): StatementNode[] { // Map to store combined scores and metadata const scores: Record< string, { score: number; statement: StatementNode; signals: { bm25: number; vector: number; bfs: number }; } > = {}; // Extract original scores when available (handle BigInt) const getOriginalScore = (statement: any) => { const rawScore = statement.similarity || statement.score || statement.bm25Score || 0; return safeNumber(rawScore); }; // Process BM25 results - preserve original BM25 scores results.bm25.forEach((statement, rank) => { const uuid = statement.uuid; const originalScore = getOriginalScore(statement); const normalizedScore = Math.max(originalScore, 1 / (rank + 1)); // Rank fallback scores[uuid] = scores[uuid] || { score: 0, statement, signals: { bm25: 0, vector: 0, bfs: 0 }, }; scores[uuid].signals.bm25 = normalizedScore; }); // Process vector similarity results - preserve semantic scores results.vector.forEach((statement, rank) => { const uuid = statement.uuid; const originalScore = getOriginalScore(statement); const normalizedScore = Math.max(originalScore, 1 / (rank + 1)); scores[uuid] = scores[uuid] || { score: 0, statement, signals: { bm25: 0, vector: 0, bfs: 0 }, }; scores[uuid].signals.vector = normalizedScore; }); // Process BFS traversal results - structural relevance results.bfs.forEach((statement, rank) => { const uuid = statement.uuid; const originalScore = getOriginalScore(statement); const normalizedScore = Math.max(originalScore, 1 / (rank + 1)); scores[uuid] = scores[uuid] || { score: 0, statement, signals: { bm25: 0, vector: 0, bfs: 0 }, }; scores[uuid].signals.bfs = normalizedScore; }); // Calculate final scores using adaptive weights Object.values(scores).forEach((item) => { const { bm25, vector, bfs } = item.signals; // Adaptive weights based on signal strength const totalSignals = (bm25 > 0 ? 1 : 0) + (vector > 0 ? 1 : 0) + (bfs > 0 ? 1 : 0); // Multi-signal bonus: statements appearing in multiple sources get higher weights const multiSignalBonus = totalSignals > 1 ? 1.2 : 1.0; // Dynamic weights: stronger for queries that benefit from each signal type const weights = { bm25: bm25 > 0 ? 1.0 : 0, // Keyword matching vector: vector > 0 ? 0.9 : 0, // Semantic similarity bfs: bfs > 0 ? 0.6 : 0, // Graph connectivity }; // Temporal recency bonus (newer statements get slight boost) const createdAt = new Date(item.statement.createdAt).getTime(); const now = Date.now(); const daysSince = (now - createdAt) / (1000 * 60 * 60 * 24); const recencyBonus = Math.max(0.9, 1.0 - (daysSince / 365) * 0.1); // Max 10% decay over 1 year // Popularity bonus based on recall count (log-scaled to prevent dominance) const recallCount = safeNumber(item.statement.recallCount); const popularityBonus = 1.0 + Math.log(1 + recallCount) * 0.15; // Up to ~30% boost for frequently recalled facts // Provenance authority bonus based on multiple source episodes const provenanceCount = Math.max( 1, safeNumber(item.statement.provenanceCount), ); const authorityBonus = 1.0 + Math.log(provenanceCount) * 0.2; // Up to ~35% boost for multi-source facts // Final weighted score with all bonuses item.score = (weights.bm25 * bm25 + weights.vector * vector + weights.bfs * bfs) * multiSignalBonus * recencyBonus * popularityBonus * authorityBonus; }); // Convert to array and sort by final score const sortedResults = Object.values(scores) .sort((a, b) => b.score - a.score) .map((item) => { // console.log(item.statement.fact, item.score); // Add the reranking score and signal breakdown for debugging return { ...item.statement, multifactorScore: item.score, signals: item.signals, }; }); return sortedResults; } /** * Apply LLM-based reranking for contextual understanding * Uses GPT-4o-mini to verify relevance with semantic reasoning */ export async function applyLLMReranking( query: string, results: { bm25: StatementNode[]; vector: StatementNode[]; bfs: StatementNode[]; }, limit: number = 10, userContext?: { name?: string; userId: string }, ): Promise { const allResults = [ ...results.bm25.slice(0, 100), ...results.vector.slice(0, 100), ...results.bfs.slice(0, 100), ]; const uniqueResults = combineAndDeduplicateStatements(allResults); logger.info(`Unique results: ${uniqueResults.length}`); if (uniqueResults.length === 0) { logger.info("No results to rerank with Cohere"); return []; } // Build user context section if provided const userContextSection = userContext?.name ? `\nUser Identity Context: - The user's name is "${userContext.name}" - References to "user", "${userContext.name}", or pronouns like "my/their" refer to the same person - When matching queries about "user's X" or "${userContext.name}'s X", these are equivalent\n` : ""; const prompt = `You are a relevance filter. Given a user query and a list of facts, identify ONLY the facts that are truly relevant to answering the query. ${userContextSection} Query: "${query}" Facts: ${uniqueResults.map((r, i) => `${i}. ${r.fact}`).join("\n")} Instructions: - A fact is RELEVANT if it directly answers or provides information needed to answer the query - A fact is NOT RELEVANT if it's tangentially related but doesn't answer the query - Consider semantic meaning, not just keyword matching ${userContext?.name ? `- Remember: "user", "${userContext.name}", and possessive references ("my", "their") all refer to the same person` : ""} - Only return facts with HIGH relevance (≥80% confidence) - If you are not sure, return an empty array Output format: [1, 5, 7] Return ONLY the numbers of highly relevant facts inside tags as a JSON array:`; try { let responseText = ""; await makeModelCall( false, [{ role: "user", content: prompt }], (text) => { responseText = text; }, { temperature: 0 }, "high", ); // Extract array from [1, 5, 7] const outputMatch = responseText.match(/([\s\S]*?)<\/output>/); if (outputMatch && outputMatch[1]) { responseText = outputMatch[1].trim(); const parsedResponse = JSON.parse(responseText || "[]"); const extractedIndices = Array.isArray(parsedResponse) ? parsedResponse : parsedResponse.entities || []; if (extractedIndices.length === 0) { logger.warn( "LLM reranking returned no valid indices, falling back to original order", ); return []; } logger.info( `LLM reranking selected ${extractedIndices.length} relevant facts`, ); const selected = extractedIndices.map((i: number) => uniqueResults[i]); return selected; } return uniqueResults.slice(0, limit); } catch (error) { logger.error("LLM reranking failed, falling back to original order:", { error, }); return uniqueResults.slice(0, limit); } } /** * Apply Cohere Rerank 3.5 to search results for improved question-to-fact matching * This is particularly effective for bridging the semantic gap between questions and factual statements */ export async function applyCohereReranking( query: string, results: { bm25: StatementNode[]; vector: StatementNode[]; bfs: StatementNode[]; }, options?: { limit?: number; model?: string; useLLMVerification?: boolean; }, ): Promise { const { model = "rerank-v3.5" } = options || {}; const limit = 100; try { const startTime = Date.now(); // Combine and deduplicate all results const allResults = [ ...results.bm25.slice(0, 100), ...results.vector.slice(0, 100), ...results.bfs.slice(0, 100), ]; const uniqueResults = combineAndDeduplicateStatements(allResults); console.log("Unique results:", uniqueResults.length); if (uniqueResults.length === 0) { logger.info("No results to rerank with Cohere"); return []; } // Check for API key const apiKey = process.env.COHERE_API_KEY; if (!apiKey) { logger.warn("COHERE_API_KEY not found, falling back to original results"); return uniqueResults.slice(0, limit); } // Initialize Cohere client const cohere = new CohereClientV2({ token: apiKey, }); // Prepare documents for Cohere API const documents = uniqueResults.map((statement) => statement.fact); console.log("Documents:", documents); logger.info( `Cohere reranking ${documents.length} statements with model ${model}`, ); logger.info(`Cohere query: "${query}"`); logger.info(`First 5 documents: ${documents.slice(0, 5).join(" | ")}`); // Call Cohere Rerank API const response = await cohere.rerank({ query, documents, model, topN: Math.min(limit, documents.length), }); console.log("Cohere reranking billed units:", response.meta?.billedUnits); // Log top 5 Cohere results for debugging logger.info( `Cohere top 5 results:\n${response.results .slice(0, 5) .map( (r, i) => ` ${i + 1}. [${r.relevanceScore.toFixed(4)}] ${documents[r.index].substring(0, 80)}...`, ) .join("\n")}`, ); // Map results back to StatementNodes with Cohere scores const rerankedResults = response.results.map((result, index) => ({ ...uniqueResults[result.index], cohereScore: result.relevanceScore, cohereRank: index + 1, })); // .filter((result) => result.cohereScore >= Number(env.COHERE_SCORE_THRESHOLD)); const responseTime = Date.now() - startTime; logger.info( `Cohere reranking completed: ${rerankedResults.length} results returned in ${responseTime}ms`, ); return rerankedResults; } catch (error) { logger.error("Cohere reranking failed:", { error }); // Graceful fallback to original results const allResults = [...results.bm25, ...results.vector, ...results.bfs]; const uniqueResults = combineAndDeduplicateStatements(allResults); return uniqueResults.slice(0, limit); } }