feat: add bgem3 embedding support

This commit is contained in:
Manoj K 2025-06-30 22:32:57 +05:30
parent 6bb933ca29
commit bfec522877
7 changed files with 228 additions and 16 deletions

View File

@ -39,4 +39,7 @@ OPENAI_API_KEY=
MAGIC_LINK_SECRET=27192e6432564f4788d55c15131bd5ac MAGIC_LINK_SECRET=27192e6432564f4788d55c15131bd5ac
NEO4J_AUTH=neo4j/27192e6432564f4788d55c15131bd5ac NEO4J_AUTH=neo4j/27192e6432564f4788d55c15131bd5ac
OLLAMA_URL=http://ollama:11434
EMBEDDING_MODEL=bge-m3
MODEL=GPT41

View File

@ -73,6 +73,7 @@ const EnvironmentSchema = z.object({
// Model envs // Model envs
MODEL: z.string().default(LLMModelEnum.GPT41), MODEL: z.string().default(LLMModelEnum.GPT41),
EMBEDDING_MODEL: z.string().default("bge-m3"),
OLLAMA_URL: z.string().optional(), OLLAMA_URL: z.string().optional(),
}); });

View File

@ -19,13 +19,14 @@ export async function makeModelCall(
let modelInstance; let modelInstance;
const model = env.MODEL; const model = env.MODEL;
let finalModel: string = "unknown"; let finalModel: string = "unknown";
const ollamaUrl = process.env.OLLAMA_URL; // const ollamaUrl = process.env.OLLAMA_URL;
const ollamaUrl = undefined;
if (ollamaUrl) { if (ollamaUrl) {
const ollama = createOllama({ const ollama = createOllama({
baseURL: ollamaUrl, baseURL: ollamaUrl,
}); });
modelInstance = ollama(model); // Default to llama2 if no model specified modelInstance = ollama(model);
} else { } else {
switch (model) { switch (model) {
case LLMModelEnum.GPT35TURBO: case LLMModelEnum.GPT35TURBO:

View File

@ -0,0 +1,184 @@
import { json } from "@remix-run/node";
import { z } from "zod";
import { createActionApiRoute } from "~/services/routeBuilders/apiBuilder.server";
import { addToQueue, type IngestBodyRequest } from "~/lib/ingest.server";
import { prisma } from "~/db.server";
import { logger } from "~/services/logger.service";
import { IngestionStatus } from "@core/database";
const ReingestionBodyRequest = z.object({
userId: z.string().optional(),
spaceId: z.string().optional(),
dryRun: z.boolean().optional().default(false),
});
type ReingestionRequest = z.infer<typeof ReingestionBodyRequest>;
async function getCompletedIngestionsByUser(userId?: string, spaceId?: string) {
const whereClause: any = {
status: IngestionStatus.COMPLETED,
deleted: null
};
if (userId) {
whereClause.workspace = {
userId: userId,
};
}
if (spaceId) {
whereClause.spaceId = spaceId;
}
const ingestions = await prisma.ingestionQueue.findMany({
where: whereClause,
include: {
workspace: {
include: {
user: true,
},
},
},
orderBy: [
{ createdAt: 'asc' }, // Maintain temporal order
],
});
return ingestions;
}
async function getAllUsers() {
const users = await prisma.user.findMany({
include: {
Workspace: true,
},
});
return users.filter(user => user.Workspace); // Only users with workspaces
}
async function reingestionForUser(userId: string, spaceId?: string, dryRun = false) {
const ingestions = await getCompletedIngestionsByUser(userId, spaceId);
logger.info(`Found ${ingestions.length} completed ingestions for user ${userId}${spaceId ? ` in space ${spaceId}` : ''}`);
if (dryRun) {
return {
userId,
ingestionCount: ingestions.length,
ingestions: ingestions.map(ing => ({
id: ing.id,
createdAt: ing.createdAt,
spaceId: ing.spaceId,
data: {
episodeBody: (ing.data as any)?.episodeBody?.substring(0, 100) +
((ing.data as any)?.episodeBody?.length > 100 ? '...' : ''),
source: (ing.data as any)?.source,
referenceTime: (ing.data as any)?.referenceTime,
},
})),
};
}
// Queue ingestions in temporal order (already sorted by createdAt ASC)
const queuedJobs = [];
for (const ingestion of ingestions) {
try {
// Parse the original data and add reingestion metadata
const originalData = ingestion.data as z.infer<typeof IngestBodyRequest>;
const reingestionData = {
...originalData,
source: `reingest-${originalData.source}`,
metadata: {
...originalData.metadata,
isReingestion: true,
originalIngestionId: ingestion.id,
},
};
const queueResult = await addToQueue(reingestionData, userId);
queuedJobs.push(queueResult);
} catch (error) {
logger.error(`Failed to queue ingestion ${ingestion.id} for user ${userId}:`, {error});
}
}
return {
userId,
ingestionCount: ingestions.length,
queuedJobsCount: queuedJobs.length,
queuedJobs,
};
}
const { action, loader } = createActionApiRoute(
{
body: ReingestionBodyRequest,
allowJWT: true,
authorization: {
action: "reingest",
},
corsStrategy: "all",
},
async ({ body, authentication }) => {
const { userId, spaceId, dryRun } = body;
try {
if (userId) {
// Reingest for specific user
const result = await reingestionForUser(userId, spaceId, dryRun);
return json({
success: true,
type: "single_user",
result,
});
} else {
// Reingest for all users
const users = await getAllUsers();
const results = [];
logger.info(`Starting reingestion for ${users.length} users`);
for (const user of users) {
try {
const result = await reingestionForUser(user.id, spaceId, dryRun);
results.push(result);
if (!dryRun) {
// Add small delay between users to prevent overwhelming the system
await new Promise(resolve => setTimeout(resolve, 1000));
}
} catch (error) {
logger.error(`Failed to reingest for user ${user.id}:`, {error});
results.push({
userId: user.id,
error: error instanceof Error ? error.message : "Unknown error",
});
}
}
return json({
success: true,
type: "all_users",
totalUsers: users.length,
results,
summary: {
totalIngestions: results.reduce((sum, r) => sum, 0),
totalQueuedJobs: results.reduce((sum, r) => sum, 0),
},
});
}
} catch (error) {
logger.error("Reingestion failed:", {error});
return json(
{
success: false,
error: error instanceof Error ? error.message : "Unknown error",
},
{ status: 500 }
);
}
}
);
export { action, loader };

View File

@ -40,14 +40,32 @@ import {
import { makeModelCall } from "~/lib/model.server"; import { makeModelCall } from "~/lib/model.server";
import { Apps, getNodeTypes, getNodeTypesString } from "~/utils/presets/nodes"; import { Apps, getNodeTypes, getNodeTypesString } from "~/utils/presets/nodes";
import { normalizePrompt } from "./prompts"; import { normalizePrompt } from "./prompts";
import { env } from "~/env.server";
import { createOllama } from "ollama-ai-provider";
// Default number of previous episodes to retrieve for context // Default number of previous episodes to retrieve for context
const DEFAULT_EPISODE_WINDOW = 5; const DEFAULT_EPISODE_WINDOW = 5;
export class KnowledgeGraphService { export class KnowledgeGraphService {
async getEmbedding(text: string) { async getEmbedding(text: string, useOpenAI = false) {
if (useOpenAI) {
// Use OpenAI embedding model when explicitly requested
const { embedding } = await embed({
model: openai.embedding("text-embedding-3-small"),
value: text,
});
return embedding;
}
// Default to using Ollama
const ollamaUrl = process.env.OLLAMA_URL;
const model = env.EMBEDDING_MODEL;
const ollama = createOllama({
baseURL: ollamaUrl,
});
const { embedding } = await embed({ const { embedding } = await embed({
model: openai.embedding("text-embedding-3-small"), model: ollama.embedding(model),
value: text, value: text,
}); });
@ -131,16 +149,16 @@ export class KnowledgeGraphService {
episode, episode,
); );
// for (const triple of updatedTriples) { for (const triple of updatedTriples) {
// const { subject, predicate, object, statement, provenance } = triple; const { subject, predicate, object, statement, provenance } = triple;
// const safeTriple = { const safeTriple = {
// subject: { ...subject, nameEmbedding: undefined }, subject: { ...subject, nameEmbedding: undefined },
// predicate: { ...predicate, nameEmbedding: undefined }, predicate: { ...predicate, nameEmbedding: undefined },
// object: { ...object, nameEmbedding: undefined }, object: { ...object, nameEmbedding: undefined },
// statement: { ...statement, factEmbedding: undefined }, statement: { ...statement, factEmbedding: undefined },
// provenance, provenance: { ...provenance, contentEmbedding: undefined },
// }; };
// } }
// Save triples sequentially to avoid parallel processing issues // Save triples sequentially to avoid parallel processing issues
for (const triple of updatedTriples) { for (const triple of updatedTriples) {
@ -257,7 +275,6 @@ export class KnowledgeGraphService {
responseText = text; responseText = text;
}); });
console.log(responseText);
const outputMatch = responseText.match(/<output>([\s\S]*?)<\/output>/); const outputMatch = responseText.match(/<output>([\s\S]*?)<\/output>/);
if (outputMatch && outputMatch[1]) { if (outputMatch && outputMatch[1]) {
responseText = outputMatch[1].trim(); responseText = outputMatch[1].trim();

View File

@ -23,6 +23,9 @@ services:
- AUTH_GOOGLE_CLIENT_ID=${AUTH_GOOGLE_CLIENT_ID} - AUTH_GOOGLE_CLIENT_ID=${AUTH_GOOGLE_CLIENT_ID}
- AUTH_GOOGLE_CLIENT_SECRET=${AUTH_GOOGLE_CLIENT_SECRET} - AUTH_GOOGLE_CLIENT_SECRET=${AUTH_GOOGLE_CLIENT_SECRET}
- ENABLE_EMAIL_LOGIN=${ENABLE_EMAIL_LOGIN} - ENABLE_EMAIL_LOGIN=${ENABLE_EMAIL_LOGIN}
- OLLAMA_URL=${OLLAMA_URL}
- EMBEDDING_MODEL=${EMBEDDING_MODEL}
- MODEL=${MODEL}
ports: ports:
- "3033:3000" - "3033:3000"
depends_on: depends_on:

View File

@ -23,6 +23,9 @@ services:
- AUTH_GOOGLE_CLIENT_ID=${AUTH_GOOGLE_CLIENT_ID} - AUTH_GOOGLE_CLIENT_ID=${AUTH_GOOGLE_CLIENT_ID}
- AUTH_GOOGLE_CLIENT_SECRET=${AUTH_GOOGLE_CLIENT_SECRET} - AUTH_GOOGLE_CLIENT_SECRET=${AUTH_GOOGLE_CLIENT_SECRET}
- ENABLE_EMAIL_LOGIN=${ENABLE_EMAIL_LOGIN} - ENABLE_EMAIL_LOGIN=${ENABLE_EMAIL_LOGIN}
- OLLAMA_URL=${OLLAMA_URL}
- EMBEDDING_MODEL=${EMBEDDING_MODEL}
- MODEL=${MODEL}
ports: ports:
- "3033:3000" - "3033:3000"
depends_on: depends_on: