mirror of
https://github.com/eliasstepanik/core.git
synced 2026-01-11 16:48:27 +00:00
Feat: add recall count and model to search
This commit is contained in:
parent
6afd993da7
commit
ddc0b0085c
@ -1,4 +1,4 @@
|
||||
import type { StatementNode } from "@core/types";
|
||||
import type { EpisodicNode, StatementNode } from "@core/types";
|
||||
import { logger } from "./logger.service";
|
||||
import { applyCrossEncoderReranking, applyWeightedRRF } from "./search/rerank";
|
||||
import {
|
||||
@ -8,6 +8,8 @@ import {
|
||||
performVectorSearch,
|
||||
} from "./search/utils";
|
||||
import { getEmbedding } from "~/lib/model.server";
|
||||
import { prisma } from "~/db.server";
|
||||
import { runQuery } from "~/lib/neo4j.server";
|
||||
|
||||
/**
|
||||
* SearchService provides methods to search the reified + temporal knowledge graph
|
||||
@ -30,6 +32,7 @@ export class SearchService {
|
||||
userId: string,
|
||||
options: SearchOptions = {},
|
||||
): Promise<{ episodes: string[]; facts: string[] }> {
|
||||
const startTime = Date.now();
|
||||
// Default options
|
||||
|
||||
const opts: Required<SearchOptions> = {
|
||||
@ -70,6 +73,21 @@ export class SearchService {
|
||||
|
||||
// 3. Return top results
|
||||
const episodes = await getEpisodesByStatements(filteredResults);
|
||||
|
||||
// Log recall asynchronously (don't await to avoid blocking response)
|
||||
const responseTime = Date.now() - startTime;
|
||||
this.logRecallAsync(
|
||||
query,
|
||||
userId,
|
||||
filteredResults,
|
||||
opts,
|
||||
responseTime,
|
||||
).catch((error) => {
|
||||
logger.error("Failed to log recall event:", error);
|
||||
});
|
||||
|
||||
this.updateRecallCount(userId, episodes, filteredResults);
|
||||
|
||||
return {
|
||||
episodes: episodes.map((episode) => episode.content),
|
||||
facts: filteredResults.map((statement) => statement.fact),
|
||||
@ -201,6 +219,100 @@ export class SearchService {
|
||||
// Otherwise use weighted RRF for multiple sources
|
||||
return applyWeightedRRF(results);
|
||||
}
|
||||
|
||||
private async logRecallAsync(
|
||||
query: string,
|
||||
userId: string,
|
||||
results: StatementNode[],
|
||||
options: Required<SearchOptions>,
|
||||
responseTime: number,
|
||||
): Promise<void> {
|
||||
try {
|
||||
// Determine target type based on results
|
||||
let targetType = "mixed_results";
|
||||
if (results.length === 1) {
|
||||
targetType = "statement";
|
||||
} else if (results.length === 0) {
|
||||
targetType = "no_results";
|
||||
}
|
||||
|
||||
// Calculate average similarity score if available
|
||||
let averageSimilarityScore: number | null = null;
|
||||
const scoresWithValues = results
|
||||
.map((result) => {
|
||||
// Try to extract score from various possible score fields
|
||||
const score =
|
||||
(result as any).rrfScore ||
|
||||
(result as any).mmrScore ||
|
||||
(result as any).crossEncoderScore ||
|
||||
(result as any).finalScore ||
|
||||
(result as any).score;
|
||||
return score && typeof score === "number" ? score : null;
|
||||
})
|
||||
.filter((score): score is number => score !== null);
|
||||
|
||||
if (scoresWithValues.length > 0) {
|
||||
averageSimilarityScore =
|
||||
scoresWithValues.reduce((sum, score) => sum + score, 0) /
|
||||
scoresWithValues.length;
|
||||
}
|
||||
|
||||
await prisma.recallLog.create({
|
||||
data: {
|
||||
accessType: "search",
|
||||
query,
|
||||
targetType,
|
||||
searchMethod: "hybrid", // BM25 + Vector + BFS
|
||||
minSimilarity: options.scoreThreshold,
|
||||
maxResults: options.limit,
|
||||
resultCount: results.length,
|
||||
similarityScore: averageSimilarityScore,
|
||||
context: JSON.stringify({
|
||||
entityTypes: options.entityTypes,
|
||||
predicateTypes: options.predicateTypes,
|
||||
maxBfsDepth: options.maxBfsDepth,
|
||||
includeInvalidated: options.includeInvalidated,
|
||||
validAt: options.validAt.toISOString(),
|
||||
startTime: options.startTime?.toISOString() || null,
|
||||
endTime: options.endTime.toISOString(),
|
||||
}),
|
||||
source: "search_api",
|
||||
responseTimeMs: responseTime,
|
||||
userId,
|
||||
},
|
||||
});
|
||||
|
||||
logger.debug(
|
||||
`Logged recall event for user ${userId}: ${results.length} results in ${responseTime}ms`,
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Error creating recall log entry:", { error });
|
||||
// Don't throw - we don't want logging failures to affect the search response
|
||||
}
|
||||
}
|
||||
|
||||
private async updateRecallCount(
|
||||
userId: string,
|
||||
episodes: EpisodicNode[],
|
||||
statements: StatementNode[],
|
||||
) {
|
||||
const episodeIds = episodes.map((episode) => episode.uuid);
|
||||
const statementIds = statements.map((statement) => statement.uuid);
|
||||
|
||||
const cypher = `
|
||||
MATCH (e:Episode)
|
||||
WHERE e.uuid IN $episodeUuids and e.userId = $userId
|
||||
SET e.recallCount = coalesce(e.recallCount, 0) + 1
|
||||
`;
|
||||
await runQuery(cypher, { episodeUuids: episodeIds, userId });
|
||||
|
||||
const cypher2 = `
|
||||
MATCH (s:Statement)
|
||||
WHERE s.uuid IN $statementUuids and s.userId = $userId
|
||||
SET s.recallCount = coalesce(s.recallCount, 0) + 1
|
||||
`;
|
||||
await runQuery(cypher2, { statementUuids: statementIds, userId });
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -0,0 +1,35 @@
|
||||
-- CreateTable
|
||||
CREATE TABLE "RecallLog" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"deleted" TIMESTAMP(3),
|
||||
"accessType" TEXT NOT NULL,
|
||||
"query" TEXT,
|
||||
"targetType" TEXT,
|
||||
"targetId" TEXT,
|
||||
"searchMethod" TEXT,
|
||||
"minSimilarity" DOUBLE PRECISION,
|
||||
"maxResults" INTEGER,
|
||||
"resultCount" INTEGER NOT NULL DEFAULT 0,
|
||||
"similarityScore" DOUBLE PRECISION,
|
||||
"context" TEXT,
|
||||
"source" TEXT,
|
||||
"sessionId" TEXT,
|
||||
"responseTimeMs" INTEGER,
|
||||
"userId" TEXT NOT NULL,
|
||||
"workspaceId" TEXT,
|
||||
"conversationId" TEXT,
|
||||
"metadata" JSONB DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "RecallLog_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "RecallLog" ADD CONSTRAINT "RecallLog_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "RecallLog" ADD CONSTRAINT "RecallLog_workspaceId_fkey" FOREIGN KEY ("workspaceId") REFERENCES "Workspace"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "RecallLog" ADD CONSTRAINT "RecallLog_conversationId_fkey" FOREIGN KEY ("conversationId") REFERENCES "Conversation"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
@ -64,6 +64,7 @@ model Conversation {
|
||||
status String @default("pending") // Can be "pending", "running", "completed", "failed", "need_attention"
|
||||
|
||||
ConversationHistory ConversationHistory[]
|
||||
RecallLog RecallLog[]
|
||||
}
|
||||
|
||||
model ConversationExecutionStep {
|
||||
@ -423,6 +424,51 @@ model PersonalAccessToken {
|
||||
authorizationCodes AuthorizationCode[]
|
||||
}
|
||||
|
||||
model RecallLog {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
deleted DateTime?
|
||||
|
||||
// Access details
|
||||
accessType String // "search", "recall", "direct_access"
|
||||
query String? // Search query (null for direct access)
|
||||
|
||||
// Target information
|
||||
targetType String? // "episode", "statement", "entity", "mixed_results"
|
||||
targetId String? // UUID of specific target (null for search with multiple results)
|
||||
|
||||
// Search/access parameters
|
||||
searchMethod String? // "semantic", "keyword", "hybrid", "contextual", "graph_traversal"
|
||||
minSimilarity Float? // Minimum similarity threshold used
|
||||
maxResults Int? // Maximum results requested
|
||||
|
||||
// Results and interaction
|
||||
resultCount Int @default(0) // Number of results returned
|
||||
similarityScore Float? // Similarity score (for single result access)
|
||||
|
||||
// Context and source
|
||||
context String? // Additional context
|
||||
source String? // Source of the access (e.g., "chat", "api", "integration")
|
||||
sessionId String? // Session identifier
|
||||
|
||||
// Performance metrics
|
||||
responseTimeMs Int? // Response time in milliseconds
|
||||
|
||||
// Relations
|
||||
user User @relation(fields: [userId], references: [id])
|
||||
userId String
|
||||
|
||||
workspace Workspace? @relation(fields: [workspaceId], references: [id])
|
||||
workspaceId String?
|
||||
|
||||
conversation Conversation? @relation(fields: [conversationId], references: [id])
|
||||
conversationId String?
|
||||
|
||||
// Metadata for additional tracking data
|
||||
metadata Json? @default("{}")
|
||||
}
|
||||
|
||||
model Space {
|
||||
id String @id @default(cuid())
|
||||
name String
|
||||
@ -505,6 +551,7 @@ model User {
|
||||
oauthIntegrationGrants OAuthIntegrationGrant[]
|
||||
oAuthClientInstallation OAuthClientInstallation[]
|
||||
UserUsage UserUsage?
|
||||
RecallLog RecallLog[]
|
||||
}
|
||||
|
||||
model UserUsage {
|
||||
@ -579,6 +626,7 @@ model Workspace {
|
||||
OAuthAuthorizationCode OAuthAuthorizationCode[]
|
||||
OAuthAccessToken OAuthAccessToken[]
|
||||
OAuthRefreshToken OAuthRefreshToken[]
|
||||
RecallLog RecallLog[]
|
||||
}
|
||||
|
||||
enum AuthenticationMethod {
|
||||
|
||||
@ -20,6 +20,7 @@ export interface EpisodicNode {
|
||||
userId: string;
|
||||
space?: string;
|
||||
sessionId?: string;
|
||||
recallCount?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -52,6 +53,7 @@ export interface StatementNode {
|
||||
attributes: Record<string, any>;
|
||||
userId: string;
|
||||
space?: string;
|
||||
recallCount?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user