Feat: add recall count and model to search

This commit is contained in:
Manoj K 2025-08-04 11:45:55 +05:30 committed by Harshith Mullapudi
parent 6afd993da7
commit ddc0b0085c
4 changed files with 198 additions and 1 deletions

View File

@ -1,4 +1,4 @@
import type { StatementNode } from "@core/types";
import type { EpisodicNode, StatementNode } from "@core/types";
import { logger } from "./logger.service";
import { applyCrossEncoderReranking, applyWeightedRRF } from "./search/rerank";
import {
@ -8,6 +8,8 @@ import {
performVectorSearch,
} from "./search/utils";
import { getEmbedding } from "~/lib/model.server";
import { prisma } from "~/db.server";
import { runQuery } from "~/lib/neo4j.server";
/**
* SearchService provides methods to search the reified + temporal knowledge graph
@ -30,6 +32,7 @@ export class SearchService {
userId: string,
options: SearchOptions = {},
): Promise<{ episodes: string[]; facts: string[] }> {
const startTime = Date.now();
// Default options
const opts: Required<SearchOptions> = {
@ -70,6 +73,21 @@ export class SearchService {
// 3. Return top results
const episodes = await getEpisodesByStatements(filteredResults);
// Log recall asynchronously (don't await to avoid blocking response)
const responseTime = Date.now() - startTime;
this.logRecallAsync(
query,
userId,
filteredResults,
opts,
responseTime,
).catch((error) => {
logger.error("Failed to log recall event:", error);
});
this.updateRecallCount(userId, episodes, filteredResults);
return {
episodes: episodes.map((episode) => episode.content),
facts: filteredResults.map((statement) => statement.fact),
@ -201,6 +219,100 @@ export class SearchService {
// Otherwise use weighted RRF for multiple sources
return applyWeightedRRF(results);
}
private async logRecallAsync(
query: string,
userId: string,
results: StatementNode[],
options: Required<SearchOptions>,
responseTime: number,
): Promise<void> {
try {
// Determine target type based on results
let targetType = "mixed_results";
if (results.length === 1) {
targetType = "statement";
} else if (results.length === 0) {
targetType = "no_results";
}
// Calculate average similarity score if available
let averageSimilarityScore: number | null = null;
const scoresWithValues = results
.map((result) => {
// Try to extract score from various possible score fields
const score =
(result as any).rrfScore ||
(result as any).mmrScore ||
(result as any).crossEncoderScore ||
(result as any).finalScore ||
(result as any).score;
return score && typeof score === "number" ? score : null;
})
.filter((score): score is number => score !== null);
if (scoresWithValues.length > 0) {
averageSimilarityScore =
scoresWithValues.reduce((sum, score) => sum + score, 0) /
scoresWithValues.length;
}
await prisma.recallLog.create({
data: {
accessType: "search",
query,
targetType,
searchMethod: "hybrid", // BM25 + Vector + BFS
minSimilarity: options.scoreThreshold,
maxResults: options.limit,
resultCount: results.length,
similarityScore: averageSimilarityScore,
context: JSON.stringify({
entityTypes: options.entityTypes,
predicateTypes: options.predicateTypes,
maxBfsDepth: options.maxBfsDepth,
includeInvalidated: options.includeInvalidated,
validAt: options.validAt.toISOString(),
startTime: options.startTime?.toISOString() || null,
endTime: options.endTime.toISOString(),
}),
source: "search_api",
responseTimeMs: responseTime,
userId,
},
});
logger.debug(
`Logged recall event for user ${userId}: ${results.length} results in ${responseTime}ms`,
);
} catch (error) {
logger.error("Error creating recall log entry:", { error });
// Don't throw - we don't want logging failures to affect the search response
}
}
private async updateRecallCount(
userId: string,
episodes: EpisodicNode[],
statements: StatementNode[],
) {
const episodeIds = episodes.map((episode) => episode.uuid);
const statementIds = statements.map((statement) => statement.uuid);
const cypher = `
MATCH (e:Episode)
WHERE e.uuid IN $episodeUuids and e.userId = $userId
SET e.recallCount = coalesce(e.recallCount, 0) + 1
`;
await runQuery(cypher, { episodeUuids: episodeIds, userId });
const cypher2 = `
MATCH (s:Statement)
WHERE s.uuid IN $statementUuids and s.userId = $userId
SET s.recallCount = coalesce(s.recallCount, 0) + 1
`;
await runQuery(cypher2, { statementUuids: statementIds, userId });
}
}
/**

View File

@ -0,0 +1,35 @@
-- CreateTable
CREATE TABLE "RecallLog" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"deleted" TIMESTAMP(3),
"accessType" TEXT NOT NULL,
"query" TEXT,
"targetType" TEXT,
"targetId" TEXT,
"searchMethod" TEXT,
"minSimilarity" DOUBLE PRECISION,
"maxResults" INTEGER,
"resultCount" INTEGER NOT NULL DEFAULT 0,
"similarityScore" DOUBLE PRECISION,
"context" TEXT,
"source" TEXT,
"sessionId" TEXT,
"responseTimeMs" INTEGER,
"userId" TEXT NOT NULL,
"workspaceId" TEXT,
"conversationId" TEXT,
"metadata" JSONB DEFAULT '{}',
CONSTRAINT "RecallLog_pkey" PRIMARY KEY ("id")
);
-- AddForeignKey
ALTER TABLE "RecallLog" ADD CONSTRAINT "RecallLog_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "RecallLog" ADD CONSTRAINT "RecallLog_workspaceId_fkey" FOREIGN KEY ("workspaceId") REFERENCES "Workspace"("id") ON DELETE SET NULL ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "RecallLog" ADD CONSTRAINT "RecallLog_conversationId_fkey" FOREIGN KEY ("conversationId") REFERENCES "Conversation"("id") ON DELETE SET NULL ON UPDATE CASCADE;

View File

@ -64,6 +64,7 @@ model Conversation {
status String @default("pending") // Can be "pending", "running", "completed", "failed", "need_attention"
ConversationHistory ConversationHistory[]
RecallLog RecallLog[]
}
model ConversationExecutionStep {
@ -423,6 +424,51 @@ model PersonalAccessToken {
authorizationCodes AuthorizationCode[]
}
model RecallLog {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
deleted DateTime?
// Access details
accessType String // "search", "recall", "direct_access"
query String? // Search query (null for direct access)
// Target information
targetType String? // "episode", "statement", "entity", "mixed_results"
targetId String? // UUID of specific target (null for search with multiple results)
// Search/access parameters
searchMethod String? // "semantic", "keyword", "hybrid", "contextual", "graph_traversal"
minSimilarity Float? // Minimum similarity threshold used
maxResults Int? // Maximum results requested
// Results and interaction
resultCount Int @default(0) // Number of results returned
similarityScore Float? // Similarity score (for single result access)
// Context and source
context String? // Additional context
source String? // Source of the access (e.g., "chat", "api", "integration")
sessionId String? // Session identifier
// Performance metrics
responseTimeMs Int? // Response time in milliseconds
// Relations
user User @relation(fields: [userId], references: [id])
userId String
workspace Workspace? @relation(fields: [workspaceId], references: [id])
workspaceId String?
conversation Conversation? @relation(fields: [conversationId], references: [id])
conversationId String?
// Metadata for additional tracking data
metadata Json? @default("{}")
}
model Space {
id String @id @default(cuid())
name String
@ -505,6 +551,7 @@ model User {
oauthIntegrationGrants OAuthIntegrationGrant[]
oAuthClientInstallation OAuthClientInstallation[]
UserUsage UserUsage?
RecallLog RecallLog[]
}
model UserUsage {
@ -579,6 +626,7 @@ model Workspace {
OAuthAuthorizationCode OAuthAuthorizationCode[]
OAuthAccessToken OAuthAccessToken[]
OAuthRefreshToken OAuthRefreshToken[]
RecallLog RecallLog[]
}
enum AuthenticationMethod {

View File

@ -20,6 +20,7 @@ export interface EpisodicNode {
userId: string;
space?: string;
sessionId?: string;
recallCount?: number;
}
/**
@ -52,6 +53,7 @@ export interface StatementNode {
attributes: Record<string, any>;
userId: string;
space?: string;
recallCount?: number;
}
/**