mirror of
https://github.com/eliasstepanik/core.git
synced 2026-01-10 08:48:29 +00:00
119 lines
3.3 KiB
TypeScript
119 lines
3.3 KiB
TypeScript
import { LLMModelEnum, type StatementNode } from "@core/types";
|
|
import { combineAndDeduplicateStatements } from "./utils";
|
|
import { type CoreMessage } from "ai";
|
|
import { makeModelCall } from "~/lib/model.server";
|
|
import { logger } from "../logger.service";
|
|
|
|
/**
|
|
* Apply Weighted Reciprocal Rank Fusion to combine results
|
|
*/
|
|
export function applyWeightedRRF(results: {
|
|
bm25: StatementNode[];
|
|
vector: StatementNode[];
|
|
bfs: StatementNode[];
|
|
}): StatementNode[] {
|
|
// Determine weights based on query characteristics
|
|
const weights = {
|
|
bm25: 1.0,
|
|
vector: 0.8,
|
|
bfs: 0.5,
|
|
};
|
|
const k = 60; // RRF constant
|
|
|
|
// Map to store combined scores
|
|
const scores: Record<string, { score: number; statement: StatementNode }> =
|
|
{};
|
|
|
|
// Process BM25 results with their weight
|
|
results.bm25.forEach((statement, rank) => {
|
|
const uuid = statement.uuid;
|
|
scores[uuid] = scores[uuid] || { score: 0, statement };
|
|
scores[uuid].score += weights.bm25 * (1 / (rank + k));
|
|
});
|
|
|
|
// Process vector similarity results with their weight
|
|
results.vector.forEach((statement, rank) => {
|
|
const uuid = statement.uuid;
|
|
scores[uuid] = scores[uuid] || { score: 0, statement };
|
|
scores[uuid].score += weights.vector * (1 / (rank + k));
|
|
});
|
|
|
|
// Process BFS traversal results with their weight
|
|
results.bfs.forEach((statement, rank) => {
|
|
const uuid = statement.uuid;
|
|
scores[uuid] = scores[uuid] || { score: 0, statement };
|
|
scores[uuid].score += weights.bfs * (1 / (rank + k));
|
|
});
|
|
|
|
// Convert to array and sort by final score
|
|
const sortedResults = Object.values(scores)
|
|
.sort((a, b) => b.score - a.score)
|
|
.map((item) => {
|
|
// Add the RRF score to the statement for debugging
|
|
return {
|
|
...item.statement,
|
|
rrfScore: item.score,
|
|
};
|
|
});
|
|
|
|
return sortedResults;
|
|
}
|
|
|
|
/**
|
|
* Apply Cross-Encoder reranking to results
|
|
* This is particularly useful when results come from a single source
|
|
*/
|
|
export async function applyCrossEncoderReranking(
|
|
query: string,
|
|
results: {
|
|
bm25: StatementNode[];
|
|
vector: StatementNode[];
|
|
bfs: StatementNode[];
|
|
},
|
|
): Promise<StatementNode[]> {
|
|
// Combine all results
|
|
const allResults = [...results.bm25, ...results.vector, ...results.bfs];
|
|
|
|
// Deduplicate by UUID
|
|
const uniqueResults = combineAndDeduplicateStatements(allResults);
|
|
|
|
if (uniqueResults.length === 0) return [];
|
|
|
|
logger.info(`Cross-encoder reranking ${uniqueResults.length} statements`);
|
|
|
|
const finalStatements: StatementNode[] = [];
|
|
|
|
await Promise.all(
|
|
uniqueResults.map(async (statement) => {
|
|
const messages: CoreMessage[] = [
|
|
{
|
|
role: "system",
|
|
content: `You are an expert tasked with determining whether the statement is relevant to the query
|
|
Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.`,
|
|
},
|
|
{
|
|
role: "user",
|
|
content: `<QUERY>${query}</QUERY>\n<STATEMENT>${statement.fact}</STATEMENT>`,
|
|
},
|
|
];
|
|
|
|
let responseText = "";
|
|
await makeModelCall(
|
|
false,
|
|
LLMModelEnum.GPT41NANO,
|
|
messages,
|
|
(text) => {
|
|
responseText = text;
|
|
},
|
|
{ temperature: 0, maxTokens: 1 },
|
|
);
|
|
|
|
if (responseText === "True") {
|
|
finalStatements.push(statement);
|
|
}
|
|
}),
|
|
);
|
|
|
|
return finalStatements;
|
|
}
|