refactor: implement hierarchical search ranking with episode graph and source tracking

This commit is contained in:
Manoj 2025-10-24 22:40:22 +05:30 committed by Harshith Mullapudi
parent f39c7cc6d0
commit 5b31c8ed62
3 changed files with 1138 additions and 185 deletions

View File

@ -1,13 +1,16 @@
import type { EpisodicNode, StatementNode } from "@core/types";
import type { EntityNode, EpisodicNode, StatementNode } from "@core/types";
import { logger } from "./logger.service";
import { applyLLMReranking } from "./search/rerank";
import {
getEpisodesByStatements,
performBfsSearch,
performBM25Search,
performVectorSearch,
performEpisodeGraphSearch,
extractEntitiesFromQuery,
groupStatementsByEpisode,
getEpisodesByUuids,
type EpisodeGraphResult,
} from "./search/utils";
import { getEmbedding } from "~/lib/model.server";
import { getEmbedding, makeModelCall } from "~/lib/model.server";
import { prisma } from "~/db.server";
import { runQuery } from "~/lib/neo4j.server";
@ -63,35 +66,100 @@ export class SearchService {
spaceIds: options.spaceIds || [],
adaptiveFiltering: options.adaptiveFiltering || false,
structured: options.structured || false,
useLLMValidation: options.useLLMValidation || true,
qualityThreshold: options.qualityThreshold || 0.3,
maxEpisodesForLLM: options.maxEpisodesForLLM || 20,
};
// Enhance query with LLM to transform keyword soup into semantic query
const queryVector = await this.getEmbedding(query);
// 1. Run parallel search methods
const [bm25Results, vectorResults, bfsResults] = await Promise.all([
// Note: We still need to extract entities from graph for Episode Graph search
// The LLM entities are just strings, we need EntityNode objects from the graph
const entities = await extractEntitiesFromQuery(query, userId, []);
logger.info(`Extracted entities ${entities.map((e: EntityNode) => e.name).join(', ')}`);
// 1. Run parallel search methods (including episode graph search) using enhanced query
const [bm25Results, vectorResults, bfsResults, episodeGraphResults] = await Promise.all([
performBM25Search(query, userId, opts),
performVectorSearch(queryVector, userId, opts),
performBfsSearch(query, queryVector, userId, opts),
performBfsSearch(query, queryVector, userId, entities, opts),
performEpisodeGraphSearch(query, entities, queryVector, userId, opts),
]);
logger.info(
`Search results - BM25: ${bm25Results.length}, Vector: ${vectorResults.length}, BFS: ${bfsResults.length}`,
`Search results - BM25: ${bm25Results.length}, Vector: ${vectorResults.length}, BFS: ${bfsResults.length}, EpisodeGraph: ${episodeGraphResults.length}`,
);
// 2. Apply reranking strategy
const rankedStatements = await this.rerankResults(
query,
userId,
{ bm25: bm25Results, vector: vectorResults, bfs: bfsResults },
opts,
// 2. TWO-STAGE RANKING PIPELINE: Quality-based filtering with hierarchical scoring
// Stage 1: Extract episodes with provenance tracking
const episodesWithProvenance = await this.extractEpisodesWithProvenance({
episodeGraph: episodeGraphResults,
bfs: bfsResults,
vector: vectorResults,
bm25: bm25Results,
});
logger.info(`Extracted ${episodesWithProvenance.length} unique episodes from all sources`);
// Stage 2: Rate episodes by source hierarchy (EpisodeGraph > BFS > Vector > BM25)
const ratedEpisodes = this.rateEpisodesBySource(episodesWithProvenance);
// Stage 3: Filter by quality (not by model capability)
const qualityThreshold = opts.qualityThreshold || QUALITY_THRESHOLDS.HIGH_QUALITY_EPISODE;
const qualityFilter = this.filterByQuality(ratedEpisodes, query, qualityThreshold);
// If no high-quality matches, return empty
if (qualityFilter.confidence < QUALITY_THRESHOLDS.NO_RESULT) {
logger.warn(`Low confidence (${qualityFilter.confidence.toFixed(2)}) for query: "${query}"`);
return opts.structured
? {
episodes: [],
facts: [],
}
: this.formatAsMarkdown([], []);
}
// Stage 4: Optional LLM validation for borderline confidence
let finalEpisodes = qualityFilter.episodes;
const useLLMValidation = opts.useLLMValidation || false;
if (
useLLMValidation &&
qualityFilter.confidence >= QUALITY_THRESHOLDS.UNCERTAIN_RESULT &&
qualityFilter.confidence < QUALITY_THRESHOLDS.CONFIDENT_RESULT
) {
logger.info(
`Borderline confidence (${qualityFilter.confidence.toFixed(2)}), using LLM validation`,
);
const maxEpisodesForLLM = opts.maxEpisodesForLLM || 20;
finalEpisodes = await this.validateEpisodesWithLLM(
query,
qualityFilter.episodes,
maxEpisodesForLLM,
);
if (finalEpisodes.length === 0) {
logger.info('LLM validation rejected all episodes, returning empty');
return opts.structured ? { episodes: [], facts: [] } : this.formatAsMarkdown([], []);
}
}
// Extract episodes and statements for response
const episodes = finalEpisodes.map((ep) => ep.episode);
const filteredResults = finalEpisodes.flatMap((ep) =>
ep.statements.map((s) => ({
statement: s.statement,
score: Number((ep.firstLevelScore || 0).toFixed(2)),
})),
);
// 3. Apply adaptive filtering based on score threshold and minimum count
const filteredResults = this.applyAdaptiveFiltering(rankedStatements, opts);
// 3. Return top results
const episodes = await getEpisodesByStatements(
filteredResults.map((item) => item.statement),
logger.info(
`Final results: ${episodes.length} episodes, ${filteredResults.length} statements, ` +
`confidence: ${qualityFilter.confidence.toFixed(2)}`,
);
// Log recall asynchronously (don't await to avoid blocking response)
@ -135,151 +203,6 @@ export class SearchService {
return this.formatAsMarkdown(unifiedEpisodes, factsData);
}
/**
* Apply adaptive filtering to ranked results
* Uses a minimum quality threshold to filter out low-quality results
*/
private applyAdaptiveFiltering(
results: StatementNode[],
options: Required<SearchOptions>,
): { statement: StatementNode; score: number }[] {
if (results.length === 0) return [];
let isRRF = false;
// Extract scores from results
const scoredResults = results.map((result) => {
// Find the score based on reranking strategy used
let score = 0;
if ((result as any).rrfScore !== undefined) {
score = (result as any).rrfScore;
isRRF = true;
} else if ((result as any).mmrScore !== undefined) {
score = (result as any).mmrScore;
} else if ((result as any).crossEncoderScore !== undefined) {
score = (result as any).crossEncoderScore;
} else if ((result as any).finalScore !== undefined) {
score = (result as any).finalScore;
} else if ((result as any).multifactorScore !== undefined) {
score = (result as any).multifactorScore;
} else if ((result as any).combinedScore !== undefined) {
score = (result as any).combinedScore;
} else if ((result as any).mmrScore !== undefined) {
score = (result as any).mmrScore;
} else if ((result as any).cohereScore !== undefined) {
score = (result as any).cohereScore;
}
return { statement: result, score };
});
if (!options.adaptiveFiltering || results.length <= 5) {
return scoredResults;
}
const hasScores = scoredResults.some((item) => item.score > 0);
// If no scores are available, return the original results
if (!hasScores) {
logger.info("No scores found in results, skipping adaptive filtering");
return options.limit > 0
? results
.slice(0, options.limit)
.map((item) => ({ statement: item, score: 0 }))
: results.map((item) => ({ statement: item, score: 0 }));
}
// Sort by score (descending)
scoredResults.sort((a, b) => b.score - a.score);
// Calculate statistics to identify low-quality results
const scores = scoredResults.map((item) => item.score);
const maxScore = Math.max(...scores);
const minScore = Math.min(...scores);
const scoreRange = maxScore - minScore;
let threshold = 0;
if (isRRF || scoreRange < 0.01) {
// For RRF scores, use a more lenient adaptive approach
// Calculate median score and use a dynamic threshold based on score distribution
const sortedScores = [...scores].sort((a, b) => b - a);
const medianIndex = Math.floor(sortedScores.length / 2);
const medianScore = sortedScores[medianIndex];
// Use the smaller of: 20% of max score or 50% of median score
// This is more lenient for broad queries while still filtering noise
const maxBasedThreshold = maxScore * 0.2;
const medianBasedThreshold = medianScore * 0.5;
threshold = Math.min(maxBasedThreshold, medianBasedThreshold);
// Ensure we keep at least minResults if available
const minResultsCount = Math.min(
options.minResults,
scoredResults.length,
);
if (scoredResults.length >= minResultsCount) {
const minResultsThreshold = scoredResults[minResultsCount - 1].score;
threshold = Math.min(threshold, minResultsThreshold);
}
} else {
// For normal score distributions, use the relative threshold approach
const relativeThreshold = options.scoreThreshold || 0.3;
const absoluteMinimum = 0.1;
threshold = Math.max(
absoluteMinimum,
minScore + scoreRange * relativeThreshold,
);
}
// Filter out low-quality results
const filteredResults = scoredResults
.filter((item) => item.score >= threshold)
.map((item) => ({ statement: item.statement, score: item.score }));
// Apply limit if specified
const limitedResults =
options.limit > 0
? filteredResults.slice(
0,
Math.min(filteredResults.length, options.limit),
)
: filteredResults;
logger.info(
`Quality filtering: ${limitedResults.length}/${results.length} results kept (threshold: ${threshold.toFixed(3)})`,
);
logger.info(
`Score range: min=${minScore.toFixed(3)}, max=${maxScore.toFixed(3)}, threshold=${threshold.toFixed(3)}`,
);
return limitedResults;
}
/**
* Apply the selected reranking strategy to search results
*/
private async rerankResults(
query: string,
userId: string,
results: {
bm25: StatementNode[];
vector: StatementNode[];
bfs: StatementNode[];
},
options: Required<SearchOptions>,
): Promise<StatementNode[]> {
// Fetch user profile for context
const user = await prisma.user.findUnique({
where: { id: userId },
select: { name: true, id: true },
});
const userContext = user
? { name: user.name ?? undefined, userId: user.id }
: undefined;
return applyLLMReranking(query, results, options.limit, userContext);
}
private async logRecallAsync(
query: string,
userId: string,
@ -545,6 +468,528 @@ export class SearchService {
return result;
}
/**
* Extract episodes with provenance tracking from all search sources
* Deduplicates episodes and tracks which statements came from which source
*/
private async extractEpisodesWithProvenance(sources: {
episodeGraph: EpisodeGraphResult[];
bfs: StatementNode[];
vector: StatementNode[];
bm25: StatementNode[];
}): Promise<EpisodeWithProvenance[]> {
const episodeMap = new Map<string, EpisodeWithProvenance>();
// Process Episode Graph results (already episode-grouped)
sources.episodeGraph.forEach((result) => {
const episodeId = result.episode.uuid;
if (!episodeMap.has(episodeId)) {
episodeMap.set(episodeId, {
episode: result.episode,
statements: [],
episodeGraphScore: result.score,
bfsScore: 0,
vectorScore: 0,
bm25Score: 0,
sourceBreakdown: { fromEpisodeGraph: 0, fromBFS: 0, fromVector: 0, fromBM25: 0 },
});
}
const ep = episodeMap.get(episodeId)!;
result.statements.forEach((statement) => {
ep.statements.push({
statement,
sources: {
episodeGraph: {
score: result.score,
entityMatches: result.metrics.entityMatchCount,
},
},
primarySource: 'episodeGraph',
});
ep.sourceBreakdown.fromEpisodeGraph++;
});
});
// Process BFS statements (need to group by episode)
const bfsStatementsByEpisode = await groupStatementsByEpisode(sources.bfs);
const bfsEpisodeIds = Array.from(bfsStatementsByEpisode.keys());
const bfsEpisodes = await getEpisodesByUuids(bfsEpisodeIds);
bfsStatementsByEpisode.forEach((statements, episodeId) => {
if (!episodeMap.has(episodeId)) {
const episode = bfsEpisodes.get(episodeId);
if (!episode) return;
episodeMap.set(episodeId, {
episode,
statements: [],
episodeGraphScore: 0,
bfsScore: 0,
vectorScore: 0,
bm25Score: 0,
sourceBreakdown: { fromEpisodeGraph: 0, fromBFS: 0, fromVector: 0, fromBM25: 0 },
});
}
const ep = episodeMap.get(episodeId)!;
statements.forEach((statement) => {
const hopDistance = (statement as any).bfsHopDistance || 4;
const bfsRelevance = (statement as any).bfsRelevance || 0;
// Check if this statement already exists (from episode graph)
const existing = ep.statements.find((s) => s.statement.uuid === statement.uuid);
if (existing) {
// Add BFS source to existing statement
existing.sources.bfs = { score: bfsRelevance, hopDistance, relevance: bfsRelevance };
} else {
// New statement from BFS
ep.statements.push({
statement,
sources: { bfs: { score: bfsRelevance, hopDistance, relevance: bfsRelevance } },
primarySource: 'bfs',
});
ep.sourceBreakdown.fromBFS++;
}
// Aggregate BFS score for episode with hop multiplier
const hopMultiplier =
hopDistance === 1 ? 2.0 : hopDistance === 2 ? 1.3 : hopDistance === 3 ? 1.0 : 0.8;
ep.bfsScore += bfsRelevance * hopMultiplier;
});
// Average BFS score
if (statements.length > 0) {
ep.bfsScore /= statements.length;
}
});
// Process Vector statements
const vectorStatementsByEpisode = await groupStatementsByEpisode(sources.vector);
const vectorEpisodeIds = Array.from(vectorStatementsByEpisode.keys());
const vectorEpisodes = await getEpisodesByUuids(vectorEpisodeIds);
vectorStatementsByEpisode.forEach((statements, episodeId) => {
if (!episodeMap.has(episodeId)) {
const episode = vectorEpisodes.get(episodeId);
if (!episode) return;
episodeMap.set(episodeId, {
episode,
statements: [],
episodeGraphScore: 0,
bfsScore: 0,
vectorScore: 0,
bm25Score: 0,
sourceBreakdown: { fromEpisodeGraph: 0, fromBFS: 0, fromVector: 0, fromBM25: 0 },
});
}
const ep = episodeMap.get(episodeId)!;
statements.forEach((statement) => {
const vectorScore = (statement as any).vectorScore || (statement as any).similarity || 0;
const existing = ep.statements.find((s) => s.statement.uuid === statement.uuid);
if (existing) {
existing.sources.vector = { score: vectorScore, similarity: vectorScore };
} else {
ep.statements.push({
statement,
sources: { vector: { score: vectorScore, similarity: vectorScore } },
primarySource: 'vector',
});
ep.sourceBreakdown.fromVector++;
}
ep.vectorScore += vectorScore;
});
if (statements.length > 0) {
ep.vectorScore /= statements.length;
}
});
// Process BM25 statements
const bm25StatementsByEpisode = await groupStatementsByEpisode(sources.bm25);
const bm25EpisodeIds = Array.from(bm25StatementsByEpisode.keys());
const bm25Episodes = await getEpisodesByUuids(bm25EpisodeIds);
bm25StatementsByEpisode.forEach((statements, episodeId) => {
if (!episodeMap.has(episodeId)) {
const episode = bm25Episodes.get(episodeId);
if (!episode) return;
episodeMap.set(episodeId, {
episode,
statements: [],
episodeGraphScore: 0,
bfsScore: 0,
vectorScore: 0,
bm25Score: 0,
sourceBreakdown: { fromEpisodeGraph: 0, fromBFS: 0, fromVector: 0, fromBM25: 0 },
});
}
const ep = episodeMap.get(episodeId)!;
statements.forEach((statement) => {
const bm25Score = (statement as any).bm25Score || (statement as any).score || 0;
const existing = ep.statements.find((s) => s.statement.uuid === statement.uuid);
if (existing) {
existing.sources.bm25 = { score: bm25Score, rank: statements.indexOf(statement) };
} else {
ep.statements.push({
statement,
sources: { bm25: { score: bm25Score, rank: statements.indexOf(statement) } },
primarySource: 'bm25',
});
ep.sourceBreakdown.fromBM25++;
}
ep.bm25Score += bm25Score;
});
if (statements.length > 0) {
ep.bm25Score /= statements.length;
}
});
return Array.from(episodeMap.values());
}
/**
* Rate episodes by source hierarchy: Episode Graph > BFS > Vector > BM25
*/
private rateEpisodesBySource(episodes: EpisodeWithProvenance[]): EpisodeWithProvenance[] {
return episodes
.map((ep) => {
// Hierarchical scoring: EpisodeGraph > BFS > Vector > BM25
let firstLevelScore = 0;
// Episode Graph: Highest weight (5.0)
if (ep.episodeGraphScore > 0) {
firstLevelScore += ep.episodeGraphScore * 5.0;
}
// BFS: Second highest (3.0), already hop-weighted in extraction
if (ep.bfsScore > 0) {
firstLevelScore += ep.bfsScore * 3.0;
}
// Vector: Third (1.5)
if (ep.vectorScore > 0) {
firstLevelScore += ep.vectorScore * 1.5;
}
// BM25: Lowest (0.2), only significant if others missing
// Reduced from 0.5 to 0.2 to prevent keyword noise from dominating
if (ep.bm25Score > 0) {
firstLevelScore += ep.bm25Score * 0.2;
}
// Concentration bonus: More statements = higher confidence
const concentrationBonus = Math.log(1 + ep.statements.length) * 0.3;
firstLevelScore *= 1 + concentrationBonus;
return {
...ep,
firstLevelScore,
};
})
.sort((a, b) => (b.firstLevelScore || 0) - (a.firstLevelScore || 0));
}
/**
* Filter episodes by quality, not by model capability
* Returns empty if no high-quality matches found
*/
private filterByQuality(
ratedEpisodes: EpisodeWithProvenance[],
query: string,
baseQualityThreshold: number = QUALITY_THRESHOLDS.HIGH_QUALITY_EPISODE,
): QualityFilterResult {
// Adaptive threshold based on available sources
// This prevents filtering out ALL results when only Vector/BM25 are available
const hasEpisodeGraph = ratedEpisodes.some((ep) => ep.episodeGraphScore > 0);
const hasBFS = ratedEpisodes.some((ep) => ep.bfsScore > 0);
const hasVector = ratedEpisodes.some((ep) => ep.vectorScore > 0);
const hasBM25 = ratedEpisodes.some((ep) => ep.bm25Score > 0);
let qualityThreshold: number;
if (hasEpisodeGraph || hasBFS) {
// Graph-based results available - use high threshold (5.0)
// Max possible score with Episode Graph: ~10+ (5.0 * 2.0)
// Max possible score with BFS: ~6+ (2.0 * 3.0)
qualityThreshold = 5.0;
} else if (hasVector) {
// Only semantic vector search - use medium threshold (1.0)
// Max possible score with Vector: ~1.5 (1.0 * 1.5)
qualityThreshold = 1.0;
} else if (hasBM25) {
// Only keyword BM25 - use low threshold (0.3)
// Max possible score with BM25: ~0.5 (1.0 * 0.5)
qualityThreshold = 0.3;
} else {
// No results at all
logger.warn(`No results from any source for query: "${query}"`);
return {
episodes: [],
confidence: 0,
message: 'No relevant information found in memory',
};
}
logger.info(
`Adaptive quality threshold: ${qualityThreshold.toFixed(1)} ` +
`(EpisodeGraph: ${hasEpisodeGraph}, BFS: ${hasBFS}, Vector: ${hasVector}, BM25: ${hasBM25})`,
);
// 1. Filter to high-quality episodes only
const highQualityEpisodes = ratedEpisodes.filter(
(ep) => (ep.firstLevelScore || 0) >= qualityThreshold,
);
if (highQualityEpisodes.length === 0) {
logger.info(`No high-quality matches for query: "${query}" (threshold: ${qualityThreshold})`);
return {
episodes: [],
confidence: 0,
message: 'No relevant information found in memory',
};
}
// 2. Apply score gap detection to find natural cutoff
const scores = highQualityEpisodes.map((ep) => ep.firstLevelScore || 0);
const gapCutoff = this.findScoreGapForEpisodes(scores);
// 3. Take episodes up to the gap
const filteredEpisodes = highQualityEpisodes.slice(0, gapCutoff);
// 4. Calculate overall confidence with adaptive normalization
const confidence = this.calculateConfidence(filteredEpisodes);
logger.info(
`Quality filtering: ${filteredEpisodes.length}/${ratedEpisodes.length} episodes kept, ` +
`confidence: ${confidence.toFixed(2)}`,
);
return {
episodes: filteredEpisodes,
confidence,
message: `Found ${filteredEpisodes.length} relevant episodes`,
};
}
/**
* Calculate confidence score with adaptive normalization
* Uses different max expected scores based on DOMINANT source (not just presence)
*
* IMPORTANT: BM25 is NEVER considered dominant - it's a fallback, not a quality signal.
* When only Vector+BM25 exist, Vector is dominant.
*/
private calculateConfidence(filteredEpisodes: EpisodeWithProvenance[]): number {
if (filteredEpisodes.length === 0) return 0;
const avgScore =
filteredEpisodes.reduce((sum, ep) => sum + (ep.firstLevelScore || 0), 0) /
filteredEpisodes.length;
// Calculate average contribution from each source (weighted)
const avgEpisodeGraphScore =
filteredEpisodes.reduce((sum, ep) => sum + (ep.episodeGraphScore || 0), 0) /
filteredEpisodes.length;
const avgBFSScore =
filteredEpisodes.reduce((sum, ep) => sum + (ep.bfsScore || 0), 0) /
filteredEpisodes.length;
const avgVectorScore =
filteredEpisodes.reduce((sum, ep) => sum + (ep.vectorScore || 0), 0) /
filteredEpisodes.length;
const avgBM25Score =
filteredEpisodes.reduce((sum, ep) => sum + (ep.bm25Score || 0), 0) /
filteredEpisodes.length;
// Determine which source is dominant (weighted contribution to final score)
// BM25 is EXCLUDED from dominant source detection - it's a fallback mechanism
const episodeGraphContribution = avgEpisodeGraphScore * 5.0;
const bfsContribution = avgBFSScore * 3.0;
const vectorContribution = avgVectorScore * 1.5;
const bm25Contribution = avgBM25Score * 0.2;
let maxExpectedScore: number;
let dominantSource: string;
if (
episodeGraphContribution > bfsContribution &&
episodeGraphContribution > vectorContribution
) {
// Episode Graph is dominant source
maxExpectedScore = 25; // Typical range: 10-30
dominantSource = 'EpisodeGraph';
} else if (bfsContribution > vectorContribution) {
// BFS is dominant source
maxExpectedScore = 15; // Typical range: 5-15
dominantSource = 'BFS';
} else if (vectorContribution > 0) {
// Vector is dominant source (even if BM25 contribution is higher)
maxExpectedScore = 3; // Typical range: 1-3
dominantSource = 'Vector';
} else {
// ONLY BM25 results (Vector=0, BFS=0, EpisodeGraph=0)
// This should be rare and indicates low-quality keyword-only matches
maxExpectedScore = 1; // Typical range: 0.3-1
dominantSource = 'BM25';
}
const confidence = Math.min(1.0, avgScore / maxExpectedScore);
logger.info(
`Confidence: avgScore=${avgScore.toFixed(2)}, maxExpected=${maxExpectedScore}, ` +
`confidence=${confidence.toFixed(2)}, dominantSource=${dominantSource} ` +
`(Contributions: EG=${episodeGraphContribution.toFixed(2)}, ` +
`BFS=${bfsContribution.toFixed(2)}, Vec=${vectorContribution.toFixed(2)}, ` +
`BM25=${bm25Contribution.toFixed(2)})`,
);
return confidence;
}
/**
* Find score gap in episode scores (similar to statement gap detection)
*/
private findScoreGapForEpisodes(scores: number[], minResults: number = 3): number {
if (scores.length <= minResults) {
return scores.length;
}
// Find largest relative gap after minResults
for (let i = minResults - 1; i < scores.length - 1; i++) {
const currentScore = scores[i];
const nextScore = scores[i + 1];
if (currentScore === 0) break;
const gap = currentScore - nextScore;
const relativeGap = gap / currentScore;
// If we find a cliff (>50% drop), cut there
if (relativeGap > QUALITY_THRESHOLDS.MINIMUM_GAP_RATIO) {
logger.info(
`Episode gap detected at position ${i}: ${currentScore.toFixed(3)}${nextScore.toFixed(3)} ` +
`(${(relativeGap * 100).toFixed(1)}% drop)`,
);
return i + 1; // Return count (index + 1)
}
}
logger.info(`No significant gap found in episode scores`);
// No significant gap found, return all
return scores.length;
}
/**
* Validate episodes with LLM for borderline confidence cases
* Only used when confidence is between 0.3 and 0.7
*/
private async validateEpisodesWithLLM(
query: string,
episodes: EpisodeWithProvenance[],
maxEpisodes: number = 20,
): Promise<EpisodeWithProvenance[]> {
const candidatesForValidation = episodes.slice(0, maxEpisodes);
const prompt = `Given user query, validate which episodes are truly relevant.
Query: "${query}"
Episodes (showing episode metadata and top statements):
${candidatesForValidation
.map(
(ep, i) => `
${i + 1}. Episode: ${ep.episode.content || 'Untitled'} (${new Date(ep.episode.createdAt).toLocaleDateString()})
First-level score: ${ep.firstLevelScore?.toFixed(2)}
Sources: ${ep.sourceBreakdown.fromEpisodeGraph} EpisodeGraph, ${ep.sourceBreakdown.fromBFS} BFS, ${ep.sourceBreakdown.fromVector} Vector, ${ep.sourceBreakdown.fromBM25} BM25
Total statements: ${ep.statements.length}
Top statements:
${ep.statements
.slice(0, 5)
.map((s, idx) => ` ${idx + 1}) ${s.statement.fact}`)
.join('\n')}
`,
)
.join('\n')}
Task: Validate which episodes DIRECTLY answer the query intent.
IMPORTANT RULES:
1. ONLY include episodes that contain information directly relevant to answering the query
2. If NONE of the episodes answer the query, return an empty array: []
3. Do NOT include episodes just because they share keywords with the query
4. Consider source quality: EpisodeGraph > BFS > Vector > BM25
Examples:
- Query "what is user name?" Only include episodes that explicitly state a user's name
- Query "user home address" Only include episodes with actual address information
- Query "random keywords" Return [] if no episodes match semantically
Output format:
<output>
{
"valid_episodes": [1, 3, 5]
}
</output>
If NO episodes are relevant to the query, return:
<output>
{
"valid_episodes": []
}
</output>`;
try {
let responseText = '';
await makeModelCall(
false,
[{ role: 'user', content: prompt }],
(text) => {
responseText = text;
},
{ temperature: 0.2, maxTokens: 500 },
'low',
);
// Parse LLM response
const outputMatch = /<output>([\s\S]*?)<\/output>/i.exec(responseText);
if (!outputMatch?.[1]) {
logger.warn('LLM validation returned no output, using all episodes');
return episodes;
}
const result = JSON.parse(outputMatch[1]);
const validIndices = result.valid_episodes || [];
if (validIndices.length === 0) {
logger.info('LLM validation: No episodes deemed relevant');
return [];
}
logger.info(`LLM validation: ${validIndices.length}/${candidatesForValidation.length} episodes validated`);
// Return validated episodes
return validIndices.map((idx: number) => candidatesForValidation[idx - 1]).filter(Boolean);
} catch (error) {
logger.error('LLM validation failed:', { error });
// Fallback: return original episodes
return episodes;
}
}
}
/**
@ -564,4 +1009,73 @@ export interface SearchOptions {
spaceIds?: string[]; // Filter results by specific spaces
adaptiveFiltering?: boolean;
structured?: boolean; // Return structured JSON instead of markdown (default: false)
useLLMValidation?: boolean; // Use LLM to validate episodes for borderline confidence cases (default: false)
qualityThreshold?: number; // Minimum episode score to be considered high-quality (default: 5.0)
maxEpisodesForLLM?: number; // Maximum episodes to send for LLM validation (default: 20)
}
/**
* Statement with source provenance tracking
*/
interface StatementWithSource {
statement: StatementNode;
sources: {
episodeGraph?: { score: number; entityMatches: number };
bfs?: { score: number; hopDistance: number; relevance: number };
vector?: { score: number; similarity: number };
bm25?: { score: number; rank: number };
};
primarySource: 'episodeGraph' | 'bfs' | 'vector' | 'bm25';
}
/**
* Episode with provenance tracking from multiple sources
*/
interface EpisodeWithProvenance {
episode: EpisodicNode;
statements: StatementWithSource[];
// Aggregated scores from each source
episodeGraphScore: number;
bfsScore: number;
vectorScore: number;
bm25Score: number;
// Source distribution
sourceBreakdown: {
fromEpisodeGraph: number;
fromBFS: number;
fromVector: number;
fromBM25: number;
};
// First-level rating score (hierarchical)
firstLevelScore?: number;
}
/**
* Quality filtering result
*/
interface QualityFilterResult {
episodes: EpisodeWithProvenance[];
confidence: number;
message: string;
}
/**
* Quality thresholds for filtering
*/
const QUALITY_THRESHOLDS = {
// Adaptive episode-level scoring (based on available sources)
HIGH_QUALITY_EPISODE: 5.0, // For Episode Graph or BFS results (max score ~10+)
MEDIUM_QUALITY_EPISODE: 1.0, // For Vector-only results (max score ~1.5)
LOW_QUALITY_EPISODE: 0.3, // For BM25-only results (max score ~0.5)
// Overall result confidence
CONFIDENT_RESULT: 0.7, // High confidence, skip LLM validation
UNCERTAIN_RESULT: 0.3, // Borderline, use LLM validation
NO_RESULT: 0.3, // Too low, return empty
// Score gap detection
MINIMUM_GAP_RATIO: 0.5, // 50% score drop = gap
};

View File

@ -51,6 +51,7 @@ export async function performBM25Search(
${spaceCondition}
OPTIONAL MATCH (episode:Episode)-[:HAS_PROVENANCE]->(s)
WITH s, score, count(episode) as provenanceCount
WHERE score >= 0.5
RETURN s, score, provenanceCount
ORDER BY score DESC
`;
@ -71,6 +72,12 @@ export async function performBM25Search(
typeof provenanceCountValue === "bigint"
? Number(provenanceCountValue)
: (provenanceCountValue?.toNumber?.() ?? provenanceCountValue ?? 0);
const scoreValue = record.get("score");
(statement as any).bm25Score =
typeof scoreValue === "number"
? scoreValue
: (scoreValue?.toNumber?.() ?? 0);
return statement;
});
} catch (error) {
@ -163,6 +170,14 @@ export async function performVectorSearch(
typeof provenanceCountValue === "bigint"
? Number(provenanceCountValue)
: (provenanceCountValue?.toNumber?.() ?? provenanceCountValue ?? 0);
// Preserve vector similarity score for empty result detection
const scoreValue = record.get("score");
(statement as any).vectorScore =
typeof scoreValue === "number"
? scoreValue
: (scoreValue?.toNumber?.() ?? 0);
return statement;
});
} catch (error) {
@ -179,12 +194,10 @@ export async function performBfsSearch(
query: string,
embedding: Embedding,
userId: string,
entities: EntityNode[],
options: Required<SearchOptions>,
): Promise<StatementNode[]> {
try {
// 1. Extract potential entities from query using chunked embeddings
const entities = await extractEntitiesFromQuery(query, userId);
if (entities.length === 0) {
return [];
}
@ -224,7 +237,7 @@ async function bfsTraversal(
const RELEVANCE_THRESHOLD = 0.5;
const EXPLORATION_THRESHOLD = 0.3;
const allStatements = new Map<string, number>(); // uuid -> relevance
const allStatements = new Map<string, { relevance: number; hopDistance: number }>(); // uuid -> {relevance, hopDistance}
const visitedEntities = new Set<string>();
// Track entities per level for iterative BFS
@ -268,14 +281,14 @@ async function bfsTraversal(
...(startTime && { startTime: startTime.toISOString() }),
});
// Store statement relevance scores
// Store statement relevance scores and hop distance
const currentLevelStatementUuids: string[] = [];
for (const record of records) {
const uuid = record.get("uuid");
const relevance = record.get("relevance");
if (!allStatements.has(uuid)) {
allStatements.set(uuid, relevance);
allStatements.set(uuid, { relevance, hopDistance: depth + 1 }); // Store hop distance (1-indexed)
currentLevelStatementUuids.push(uuid);
}
}
@ -304,25 +317,45 @@ async function bfsTraversal(
}
// Filter by relevance threshold and fetch full statements
const relevantUuids = Array.from(allStatements.entries())
.filter(([_, relevance]) => relevance >= RELEVANCE_THRESHOLD)
.sort((a, b) => b[1] - a[1])
.map(([uuid]) => uuid);
const relevantResults = Array.from(allStatements.entries())
.filter(([_, data]) => data.relevance >= RELEVANCE_THRESHOLD)
.sort((a, b) => b[1].relevance - a[1].relevance);
if (relevantUuids.length === 0) {
if (relevantResults.length === 0) {
return [];
}
const relevantUuids = relevantResults.map(([uuid]) => uuid);
const fetchCypher = `
MATCH (s:Statement{userId: $userId})
WHERE s.uuid IN $uuids
RETURN s
`;
const fetchRecords = await runQuery(fetchCypher, { uuids: relevantUuids, userId });
const statements = fetchRecords.map(r => r.get("s").properties as StatementNode);
const statementMap = new Map(
fetchRecords.map(r => [r.get("s").properties.uuid, r.get("s").properties as StatementNode])
);
// Attach hop distance to statements
const statements = relevantResults.map(([uuid, data]) => {
const statement = statementMap.get(uuid)!;
// Add bfsHopDistance and bfsRelevance as metadata
(statement as any).bfsHopDistance = data.hopDistance;
(statement as any).bfsRelevance = data.relevance;
return statement;
});
const hopCounts = statements.reduce((acc, s) => {
const hop = (s as any).bfsHopDistance;
acc[hop] = (acc[hop] || 0) + 1;
return acc;
}, {} as Record<number, number>);
logger.info(
`BFS: explored ${allStatements.size} statements across ${maxDepth} hops, returning ${statements.length} (≥${RELEVANCE_THRESHOLD})`
`BFS: explored ${allStatements.size} statements across ${maxDepth} hops, ` +
`returning ${statements.length} (≥${RELEVANCE_THRESHOLD}) - ` +
`1-hop: ${hopCounts[1] || 0}, 2-hop: ${hopCounts[2] || 0}, 3-hop: ${hopCounts[3] || 0}, 4-hop: ${hopCounts[4] || 0}`
);
return statements;
@ -361,15 +394,22 @@ function generateQueryChunks(query: string): string[] {
export async function extractEntitiesFromQuery(
query: string,
userId: string,
startEntities: string[] = [],
): Promise<EntityNode[]> {
try {
// Generate chunks from query
const chunks = generateQueryChunks(query);
// Get embeddings for each chunk
const chunkEmbeddings = await Promise.all(
chunks.map(chunk => getEmbedding(chunk))
);
let chunkEmbeddings: Embedding[] = [];
if (startEntities.length === 0) {
// Generate chunks from query
const chunks = generateQueryChunks(query);
// Get embeddings for each chunk
chunkEmbeddings = await Promise.all(
chunks.map(chunk => getEmbedding(chunk))
);
} else {
chunkEmbeddings = await Promise.all(
startEntities.map(chunk => getEmbedding(chunk))
);
}
// Search for entities matching each chunk embedding
const allEntitySets = await Promise.all(
@ -425,3 +465,280 @@ export async function getEpisodesByStatements(
const records = await runQuery(cypher, params);
return records.map((record) => record.get("e").properties as EpisodicNode);
}
/**
* Episode Graph Search Result
*/
export interface EpisodeGraphResult {
episode: EpisodicNode;
statements: StatementNode[];
score: number;
metrics: {
entityMatchCount: number;
totalStatementCount: number;
avgRelevance: number;
connectivityScore: number;
};
}
/**
* Perform episode-centric graph search
* Finds episodes with dense subgraphs of statements connected to query entities
*/
export async function performEpisodeGraphSearch(
query: string,
queryEntities: EntityNode[],
queryEmbedding: Embedding,
userId: string,
options: Required<SearchOptions>,
): Promise<EpisodeGraphResult[]> {
try {
// If no entities extracted, return empty
if (queryEntities.length === 0) {
logger.info("Episode graph search: no entities extracted from query");
return [];
}
const queryEntityIds = queryEntities.map(e => e.uuid);
logger.info(`Episode graph search: ${queryEntityIds.length} query entities`, {
entities: queryEntities.map(e => e.name).join(', ')
});
// Timeframe condition for temporal filtering
let timeframeCondition = `
AND s.validAt <= $validAt
${options.includeInvalidated ? '' : 'AND (s.invalidAt IS NULL OR s.invalidAt > $validAt)'}
`;
if (options.startTime) {
timeframeCondition += ` AND s.validAt >= $startTime`;
}
// Space filtering if provided
let spaceCondition = "";
if (options.spaceIds.length > 0) {
spaceCondition = `
AND s.spaceIds IS NOT NULL AND ANY(spaceId IN $spaceIds WHERE spaceId IN s.spaceIds)
`;
}
const cypher = `
// Step 1: Find statements connected to query entities
MATCH (queryEntity:Entity)-[:HAS_SUBJECT|HAS_OBJECT|HAS_PREDICATE]-(s:Statement)
WHERE queryEntity.uuid IN $queryEntityIds
AND queryEntity.userId = $userId
AND s.userId = $userId
${timeframeCondition}
${spaceCondition}
// Step 2: Find episodes containing these statements
MATCH (s)<-[:HAS_PROVENANCE]-(ep:Episode)
// Step 3: Collect all statements from these episodes (for metrics only)
MATCH (ep)-[:HAS_PROVENANCE]->(epStatement:Statement)
WHERE epStatement.validAt <= $validAt
AND (epStatement.invalidAt IS NULL OR epStatement.invalidAt > $validAt)
${spaceCondition.replace(/s\./g, 'epStatement.')}
// Step 4: Calculate episode-level metrics
WITH ep,
collect(DISTINCT s) as entityMatchedStatements,
collect(DISTINCT epStatement) as allEpisodeStatements,
collect(DISTINCT queryEntity) as matchedEntities
// Step 5: Calculate semantic relevance for all episode statements
WITH ep,
entityMatchedStatements,
allEpisodeStatements,
matchedEntities,
[stmt IN allEpisodeStatements |
gds.similarity.cosine(stmt.factEmbedding, $queryEmbedding)
] as statementRelevances
// Step 6: Calculate aggregate scores
WITH ep,
entityMatchedStatements,
size(matchedEntities) as entityMatchCount,
size(entityMatchedStatements) as entityStmtCount,
size(allEpisodeStatements) as totalStmtCount,
reduce(sum = 0.0, score IN statementRelevances | sum + score) /
CASE WHEN size(statementRelevances) = 0 THEN 1 ELSE size(statementRelevances) END as avgRelevance
// Step 7: Calculate connectivity score
WITH ep,
entityMatchedStatements,
entityMatchCount,
entityStmtCount,
totalStmtCount,
avgRelevance,
(toFloat(entityStmtCount) / CASE WHEN totalStmtCount = 0 THEN 1 ELSE totalStmtCount END) *
entityMatchCount as connectivityScore
// Step 8: Filter for quality episodes
WHERE entityMatchCount >= 1
AND avgRelevance >= 0.5
AND totalStmtCount >= 1
// Step 9: Calculate final episode score
WITH ep,
entityMatchedStatements,
entityMatchCount,
totalStmtCount,
avgRelevance,
connectivityScore,
// Prioritize: entity matches (2.0x) + connectivity + semantic relevance
(entityMatchCount * 2.0) + connectivityScore + avgRelevance as episodeScore
// Step 10: Return ranked episodes with ONLY entity-matched statements
RETURN ep,
entityMatchedStatements as statements,
entityMatchCount,
totalStmtCount,
avgRelevance,
connectivityScore,
episodeScore
ORDER BY episodeScore DESC, entityMatchCount DESC, totalStmtCount DESC
LIMIT 20
`;
const params = {
queryEntityIds,
userId,
queryEmbedding,
validAt: options.endTime.toISOString(),
...(options.startTime && { startTime: options.startTime.toISOString() }),
...(options.spaceIds.length > 0 && { spaceIds: options.spaceIds }),
};
const records = await runQuery(cypher, params);
const results: EpisodeGraphResult[] = records.map((record) => {
const episode = record.get("ep").properties as EpisodicNode;
const statements = record.get("statements").map((s: any) => s.properties as StatementNode);
const entityMatchCount = typeof record.get("entityMatchCount") === 'bigint'
? Number(record.get("entityMatchCount"))
: record.get("entityMatchCount");
const totalStmtCount = typeof record.get("totalStmtCount") === 'bigint'
? Number(record.get("totalStmtCount"))
: record.get("totalStmtCount");
const avgRelevance = record.get("avgRelevance");
const connectivityScore = record.get("connectivityScore");
const episodeScore = record.get("episodeScore");
return {
episode,
statements,
score: episodeScore,
metrics: {
entityMatchCount,
totalStatementCount: totalStmtCount,
avgRelevance,
connectivityScore,
},
};
});
// Log statement counts for debugging
results.forEach((result, idx) => {
logger.info(
`Episode ${idx + 1}: entityMatches=${result.metrics.entityMatchCount}, ` +
`totalStmtCount=${result.metrics.totalStatementCount}, ` +
`returnedStatements=${result.statements.length}`
);
});
logger.info(
`Episode graph search: found ${results.length} episodes, ` +
`top score: ${results[0]?.score.toFixed(2) || 'N/A'}`
);
return results;
} catch (error) {
logger.error("Episode graph search error:", { error });
return [];
}
}
/**
* Get episode IDs for statements in batch (efficient, no N+1 queries)
*/
export async function getEpisodeIdsForStatements(
statementUuids: string[]
): Promise<Map<string, string>> {
if (statementUuids.length === 0) {
return new Map();
}
const cypher = `
MATCH (s:Statement)<-[:HAS_PROVENANCE]-(e:Episode)
WHERE s.uuid IN $statementUuids
RETURN s.uuid as statementUuid, e.uuid as episodeUuid
`;
const records = await runQuery(cypher, { statementUuids });
const map = new Map<string, string>();
records.forEach(record => {
map.set(record.get('statementUuid'), record.get('episodeUuid'));
});
return map;
}
/**
* Group statements by their episode IDs efficiently
*/
export async function groupStatementsByEpisode(
statements: StatementNode[]
): Promise<Map<string, StatementNode[]>> {
const grouped = new Map<string, StatementNode[]>();
if (statements.length === 0) {
return grouped;
}
// Batch fetch episode IDs for all statements
const episodeIdMap = await getEpisodeIdsForStatements(
statements.map(s => s.uuid)
);
// Group statements by episode ID
statements.forEach((statement) => {
const episodeId = episodeIdMap.get(statement.uuid);
if (episodeId) {
if (!grouped.has(episodeId)) {
grouped.set(episodeId, []);
}
grouped.get(episodeId)!.push(statement);
}
});
return grouped;
}
/**
* Fetch episode objects by their UUIDs in batch
*/
export async function getEpisodesByUuids(
episodeUuids: string[]
): Promise<Map<string, EpisodicNode>> {
if (episodeUuids.length === 0) {
return new Map();
}
const cypher = `
MATCH (e:Episode)
WHERE e.uuid IN $episodeUuids
RETURN e
`;
const records = await runQuery(cypher, { episodeUuids });
const map = new Map<string, EpisodicNode>();
records.forEach(record => {
const episode = record.get('e').properties as EpisodicNode;
map.set(episode.uuid, episode);
});
return map;
}

View File

@ -0,0 +1,122 @@
import { task } from "@trigger.dev/sdk";
import { z } from "zod";
import { IngestionQueue, IngestionStatus } from "@core/database";
import { logger } from "~/services/logger.service";
import { prisma } from "../utils/prisma";
import { IngestBodyRequest, ingestTask } from "./ingest";
export const RetryNoCreditBodyRequest = z.object({
workspaceId: z.string(),
});
// Register the Trigger.dev task to retry NO_CREDITS episodes
export const retryNoCreditsTask = task({
id: "retry-no-credits-episodes",
run: async (payload: z.infer<typeof RetryNoCreditBodyRequest>) => {
try {
logger.log(
`Starting retry of NO_CREDITS episodes for workspace ${payload.workspaceId}`,
);
// Find all ingestion queue items with NO_CREDITS status for this workspace
const noCreditItems = await prisma.ingestionQueue.findMany({
where: {
workspaceId: payload.workspaceId,
status: IngestionStatus.NO_CREDITS,
},
orderBy: {
createdAt: "asc", // Process oldest first
},
include: {
workspace: true,
},
});
if (noCreditItems.length === 0) {
logger.log(
`No NO_CREDITS episodes found for workspace ${payload.workspaceId}`,
);
return {
success: true,
message: "No episodes to retry",
retriedCount: 0,
};
}
logger.log(
`Found ${noCreditItems.length} NO_CREDITS episodes to retry`,
);
const results = {
total: noCreditItems.length,
retriggered: 0,
failed: 0,
errors: [] as Array<{ queueId: string; error: string }>,
};
// Process each item
for (const item of noCreditItems) {
try {
const queueData = item.data as z.infer<typeof IngestBodyRequest>;
// Reset status to PENDING and clear error
await prisma.ingestionQueue.update({
where: { id: item.id },
data: {
status: IngestionStatus.PENDING,
error: null,
retryCount: item.retryCount + 1,
},
});
// Trigger the ingestion task
await ingestTask.trigger({
body: queueData,
userId: item.workspace?.userId as string,
workspaceId: payload.workspaceId,
queueId: item.id,
});
results.retriggered++;
logger.log(
`Successfully retriggered episode ${item.id} (retry #${item.retryCount + 1})`,
);
} catch (error: any) {
results.failed++;
results.errors.push({
queueId: item.id,
error: error.message,
});
logger.error(`Failed to retrigger episode ${item.id}:`, error);
// Update the item to mark it as failed
await prisma.ingestionQueue.update({
where: { id: item.id },
data: {
status: IngestionStatus.FAILED,
error: `Retry failed: ${error.message}`,
},
});
}
}
logger.log(
`Completed retry for workspace ${payload.workspaceId}. Retriggered: ${results.retriggered}, Failed: ${results.failed}`,
);
return {
success: true,
...results,
};
} catch (err: any) {
logger.error(
`Error retrying NO_CREDITS episodes for workspace ${payload.workspaceId}:`,
err,
);
return {
success: false,
error: err.message,
};
}
},
});