mirror of
https://github.com/eliasstepanik/core.git
synced 2026-01-10 23:48:26 +00:00
refactor: implement hierarchical search ranking with episode graph and source tracking
This commit is contained in:
parent
f39c7cc6d0
commit
5b31c8ed62
@ -1,13 +1,16 @@
|
||||
import type { EpisodicNode, StatementNode } from "@core/types";
|
||||
import type { EntityNode, EpisodicNode, StatementNode } from "@core/types";
|
||||
import { logger } from "./logger.service";
|
||||
import { applyLLMReranking } from "./search/rerank";
|
||||
import {
|
||||
getEpisodesByStatements,
|
||||
performBfsSearch,
|
||||
performBM25Search,
|
||||
performVectorSearch,
|
||||
performEpisodeGraphSearch,
|
||||
extractEntitiesFromQuery,
|
||||
groupStatementsByEpisode,
|
||||
getEpisodesByUuids,
|
||||
type EpisodeGraphResult,
|
||||
} from "./search/utils";
|
||||
import { getEmbedding } from "~/lib/model.server";
|
||||
import { getEmbedding, makeModelCall } from "~/lib/model.server";
|
||||
import { prisma } from "~/db.server";
|
||||
import { runQuery } from "~/lib/neo4j.server";
|
||||
|
||||
@ -63,35 +66,100 @@ export class SearchService {
|
||||
spaceIds: options.spaceIds || [],
|
||||
adaptiveFiltering: options.adaptiveFiltering || false,
|
||||
structured: options.structured || false,
|
||||
useLLMValidation: options.useLLMValidation || true,
|
||||
qualityThreshold: options.qualityThreshold || 0.3,
|
||||
maxEpisodesForLLM: options.maxEpisodesForLLM || 20,
|
||||
};
|
||||
|
||||
// Enhance query with LLM to transform keyword soup into semantic query
|
||||
|
||||
const queryVector = await this.getEmbedding(query);
|
||||
|
||||
// 1. Run parallel search methods
|
||||
const [bm25Results, vectorResults, bfsResults] = await Promise.all([
|
||||
// Note: We still need to extract entities from graph for Episode Graph search
|
||||
// The LLM entities are just strings, we need EntityNode objects from the graph
|
||||
const entities = await extractEntitiesFromQuery(query, userId, []);
|
||||
logger.info(`Extracted entities ${entities.map((e: EntityNode) => e.name).join(', ')}`);
|
||||
|
||||
// 1. Run parallel search methods (including episode graph search) using enhanced query
|
||||
const [bm25Results, vectorResults, bfsResults, episodeGraphResults] = await Promise.all([
|
||||
performBM25Search(query, userId, opts),
|
||||
performVectorSearch(queryVector, userId, opts),
|
||||
performBfsSearch(query, queryVector, userId, opts),
|
||||
performBfsSearch(query, queryVector, userId, entities, opts),
|
||||
performEpisodeGraphSearch(query, entities, queryVector, userId, opts),
|
||||
]);
|
||||
|
||||
logger.info(
|
||||
`Search results - BM25: ${bm25Results.length}, Vector: ${vectorResults.length}, BFS: ${bfsResults.length}`,
|
||||
`Search results - BM25: ${bm25Results.length}, Vector: ${vectorResults.length}, BFS: ${bfsResults.length}, EpisodeGraph: ${episodeGraphResults.length}`,
|
||||
);
|
||||
|
||||
// 2. Apply reranking strategy
|
||||
const rankedStatements = await this.rerankResults(
|
||||
query,
|
||||
userId,
|
||||
{ bm25: bm25Results, vector: vectorResults, bfs: bfsResults },
|
||||
opts,
|
||||
// 2. TWO-STAGE RANKING PIPELINE: Quality-based filtering with hierarchical scoring
|
||||
|
||||
// Stage 1: Extract episodes with provenance tracking
|
||||
const episodesWithProvenance = await this.extractEpisodesWithProvenance({
|
||||
episodeGraph: episodeGraphResults,
|
||||
bfs: bfsResults,
|
||||
vector: vectorResults,
|
||||
bm25: bm25Results,
|
||||
});
|
||||
|
||||
logger.info(`Extracted ${episodesWithProvenance.length} unique episodes from all sources`);
|
||||
|
||||
// Stage 2: Rate episodes by source hierarchy (EpisodeGraph > BFS > Vector > BM25)
|
||||
const ratedEpisodes = this.rateEpisodesBySource(episodesWithProvenance);
|
||||
|
||||
// Stage 3: Filter by quality (not by model capability)
|
||||
const qualityThreshold = opts.qualityThreshold || QUALITY_THRESHOLDS.HIGH_QUALITY_EPISODE;
|
||||
const qualityFilter = this.filterByQuality(ratedEpisodes, query, qualityThreshold);
|
||||
|
||||
// If no high-quality matches, return empty
|
||||
if (qualityFilter.confidence < QUALITY_THRESHOLDS.NO_RESULT) {
|
||||
logger.warn(`Low confidence (${qualityFilter.confidence.toFixed(2)}) for query: "${query}"`);
|
||||
return opts.structured
|
||||
? {
|
||||
episodes: [],
|
||||
facts: [],
|
||||
}
|
||||
: this.formatAsMarkdown([], []);
|
||||
}
|
||||
|
||||
// Stage 4: Optional LLM validation for borderline confidence
|
||||
let finalEpisodes = qualityFilter.episodes;
|
||||
const useLLMValidation = opts.useLLMValidation || false;
|
||||
|
||||
if (
|
||||
useLLMValidation &&
|
||||
qualityFilter.confidence >= QUALITY_THRESHOLDS.UNCERTAIN_RESULT &&
|
||||
qualityFilter.confidence < QUALITY_THRESHOLDS.CONFIDENT_RESULT
|
||||
) {
|
||||
logger.info(
|
||||
`Borderline confidence (${qualityFilter.confidence.toFixed(2)}), using LLM validation`,
|
||||
);
|
||||
|
||||
const maxEpisodesForLLM = opts.maxEpisodesForLLM || 20;
|
||||
finalEpisodes = await this.validateEpisodesWithLLM(
|
||||
query,
|
||||
qualityFilter.episodes,
|
||||
maxEpisodesForLLM,
|
||||
);
|
||||
|
||||
if (finalEpisodes.length === 0) {
|
||||
logger.info('LLM validation rejected all episodes, returning empty');
|
||||
return opts.structured ? { episodes: [], facts: [] } : this.formatAsMarkdown([], []);
|
||||
}
|
||||
}
|
||||
|
||||
// Extract episodes and statements for response
|
||||
const episodes = finalEpisodes.map((ep) => ep.episode);
|
||||
const filteredResults = finalEpisodes.flatMap((ep) =>
|
||||
ep.statements.map((s) => ({
|
||||
statement: s.statement,
|
||||
score: Number((ep.firstLevelScore || 0).toFixed(2)),
|
||||
})),
|
||||
);
|
||||
|
||||
// 3. Apply adaptive filtering based on score threshold and minimum count
|
||||
const filteredResults = this.applyAdaptiveFiltering(rankedStatements, opts);
|
||||
|
||||
// 3. Return top results
|
||||
const episodes = await getEpisodesByStatements(
|
||||
filteredResults.map((item) => item.statement),
|
||||
logger.info(
|
||||
`Final results: ${episodes.length} episodes, ${filteredResults.length} statements, ` +
|
||||
`confidence: ${qualityFilter.confidence.toFixed(2)}`,
|
||||
);
|
||||
|
||||
// Log recall asynchronously (don't await to avoid blocking response)
|
||||
@ -135,151 +203,6 @@ export class SearchService {
|
||||
return this.formatAsMarkdown(unifiedEpisodes, factsData);
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply adaptive filtering to ranked results
|
||||
* Uses a minimum quality threshold to filter out low-quality results
|
||||
*/
|
||||
private applyAdaptiveFiltering(
|
||||
results: StatementNode[],
|
||||
options: Required<SearchOptions>,
|
||||
): { statement: StatementNode; score: number }[] {
|
||||
if (results.length === 0) return [];
|
||||
|
||||
let isRRF = false;
|
||||
// Extract scores from results
|
||||
const scoredResults = results.map((result) => {
|
||||
// Find the score based on reranking strategy used
|
||||
let score = 0;
|
||||
if ((result as any).rrfScore !== undefined) {
|
||||
score = (result as any).rrfScore;
|
||||
isRRF = true;
|
||||
} else if ((result as any).mmrScore !== undefined) {
|
||||
score = (result as any).mmrScore;
|
||||
} else if ((result as any).crossEncoderScore !== undefined) {
|
||||
score = (result as any).crossEncoderScore;
|
||||
} else if ((result as any).finalScore !== undefined) {
|
||||
score = (result as any).finalScore;
|
||||
} else if ((result as any).multifactorScore !== undefined) {
|
||||
score = (result as any).multifactorScore;
|
||||
} else if ((result as any).combinedScore !== undefined) {
|
||||
score = (result as any).combinedScore;
|
||||
} else if ((result as any).mmrScore !== undefined) {
|
||||
score = (result as any).mmrScore;
|
||||
} else if ((result as any).cohereScore !== undefined) {
|
||||
score = (result as any).cohereScore;
|
||||
}
|
||||
|
||||
return { statement: result, score };
|
||||
});
|
||||
|
||||
if (!options.adaptiveFiltering || results.length <= 5) {
|
||||
return scoredResults;
|
||||
}
|
||||
|
||||
const hasScores = scoredResults.some((item) => item.score > 0);
|
||||
// If no scores are available, return the original results
|
||||
if (!hasScores) {
|
||||
logger.info("No scores found in results, skipping adaptive filtering");
|
||||
return options.limit > 0
|
||||
? results
|
||||
.slice(0, options.limit)
|
||||
.map((item) => ({ statement: item, score: 0 }))
|
||||
: results.map((item) => ({ statement: item, score: 0 }));
|
||||
}
|
||||
|
||||
// Sort by score (descending)
|
||||
scoredResults.sort((a, b) => b.score - a.score);
|
||||
|
||||
// Calculate statistics to identify low-quality results
|
||||
const scores = scoredResults.map((item) => item.score);
|
||||
const maxScore = Math.max(...scores);
|
||||
const minScore = Math.min(...scores);
|
||||
const scoreRange = maxScore - minScore;
|
||||
|
||||
let threshold = 0;
|
||||
if (isRRF || scoreRange < 0.01) {
|
||||
// For RRF scores, use a more lenient adaptive approach
|
||||
// Calculate median score and use a dynamic threshold based on score distribution
|
||||
const sortedScores = [...scores].sort((a, b) => b - a);
|
||||
const medianIndex = Math.floor(sortedScores.length / 2);
|
||||
const medianScore = sortedScores[medianIndex];
|
||||
|
||||
// Use the smaller of: 20% of max score or 50% of median score
|
||||
// This is more lenient for broad queries while still filtering noise
|
||||
const maxBasedThreshold = maxScore * 0.2;
|
||||
const medianBasedThreshold = medianScore * 0.5;
|
||||
threshold = Math.min(maxBasedThreshold, medianBasedThreshold);
|
||||
|
||||
// Ensure we keep at least minResults if available
|
||||
const minResultsCount = Math.min(
|
||||
options.minResults,
|
||||
scoredResults.length,
|
||||
);
|
||||
if (scoredResults.length >= minResultsCount) {
|
||||
const minResultsThreshold = scoredResults[minResultsCount - 1].score;
|
||||
threshold = Math.min(threshold, minResultsThreshold);
|
||||
}
|
||||
} else {
|
||||
// For normal score distributions, use the relative threshold approach
|
||||
const relativeThreshold = options.scoreThreshold || 0.3;
|
||||
const absoluteMinimum = 0.1;
|
||||
|
||||
threshold = Math.max(
|
||||
absoluteMinimum,
|
||||
minScore + scoreRange * relativeThreshold,
|
||||
);
|
||||
}
|
||||
|
||||
// Filter out low-quality results
|
||||
const filteredResults = scoredResults
|
||||
.filter((item) => item.score >= threshold)
|
||||
.map((item) => ({ statement: item.statement, score: item.score }));
|
||||
|
||||
// Apply limit if specified
|
||||
const limitedResults =
|
||||
options.limit > 0
|
||||
? filteredResults.slice(
|
||||
0,
|
||||
Math.min(filteredResults.length, options.limit),
|
||||
)
|
||||
: filteredResults;
|
||||
|
||||
logger.info(
|
||||
`Quality filtering: ${limitedResults.length}/${results.length} results kept (threshold: ${threshold.toFixed(3)})`,
|
||||
);
|
||||
logger.info(
|
||||
`Score range: min=${minScore.toFixed(3)}, max=${maxScore.toFixed(3)}, threshold=${threshold.toFixed(3)}`,
|
||||
);
|
||||
|
||||
return limitedResults;
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply the selected reranking strategy to search results
|
||||
*/
|
||||
private async rerankResults(
|
||||
query: string,
|
||||
userId: string,
|
||||
results: {
|
||||
bm25: StatementNode[];
|
||||
vector: StatementNode[];
|
||||
bfs: StatementNode[];
|
||||
},
|
||||
options: Required<SearchOptions>,
|
||||
): Promise<StatementNode[]> {
|
||||
// Fetch user profile for context
|
||||
const user = await prisma.user.findUnique({
|
||||
where: { id: userId },
|
||||
select: { name: true, id: true },
|
||||
});
|
||||
|
||||
const userContext = user
|
||||
? { name: user.name ?? undefined, userId: user.id }
|
||||
: undefined;
|
||||
|
||||
return applyLLMReranking(query, results, options.limit, userContext);
|
||||
}
|
||||
|
||||
private async logRecallAsync(
|
||||
query: string,
|
||||
userId: string,
|
||||
@ -545,6 +468,528 @@ export class SearchService {
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract episodes with provenance tracking from all search sources
|
||||
* Deduplicates episodes and tracks which statements came from which source
|
||||
*/
|
||||
private async extractEpisodesWithProvenance(sources: {
|
||||
episodeGraph: EpisodeGraphResult[];
|
||||
bfs: StatementNode[];
|
||||
vector: StatementNode[];
|
||||
bm25: StatementNode[];
|
||||
}): Promise<EpisodeWithProvenance[]> {
|
||||
const episodeMap = new Map<string, EpisodeWithProvenance>();
|
||||
|
||||
// Process Episode Graph results (already episode-grouped)
|
||||
sources.episodeGraph.forEach((result) => {
|
||||
const episodeId = result.episode.uuid;
|
||||
|
||||
if (!episodeMap.has(episodeId)) {
|
||||
episodeMap.set(episodeId, {
|
||||
episode: result.episode,
|
||||
statements: [],
|
||||
episodeGraphScore: result.score,
|
||||
bfsScore: 0,
|
||||
vectorScore: 0,
|
||||
bm25Score: 0,
|
||||
sourceBreakdown: { fromEpisodeGraph: 0, fromBFS: 0, fromVector: 0, fromBM25: 0 },
|
||||
});
|
||||
}
|
||||
|
||||
const ep = episodeMap.get(episodeId)!;
|
||||
result.statements.forEach((statement) => {
|
||||
ep.statements.push({
|
||||
statement,
|
||||
sources: {
|
||||
episodeGraph: {
|
||||
score: result.score,
|
||||
entityMatches: result.metrics.entityMatchCount,
|
||||
},
|
||||
},
|
||||
primarySource: 'episodeGraph',
|
||||
});
|
||||
ep.sourceBreakdown.fromEpisodeGraph++;
|
||||
});
|
||||
});
|
||||
|
||||
// Process BFS statements (need to group by episode)
|
||||
const bfsStatementsByEpisode = await groupStatementsByEpisode(sources.bfs);
|
||||
const bfsEpisodeIds = Array.from(bfsStatementsByEpisode.keys());
|
||||
const bfsEpisodes = await getEpisodesByUuids(bfsEpisodeIds);
|
||||
|
||||
bfsStatementsByEpisode.forEach((statements, episodeId) => {
|
||||
if (!episodeMap.has(episodeId)) {
|
||||
const episode = bfsEpisodes.get(episodeId);
|
||||
if (!episode) return;
|
||||
|
||||
episodeMap.set(episodeId, {
|
||||
episode,
|
||||
statements: [],
|
||||
episodeGraphScore: 0,
|
||||
bfsScore: 0,
|
||||
vectorScore: 0,
|
||||
bm25Score: 0,
|
||||
sourceBreakdown: { fromEpisodeGraph: 0, fromBFS: 0, fromVector: 0, fromBM25: 0 },
|
||||
});
|
||||
}
|
||||
|
||||
const ep = episodeMap.get(episodeId)!;
|
||||
statements.forEach((statement) => {
|
||||
const hopDistance = (statement as any).bfsHopDistance || 4;
|
||||
const bfsRelevance = (statement as any).bfsRelevance || 0;
|
||||
|
||||
// Check if this statement already exists (from episode graph)
|
||||
const existing = ep.statements.find((s) => s.statement.uuid === statement.uuid);
|
||||
if (existing) {
|
||||
// Add BFS source to existing statement
|
||||
existing.sources.bfs = { score: bfsRelevance, hopDistance, relevance: bfsRelevance };
|
||||
} else {
|
||||
// New statement from BFS
|
||||
ep.statements.push({
|
||||
statement,
|
||||
sources: { bfs: { score: bfsRelevance, hopDistance, relevance: bfsRelevance } },
|
||||
primarySource: 'bfs',
|
||||
});
|
||||
ep.sourceBreakdown.fromBFS++;
|
||||
}
|
||||
|
||||
// Aggregate BFS score for episode with hop multiplier
|
||||
const hopMultiplier =
|
||||
hopDistance === 1 ? 2.0 : hopDistance === 2 ? 1.3 : hopDistance === 3 ? 1.0 : 0.8;
|
||||
ep.bfsScore += bfsRelevance * hopMultiplier;
|
||||
});
|
||||
|
||||
// Average BFS score
|
||||
if (statements.length > 0) {
|
||||
ep.bfsScore /= statements.length;
|
||||
}
|
||||
});
|
||||
|
||||
// Process Vector statements
|
||||
const vectorStatementsByEpisode = await groupStatementsByEpisode(sources.vector);
|
||||
const vectorEpisodeIds = Array.from(vectorStatementsByEpisode.keys());
|
||||
const vectorEpisodes = await getEpisodesByUuids(vectorEpisodeIds);
|
||||
|
||||
vectorStatementsByEpisode.forEach((statements, episodeId) => {
|
||||
if (!episodeMap.has(episodeId)) {
|
||||
const episode = vectorEpisodes.get(episodeId);
|
||||
if (!episode) return;
|
||||
|
||||
episodeMap.set(episodeId, {
|
||||
episode,
|
||||
statements: [],
|
||||
episodeGraphScore: 0,
|
||||
bfsScore: 0,
|
||||
vectorScore: 0,
|
||||
bm25Score: 0,
|
||||
sourceBreakdown: { fromEpisodeGraph: 0, fromBFS: 0, fromVector: 0, fromBM25: 0 },
|
||||
});
|
||||
}
|
||||
|
||||
const ep = episodeMap.get(episodeId)!;
|
||||
statements.forEach((statement) => {
|
||||
const vectorScore = (statement as any).vectorScore || (statement as any).similarity || 0;
|
||||
|
||||
const existing = ep.statements.find((s) => s.statement.uuid === statement.uuid);
|
||||
if (existing) {
|
||||
existing.sources.vector = { score: vectorScore, similarity: vectorScore };
|
||||
} else {
|
||||
ep.statements.push({
|
||||
statement,
|
||||
sources: { vector: { score: vectorScore, similarity: vectorScore } },
|
||||
primarySource: 'vector',
|
||||
});
|
||||
ep.sourceBreakdown.fromVector++;
|
||||
}
|
||||
|
||||
ep.vectorScore += vectorScore;
|
||||
});
|
||||
|
||||
if (statements.length > 0) {
|
||||
ep.vectorScore /= statements.length;
|
||||
}
|
||||
});
|
||||
|
||||
// Process BM25 statements
|
||||
const bm25StatementsByEpisode = await groupStatementsByEpisode(sources.bm25);
|
||||
const bm25EpisodeIds = Array.from(bm25StatementsByEpisode.keys());
|
||||
const bm25Episodes = await getEpisodesByUuids(bm25EpisodeIds);
|
||||
|
||||
bm25StatementsByEpisode.forEach((statements, episodeId) => {
|
||||
if (!episodeMap.has(episodeId)) {
|
||||
const episode = bm25Episodes.get(episodeId);
|
||||
if (!episode) return;
|
||||
|
||||
episodeMap.set(episodeId, {
|
||||
episode,
|
||||
statements: [],
|
||||
episodeGraphScore: 0,
|
||||
bfsScore: 0,
|
||||
vectorScore: 0,
|
||||
bm25Score: 0,
|
||||
sourceBreakdown: { fromEpisodeGraph: 0, fromBFS: 0, fromVector: 0, fromBM25: 0 },
|
||||
});
|
||||
}
|
||||
|
||||
const ep = episodeMap.get(episodeId)!;
|
||||
statements.forEach((statement) => {
|
||||
const bm25Score = (statement as any).bm25Score || (statement as any).score || 0;
|
||||
|
||||
const existing = ep.statements.find((s) => s.statement.uuid === statement.uuid);
|
||||
if (existing) {
|
||||
existing.sources.bm25 = { score: bm25Score, rank: statements.indexOf(statement) };
|
||||
} else {
|
||||
ep.statements.push({
|
||||
statement,
|
||||
sources: { bm25: { score: bm25Score, rank: statements.indexOf(statement) } },
|
||||
primarySource: 'bm25',
|
||||
});
|
||||
ep.sourceBreakdown.fromBM25++;
|
||||
}
|
||||
|
||||
ep.bm25Score += bm25Score;
|
||||
});
|
||||
|
||||
if (statements.length > 0) {
|
||||
ep.bm25Score /= statements.length;
|
||||
}
|
||||
});
|
||||
|
||||
return Array.from(episodeMap.values());
|
||||
}
|
||||
|
||||
/**
|
||||
* Rate episodes by source hierarchy: Episode Graph > BFS > Vector > BM25
|
||||
*/
|
||||
private rateEpisodesBySource(episodes: EpisodeWithProvenance[]): EpisodeWithProvenance[] {
|
||||
return episodes
|
||||
.map((ep) => {
|
||||
// Hierarchical scoring: EpisodeGraph > BFS > Vector > BM25
|
||||
let firstLevelScore = 0;
|
||||
|
||||
// Episode Graph: Highest weight (5.0)
|
||||
if (ep.episodeGraphScore > 0) {
|
||||
firstLevelScore += ep.episodeGraphScore * 5.0;
|
||||
}
|
||||
|
||||
// BFS: Second highest (3.0), already hop-weighted in extraction
|
||||
if (ep.bfsScore > 0) {
|
||||
firstLevelScore += ep.bfsScore * 3.0;
|
||||
}
|
||||
|
||||
// Vector: Third (1.5)
|
||||
if (ep.vectorScore > 0) {
|
||||
firstLevelScore += ep.vectorScore * 1.5;
|
||||
}
|
||||
|
||||
// BM25: Lowest (0.2), only significant if others missing
|
||||
// Reduced from 0.5 to 0.2 to prevent keyword noise from dominating
|
||||
if (ep.bm25Score > 0) {
|
||||
firstLevelScore += ep.bm25Score * 0.2;
|
||||
}
|
||||
|
||||
// Concentration bonus: More statements = higher confidence
|
||||
const concentrationBonus = Math.log(1 + ep.statements.length) * 0.3;
|
||||
firstLevelScore *= 1 + concentrationBonus;
|
||||
|
||||
return {
|
||||
...ep,
|
||||
firstLevelScore,
|
||||
};
|
||||
})
|
||||
.sort((a, b) => (b.firstLevelScore || 0) - (a.firstLevelScore || 0));
|
||||
}
|
||||
|
||||
/**
|
||||
* Filter episodes by quality, not by model capability
|
||||
* Returns empty if no high-quality matches found
|
||||
*/
|
||||
private filterByQuality(
|
||||
ratedEpisodes: EpisodeWithProvenance[],
|
||||
query: string,
|
||||
baseQualityThreshold: number = QUALITY_THRESHOLDS.HIGH_QUALITY_EPISODE,
|
||||
): QualityFilterResult {
|
||||
// Adaptive threshold based on available sources
|
||||
// This prevents filtering out ALL results when only Vector/BM25 are available
|
||||
const hasEpisodeGraph = ratedEpisodes.some((ep) => ep.episodeGraphScore > 0);
|
||||
const hasBFS = ratedEpisodes.some((ep) => ep.bfsScore > 0);
|
||||
const hasVector = ratedEpisodes.some((ep) => ep.vectorScore > 0);
|
||||
const hasBM25 = ratedEpisodes.some((ep) => ep.bm25Score > 0);
|
||||
|
||||
let qualityThreshold: number;
|
||||
|
||||
if (hasEpisodeGraph || hasBFS) {
|
||||
// Graph-based results available - use high threshold (5.0)
|
||||
// Max possible score with Episode Graph: ~10+ (5.0 * 2.0)
|
||||
// Max possible score with BFS: ~6+ (2.0 * 3.0)
|
||||
qualityThreshold = 5.0;
|
||||
} else if (hasVector) {
|
||||
// Only semantic vector search - use medium threshold (1.0)
|
||||
// Max possible score with Vector: ~1.5 (1.0 * 1.5)
|
||||
qualityThreshold = 1.0;
|
||||
} else if (hasBM25) {
|
||||
// Only keyword BM25 - use low threshold (0.3)
|
||||
// Max possible score with BM25: ~0.5 (1.0 * 0.5)
|
||||
qualityThreshold = 0.3;
|
||||
} else {
|
||||
// No results at all
|
||||
logger.warn(`No results from any source for query: "${query}"`);
|
||||
return {
|
||||
episodes: [],
|
||||
confidence: 0,
|
||||
message: 'No relevant information found in memory',
|
||||
};
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Adaptive quality threshold: ${qualityThreshold.toFixed(1)} ` +
|
||||
`(EpisodeGraph: ${hasEpisodeGraph}, BFS: ${hasBFS}, Vector: ${hasVector}, BM25: ${hasBM25})`,
|
||||
);
|
||||
|
||||
// 1. Filter to high-quality episodes only
|
||||
const highQualityEpisodes = ratedEpisodes.filter(
|
||||
(ep) => (ep.firstLevelScore || 0) >= qualityThreshold,
|
||||
);
|
||||
|
||||
if (highQualityEpisodes.length === 0) {
|
||||
logger.info(`No high-quality matches for query: "${query}" (threshold: ${qualityThreshold})`);
|
||||
return {
|
||||
episodes: [],
|
||||
confidence: 0,
|
||||
message: 'No relevant information found in memory',
|
||||
};
|
||||
}
|
||||
|
||||
// 2. Apply score gap detection to find natural cutoff
|
||||
const scores = highQualityEpisodes.map((ep) => ep.firstLevelScore || 0);
|
||||
const gapCutoff = this.findScoreGapForEpisodes(scores);
|
||||
|
||||
// 3. Take episodes up to the gap
|
||||
const filteredEpisodes = highQualityEpisodes.slice(0, gapCutoff);
|
||||
|
||||
// 4. Calculate overall confidence with adaptive normalization
|
||||
const confidence = this.calculateConfidence(filteredEpisodes);
|
||||
|
||||
logger.info(
|
||||
`Quality filtering: ${filteredEpisodes.length}/${ratedEpisodes.length} episodes kept, ` +
|
||||
`confidence: ${confidence.toFixed(2)}`,
|
||||
);
|
||||
|
||||
return {
|
||||
episodes: filteredEpisodes,
|
||||
confidence,
|
||||
message: `Found ${filteredEpisodes.length} relevant episodes`,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate confidence score with adaptive normalization
|
||||
* Uses different max expected scores based on DOMINANT source (not just presence)
|
||||
*
|
||||
* IMPORTANT: BM25 is NEVER considered dominant - it's a fallback, not a quality signal.
|
||||
* When only Vector+BM25 exist, Vector is dominant.
|
||||
*/
|
||||
private calculateConfidence(filteredEpisodes: EpisodeWithProvenance[]): number {
|
||||
if (filteredEpisodes.length === 0) return 0;
|
||||
|
||||
const avgScore =
|
||||
filteredEpisodes.reduce((sum, ep) => sum + (ep.firstLevelScore || 0), 0) /
|
||||
filteredEpisodes.length;
|
||||
|
||||
// Calculate average contribution from each source (weighted)
|
||||
const avgEpisodeGraphScore =
|
||||
filteredEpisodes.reduce((sum, ep) => sum + (ep.episodeGraphScore || 0), 0) /
|
||||
filteredEpisodes.length;
|
||||
|
||||
const avgBFSScore =
|
||||
filteredEpisodes.reduce((sum, ep) => sum + (ep.bfsScore || 0), 0) /
|
||||
filteredEpisodes.length;
|
||||
|
||||
const avgVectorScore =
|
||||
filteredEpisodes.reduce((sum, ep) => sum + (ep.vectorScore || 0), 0) /
|
||||
filteredEpisodes.length;
|
||||
|
||||
const avgBM25Score =
|
||||
filteredEpisodes.reduce((sum, ep) => sum + (ep.bm25Score || 0), 0) /
|
||||
filteredEpisodes.length;
|
||||
|
||||
// Determine which source is dominant (weighted contribution to final score)
|
||||
// BM25 is EXCLUDED from dominant source detection - it's a fallback mechanism
|
||||
const episodeGraphContribution = avgEpisodeGraphScore * 5.0;
|
||||
const bfsContribution = avgBFSScore * 3.0;
|
||||
const vectorContribution = avgVectorScore * 1.5;
|
||||
const bm25Contribution = avgBM25Score * 0.2;
|
||||
|
||||
let maxExpectedScore: number;
|
||||
let dominantSource: string;
|
||||
|
||||
if (
|
||||
episodeGraphContribution > bfsContribution &&
|
||||
episodeGraphContribution > vectorContribution
|
||||
) {
|
||||
// Episode Graph is dominant source
|
||||
maxExpectedScore = 25; // Typical range: 10-30
|
||||
dominantSource = 'EpisodeGraph';
|
||||
} else if (bfsContribution > vectorContribution) {
|
||||
// BFS is dominant source
|
||||
maxExpectedScore = 15; // Typical range: 5-15
|
||||
dominantSource = 'BFS';
|
||||
} else if (vectorContribution > 0) {
|
||||
// Vector is dominant source (even if BM25 contribution is higher)
|
||||
maxExpectedScore = 3; // Typical range: 1-3
|
||||
dominantSource = 'Vector';
|
||||
} else {
|
||||
// ONLY BM25 results (Vector=0, BFS=0, EpisodeGraph=0)
|
||||
// This should be rare and indicates low-quality keyword-only matches
|
||||
maxExpectedScore = 1; // Typical range: 0.3-1
|
||||
dominantSource = 'BM25';
|
||||
}
|
||||
|
||||
const confidence = Math.min(1.0, avgScore / maxExpectedScore);
|
||||
|
||||
logger.info(
|
||||
`Confidence: avgScore=${avgScore.toFixed(2)}, maxExpected=${maxExpectedScore}, ` +
|
||||
`confidence=${confidence.toFixed(2)}, dominantSource=${dominantSource} ` +
|
||||
`(Contributions: EG=${episodeGraphContribution.toFixed(2)}, ` +
|
||||
`BFS=${bfsContribution.toFixed(2)}, Vec=${vectorContribution.toFixed(2)}, ` +
|
||||
`BM25=${bm25Contribution.toFixed(2)})`,
|
||||
);
|
||||
|
||||
return confidence;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find score gap in episode scores (similar to statement gap detection)
|
||||
*/
|
||||
private findScoreGapForEpisodes(scores: number[], minResults: number = 3): number {
|
||||
if (scores.length <= minResults) {
|
||||
return scores.length;
|
||||
}
|
||||
|
||||
// Find largest relative gap after minResults
|
||||
for (let i = minResults - 1; i < scores.length - 1; i++) {
|
||||
const currentScore = scores[i];
|
||||
const nextScore = scores[i + 1];
|
||||
|
||||
if (currentScore === 0) break;
|
||||
|
||||
const gap = currentScore - nextScore;
|
||||
const relativeGap = gap / currentScore;
|
||||
|
||||
// If we find a cliff (>50% drop), cut there
|
||||
if (relativeGap > QUALITY_THRESHOLDS.MINIMUM_GAP_RATIO) {
|
||||
logger.info(
|
||||
`Episode gap detected at position ${i}: ${currentScore.toFixed(3)} → ${nextScore.toFixed(3)} ` +
|
||||
`(${(relativeGap * 100).toFixed(1)}% drop)`,
|
||||
);
|
||||
return i + 1; // Return count (index + 1)
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(`No significant gap found in episode scores`);
|
||||
|
||||
// No significant gap found, return all
|
||||
return scores.length;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate episodes with LLM for borderline confidence cases
|
||||
* Only used when confidence is between 0.3 and 0.7
|
||||
*/
|
||||
private async validateEpisodesWithLLM(
|
||||
query: string,
|
||||
episodes: EpisodeWithProvenance[],
|
||||
maxEpisodes: number = 20,
|
||||
): Promise<EpisodeWithProvenance[]> {
|
||||
const candidatesForValidation = episodes.slice(0, maxEpisodes);
|
||||
|
||||
const prompt = `Given user query, validate which episodes are truly relevant.
|
||||
|
||||
Query: "${query}"
|
||||
|
||||
Episodes (showing episode metadata and top statements):
|
||||
${candidatesForValidation
|
||||
.map(
|
||||
(ep, i) => `
|
||||
${i + 1}. Episode: ${ep.episode.content || 'Untitled'} (${new Date(ep.episode.createdAt).toLocaleDateString()})
|
||||
First-level score: ${ep.firstLevelScore?.toFixed(2)}
|
||||
Sources: ${ep.sourceBreakdown.fromEpisodeGraph} EpisodeGraph, ${ep.sourceBreakdown.fromBFS} BFS, ${ep.sourceBreakdown.fromVector} Vector, ${ep.sourceBreakdown.fromBM25} BM25
|
||||
Total statements: ${ep.statements.length}
|
||||
|
||||
Top statements:
|
||||
${ep.statements
|
||||
.slice(0, 5)
|
||||
.map((s, idx) => ` ${idx + 1}) ${s.statement.fact}`)
|
||||
.join('\n')}
|
||||
`,
|
||||
)
|
||||
.join('\n')}
|
||||
|
||||
Task: Validate which episodes DIRECTLY answer the query intent.
|
||||
|
||||
IMPORTANT RULES:
|
||||
1. ONLY include episodes that contain information directly relevant to answering the query
|
||||
2. If NONE of the episodes answer the query, return an empty array: []
|
||||
3. Do NOT include episodes just because they share keywords with the query
|
||||
4. Consider source quality: EpisodeGraph > BFS > Vector > BM25
|
||||
|
||||
Examples:
|
||||
- Query "what is user name?" → Only include episodes that explicitly state a user's name
|
||||
- Query "user home address" → Only include episodes with actual address information
|
||||
- Query "random keywords" → Return [] if no episodes match semantically
|
||||
|
||||
Output format:
|
||||
<output>
|
||||
{
|
||||
"valid_episodes": [1, 3, 5]
|
||||
}
|
||||
</output>
|
||||
|
||||
If NO episodes are relevant to the query, return:
|
||||
<output>
|
||||
{
|
||||
"valid_episodes": []
|
||||
}
|
||||
</output>`;
|
||||
|
||||
try {
|
||||
let responseText = '';
|
||||
await makeModelCall(
|
||||
false,
|
||||
[{ role: 'user', content: prompt }],
|
||||
(text) => {
|
||||
responseText = text;
|
||||
},
|
||||
{ temperature: 0.2, maxTokens: 500 },
|
||||
'low',
|
||||
);
|
||||
|
||||
// Parse LLM response
|
||||
const outputMatch = /<output>([\s\S]*?)<\/output>/i.exec(responseText);
|
||||
if (!outputMatch?.[1]) {
|
||||
logger.warn('LLM validation returned no output, using all episodes');
|
||||
return episodes;
|
||||
}
|
||||
|
||||
const result = JSON.parse(outputMatch[1]);
|
||||
const validIndices = result.valid_episodes || [];
|
||||
|
||||
if (validIndices.length === 0) {
|
||||
logger.info('LLM validation: No episodes deemed relevant');
|
||||
return [];
|
||||
}
|
||||
|
||||
logger.info(`LLM validation: ${validIndices.length}/${candidatesForValidation.length} episodes validated`);
|
||||
|
||||
// Return validated episodes
|
||||
return validIndices.map((idx: number) => candidatesForValidation[idx - 1]).filter(Boolean);
|
||||
} catch (error) {
|
||||
logger.error('LLM validation failed:', { error });
|
||||
// Fallback: return original episodes
|
||||
return episodes;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
@ -564,4 +1009,73 @@ export interface SearchOptions {
|
||||
spaceIds?: string[]; // Filter results by specific spaces
|
||||
adaptiveFiltering?: boolean;
|
||||
structured?: boolean; // Return structured JSON instead of markdown (default: false)
|
||||
useLLMValidation?: boolean; // Use LLM to validate episodes for borderline confidence cases (default: false)
|
||||
qualityThreshold?: number; // Minimum episode score to be considered high-quality (default: 5.0)
|
||||
maxEpisodesForLLM?: number; // Maximum episodes to send for LLM validation (default: 20)
|
||||
}
|
||||
|
||||
/**
|
||||
* Statement with source provenance tracking
|
||||
*/
|
||||
interface StatementWithSource {
|
||||
statement: StatementNode;
|
||||
sources: {
|
||||
episodeGraph?: { score: number; entityMatches: number };
|
||||
bfs?: { score: number; hopDistance: number; relevance: number };
|
||||
vector?: { score: number; similarity: number };
|
||||
bm25?: { score: number; rank: number };
|
||||
};
|
||||
primarySource: 'episodeGraph' | 'bfs' | 'vector' | 'bm25';
|
||||
}
|
||||
|
||||
/**
|
||||
* Episode with provenance tracking from multiple sources
|
||||
*/
|
||||
interface EpisodeWithProvenance {
|
||||
episode: EpisodicNode;
|
||||
statements: StatementWithSource[];
|
||||
|
||||
// Aggregated scores from each source
|
||||
episodeGraphScore: number;
|
||||
bfsScore: number;
|
||||
vectorScore: number;
|
||||
bm25Score: number;
|
||||
|
||||
// Source distribution
|
||||
sourceBreakdown: {
|
||||
fromEpisodeGraph: number;
|
||||
fromBFS: number;
|
||||
fromVector: number;
|
||||
fromBM25: number;
|
||||
};
|
||||
|
||||
// First-level rating score (hierarchical)
|
||||
firstLevelScore?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Quality filtering result
|
||||
*/
|
||||
interface QualityFilterResult {
|
||||
episodes: EpisodeWithProvenance[];
|
||||
confidence: number;
|
||||
message: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Quality thresholds for filtering
|
||||
*/
|
||||
const QUALITY_THRESHOLDS = {
|
||||
// Adaptive episode-level scoring (based on available sources)
|
||||
HIGH_QUALITY_EPISODE: 5.0, // For Episode Graph or BFS results (max score ~10+)
|
||||
MEDIUM_QUALITY_EPISODE: 1.0, // For Vector-only results (max score ~1.5)
|
||||
LOW_QUALITY_EPISODE: 0.3, // For BM25-only results (max score ~0.5)
|
||||
|
||||
// Overall result confidence
|
||||
CONFIDENT_RESULT: 0.7, // High confidence, skip LLM validation
|
||||
UNCERTAIN_RESULT: 0.3, // Borderline, use LLM validation
|
||||
NO_RESULT: 0.3, // Too low, return empty
|
||||
|
||||
// Score gap detection
|
||||
MINIMUM_GAP_RATIO: 0.5, // 50% score drop = gap
|
||||
};
|
||||
|
||||
@ -51,6 +51,7 @@ export async function performBM25Search(
|
||||
${spaceCondition}
|
||||
OPTIONAL MATCH (episode:Episode)-[:HAS_PROVENANCE]->(s)
|
||||
WITH s, score, count(episode) as provenanceCount
|
||||
WHERE score >= 0.5
|
||||
RETURN s, score, provenanceCount
|
||||
ORDER BY score DESC
|
||||
`;
|
||||
@ -71,6 +72,12 @@ export async function performBM25Search(
|
||||
typeof provenanceCountValue === "bigint"
|
||||
? Number(provenanceCountValue)
|
||||
: (provenanceCountValue?.toNumber?.() ?? provenanceCountValue ?? 0);
|
||||
|
||||
const scoreValue = record.get("score");
|
||||
(statement as any).bm25Score =
|
||||
typeof scoreValue === "number"
|
||||
? scoreValue
|
||||
: (scoreValue?.toNumber?.() ?? 0);
|
||||
return statement;
|
||||
});
|
||||
} catch (error) {
|
||||
@ -163,6 +170,14 @@ export async function performVectorSearch(
|
||||
typeof provenanceCountValue === "bigint"
|
||||
? Number(provenanceCountValue)
|
||||
: (provenanceCountValue?.toNumber?.() ?? provenanceCountValue ?? 0);
|
||||
|
||||
// Preserve vector similarity score for empty result detection
|
||||
const scoreValue = record.get("score");
|
||||
(statement as any).vectorScore =
|
||||
typeof scoreValue === "number"
|
||||
? scoreValue
|
||||
: (scoreValue?.toNumber?.() ?? 0);
|
||||
|
||||
return statement;
|
||||
});
|
||||
} catch (error) {
|
||||
@ -179,12 +194,10 @@ export async function performBfsSearch(
|
||||
query: string,
|
||||
embedding: Embedding,
|
||||
userId: string,
|
||||
entities: EntityNode[],
|
||||
options: Required<SearchOptions>,
|
||||
): Promise<StatementNode[]> {
|
||||
try {
|
||||
// 1. Extract potential entities from query using chunked embeddings
|
||||
const entities = await extractEntitiesFromQuery(query, userId);
|
||||
|
||||
if (entities.length === 0) {
|
||||
return [];
|
||||
}
|
||||
@ -224,7 +237,7 @@ async function bfsTraversal(
|
||||
const RELEVANCE_THRESHOLD = 0.5;
|
||||
const EXPLORATION_THRESHOLD = 0.3;
|
||||
|
||||
const allStatements = new Map<string, number>(); // uuid -> relevance
|
||||
const allStatements = new Map<string, { relevance: number; hopDistance: number }>(); // uuid -> {relevance, hopDistance}
|
||||
const visitedEntities = new Set<string>();
|
||||
|
||||
// Track entities per level for iterative BFS
|
||||
@ -268,14 +281,14 @@ async function bfsTraversal(
|
||||
...(startTime && { startTime: startTime.toISOString() }),
|
||||
});
|
||||
|
||||
// Store statement relevance scores
|
||||
// Store statement relevance scores and hop distance
|
||||
const currentLevelStatementUuids: string[] = [];
|
||||
for (const record of records) {
|
||||
const uuid = record.get("uuid");
|
||||
const relevance = record.get("relevance");
|
||||
|
||||
if (!allStatements.has(uuid)) {
|
||||
allStatements.set(uuid, relevance);
|
||||
allStatements.set(uuid, { relevance, hopDistance: depth + 1 }); // Store hop distance (1-indexed)
|
||||
currentLevelStatementUuids.push(uuid);
|
||||
}
|
||||
}
|
||||
@ -304,25 +317,45 @@ async function bfsTraversal(
|
||||
}
|
||||
|
||||
// Filter by relevance threshold and fetch full statements
|
||||
const relevantUuids = Array.from(allStatements.entries())
|
||||
.filter(([_, relevance]) => relevance >= RELEVANCE_THRESHOLD)
|
||||
.sort((a, b) => b[1] - a[1])
|
||||
.map(([uuid]) => uuid);
|
||||
const relevantResults = Array.from(allStatements.entries())
|
||||
.filter(([_, data]) => data.relevance >= RELEVANCE_THRESHOLD)
|
||||
.sort((a, b) => b[1].relevance - a[1].relevance);
|
||||
|
||||
if (relevantUuids.length === 0) {
|
||||
if (relevantResults.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const relevantUuids = relevantResults.map(([uuid]) => uuid);
|
||||
|
||||
const fetchCypher = `
|
||||
MATCH (s:Statement{userId: $userId})
|
||||
WHERE s.uuid IN $uuids
|
||||
RETURN s
|
||||
`;
|
||||
const fetchRecords = await runQuery(fetchCypher, { uuids: relevantUuids, userId });
|
||||
const statements = fetchRecords.map(r => r.get("s").properties as StatementNode);
|
||||
const statementMap = new Map(
|
||||
fetchRecords.map(r => [r.get("s").properties.uuid, r.get("s").properties as StatementNode])
|
||||
);
|
||||
|
||||
// Attach hop distance to statements
|
||||
const statements = relevantResults.map(([uuid, data]) => {
|
||||
const statement = statementMap.get(uuid)!;
|
||||
// Add bfsHopDistance and bfsRelevance as metadata
|
||||
(statement as any).bfsHopDistance = data.hopDistance;
|
||||
(statement as any).bfsRelevance = data.relevance;
|
||||
return statement;
|
||||
});
|
||||
|
||||
const hopCounts = statements.reduce((acc, s) => {
|
||||
const hop = (s as any).bfsHopDistance;
|
||||
acc[hop] = (acc[hop] || 0) + 1;
|
||||
return acc;
|
||||
}, {} as Record<number, number>);
|
||||
|
||||
logger.info(
|
||||
`BFS: explored ${allStatements.size} statements across ${maxDepth} hops, returning ${statements.length} (≥${RELEVANCE_THRESHOLD})`
|
||||
`BFS: explored ${allStatements.size} statements across ${maxDepth} hops, ` +
|
||||
`returning ${statements.length} (≥${RELEVANCE_THRESHOLD}) - ` +
|
||||
`1-hop: ${hopCounts[1] || 0}, 2-hop: ${hopCounts[2] || 0}, 3-hop: ${hopCounts[3] || 0}, 4-hop: ${hopCounts[4] || 0}`
|
||||
);
|
||||
|
||||
return statements;
|
||||
@ -361,15 +394,22 @@ function generateQueryChunks(query: string): string[] {
|
||||
export async function extractEntitiesFromQuery(
|
||||
query: string,
|
||||
userId: string,
|
||||
startEntities: string[] = [],
|
||||
): Promise<EntityNode[]> {
|
||||
try {
|
||||
// Generate chunks from query
|
||||
const chunks = generateQueryChunks(query);
|
||||
|
||||
// Get embeddings for each chunk
|
||||
const chunkEmbeddings = await Promise.all(
|
||||
chunks.map(chunk => getEmbedding(chunk))
|
||||
);
|
||||
let chunkEmbeddings: Embedding[] = [];
|
||||
if (startEntities.length === 0) {
|
||||
// Generate chunks from query
|
||||
const chunks = generateQueryChunks(query);
|
||||
// Get embeddings for each chunk
|
||||
chunkEmbeddings = await Promise.all(
|
||||
chunks.map(chunk => getEmbedding(chunk))
|
||||
);
|
||||
} else {
|
||||
chunkEmbeddings = await Promise.all(
|
||||
startEntities.map(chunk => getEmbedding(chunk))
|
||||
);
|
||||
}
|
||||
|
||||
// Search for entities matching each chunk embedding
|
||||
const allEntitySets = await Promise.all(
|
||||
@ -425,3 +465,280 @@ export async function getEpisodesByStatements(
|
||||
const records = await runQuery(cypher, params);
|
||||
return records.map((record) => record.get("e").properties as EpisodicNode);
|
||||
}
|
||||
|
||||
/**
|
||||
* Episode Graph Search Result
|
||||
*/
|
||||
export interface EpisodeGraphResult {
|
||||
episode: EpisodicNode;
|
||||
statements: StatementNode[];
|
||||
score: number;
|
||||
metrics: {
|
||||
entityMatchCount: number;
|
||||
totalStatementCount: number;
|
||||
avgRelevance: number;
|
||||
connectivityScore: number;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform episode-centric graph search
|
||||
* Finds episodes with dense subgraphs of statements connected to query entities
|
||||
*/
|
||||
export async function performEpisodeGraphSearch(
|
||||
query: string,
|
||||
queryEntities: EntityNode[],
|
||||
queryEmbedding: Embedding,
|
||||
userId: string,
|
||||
options: Required<SearchOptions>,
|
||||
): Promise<EpisodeGraphResult[]> {
|
||||
try {
|
||||
// If no entities extracted, return empty
|
||||
if (queryEntities.length === 0) {
|
||||
logger.info("Episode graph search: no entities extracted from query");
|
||||
return [];
|
||||
}
|
||||
|
||||
const queryEntityIds = queryEntities.map(e => e.uuid);
|
||||
logger.info(`Episode graph search: ${queryEntityIds.length} query entities`, {
|
||||
entities: queryEntities.map(e => e.name).join(', ')
|
||||
});
|
||||
|
||||
// Timeframe condition for temporal filtering
|
||||
let timeframeCondition = `
|
||||
AND s.validAt <= $validAt
|
||||
${options.includeInvalidated ? '' : 'AND (s.invalidAt IS NULL OR s.invalidAt > $validAt)'}
|
||||
`;
|
||||
if (options.startTime) {
|
||||
timeframeCondition += ` AND s.validAt >= $startTime`;
|
||||
}
|
||||
|
||||
// Space filtering if provided
|
||||
let spaceCondition = "";
|
||||
if (options.spaceIds.length > 0) {
|
||||
spaceCondition = `
|
||||
AND s.spaceIds IS NOT NULL AND ANY(spaceId IN $spaceIds WHERE spaceId IN s.spaceIds)
|
||||
`;
|
||||
}
|
||||
|
||||
const cypher = `
|
||||
// Step 1: Find statements connected to query entities
|
||||
MATCH (queryEntity:Entity)-[:HAS_SUBJECT|HAS_OBJECT|HAS_PREDICATE]-(s:Statement)
|
||||
WHERE queryEntity.uuid IN $queryEntityIds
|
||||
AND queryEntity.userId = $userId
|
||||
AND s.userId = $userId
|
||||
${timeframeCondition}
|
||||
${spaceCondition}
|
||||
|
||||
// Step 2: Find episodes containing these statements
|
||||
MATCH (s)<-[:HAS_PROVENANCE]-(ep:Episode)
|
||||
|
||||
// Step 3: Collect all statements from these episodes (for metrics only)
|
||||
MATCH (ep)-[:HAS_PROVENANCE]->(epStatement:Statement)
|
||||
WHERE epStatement.validAt <= $validAt
|
||||
AND (epStatement.invalidAt IS NULL OR epStatement.invalidAt > $validAt)
|
||||
${spaceCondition.replace(/s\./g, 'epStatement.')}
|
||||
|
||||
// Step 4: Calculate episode-level metrics
|
||||
WITH ep,
|
||||
collect(DISTINCT s) as entityMatchedStatements,
|
||||
collect(DISTINCT epStatement) as allEpisodeStatements,
|
||||
collect(DISTINCT queryEntity) as matchedEntities
|
||||
|
||||
// Step 5: Calculate semantic relevance for all episode statements
|
||||
WITH ep,
|
||||
entityMatchedStatements,
|
||||
allEpisodeStatements,
|
||||
matchedEntities,
|
||||
[stmt IN allEpisodeStatements |
|
||||
gds.similarity.cosine(stmt.factEmbedding, $queryEmbedding)
|
||||
] as statementRelevances
|
||||
|
||||
// Step 6: Calculate aggregate scores
|
||||
WITH ep,
|
||||
entityMatchedStatements,
|
||||
size(matchedEntities) as entityMatchCount,
|
||||
size(entityMatchedStatements) as entityStmtCount,
|
||||
size(allEpisodeStatements) as totalStmtCount,
|
||||
reduce(sum = 0.0, score IN statementRelevances | sum + score) /
|
||||
CASE WHEN size(statementRelevances) = 0 THEN 1 ELSE size(statementRelevances) END as avgRelevance
|
||||
|
||||
// Step 7: Calculate connectivity score
|
||||
WITH ep,
|
||||
entityMatchedStatements,
|
||||
entityMatchCount,
|
||||
entityStmtCount,
|
||||
totalStmtCount,
|
||||
avgRelevance,
|
||||
(toFloat(entityStmtCount) / CASE WHEN totalStmtCount = 0 THEN 1 ELSE totalStmtCount END) *
|
||||
entityMatchCount as connectivityScore
|
||||
|
||||
// Step 8: Filter for quality episodes
|
||||
WHERE entityMatchCount >= 1
|
||||
AND avgRelevance >= 0.5
|
||||
AND totalStmtCount >= 1
|
||||
|
||||
// Step 9: Calculate final episode score
|
||||
WITH ep,
|
||||
entityMatchedStatements,
|
||||
entityMatchCount,
|
||||
totalStmtCount,
|
||||
avgRelevance,
|
||||
connectivityScore,
|
||||
// Prioritize: entity matches (2.0x) + connectivity + semantic relevance
|
||||
(entityMatchCount * 2.0) + connectivityScore + avgRelevance as episodeScore
|
||||
|
||||
// Step 10: Return ranked episodes with ONLY entity-matched statements
|
||||
RETURN ep,
|
||||
entityMatchedStatements as statements,
|
||||
entityMatchCount,
|
||||
totalStmtCount,
|
||||
avgRelevance,
|
||||
connectivityScore,
|
||||
episodeScore
|
||||
|
||||
ORDER BY episodeScore DESC, entityMatchCount DESC, totalStmtCount DESC
|
||||
LIMIT 20
|
||||
`;
|
||||
|
||||
const params = {
|
||||
queryEntityIds,
|
||||
userId,
|
||||
queryEmbedding,
|
||||
validAt: options.endTime.toISOString(),
|
||||
...(options.startTime && { startTime: options.startTime.toISOString() }),
|
||||
...(options.spaceIds.length > 0 && { spaceIds: options.spaceIds }),
|
||||
};
|
||||
|
||||
const records = await runQuery(cypher, params);
|
||||
|
||||
const results: EpisodeGraphResult[] = records.map((record) => {
|
||||
const episode = record.get("ep").properties as EpisodicNode;
|
||||
const statements = record.get("statements").map((s: any) => s.properties as StatementNode);
|
||||
const entityMatchCount = typeof record.get("entityMatchCount") === 'bigint'
|
||||
? Number(record.get("entityMatchCount"))
|
||||
: record.get("entityMatchCount");
|
||||
const totalStmtCount = typeof record.get("totalStmtCount") === 'bigint'
|
||||
? Number(record.get("totalStmtCount"))
|
||||
: record.get("totalStmtCount");
|
||||
const avgRelevance = record.get("avgRelevance");
|
||||
const connectivityScore = record.get("connectivityScore");
|
||||
const episodeScore = record.get("episodeScore");
|
||||
|
||||
return {
|
||||
episode,
|
||||
statements,
|
||||
score: episodeScore,
|
||||
metrics: {
|
||||
entityMatchCount,
|
||||
totalStatementCount: totalStmtCount,
|
||||
avgRelevance,
|
||||
connectivityScore,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
// Log statement counts for debugging
|
||||
results.forEach((result, idx) => {
|
||||
logger.info(
|
||||
`Episode ${idx + 1}: entityMatches=${result.metrics.entityMatchCount}, ` +
|
||||
`totalStmtCount=${result.metrics.totalStatementCount}, ` +
|
||||
`returnedStatements=${result.statements.length}`
|
||||
);
|
||||
});
|
||||
|
||||
logger.info(
|
||||
`Episode graph search: found ${results.length} episodes, ` +
|
||||
`top score: ${results[0]?.score.toFixed(2) || 'N/A'}`
|
||||
);
|
||||
|
||||
return results;
|
||||
} catch (error) {
|
||||
logger.error("Episode graph search error:", { error });
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get episode IDs for statements in batch (efficient, no N+1 queries)
|
||||
*/
|
||||
export async function getEpisodeIdsForStatements(
|
||||
statementUuids: string[]
|
||||
): Promise<Map<string, string>> {
|
||||
if (statementUuids.length === 0) {
|
||||
return new Map();
|
||||
}
|
||||
|
||||
const cypher = `
|
||||
MATCH (s:Statement)<-[:HAS_PROVENANCE]-(e:Episode)
|
||||
WHERE s.uuid IN $statementUuids
|
||||
RETURN s.uuid as statementUuid, e.uuid as episodeUuid
|
||||
`;
|
||||
|
||||
const records = await runQuery(cypher, { statementUuids });
|
||||
|
||||
const map = new Map<string, string>();
|
||||
records.forEach(record => {
|
||||
map.set(record.get('statementUuid'), record.get('episodeUuid'));
|
||||
});
|
||||
|
||||
return map;
|
||||
}
|
||||
|
||||
/**
|
||||
* Group statements by their episode IDs efficiently
|
||||
*/
|
||||
export async function groupStatementsByEpisode(
|
||||
statements: StatementNode[]
|
||||
): Promise<Map<string, StatementNode[]>> {
|
||||
const grouped = new Map<string, StatementNode[]>();
|
||||
|
||||
if (statements.length === 0) {
|
||||
return grouped;
|
||||
}
|
||||
|
||||
// Batch fetch episode IDs for all statements
|
||||
const episodeIdMap = await getEpisodeIdsForStatements(
|
||||
statements.map(s => s.uuid)
|
||||
);
|
||||
|
||||
// Group statements by episode ID
|
||||
statements.forEach((statement) => {
|
||||
const episodeId = episodeIdMap.get(statement.uuid);
|
||||
if (episodeId) {
|
||||
if (!grouped.has(episodeId)) {
|
||||
grouped.set(episodeId, []);
|
||||
}
|
||||
grouped.get(episodeId)!.push(statement);
|
||||
}
|
||||
});
|
||||
|
||||
return grouped;
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch episode objects by their UUIDs in batch
|
||||
*/
|
||||
export async function getEpisodesByUuids(
|
||||
episodeUuids: string[]
|
||||
): Promise<Map<string, EpisodicNode>> {
|
||||
if (episodeUuids.length === 0) {
|
||||
return new Map();
|
||||
}
|
||||
|
||||
const cypher = `
|
||||
MATCH (e:Episode)
|
||||
WHERE e.uuid IN $episodeUuids
|
||||
RETURN e
|
||||
`;
|
||||
|
||||
const records = await runQuery(cypher, { episodeUuids });
|
||||
|
||||
const map = new Map<string, EpisodicNode>();
|
||||
records.forEach(record => {
|
||||
const episode = record.get('e').properties as EpisodicNode;
|
||||
map.set(episode.uuid, episode);
|
||||
});
|
||||
|
||||
return map;
|
||||
}
|
||||
|
||||
122
apps/webapp/app/trigger/ingest/retry-no-credits.ts
Normal file
122
apps/webapp/app/trigger/ingest/retry-no-credits.ts
Normal file
@ -0,0 +1,122 @@
|
||||
import { task } from "@trigger.dev/sdk";
|
||||
import { z } from "zod";
|
||||
import { IngestionQueue, IngestionStatus } from "@core/database";
|
||||
import { logger } from "~/services/logger.service";
|
||||
import { prisma } from "../utils/prisma";
|
||||
import { IngestBodyRequest, ingestTask } from "./ingest";
|
||||
|
||||
export const RetryNoCreditBodyRequest = z.object({
|
||||
workspaceId: z.string(),
|
||||
});
|
||||
|
||||
// Register the Trigger.dev task to retry NO_CREDITS episodes
|
||||
export const retryNoCreditsTask = task({
|
||||
id: "retry-no-credits-episodes",
|
||||
run: async (payload: z.infer<typeof RetryNoCreditBodyRequest>) => {
|
||||
try {
|
||||
logger.log(
|
||||
`Starting retry of NO_CREDITS episodes for workspace ${payload.workspaceId}`,
|
||||
);
|
||||
|
||||
// Find all ingestion queue items with NO_CREDITS status for this workspace
|
||||
const noCreditItems = await prisma.ingestionQueue.findMany({
|
||||
where: {
|
||||
workspaceId: payload.workspaceId,
|
||||
status: IngestionStatus.NO_CREDITS,
|
||||
},
|
||||
orderBy: {
|
||||
createdAt: "asc", // Process oldest first
|
||||
},
|
||||
include: {
|
||||
workspace: true,
|
||||
},
|
||||
});
|
||||
|
||||
if (noCreditItems.length === 0) {
|
||||
logger.log(
|
||||
`No NO_CREDITS episodes found for workspace ${payload.workspaceId}`,
|
||||
);
|
||||
return {
|
||||
success: true,
|
||||
message: "No episodes to retry",
|
||||
retriedCount: 0,
|
||||
};
|
||||
}
|
||||
|
||||
logger.log(
|
||||
`Found ${noCreditItems.length} NO_CREDITS episodes to retry`,
|
||||
);
|
||||
|
||||
const results = {
|
||||
total: noCreditItems.length,
|
||||
retriggered: 0,
|
||||
failed: 0,
|
||||
errors: [] as Array<{ queueId: string; error: string }>,
|
||||
};
|
||||
|
||||
// Process each item
|
||||
for (const item of noCreditItems) {
|
||||
try {
|
||||
const queueData = item.data as z.infer<typeof IngestBodyRequest>;
|
||||
|
||||
// Reset status to PENDING and clear error
|
||||
await prisma.ingestionQueue.update({
|
||||
where: { id: item.id },
|
||||
data: {
|
||||
status: IngestionStatus.PENDING,
|
||||
error: null,
|
||||
retryCount: item.retryCount + 1,
|
||||
},
|
||||
});
|
||||
|
||||
// Trigger the ingestion task
|
||||
await ingestTask.trigger({
|
||||
body: queueData,
|
||||
userId: item.workspace?.userId as string,
|
||||
workspaceId: payload.workspaceId,
|
||||
queueId: item.id,
|
||||
});
|
||||
|
||||
results.retriggered++;
|
||||
logger.log(
|
||||
`Successfully retriggered episode ${item.id} (retry #${item.retryCount + 1})`,
|
||||
);
|
||||
} catch (error: any) {
|
||||
results.failed++;
|
||||
results.errors.push({
|
||||
queueId: item.id,
|
||||
error: error.message,
|
||||
});
|
||||
logger.error(`Failed to retrigger episode ${item.id}:`, error);
|
||||
|
||||
// Update the item to mark it as failed
|
||||
await prisma.ingestionQueue.update({
|
||||
where: { id: item.id },
|
||||
data: {
|
||||
status: IngestionStatus.FAILED,
|
||||
error: `Retry failed: ${error.message}`,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
logger.log(
|
||||
`Completed retry for workspace ${payload.workspaceId}. Retriggered: ${results.retriggered}, Failed: ${results.failed}`,
|
||||
);
|
||||
|
||||
return {
|
||||
success: true,
|
||||
...results,
|
||||
};
|
||||
} catch (err: any) {
|
||||
logger.error(
|
||||
`Error retrying NO_CREDITS episodes for workspace ${payload.workspaceId}:`,
|
||||
err,
|
||||
);
|
||||
return {
|
||||
success: false,
|
||||
error: err.message,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
Loading…
x
Reference in New Issue
Block a user