Feat: add token counter

Feat: high, low complexity model based on task
This commit is contained in:
Manoj 2025-09-30 13:13:42 +05:30
parent a73e6b1125
commit e610336563
5 changed files with 156 additions and 32 deletions

View File

@ -14,14 +14,59 @@ import { google } from "@ai-sdk/google";
import { createAmazonBedrock } from "@ai-sdk/amazon-bedrock"; import { createAmazonBedrock } from "@ai-sdk/amazon-bedrock";
import { fromNodeProviderChain } from "@aws-sdk/credential-providers"; import { fromNodeProviderChain } from "@aws-sdk/credential-providers";
export type ModelComplexity = 'high' | 'low';
/**
* Get the appropriate model for a given complexity level.
* HIGH complexity uses the configured MODEL.
* LOW complexity automatically downgrades to cheaper variants if possible.
*/
export function getModelForTask(complexity: ModelComplexity = 'high'): string {
const baseModel = process.env.MODEL || 'gpt-4.1-2025-04-14';
// HIGH complexity - always use the configured model
if (complexity === 'high') {
return baseModel;
}
// LOW complexity - automatically downgrade expensive models to cheaper variants
// If already using a cheap model, keep it
const downgrades: Record<string, string> = {
// OpenAI downgrades
'gpt-5-2025-08-07': 'gpt-5-mini-2025-08-07',
'gpt-4.1-2025-04-14': 'gpt-4.1-mini-2025-04-14',
// Anthropic downgrades
'claude-sonnet-4-5': 'claude-3-5-haiku-20241022',
'claude-3-7-sonnet-20250219': 'claude-3-5-haiku-20241022',
'claude-3-opus-20240229': 'claude-3-5-haiku-20241022',
// Google downgrades
'gemini-2.5-pro-preview-03-25': 'gemini-2.5-flash-preview-04-17',
'gemini-2.0-flash': 'gemini-2.0-flash-lite',
// AWS Bedrock downgrades (keep same model - already cost-optimized)
'us.amazon.nova-premier-v1:0': 'us.amazon.nova-premier-v1:0',
};
return downgrades[baseModel] || baseModel;
}
export interface TokenUsage {
promptTokens: number;
completionTokens: number;
totalTokens: number;
}
export async function makeModelCall( export async function makeModelCall(
stream: boolean, stream: boolean,
messages: CoreMessage[], messages: CoreMessage[],
onFinish: (text: string, model: string) => void, onFinish: (text: string, model: string, usage?: TokenUsage) => void,
options?: any, options?: any,
complexity: ModelComplexity = 'high',
) { ) {
let modelInstance: LanguageModelV1 | undefined; let modelInstance: LanguageModelV1 | undefined;
let model = process.env.MODEL as any; let model = getModelForTask(complexity);
const ollamaUrl = process.env.OLLAMA_URL; const ollamaUrl = process.env.OLLAMA_URL;
let ollama: OllamaProvider | undefined; let ollama: OllamaProvider | undefined;
@ -32,14 +77,14 @@ export async function makeModelCall(
} }
const bedrock = createAmazonBedrock({ const bedrock = createAmazonBedrock({
region: process.env.AWS_REGION || 'us-east-1', region: process.env.AWS_REGION || 'us-east-1',
credentialProvider: fromNodeProviderChain(), credentialProvider: fromNodeProviderChain(),
}); });
const generateTextOptions: any = {} const generateTextOptions: any = {}
model = 'us.amazon.nova-premier-v1:0' console.log('complexity:', complexity, 'model:', model)
switch (model) { switch (model) {
case "gpt-4.1-2025-04-14": case "gpt-4.1-2025-04-14":
case "gpt-4.1-mini-2025-04-14": case "gpt-4.1-mini-2025-04-14":
@ -70,6 +115,7 @@ export async function makeModelCall(
case "us.mistral.pixtral-large-2502-v1:0": case "us.mistral.pixtral-large-2502-v1:0":
case "us.amazon.nova-premier-v1:0": case "us.amazon.nova-premier-v1:0":
modelInstance = bedrock(`${model}`); modelInstance = bedrock(`${model}`);
generateTextOptions.maxTokens = 100000
break; break;
default: default:
@ -89,29 +135,49 @@ export async function makeModelCall(
model: modelInstance, model: modelInstance,
messages, messages,
...generateTextOptions, ...generateTextOptions,
onFinish: async ({ text }) => { onFinish: async ({ text, usage }) => {
onFinish(text, model); const tokenUsage = usage ? {
promptTokens: usage.promptTokens,
completionTokens: usage.completionTokens,
totalTokens: usage.totalTokens,
} : undefined;
if (tokenUsage) {
logger.log(`[${complexity.toUpperCase()}] ${model} - Tokens: ${tokenUsage.totalTokens} (prompt: ${tokenUsage.promptTokens}, completion: ${tokenUsage.completionTokens})`);
}
onFinish(text, model, tokenUsage);
}, },
}); });
} }
const { text, response } = await generateText({ const { text, usage } = await generateText({
model: modelInstance, model: modelInstance,
messages, messages,
...generateTextOptions, ...generateTextOptions,
}); });
onFinish(text, model); const tokenUsage = usage ? {
promptTokens: usage.promptTokens,
completionTokens: usage.completionTokens,
totalTokens: usage.totalTokens,
} : undefined;
if (tokenUsage) {
logger.log(`[${complexity.toUpperCase()}] ${model} - Tokens: ${tokenUsage.totalTokens} (prompt: ${tokenUsage.promptTokens}, completion: ${tokenUsage.completionTokens})`);
}
onFinish(text, model, tokenUsage);
return text; return text;
} }
/** /**
* Determines if the current model is proprietary (OpenAI, Anthropic, Google, Grok) * Determines if a given model is proprietary (OpenAI, Anthropic, Google, Grok)
* or open source (accessed via Bedrock, Ollama, etc.) * or open source (accessed via Bedrock, Ollama, etc.)
*/ */
export function isProprietaryModel(): boolean { export function isProprietaryModel(modelName?: string, complexity: ModelComplexity = 'high'): boolean {
const model = process.env.MODEL; const model = modelName || getModelForTask(complexity);
if (!model) return false; if (!model) return false;
// Proprietary model patterns // Proprietary model patterns

View File

@ -229,10 +229,20 @@ export class KnowledgeGraphService {
episodeUuid: string | null; episodeUuid: string | null;
statementsCreated: number; statementsCreated: number;
processingTimeMs: number; processingTimeMs: number;
tokenUsage?: {
high: { input: number; output: number; total: number };
low: { input: number; output: number; total: number };
};
}> { }> {
const startTime = Date.now(); const startTime = Date.now();
const now = new Date(); const now = new Date();
// Track token usage by complexity
const tokenMetrics = {
high: { input: 0, output: 0, total: 0 },
low: { input: 0, output: 0, total: 0 },
};
try { try {
// Step 1: Context Retrieval - Get previous episodes for context // Step 1: Context Retrieval - Get previous episodes for context
const previousEpisodes = await getRecentEpisodes({ const previousEpisodes = await getRecentEpisodes({
@ -259,6 +269,7 @@ export class KnowledgeGraphService {
params.source, params.source,
params.userId, params.userId,
prisma, prisma,
tokenMetrics,
new Date(params.referenceTime), new Date(params.referenceTime),
sessionContext, sessionContext,
params.type, params.type,
@ -296,6 +307,7 @@ export class KnowledgeGraphService {
const extractedNodes = await this.extractEntities( const extractedNodes = await this.extractEntities(
episode, episode,
previousEpisodes, previousEpisodes,
tokenMetrics,
); );
console.log(extractedNodes.map((node) => node.name)); console.log(extractedNodes.map((node) => node.name));
@ -317,6 +329,7 @@ export class KnowledgeGraphService {
episode, episode,
categorizedEntities, categorizedEntities,
previousEpisodes, previousEpisodes,
tokenMetrics,
); );
const extractedStatementsTime = Date.now(); const extractedStatementsTime = Date.now();
@ -329,6 +342,7 @@ export class KnowledgeGraphService {
extractedStatements, extractedStatements,
episode, episode,
previousEpisodes, previousEpisodes,
tokenMetrics,
); );
const resolvedTriplesTime = Date.now(); const resolvedTriplesTime = Date.now();
@ -342,6 +356,7 @@ export class KnowledgeGraphService {
resolvedTriples, resolvedTriples,
episode, episode,
previousEpisodes, previousEpisodes,
tokenMetrics,
); );
const resolvedStatementsTime = Date.now(); const resolvedStatementsTime = Date.now();
@ -408,6 +423,7 @@ export class KnowledgeGraphService {
// nodesCreated: hydratedNodes.length, // nodesCreated: hydratedNodes.length,
statementsCreated: resolvedStatements.length, statementsCreated: resolvedStatements.length,
processingTimeMs, processingTimeMs,
tokenUsage: tokenMetrics,
}; };
} catch (error) { } catch (error) {
console.error("Error in addEpisode:", error); console.error("Error in addEpisode:", error);
@ -421,6 +437,7 @@ export class KnowledgeGraphService {
private async extractEntities( private async extractEntities(
episode: EpisodicNode, episode: EpisodicNode,
previousEpisodes: EpisodicNode[], previousEpisodes: EpisodicNode[],
tokenMetrics: { high: { input: number; output: number; total: number }; low: { input: number; output: number; total: number } },
): Promise<EntityNode[]> { ): Promise<EntityNode[]> {
// Use the prompt library to get the appropriate prompts // Use the prompt library to get the appropriate prompts
const context = { const context = {
@ -437,9 +454,15 @@ export class KnowledgeGraphService {
let responseText = ""; let responseText = "";
await makeModelCall(false, messages as CoreMessage[], (text) => { // Entity extraction requires HIGH complexity (creative reasoning, nuanced NER)
await makeModelCall(false, messages as CoreMessage[], (text, _model, usage) => {
responseText = text; responseText = text;
}); if (usage) {
tokenMetrics.high.input += usage.promptTokens;
tokenMetrics.high.output += usage.completionTokens;
tokenMetrics.high.total += usage.totalTokens;
}
}, undefined, 'high');
// Convert to EntityNode objects // Convert to EntityNode objects
let entities: EntityNode[] = []; let entities: EntityNode[] = [];
@ -484,6 +507,7 @@ export class KnowledgeGraphService {
expanded: EntityNode[]; expanded: EntityNode[];
}, },
previousEpisodes: EpisodicNode[], previousEpisodes: EpisodicNode[],
tokenMetrics: { high: { input: number; output: number; total: number }; low: { input: number; output: number; total: number } },
): Promise<Triple[]> { ): Promise<Triple[]> {
// Use the prompt library to get the appropriate prompts // Use the prompt library to get the appropriate prompts
const context = { const context = {
@ -505,16 +529,21 @@ export class KnowledgeGraphService {
referenceTime: episode.validAt.toISOString(), referenceTime: episode.validAt.toISOString(),
}; };
// Get the statement extraction prompt from the prompt library // Statement extraction requires HIGH complexity (causal reasoning, emotional context)
// Choose between proprietary and OSS prompts based on model type // Choose between proprietary and OSS prompts based on model type
const messages = isProprietaryModel() const messages = isProprietaryModel(undefined, 'high')
? extractStatements(context) ? extractStatements(context)
: extractStatementsOSS(context); : extractStatementsOSS(context);
let responseText = ""; let responseText = "";
await makeModelCall(false, messages as CoreMessage[], (text) => { await makeModelCall(false, messages as CoreMessage[], (text, _model, usage) => {
responseText = text; responseText = text;
}); if (usage) {
tokenMetrics.high.input += usage.promptTokens;
tokenMetrics.high.output += usage.completionTokens;
tokenMetrics.high.total += usage.totalTokens;
}
}, undefined, 'high');
const outputMatch = responseText.match(/<output>([\s\S]*?)<\/output>/); const outputMatch = responseText.match(/<output>([\s\S]*?)<\/output>/);
if (outputMatch && outputMatch[1]) { if (outputMatch && outputMatch[1]) {
@ -648,6 +677,7 @@ export class KnowledgeGraphService {
triples: Triple[], triples: Triple[],
episode: EpisodicNode, episode: EpisodicNode,
previousEpisodes: EpisodicNode[], previousEpisodes: EpisodicNode[],
tokenMetrics: { high: { input: number; output: number; total: number }; low: { input: number; output: number; total: number } },
): Promise<Triple[]> { ): Promise<Triple[]> {
// Step 1: Extract unique entities from triples // Step 1: Extract unique entities from triples
const uniqueEntitiesMap = new Map<string, EntityNode>(); const uniqueEntitiesMap = new Map<string, EntityNode>();
@ -773,9 +803,15 @@ export class KnowledgeGraphService {
const messages = dedupeNodes(dedupeContext); const messages = dedupeNodes(dedupeContext);
let responseText = ""; let responseText = "";
await makeModelCall(false, messages as CoreMessage[], (text) => { // Entity deduplication is LOW complexity (pattern matching, similarity comparison)
await makeModelCall(false, messages as CoreMessage[], (text, _model, usage) => {
responseText = text; responseText = text;
}); if (usage) {
tokenMetrics.low.input += usage.promptTokens;
tokenMetrics.low.output += usage.completionTokens;
tokenMetrics.low.total += usage.totalTokens;
}
}, undefined, 'low');
// Step 5: Process LLM response // Step 5: Process LLM response
const outputMatch = responseText.match(/<output>([\s\S]*?)<\/output>/); const outputMatch = responseText.match(/<output>([\s\S]*?)<\/output>/);
@ -856,6 +892,7 @@ export class KnowledgeGraphService {
triples: Triple[], triples: Triple[],
episode: EpisodicNode, episode: EpisodicNode,
previousEpisodes: EpisodicNode[], previousEpisodes: EpisodicNode[],
tokenMetrics: { high: { input: number; output: number; total: number }; low: { input: number; output: number; total: number } },
): Promise<{ ): Promise<{
resolvedStatements: Triple[]; resolvedStatements: Triple[];
invalidatedStatements: string[]; invalidatedStatements: string[];
@ -1008,10 +1045,15 @@ export class KnowledgeGraphService {
let responseText = ""; let responseText = "";
// Call the LLM to analyze all statements at once // Statement resolution is LOW complexity (rule-based duplicate/contradiction detection)
await makeModelCall(false, messages, (text) => { await makeModelCall(false, messages, (text, _model, usage) => {
responseText = text; responseText = text;
}); if (usage) {
tokenMetrics.low.input += usage.promptTokens;
tokenMetrics.low.output += usage.completionTokens;
tokenMetrics.low.total += usage.totalTokens;
}
}, undefined, 'low');
try { try {
// Extract the JSON response from the output tags // Extract the JSON response from the output tags
@ -1092,6 +1134,7 @@ export class KnowledgeGraphService {
private async addAttributesToEntities( private async addAttributesToEntities(
triples: Triple[], triples: Triple[],
episode: EpisodicNode, episode: EpisodicNode,
tokenMetrics: { high: { input: number; output: number; total: number }; low: { input: number; output: number; total: number } },
): Promise<Triple[]> { ): Promise<Triple[]> {
// Collect all unique entities from the triples // Collect all unique entities from the triples
const entityMap = new Map<string, EntityNode>(); const entityMap = new Map<string, EntityNode>();
@ -1131,10 +1174,15 @@ export class KnowledgeGraphService {
let responseText = ""; let responseText = "";
// Call the LLM to extract attributes // Attribute extraction is LOW complexity (simple key-value extraction)
await makeModelCall(false, messages as CoreMessage[], (text) => { await makeModelCall(false, messages as CoreMessage[], (text, _model, usage) => {
responseText = text; responseText = text;
}); if (usage) {
tokenMetrics.low.input += usage.promptTokens;
tokenMetrics.low.output += usage.completionTokens;
tokenMetrics.low.total += usage.totalTokens;
}
}, undefined, 'low');
try { try {
const outputMatch = responseText.match(/<output>([\s\S]*?)<\/output>/); const outputMatch = responseText.match(/<output>([\s\S]*?)<\/output>/);
@ -1172,6 +1220,7 @@ export class KnowledgeGraphService {
source: string, source: string,
userId: string, userId: string,
prisma: PrismaClient, prisma: PrismaClient,
tokenMetrics: { high: { input: number; output: number; total: number }; low: { input: number; output: number; total: number } },
episodeTimestamp?: Date, episodeTimestamp?: Date,
sessionContext?: string, sessionContext?: string,
contentType?: EpisodeType, contentType?: EpisodeType,
@ -1206,10 +1255,16 @@ export class KnowledgeGraphService {
contentType === EpisodeTypeEnum.DOCUMENT contentType === EpisodeTypeEnum.DOCUMENT
? normalizeDocumentPrompt(context) ? normalizeDocumentPrompt(context)
: normalizePrompt(context); : normalizePrompt(context);
// Normalization is LOW complexity (text cleaning and standardization)
let responseText = ""; let responseText = "";
await makeModelCall(false, messages, (text) => { await makeModelCall(false, messages, (text, _model, usage) => {
responseText = text; responseText = text;
}); if (usage) {
tokenMetrics.low.input += usage.promptTokens;
tokenMetrics.low.output += usage.completionTokens;
tokenMetrics.low.total += usage.totalTokens;
}
}, undefined, 'low');
let normalizedEpisodeBody = ""; let normalizedEpisodeBody = "";
const outputMatch = responseText.match(/<output>([\s\S]*?)<\/output>/); const outputMatch = responseText.match(/<output>([\s\S]*?)<\/output>/);
if (outputMatch && outputMatch[1]) { if (outputMatch && outputMatch[1]) {

View File

@ -829,11 +829,11 @@ async function processBatch(
userId, userId,
); );
// Call LLM for space assignments // Space assignment is LOW complexity (rule-based classification with confidence scores)
let responseText = ""; let responseText = "";
await makeModelCall(false, prompt, (text: string) => { await makeModelCall(false, prompt, (text: string) => {
responseText = text; responseText = text;
}); }, undefined, 'low');
// Response text is now set by the callback // Response text is now set by the callback

View File

@ -265,10 +265,11 @@ async function extractExplicitPatterns(
const prompt = createExplicitPatternPrompt(themes, summary, statements); const prompt = createExplicitPatternPrompt(themes, summary, statements);
// Pattern extraction requires HIGH complexity (insight synthesis, pattern recognition)
let responseText = ""; let responseText = "";
await makeModelCall(false, prompt, (text: string) => { await makeModelCall(false, prompt, (text: string) => {
responseText = text; responseText = text;
}); }, undefined, 'high');
const patterns = parseExplicitPatternResponse(responseText); const patterns = parseExplicitPatternResponse(responseText);
@ -290,10 +291,11 @@ async function extractImplicitPatterns(
const prompt = createImplicitPatternPrompt(statements); const prompt = createImplicitPatternPrompt(statements);
// Implicit pattern discovery requires HIGH complexity (pattern recognition from statements)
let responseText = ""; let responseText = "";
await makeModelCall(false, prompt, (text: string) => { await makeModelCall(false, prompt, (text: string) => {
responseText = text; responseText = text;
}); }, undefined, 'high');
const patterns = parseImplicitPatternResponse(responseText); const patterns = parseImplicitPatternResponse(responseText);

View File

@ -341,10 +341,11 @@ async function generateUnifiedSummary(
previousThemes, previousThemes,
); );
// Space summary generation requires HIGH complexity (creative synthesis, narrative generation)
let responseText = ""; let responseText = "";
await makeModelCall(false, prompt, (text: string) => { await makeModelCall(false, prompt, (text: string) => {
responseText = text; responseText = text;
}); }, undefined, 'high');
return parseSummaryResponse(responseText); return parseSummaryResponse(responseText);
} catch (error) { } catch (error) {