From bfec522877bef6ea8c3d7875a211d12c77fd95bb Mon Sep 17 00:00:00 2001 From: Manoj K Date: Mon, 30 Jun 2025 22:32:57 +0530 Subject: [PATCH] feat: add bgem3 embedding support --- .env.example | 5 +- apps/webapp/app/env.server.ts | 1 + apps/webapp/app/lib/model.server.ts | 5 +- apps/webapp/app/routes/reingest.tsx | 184 ++++++++++++++++++ .../app/services/knowledgeGraph.server.ts | 43 ++-- docker-compose.aws.yaml | 3 + docker-compose.yaml | 3 + 7 files changed, 228 insertions(+), 16 deletions(-) create mode 100644 apps/webapp/app/routes/reingest.tsx diff --git a/.env.example b/.env.example index f311255..3f9d5da 100644 --- a/.env.example +++ b/.env.example @@ -39,4 +39,7 @@ OPENAI_API_KEY= MAGIC_LINK_SECRET=27192e6432564f4788d55c15131bd5ac -NEO4J_AUTH=neo4j/27192e6432564f4788d55c15131bd5ac \ No newline at end of file +NEO4J_AUTH=neo4j/27192e6432564f4788d55c15131bd5ac +OLLAMA_URL=http://ollama:11434 +EMBEDDING_MODEL=bge-m3 +MODEL=GPT41 diff --git a/apps/webapp/app/env.server.ts b/apps/webapp/app/env.server.ts index 3b624d0..646a677 100644 --- a/apps/webapp/app/env.server.ts +++ b/apps/webapp/app/env.server.ts @@ -73,6 +73,7 @@ const EnvironmentSchema = z.object({ // Model envs MODEL: z.string().default(LLMModelEnum.GPT41), + EMBEDDING_MODEL: z.string().default("bge-m3"), OLLAMA_URL: z.string().optional(), }); diff --git a/apps/webapp/app/lib/model.server.ts b/apps/webapp/app/lib/model.server.ts index 0454e04..3d10903 100644 --- a/apps/webapp/app/lib/model.server.ts +++ b/apps/webapp/app/lib/model.server.ts @@ -19,13 +19,14 @@ export async function makeModelCall( let modelInstance; const model = env.MODEL; let finalModel: string = "unknown"; - const ollamaUrl = process.env.OLLAMA_URL; + // const ollamaUrl = process.env.OLLAMA_URL; + const ollamaUrl = undefined; if (ollamaUrl) { const ollama = createOllama({ baseURL: ollamaUrl, }); - modelInstance = ollama(model); // Default to llama2 if no model specified + modelInstance = ollama(model); } else { switch (model) { case LLMModelEnum.GPT35TURBO: diff --git a/apps/webapp/app/routes/reingest.tsx b/apps/webapp/app/routes/reingest.tsx new file mode 100644 index 0000000..d619c54 --- /dev/null +++ b/apps/webapp/app/routes/reingest.tsx @@ -0,0 +1,184 @@ +import { json } from "@remix-run/node"; +import { z } from "zod"; +import { createActionApiRoute } from "~/services/routeBuilders/apiBuilder.server"; +import { addToQueue, type IngestBodyRequest } from "~/lib/ingest.server"; +import { prisma } from "~/db.server"; +import { logger } from "~/services/logger.service"; +import { IngestionStatus } from "@core/database"; + +const ReingestionBodyRequest = z.object({ + userId: z.string().optional(), + spaceId: z.string().optional(), + dryRun: z.boolean().optional().default(false), +}); + +type ReingestionRequest = z.infer; + +async function getCompletedIngestionsByUser(userId?: string, spaceId?: string) { + const whereClause: any = { + status: IngestionStatus.COMPLETED, + deleted: null + }; + + if (userId) { + whereClause.workspace = { + userId: userId, + }; + } + + if (spaceId) { + whereClause.spaceId = spaceId; + } + + const ingestions = await prisma.ingestionQueue.findMany({ + where: whereClause, + include: { + workspace: { + include: { + user: true, + }, + }, + }, + orderBy: [ + { createdAt: 'asc' }, // Maintain temporal order + ], + }); + + return ingestions; +} + +async function getAllUsers() { + const users = await prisma.user.findMany({ + include: { + Workspace: true, + }, + }); + return users.filter(user => user.Workspace); // Only users with workspaces +} + +async function reingestionForUser(userId: string, spaceId?: string, dryRun = false) { + const ingestions = await getCompletedIngestionsByUser(userId, spaceId); + + logger.info(`Found ${ingestions.length} completed ingestions for user ${userId}${spaceId ? ` in space ${spaceId}` : ''}`); + + if (dryRun) { + return { + userId, + ingestionCount: ingestions.length, + ingestions: ingestions.map(ing => ({ + id: ing.id, + createdAt: ing.createdAt, + spaceId: ing.spaceId, + data: { + episodeBody: (ing.data as any)?.episodeBody?.substring(0, 100) + + ((ing.data as any)?.episodeBody?.length > 100 ? '...' : ''), + source: (ing.data as any)?.source, + referenceTime: (ing.data as any)?.referenceTime, + }, + })), + }; + } + + // Queue ingestions in temporal order (already sorted by createdAt ASC) + const queuedJobs = []; + for (const ingestion of ingestions) { + try { + // Parse the original data and add reingestion metadata + const originalData = ingestion.data as z.infer; + + const reingestionData = { + ...originalData, + source: `reingest-${originalData.source}`, + metadata: { + ...originalData.metadata, + isReingestion: true, + originalIngestionId: ingestion.id, + }, + }; + + const queueResult = await addToQueue(reingestionData, userId); + queuedJobs.push(queueResult); + } catch (error) { + logger.error(`Failed to queue ingestion ${ingestion.id} for user ${userId}:`, {error}); + } + } + + return { + userId, + ingestionCount: ingestions.length, + queuedJobsCount: queuedJobs.length, + queuedJobs, + }; +} + +const { action, loader } = createActionApiRoute( + { + body: ReingestionBodyRequest, + allowJWT: true, + authorization: { + action: "reingest", + }, + corsStrategy: "all", + }, + async ({ body, authentication }) => { + const { userId, spaceId, dryRun } = body; + + try { + if (userId) { + // Reingest for specific user + const result = await reingestionForUser(userId, spaceId, dryRun); + return json({ + success: true, + type: "single_user", + result, + }); + } else { + // Reingest for all users + const users = await getAllUsers(); + const results = []; + + logger.info(`Starting reingestion for ${users.length} users`); + + for (const user of users) { + try { + const result = await reingestionForUser(user.id, spaceId, dryRun); + results.push(result); + + if (!dryRun) { + // Add small delay between users to prevent overwhelming the system + await new Promise(resolve => setTimeout(resolve, 1000)); + } + } catch (error) { + logger.error(`Failed to reingest for user ${user.id}:`, {error}); + results.push({ + userId: user.id, + error: error instanceof Error ? error.message : "Unknown error", + }); + } + } + + return json({ + success: true, + type: "all_users", + totalUsers: users.length, + results, + summary: { + totalIngestions: results.reduce((sum, r) => sum, 0), + totalQueuedJobs: results.reduce((sum, r) => sum, 0), + }, + }); + } + } catch (error) { + logger.error("Reingestion failed:", {error}); + return json( + { + success: false, + error: error instanceof Error ? error.message : "Unknown error", + }, + { status: 500 } + ); + } + } +); + +export { action, loader }; \ No newline at end of file diff --git a/apps/webapp/app/services/knowledgeGraph.server.ts b/apps/webapp/app/services/knowledgeGraph.server.ts index d011ad8..3a7770d 100644 --- a/apps/webapp/app/services/knowledgeGraph.server.ts +++ b/apps/webapp/app/services/knowledgeGraph.server.ts @@ -40,14 +40,32 @@ import { import { makeModelCall } from "~/lib/model.server"; import { Apps, getNodeTypes, getNodeTypesString } from "~/utils/presets/nodes"; import { normalizePrompt } from "./prompts"; +import { env } from "~/env.server"; +import { createOllama } from "ollama-ai-provider"; // Default number of previous episodes to retrieve for context const DEFAULT_EPISODE_WINDOW = 5; export class KnowledgeGraphService { - async getEmbedding(text: string) { + async getEmbedding(text: string, useOpenAI = false) { + if (useOpenAI) { + // Use OpenAI embedding model when explicitly requested + const { embedding } = await embed({ + model: openai.embedding("text-embedding-3-small"), + value: text, + }); + return embedding; + } + + // Default to using Ollama + const ollamaUrl = process.env.OLLAMA_URL; + const model = env.EMBEDDING_MODEL; + + const ollama = createOllama({ + baseURL: ollamaUrl, + }); const { embedding } = await embed({ - model: openai.embedding("text-embedding-3-small"), + model: ollama.embedding(model), value: text, }); @@ -131,16 +149,16 @@ export class KnowledgeGraphService { episode, ); - // for (const triple of updatedTriples) { - // const { subject, predicate, object, statement, provenance } = triple; - // const safeTriple = { - // subject: { ...subject, nameEmbedding: undefined }, - // predicate: { ...predicate, nameEmbedding: undefined }, - // object: { ...object, nameEmbedding: undefined }, - // statement: { ...statement, factEmbedding: undefined }, - // provenance, - // }; - // } + for (const triple of updatedTriples) { + const { subject, predicate, object, statement, provenance } = triple; + const safeTriple = { + subject: { ...subject, nameEmbedding: undefined }, + predicate: { ...predicate, nameEmbedding: undefined }, + object: { ...object, nameEmbedding: undefined }, + statement: { ...statement, factEmbedding: undefined }, + provenance: { ...provenance, contentEmbedding: undefined }, + }; + } // Save triples sequentially to avoid parallel processing issues for (const triple of updatedTriples) { @@ -257,7 +275,6 @@ export class KnowledgeGraphService { responseText = text; }); - console.log(responseText); const outputMatch = responseText.match(/([\s\S]*?)<\/output>/); if (outputMatch && outputMatch[1]) { responseText = outputMatch[1].trim(); diff --git a/docker-compose.aws.yaml b/docker-compose.aws.yaml index c793dff..bfce1a7 100644 --- a/docker-compose.aws.yaml +++ b/docker-compose.aws.yaml @@ -23,6 +23,9 @@ services: - AUTH_GOOGLE_CLIENT_ID=${AUTH_GOOGLE_CLIENT_ID} - AUTH_GOOGLE_CLIENT_SECRET=${AUTH_GOOGLE_CLIENT_SECRET} - ENABLE_EMAIL_LOGIN=${ENABLE_EMAIL_LOGIN} + - OLLAMA_URL=${OLLAMA_URL} + - EMBEDDING_MODEL=${EMBEDDING_MODEL} + - MODEL=${MODEL} ports: - "3033:3000" depends_on: diff --git a/docker-compose.yaml b/docker-compose.yaml index 4599a2b..d234ce0 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -23,6 +23,9 @@ services: - AUTH_GOOGLE_CLIENT_ID=${AUTH_GOOGLE_CLIENT_ID} - AUTH_GOOGLE_CLIENT_SECRET=${AUTH_GOOGLE_CLIENT_SECRET} - ENABLE_EMAIL_LOGIN=${ENABLE_EMAIL_LOGIN} + - OLLAMA_URL=${OLLAMA_URL} + - EMBEDDING_MODEL=${EMBEDDING_MODEL} + - MODEL=${MODEL} ports: - "3033:3000" depends_on: