Feat: add token counter

Feat: high, low complexity model based on task
This commit is contained in:
Manoj 2025-09-30 13:13:42 +05:30
parent a73e6b1125
commit e610336563
5 changed files with 156 additions and 32 deletions

View File

@ -14,14 +14,59 @@ import { google } from "@ai-sdk/google";
import { createAmazonBedrock } from "@ai-sdk/amazon-bedrock";
import { fromNodeProviderChain } from "@aws-sdk/credential-providers";
export type ModelComplexity = 'high' | 'low';
/**
* Get the appropriate model for a given complexity level.
* HIGH complexity uses the configured MODEL.
* LOW complexity automatically downgrades to cheaper variants if possible.
*/
export function getModelForTask(complexity: ModelComplexity = 'high'): string {
const baseModel = process.env.MODEL || 'gpt-4.1-2025-04-14';
// HIGH complexity - always use the configured model
if (complexity === 'high') {
return baseModel;
}
// LOW complexity - automatically downgrade expensive models to cheaper variants
// If already using a cheap model, keep it
const downgrades: Record<string, string> = {
// OpenAI downgrades
'gpt-5-2025-08-07': 'gpt-5-mini-2025-08-07',
'gpt-4.1-2025-04-14': 'gpt-4.1-mini-2025-04-14',
// Anthropic downgrades
'claude-sonnet-4-5': 'claude-3-5-haiku-20241022',
'claude-3-7-sonnet-20250219': 'claude-3-5-haiku-20241022',
'claude-3-opus-20240229': 'claude-3-5-haiku-20241022',
// Google downgrades
'gemini-2.5-pro-preview-03-25': 'gemini-2.5-flash-preview-04-17',
'gemini-2.0-flash': 'gemini-2.0-flash-lite',
// AWS Bedrock downgrades (keep same model - already cost-optimized)
'us.amazon.nova-premier-v1:0': 'us.amazon.nova-premier-v1:0',
};
return downgrades[baseModel] || baseModel;
}
export interface TokenUsage {
promptTokens: number;
completionTokens: number;
totalTokens: number;
}
export async function makeModelCall(
stream: boolean,
messages: CoreMessage[],
onFinish: (text: string, model: string) => void,
onFinish: (text: string, model: string, usage?: TokenUsage) => void,
options?: any,
complexity: ModelComplexity = 'high',
) {
let modelInstance: LanguageModelV1 | undefined;
let model = process.env.MODEL as any;
let model = getModelForTask(complexity);
const ollamaUrl = process.env.OLLAMA_URL;
let ollama: OllamaProvider | undefined;
@ -32,14 +77,14 @@ export async function makeModelCall(
}
const bedrock = createAmazonBedrock({
region: process.env.AWS_REGION || 'us-east-1',
region: process.env.AWS_REGION || 'us-east-1',
credentialProvider: fromNodeProviderChain(),
});
const generateTextOptions: any = {}
model = 'us.amazon.nova-premier-v1:0'
console.log('complexity:', complexity, 'model:', model)
switch (model) {
case "gpt-4.1-2025-04-14":
case "gpt-4.1-mini-2025-04-14":
@ -70,6 +115,7 @@ export async function makeModelCall(
case "us.mistral.pixtral-large-2502-v1:0":
case "us.amazon.nova-premier-v1:0":
modelInstance = bedrock(`${model}`);
generateTextOptions.maxTokens = 100000
break;
default:
@ -89,29 +135,49 @@ export async function makeModelCall(
model: modelInstance,
messages,
...generateTextOptions,
onFinish: async ({ text }) => {
onFinish(text, model);
onFinish: async ({ text, usage }) => {
const tokenUsage = usage ? {
promptTokens: usage.promptTokens,
completionTokens: usage.completionTokens,
totalTokens: usage.totalTokens,
} : undefined;
if (tokenUsage) {
logger.log(`[${complexity.toUpperCase()}] ${model} - Tokens: ${tokenUsage.totalTokens} (prompt: ${tokenUsage.promptTokens}, completion: ${tokenUsage.completionTokens})`);
}
onFinish(text, model, tokenUsage);
},
});
}
const { text, response } = await generateText({
const { text, usage } = await generateText({
model: modelInstance,
messages,
...generateTextOptions,
});
onFinish(text, model);
const tokenUsage = usage ? {
promptTokens: usage.promptTokens,
completionTokens: usage.completionTokens,
totalTokens: usage.totalTokens,
} : undefined;
if (tokenUsage) {
logger.log(`[${complexity.toUpperCase()}] ${model} - Tokens: ${tokenUsage.totalTokens} (prompt: ${tokenUsage.promptTokens}, completion: ${tokenUsage.completionTokens})`);
}
onFinish(text, model, tokenUsage);
return text;
}
/**
* Determines if the current model is proprietary (OpenAI, Anthropic, Google, Grok)
* Determines if a given model is proprietary (OpenAI, Anthropic, Google, Grok)
* or open source (accessed via Bedrock, Ollama, etc.)
*/
export function isProprietaryModel(): boolean {
const model = process.env.MODEL;
export function isProprietaryModel(modelName?: string, complexity: ModelComplexity = 'high'): boolean {
const model = modelName || getModelForTask(complexity);
if (!model) return false;
// Proprietary model patterns

View File

@ -229,10 +229,20 @@ export class KnowledgeGraphService {
episodeUuid: string | null;
statementsCreated: number;
processingTimeMs: number;
tokenUsage?: {
high: { input: number; output: number; total: number };
low: { input: number; output: number; total: number };
};
}> {
const startTime = Date.now();
const now = new Date();
// Track token usage by complexity
const tokenMetrics = {
high: { input: 0, output: 0, total: 0 },
low: { input: 0, output: 0, total: 0 },
};
try {
// Step 1: Context Retrieval - Get previous episodes for context
const previousEpisodes = await getRecentEpisodes({
@ -259,6 +269,7 @@ export class KnowledgeGraphService {
params.source,
params.userId,
prisma,
tokenMetrics,
new Date(params.referenceTime),
sessionContext,
params.type,
@ -296,6 +307,7 @@ export class KnowledgeGraphService {
const extractedNodes = await this.extractEntities(
episode,
previousEpisodes,
tokenMetrics,
);
console.log(extractedNodes.map((node) => node.name));
@ -317,6 +329,7 @@ export class KnowledgeGraphService {
episode,
categorizedEntities,
previousEpisodes,
tokenMetrics,
);
const extractedStatementsTime = Date.now();
@ -329,6 +342,7 @@ export class KnowledgeGraphService {
extractedStatements,
episode,
previousEpisodes,
tokenMetrics,
);
const resolvedTriplesTime = Date.now();
@ -342,6 +356,7 @@ export class KnowledgeGraphService {
resolvedTriples,
episode,
previousEpisodes,
tokenMetrics,
);
const resolvedStatementsTime = Date.now();
@ -408,6 +423,7 @@ export class KnowledgeGraphService {
// nodesCreated: hydratedNodes.length,
statementsCreated: resolvedStatements.length,
processingTimeMs,
tokenUsage: tokenMetrics,
};
} catch (error) {
console.error("Error in addEpisode:", error);
@ -421,6 +437,7 @@ export class KnowledgeGraphService {
private async extractEntities(
episode: EpisodicNode,
previousEpisodes: EpisodicNode[],
tokenMetrics: { high: { input: number; output: number; total: number }; low: { input: number; output: number; total: number } },
): Promise<EntityNode[]> {
// Use the prompt library to get the appropriate prompts
const context = {
@ -437,9 +454,15 @@ export class KnowledgeGraphService {
let responseText = "";
await makeModelCall(false, messages as CoreMessage[], (text) => {
// Entity extraction requires HIGH complexity (creative reasoning, nuanced NER)
await makeModelCall(false, messages as CoreMessage[], (text, _model, usage) => {
responseText = text;
});
if (usage) {
tokenMetrics.high.input += usage.promptTokens;
tokenMetrics.high.output += usage.completionTokens;
tokenMetrics.high.total += usage.totalTokens;
}
}, undefined, 'high');
// Convert to EntityNode objects
let entities: EntityNode[] = [];
@ -484,6 +507,7 @@ export class KnowledgeGraphService {
expanded: EntityNode[];
},
previousEpisodes: EpisodicNode[],
tokenMetrics: { high: { input: number; output: number; total: number }; low: { input: number; output: number; total: number } },
): Promise<Triple[]> {
// Use the prompt library to get the appropriate prompts
const context = {
@ -505,16 +529,21 @@ export class KnowledgeGraphService {
referenceTime: episode.validAt.toISOString(),
};
// Get the statement extraction prompt from the prompt library
// Statement extraction requires HIGH complexity (causal reasoning, emotional context)
// Choose between proprietary and OSS prompts based on model type
const messages = isProprietaryModel()
const messages = isProprietaryModel(undefined, 'high')
? extractStatements(context)
: extractStatementsOSS(context);
let responseText = "";
await makeModelCall(false, messages as CoreMessage[], (text) => {
await makeModelCall(false, messages as CoreMessage[], (text, _model, usage) => {
responseText = text;
});
if (usage) {
tokenMetrics.high.input += usage.promptTokens;
tokenMetrics.high.output += usage.completionTokens;
tokenMetrics.high.total += usage.totalTokens;
}
}, undefined, 'high');
const outputMatch = responseText.match(/<output>([\s\S]*?)<\/output>/);
if (outputMatch && outputMatch[1]) {
@ -648,6 +677,7 @@ export class KnowledgeGraphService {
triples: Triple[],
episode: EpisodicNode,
previousEpisodes: EpisodicNode[],
tokenMetrics: { high: { input: number; output: number; total: number }; low: { input: number; output: number; total: number } },
): Promise<Triple[]> {
// Step 1: Extract unique entities from triples
const uniqueEntitiesMap = new Map<string, EntityNode>();
@ -773,9 +803,15 @@ export class KnowledgeGraphService {
const messages = dedupeNodes(dedupeContext);
let responseText = "";
await makeModelCall(false, messages as CoreMessage[], (text) => {
// Entity deduplication is LOW complexity (pattern matching, similarity comparison)
await makeModelCall(false, messages as CoreMessage[], (text, _model, usage) => {
responseText = text;
});
if (usage) {
tokenMetrics.low.input += usage.promptTokens;
tokenMetrics.low.output += usage.completionTokens;
tokenMetrics.low.total += usage.totalTokens;
}
}, undefined, 'low');
// Step 5: Process LLM response
const outputMatch = responseText.match(/<output>([\s\S]*?)<\/output>/);
@ -856,6 +892,7 @@ export class KnowledgeGraphService {
triples: Triple[],
episode: EpisodicNode,
previousEpisodes: EpisodicNode[],
tokenMetrics: { high: { input: number; output: number; total: number }; low: { input: number; output: number; total: number } },
): Promise<{
resolvedStatements: Triple[];
invalidatedStatements: string[];
@ -1008,10 +1045,15 @@ export class KnowledgeGraphService {
let responseText = "";
// Call the LLM to analyze all statements at once
await makeModelCall(false, messages, (text) => {
// Statement resolution is LOW complexity (rule-based duplicate/contradiction detection)
await makeModelCall(false, messages, (text, _model, usage) => {
responseText = text;
});
if (usage) {
tokenMetrics.low.input += usage.promptTokens;
tokenMetrics.low.output += usage.completionTokens;
tokenMetrics.low.total += usage.totalTokens;
}
}, undefined, 'low');
try {
// Extract the JSON response from the output tags
@ -1092,6 +1134,7 @@ export class KnowledgeGraphService {
private async addAttributesToEntities(
triples: Triple[],
episode: EpisodicNode,
tokenMetrics: { high: { input: number; output: number; total: number }; low: { input: number; output: number; total: number } },
): Promise<Triple[]> {
// Collect all unique entities from the triples
const entityMap = new Map<string, EntityNode>();
@ -1131,10 +1174,15 @@ export class KnowledgeGraphService {
let responseText = "";
// Call the LLM to extract attributes
await makeModelCall(false, messages as CoreMessage[], (text) => {
// Attribute extraction is LOW complexity (simple key-value extraction)
await makeModelCall(false, messages as CoreMessage[], (text, _model, usage) => {
responseText = text;
});
if (usage) {
tokenMetrics.low.input += usage.promptTokens;
tokenMetrics.low.output += usage.completionTokens;
tokenMetrics.low.total += usage.totalTokens;
}
}, undefined, 'low');
try {
const outputMatch = responseText.match(/<output>([\s\S]*?)<\/output>/);
@ -1172,6 +1220,7 @@ export class KnowledgeGraphService {
source: string,
userId: string,
prisma: PrismaClient,
tokenMetrics: { high: { input: number; output: number; total: number }; low: { input: number; output: number; total: number } },
episodeTimestamp?: Date,
sessionContext?: string,
contentType?: EpisodeType,
@ -1206,10 +1255,16 @@ export class KnowledgeGraphService {
contentType === EpisodeTypeEnum.DOCUMENT
? normalizeDocumentPrompt(context)
: normalizePrompt(context);
// Normalization is LOW complexity (text cleaning and standardization)
let responseText = "";
await makeModelCall(false, messages, (text) => {
await makeModelCall(false, messages, (text, _model, usage) => {
responseText = text;
});
if (usage) {
tokenMetrics.low.input += usage.promptTokens;
tokenMetrics.low.output += usage.completionTokens;
tokenMetrics.low.total += usage.totalTokens;
}
}, undefined, 'low');
let normalizedEpisodeBody = "";
const outputMatch = responseText.match(/<output>([\s\S]*?)<\/output>/);
if (outputMatch && outputMatch[1]) {

View File

@ -829,11 +829,11 @@ async function processBatch(
userId,
);
// Call LLM for space assignments
// Space assignment is LOW complexity (rule-based classification with confidence scores)
let responseText = "";
await makeModelCall(false, prompt, (text: string) => {
responseText = text;
});
}, undefined, 'low');
// Response text is now set by the callback

View File

@ -265,10 +265,11 @@ async function extractExplicitPatterns(
const prompt = createExplicitPatternPrompt(themes, summary, statements);
// Pattern extraction requires HIGH complexity (insight synthesis, pattern recognition)
let responseText = "";
await makeModelCall(false, prompt, (text: string) => {
responseText = text;
});
}, undefined, 'high');
const patterns = parseExplicitPatternResponse(responseText);
@ -290,10 +291,11 @@ async function extractImplicitPatterns(
const prompt = createImplicitPatternPrompt(statements);
// Implicit pattern discovery requires HIGH complexity (pattern recognition from statements)
let responseText = "";
await makeModelCall(false, prompt, (text: string) => {
responseText = text;
});
}, undefined, 'high');
const patterns = parseImplicitPatternResponse(responseText);

View File

@ -341,10 +341,11 @@ async function generateUnifiedSummary(
previousThemes,
);
// Space summary generation requires HIGH complexity (creative synthesis, narrative generation)
let responseText = "";
await makeModelCall(false, prompt, (text: string) => {
responseText = text;
});
}, undefined, 'high');
return parseSummaryResponse(responseText);
} catch (error) {