533 lines
16 KiB
TypeScript

import { type StatementNode } from "@core/types";
import { combineAndDeduplicateStatements } from "./utils";
import { type CoreMessage } from "ai";
import { makeModelCall } from "~/lib/model.server";
import { logger } from "../logger.service";
import { CohereClientV2 } from "cohere-ai";
// Utility function to safely convert BigInt values to Number
function safeNumber(value: any): number {
if (typeof value === "bigint") {
return Number(value);
}
if (typeof value === "number") {
return value;
}
if (typeof value === "string") {
const parsed = parseFloat(value);
return isNaN(parsed) ? 0 : parsed;
}
return 0;
}
/**
* Apply Weighted Reciprocal Rank Fusion to combine results
*/
export function applyWeightedRRF(results: {
bm25: StatementNode[];
vector: StatementNode[];
bfs: StatementNode[];
}): StatementNode[] {
// Determine weights based on query characteristics
const weights = {
bm25: 1.0,
vector: 0.8,
bfs: 0.5,
};
const k = 60; // RRF constant
// Map to store combined scores
const scores: Record<string, { score: number; statement: StatementNode }> =
{};
// Process BM25 results with their weight
results.bm25.forEach((statement, rank) => {
const uuid = statement.uuid;
scores[uuid] = scores[uuid] || { score: 0, statement };
scores[uuid].score += weights.bm25 * (1 / (rank + k));
});
// Process vector similarity results with their weight
results.vector.forEach((statement, rank) => {
const uuid = statement.uuid;
scores[uuid] = scores[uuid] || { score: 0, statement };
scores[uuid].score += weights.vector * (1 / (rank + k));
});
// Process BFS traversal results with their weight
results.bfs.forEach((statement, rank) => {
const uuid = statement.uuid;
scores[uuid] = scores[uuid] || { score: 0, statement };
scores[uuid].score += weights.bfs * (1 / (rank + k));
});
// Convert to array and sort by final score
const sortedResults = Object.values(scores)
.sort((a, b) => b.score - a.score)
.map((item) => {
// Add the RRF score to the statement for debugging
return {
...item.statement,
rrfScore: item.score,
};
});
return sortedResults;
}
/**
* Apply MMR (Maximal Marginal Relevance) reranking to reduce redundancy while maintaining relevance
* MMR balances relevance and diversity to prevent redundant fact statements in results
*/
export function applyMMRReranking(
statements: StatementNode[],
lambda: number = 0.7, // Balance between relevance (1.0) and diversity (0.0)
maxResults: number = 50,
): StatementNode[] {
if (statements.length === 0) return [];
// Extract relevance scores and embeddings
const candidates = statements.map((statement) => {
let relevanceScore = 0;
// Use existing scores from MultiFactorReranking or other sources
if ((statement as any).multifactorScore !== undefined) {
relevanceScore = safeNumber((statement as any).multifactorScore);
} else if ((statement as any).rrfScore !== undefined) {
relevanceScore = safeNumber((statement as any).rrfScore);
} else if ((statement as any).crossEncoderScore !== undefined) {
relevanceScore = safeNumber((statement as any).crossEncoderScore);
} else if ((statement as any).finalScore !== undefined) {
relevanceScore = safeNumber((statement as any).finalScore);
}
return {
statement,
relevanceScore,
embedding: statement.factEmbedding || [],
selected: false,
};
});
// Sort by relevance score (descending)
candidates.sort((a, b) => b.relevanceScore - a.relevanceScore);
const selectedCandidates: typeof candidates = [];
const remainingCandidates = [...candidates];
// Pre-filter candidates with no embeddings for faster processing
const candidatesWithEmbeddings = remainingCandidates.filter(
(c) => c.embedding.length > 0,
);
const candidatesWithoutEmbeddings = remainingCandidates.filter(
(c) => c.embedding.length === 0,
);
// MMR Selection Algorithm with optimizations
while (
selectedCandidates.length < maxResults &&
remainingCandidates.length > 0
) {
let bestCandidate = null;
let bestScore = -Infinity;
let bestIndex = -1;
// Early termination: if we have enough high-relevance items, stop diversity checking
const relevanceThreshold =
selectedCandidates.length > 0
? selectedCandidates[selectedCandidates.length - 1].relevanceScore * 0.5
: 0;
for (let i = 0; i < remainingCandidates.length; i++) {
const candidate = remainingCandidates[i];
// Skip similarity calculation for very low relevance items
if (
candidate.relevanceScore < relevanceThreshold &&
selectedCandidates.length > 3
) {
continue;
}
let maxSimilarityToSelected = 0;
// Only calculate similarity if candidate has embedding and we have selected items
if (selectedCandidates.length > 0 && candidate.embedding.length > 0) {
// Optimization: only check similarity with most recent selected items (last 5)
const recentSelected = selectedCandidates.slice(
-Math.min(5, selectedCandidates.length),
);
for (const selected of recentSelected) {
if (selected.embedding.length > 0) {
const similarity = cosineSimilarity(
candidate.embedding,
selected.embedding,
);
maxSimilarityToSelected = Math.max(
maxSimilarityToSelected,
similarity,
);
// Early exit: if similarity is very high, no need to check more
if (similarity > 0.95) break;
}
}
}
// MMR Score: λ * relevance - (1-λ) * max_similarity_to_selected
const mmrScore =
lambda * candidate.relevanceScore -
(1 - lambda) * maxSimilarityToSelected;
if (mmrScore > bestScore) {
bestScore = mmrScore;
bestCandidate = candidate;
bestIndex = i;
}
}
if (bestCandidate && bestIndex !== -1) {
selectedCandidates.push(bestCandidate);
remainingCandidates.splice(bestIndex, 1);
} else {
// No more candidates to select
break;
}
}
// Return selected statements with MMR scores
return selectedCandidates.map((item, index) => ({
...item.statement,
mmrScore: item.relevanceScore, // Keep original relevance score
mmrRank: index + 1,
}));
}
/**
* Calculate cosine similarity between two vectors
*/
function cosineSimilarity(a: number[], b: number[]): number {
if (a.length !== b.length || a.length === 0) return 0;
let dotProduct = 0;
let normA = 0;
let normB = 0;
for (let i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
if (normA === 0 || normB === 0) return 0;
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
/**
* Apply Cross-Encoder reranking to results
* This is particularly useful when results come from a single source
*/
export async function applyCrossEncoderReranking(
query: string,
results: {
bm25: StatementNode[];
vector: StatementNode[];
bfs: StatementNode[];
},
): Promise<StatementNode[]> {
// Combine all results
const allResults = [...results.bm25, ...results.vector, ...results.bfs];
// Deduplicate by UUID
const uniqueResults = combineAndDeduplicateStatements(allResults);
if (uniqueResults.length === 0) return [];
logger.info(`Cross-encoder reranking ${uniqueResults.length} statements`);
const finalStatements: StatementNode[] = [];
await Promise.all(
uniqueResults.map(async (statement) => {
const messages: CoreMessage[] = [
{
role: "system",
content: `You are an expert tasked with determining whether the statement is relevant to the query
Respond with "True" if STATEMENT is relevant to QUERY and "False" otherwise.`,
},
{
role: "user",
content: `<QUERY>${query}</QUERY>\n<STATEMENT>${statement.fact}</STATEMENT>`,
},
];
let responseText = "";
await makeModelCall(
false,
messages,
(text) => {
responseText = text;
},
{ temperature: 0, maxTokens: 1 },
);
if (responseText === "True") {
finalStatements.push(statement);
}
}),
);
return finalStatements;
}
/**
* Apply combined MultiFactorReranking + MMR for optimal relevance and diversity
* First applies MultiFactorReranking for authority/popularity/temporal scoring,
* then applies MMR to reduce redundancy while maintaining relevance
*/
export function applyMultiFactorMMRReranking(
results: {
bm25: StatementNode[];
vector: StatementNode[];
bfs: StatementNode[];
},
options?: {
lambda?: number; // MMR balance parameter (default: 0.7)
maxResults?: number; // Maximum results to return (default: 50)
},
): StatementNode[] {
const { lambda = 0.7, maxResults = 50 } = options || {};
// Step 1: Apply MultiFactorReranking to get relevance/authority/popularity scores
const multiFactorResults = applyMultiFactorReranking(results);
// Step 2: Apply MMR to reduce redundancy while maintaining relevance
const mmrResults = applyMMRReranking(multiFactorResults, lambda, maxResults);
// Add combined score for debugging
return mmrResults.map((statement) => ({
...statement,
combinedScore: safeNumber((statement as any).mmrScore), // MMR preserves MultiFactorScore
rerankerUsed: "multifactor+mmr",
}));
}
/**
* Apply Multi-Factor Reranking combining semantic, structural, temporal, and provenance signals
*/
export function applyMultiFactorReranking(results: {
bm25: StatementNode[];
vector: StatementNode[];
bfs: StatementNode[];
}): StatementNode[] {
// Map to store combined scores and metadata
const scores: Record<
string,
{
score: number;
statement: StatementNode;
signals: { bm25: number; vector: number; bfs: number };
}
> = {};
// Extract original scores when available (handle BigInt)
const getOriginalScore = (statement: any) => {
const rawScore =
statement.similarity || statement.score || statement.bm25Score || 0;
return safeNumber(rawScore);
};
// Process BM25 results - preserve original BM25 scores
results.bm25.forEach((statement, rank) => {
const uuid = statement.uuid;
const originalScore = getOriginalScore(statement);
const normalizedScore = Math.max(originalScore, 1 / (rank + 1)); // Rank fallback
scores[uuid] = scores[uuid] || {
score: 0,
statement,
signals: { bm25: 0, vector: 0, bfs: 0 },
};
scores[uuid].signals.bm25 = normalizedScore;
});
// Process vector similarity results - preserve semantic scores
results.vector.forEach((statement, rank) => {
const uuid = statement.uuid;
const originalScore = getOriginalScore(statement);
const normalizedScore = Math.max(originalScore, 1 / (rank + 1));
scores[uuid] = scores[uuid] || {
score: 0,
statement,
signals: { bm25: 0, vector: 0, bfs: 0 },
};
scores[uuid].signals.vector = normalizedScore;
});
// Process BFS traversal results - structural relevance
results.bfs.forEach((statement, rank) => {
const uuid = statement.uuid;
const originalScore = getOriginalScore(statement);
const normalizedScore = Math.max(originalScore, 1 / (rank + 1));
scores[uuid] = scores[uuid] || {
score: 0,
statement,
signals: { bm25: 0, vector: 0, bfs: 0 },
};
scores[uuid].signals.bfs = normalizedScore;
});
// Calculate final scores using adaptive weights
Object.values(scores).forEach((item) => {
const { bm25, vector, bfs } = item.signals;
// Adaptive weights based on signal strength
const totalSignals =
(bm25 > 0 ? 1 : 0) + (vector > 0 ? 1 : 0) + (bfs > 0 ? 1 : 0);
// Multi-signal bonus: statements appearing in multiple sources get higher weights
const multiSignalBonus = totalSignals > 1 ? 1.2 : 1.0;
// Dynamic weights: stronger for queries that benefit from each signal type
const weights = {
bm25: bm25 > 0 ? 1.0 : 0, // Keyword matching
vector: vector > 0 ? 0.9 : 0, // Semantic similarity
bfs: bfs > 0 ? 0.6 : 0, // Graph connectivity
};
// Temporal recency bonus (newer statements get slight boost)
const createdAt = new Date(item.statement.createdAt).getTime();
const now = Date.now();
const daysSince = (now - createdAt) / (1000 * 60 * 60 * 24);
const recencyBonus = Math.max(0.9, 1.0 - (daysSince / 365) * 0.1); // Max 10% decay over 1 year
// Popularity bonus based on recall count (log-scaled to prevent dominance)
const recallCount = safeNumber(item.statement.recallCount);
const popularityBonus = 1.0 + Math.log(1 + recallCount) * 0.15; // Up to ~30% boost for frequently recalled facts
// Provenance authority bonus based on multiple source episodes
const provenanceCount = Math.max(
1,
safeNumber(item.statement.provenanceCount),
);
const authorityBonus = 1.0 + Math.log(provenanceCount) * 0.2; // Up to ~35% boost for multi-source facts
// Final weighted score with all bonuses
item.score =
(weights.bm25 * bm25 + weights.vector * vector + weights.bfs * bfs) *
multiSignalBonus *
recencyBonus *
popularityBonus *
authorityBonus;
});
// Convert to array and sort by final score
const sortedResults = Object.values(scores)
.sort((a, b) => b.score - a.score)
.map((item) => {
// console.log(item.statement.fact, item.score);
// Add the reranking score and signal breakdown for debugging
return {
...item.statement,
multifactorScore: item.score,
signals: item.signals,
};
});
return sortedResults;
}
/**
* Apply Cohere Rerank 3.5 to search results for improved question-to-fact matching
* This is particularly effective for bridging the semantic gap between questions and factual statements
*/
export async function applyCohereReranking(
query: string,
results: {
bm25: StatementNode[];
vector: StatementNode[];
bfs: StatementNode[];
},
options?: {
limit?: number;
model?: string;
},
): Promise<StatementNode[]> {
const { model = "rerank-v3.5" } = options || {};
const limit = 100;
try {
const startTime = Date.now();
// Combine and deduplicate all results
const allResults = [
...results.bm25.slice(0, 100),
...results.vector.slice(0, 100),
...results.bfs.slice(0, 100),
];
const uniqueResults = combineAndDeduplicateStatements(allResults);
console.log("Unique results:", uniqueResults.length);
if (uniqueResults.length === 0) {
logger.info("No results to rerank with Cohere");
return [];
}
// Check for API key
const apiKey = process.env.COHERE_API_KEY;
if (!apiKey) {
logger.warn("COHERE_API_KEY not found, falling back to original results");
return uniqueResults.slice(0, limit);
}
// Initialize Cohere client
const cohere = new CohereClientV2({
token: apiKey,
});
// Prepare documents for Cohere API
const documents = uniqueResults.map((statement) => statement.fact);
logger.info(
`Cohere reranking ${documents.length} statements with model ${model}`,
);
// Call Cohere Rerank API
const response = await cohere.rerank({
query,
documents,
model,
topN: Math.min(limit, documents.length),
});
console.log("Cohere reranking billed units:", response.meta?.billedUnits);
// Map results back to StatementNodes with Cohere scores
const rerankedResults = response.results
.map((result, index) => ({
...uniqueResults[result.index],
cohereScore: result.relevanceScore,
cohereRank: index + 1,
}))
.filter((result) => result.cohereScore >= 0.1);
const responseTime = Date.now() - startTime;
logger.info(
`Cohere reranking completed: ${rerankedResults.length} results returned in ${responseTime}ms`,
);
return rerankedResults;
} catch (error) {
logger.error("Cohere reranking failed:", { error });
// Graceful fallback to original results
const allResults = [...results.bm25, ...results.vector, ...results.bfs];
const uniqueResults = combineAndDeduplicateStatements(allResults);
return uniqueResults.slice(0, limit);
}
}